Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions aodn_cloud_optimised/lib/CommonHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,13 @@ def __init__(self, **kwargs):

self.schema = self.dataset_config.get("schema")

logger_name = self.dataset_config.get("logger_name", "generic")
self.logger = get_logger(logger_name, raise_error=self.raise_error)
self.logger = kwargs.get("logger", None)
if not self.logger:
logger_name = self.dataset_config.get("logger_name", "generic")
self.logger = get_logger(logger_name, raise_error=self.raise_error)
self.logger_override = False
else:
self.logger_override = True

cloud_optimised_format = self.dataset_config.get("cloud_optimised_format")

Expand All @@ -115,17 +120,21 @@ def __init__(self, **kwargs):
"clear_existing_data", None
) # setting to True will recreate the zarr from scratch at every run!

# TODO: if there is a schedular then we probably don't need this
self.coiled_cluster_options = self.dataset_config.get("run_settings", {}).get(
"coiled_cluster_options", None
)

# TODO: if there is a schedular then we probably don't need this
self.cluster_manager = ClusterManager(
cluster_mode=self.cluster_mode,
dataset_name=self.dataset_name,
dataset_config=self.dataset_config,
logger=self.logger,
)

self.schedular = kwargs.get("schedular", None)

self.s3_client_opts_common = kwargs.get("s3_client_opts_common", None)

self.s3_fs_common_opts = self.dataset_config["run_settings"].get(
Expand Down Expand Up @@ -678,8 +687,14 @@ def cloud_optimised_creation(
# Filter out None values
filtered_kwargs = {k: v for k, v in kwargs_handler_class.items() if v is not None}
kwargs_handler_class = filtered_kwargs
logger_name = dataset_config.get("logger_name", "generic")
logger = get_logger(logger_name, raise_error=kwargs.get("raise_error", False))

# Replace the logger if one is provided
logger = kwargs.get("logger", None)
if not logger:
logger_name = dataset_config.get("logger_name", "generic")
logger = get_logger(logger_name, raise_error=kwargs.get("raise_error", False))
else:
kwargs_handler_class["logger"] = logger
Comment on lines +693 to +697

kwargs_handler_class["dataset_config"] = dataset_config
kwargs_handler_class["clear_existing_data"] = handler_clear_existing_data_arg
Expand All @@ -691,6 +706,8 @@ def cloud_optimised_creation(
run_summary = RunSummary()
kwargs_handler_class["run_summary"] = run_summary

kwargs_handler_class["schedular"] = kwargs.get("schedular", None)

# Creating an instance of the specified class with the provided arguments
start_whole_processing = timeit.default_timer()
with handler_class(**kwargs_handler_class) as handler_instance:
Expand Down
58 changes: 36 additions & 22 deletions aodn_cloud_optimised/lib/GenericParquetHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ def publish_cloud_optimised(
raise ValueError

metadata_collector = []
self.logger.debug(self.cloud_optimised_output_path)
pq.write_to_dataset(
pdf,
root_path=self.cloud_optimised_output_path,
Expand Down Expand Up @@ -1530,31 +1531,38 @@ def to_cloud_optimised(self, s3_file_uri_list) -> None:
bucket_name, prefix, self.s3_client_opts_output
)

# Capture only the count — NOT the full list — to avoid cloudpickle serializing
# the list into every Dask task closure. The real leak is via `self`: task()
# captures `self` because it calls self.to_cloud_optimised_single(), and
# cloudpickle serialises the entire handler instance with every client.submit().
# self.s3_file_uri_list is set AFTER the batch loop so that self is lean
# (~7 KB) rather than carrying the 19k-path list (~694 KB × batch_size).
total_files = len(s3_file_uri_list)
if not self.schedular:

def task(f, i):
try:
self.to_cloud_optimised_single(f)
except Exception as e:
self.logger.error(
f"Issue {i}/{total_files} with {f}: {type(e).__name__}: {e}"
)
# If a schedular is not provided then use the aodn CO provided schedulars

client, cluster = self.create_cluster()
# Capture only the count — NOT the full list — to avoid cloudpickle serializing
# the list into every Dask task closure. The real leak is via `self`: task()
# captures `self` because it calls self.to_cloud_optimised_single(), and
# cloudpickle serialises the entire handler instance with every client.submit().
# self.s3_file_uri_list is set AFTER the batch loop so that self is lean
# (~7 KB) rather than carrying the 19k-path list (~694 KB × batch_size).
total_files = len(s3_file_uri_list)

if self.cluster_mode:
if self.cluster_mode == "coiled":
self.cluster_id = cluster.cluster_id
def task(f, i):
try:
self.to_cloud_optimised_single(f)
except Exception as e:
self.logger.error(
f"Issue {i}/{total_files} with {f}: {type(e).__name__}: {e}"
)

client, cluster = self.create_cluster()

if self.cluster_mode:
if self.cluster_mode == "coiled":
self.cluster_id = cluster.cluster_id
else:
self.cluster_id = cluster.name
else:
self.cluster_id = cluster.name
self.cluster_id = "local_execution"
else:
self.cluster_id = "local_execution"
client = None
cluster = None

batch_size = self.get_batch_size(client=client)

Expand Down Expand Up @@ -1623,6 +1631,8 @@ def task(f, i):
self.logger.info(
f"{self.uuid_log}: New cluster created. Retrying batch {ii + 1}."
)
elif self.schedular:
self.schedular.schedule(handler=self, files=batch)
else:
# Fall back to local processing with ThreadPoolExecutor
self.logger.info(
Expand All @@ -1643,7 +1653,8 @@ def task(f, i):
ii += 1

# Cleanup memory
del batch_tasks
if 'batch_tasks' in locals():
del batch_tasks

# Trigger garbage collection
gc.collect()
Expand All @@ -1656,4 +1667,7 @@ def task(f, i):
# Set only after all tasks are submitted so self is not carrying the full list
# during cloudpickle serialisation of each Dask task closure (saves ~3 GB/batch).
self.s3_file_uri_list = s3_file_uri_list
self.logger.handlers.clear()
try:
self.logger.handlers.clear()
except AttributeError:
pass
Comment on lines +1670 to +1673
42 changes: 29 additions & 13 deletions aodn_cloud_optimised/lib/GenericZarrHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from distributed.scheduler import KilledWorker
from distributed.shuffle._exceptions import P2PConsistencyError
from xarray.coding.times import CFDatetimeCoder
from xarray.structure.merge import MergeError
from xarray import MergeError

from aodn_cloud_optimised.lib.CommonHandler import CommonHandler
from aodn_cloud_optimised.lib.logging import get_logger
Expand All @@ -38,7 +38,7 @@ class GridSizeMismatchError(ValueError):


def check_variable_values_dask(
file_path, reference_values, variable_name, dataset_config, uuid_log
file_path, reference_values, variable_name, dataset_config, uuid_log, logger_override=False
):
"""Checks if variable values in a file match reference values.

Expand All @@ -59,7 +59,10 @@ def check_variable_values_dask(
indicating if the file is problematic (True) or consistent (False).
Returns (file_path, True) if any exception occurs during processing.
"""
logger_name = dataset_config.get("logger_name", "generic")
if logger_override:
logger_name = __name__
else:
logger_name = dataset_config.get("logger_name", "generic")
logger = get_logger(logger_name)
try:
ds = xr.open_dataset(file_path)
Expand Down Expand Up @@ -89,7 +92,7 @@ def check_variable_values_dask(
return file_path, True


def check_append_dim_range_dask(file_path, dim_name, dataset_config, uuid_log):
def check_append_dim_range_dask(file_path, dim_name, dataset_config, uuid_log, logger_override=False):
"""Returns the min and max values of a dimension in a single file.

Designed to be run in parallel (e.g., with Dask) to identify files whose
Expand All @@ -106,7 +109,10 @@ def check_append_dim_range_dask(file_path, dim_name, dataset_config, uuid_log):
``(file_path, None, None)`` if the file cannot be opened or the
dimension is missing.
"""
logger_name = dataset_config.get("logger_name", "generic")
if logger_override:
logger_name = __name__
else:
logger_name = dataset_config.get("logger_name", "generic")
logger = get_logger(logger_name)
try:
with xr.open_dataset(file_path) as ds:
Expand All @@ -130,7 +136,7 @@ def get_var_template_shape(ds, var_template_shape):
return None


def preprocess_xarray(ds, dataset_config):
def preprocess_xarray(ds, dataset_config, logger_override=False):
"""Performs preprocessing steps on an xarray Dataset.

This function applies preprocessing logic defined in the dataset
Expand Down Expand Up @@ -164,11 +170,14 @@ def preprocess_xarray(ds, dataset_config):
# https://github.com/fsspec/filesystem_spec/issues/1747
# https://discourse.pangeo.io/t/remote-cluster-with-dask-distributed-uses-the-deployment-machines-memory-and-internet-bandwitch/4637
# https://github.com/dask/distributed/discussions/8913
logger_name = dataset_config.get("logger_name", "generic")
dimensions = dataset_config["schema_transformation"].get("dimensions")
schema = dataset_config.get("schema")

if logger_override:
logger_name = __name__
else:
logger_name = dataset_config.get("logger_name", "generic")
logger = get_logger(logger_name)
dimensions = dataset_config["schema_transformation"].get("dimensions")
schema = dataset_config.get("schema")

# Drop variables not in the list
vars_to_drop = set(ds.data_vars) - set(schema)
Expand Down Expand Up @@ -1084,11 +1093,13 @@ def publish_cloud_optimised_fileset_batch(self, s3_file_uri_list):
]

partial_preprocess = partial(
preprocess_xarray, dataset_config=self.dataset_config
preprocess_xarray, dataset_config=self.dataset_config, logger_override=self.logger_override
)

if self.cluster_mode:
batch_size = self.get_batch_size(client=self.client)
elif self.schedular:
batch_size = self.get_batch_size()
else:
batch_size = 1

Expand Down Expand Up @@ -1137,7 +1148,7 @@ def publish_cloud_optimised_fileset_batch(self, s3_file_uri_list):
self.logger.debug(
f"{self.uuid_log}: 'partial_preprocess_already_run' is False. Applying preprocess_xarray."
)
ds = preprocess_xarray(ds, self.dataset_config)
ds = preprocess_xarray(ds, self.dataset_config, logger_override=self.logger_override)

self._write_ds(ds, idx)
self.logger.info(
Expand Down Expand Up @@ -1801,12 +1812,13 @@ def handle_append_dim_overlap(
dim_name=dim_name,
dataset_config=self.dataset_config,
uuid_log=self.uuid_log,
logger_override=self.logger_override,
)
ranges = self.client.gather(futures)
else:
ranges = [
check_append_dim_range_dask(
f, dim_name, self.dataset_config, self.uuid_log
f, dim_name, self.dataset_config, self.uuid_log, self.logger_override
)
for f in batch_files
]
Expand Down Expand Up @@ -1960,6 +1972,7 @@ def check_variable_values_parallel(self, file_paths, variable_name):
variable_name=variable_name,
dataset_config=self.dataset_config,
uuid_log=self.uuid_log,
override_logger=self.logger_override
)
results = self.client.gather(futures)
else:
Expand All @@ -1975,6 +1988,7 @@ def check_variable_values_parallel(self, file_paths, variable_name):
variable_name,
self.dataset_config,
self.uuid_log,
self.logger_override
)
for file_path in file_paths[1:]
]
Expand Down Expand Up @@ -2487,7 +2501,7 @@ def _open_ds(self, file, partial_preprocess, drop_vars_list, engine="h5netcdf"):
ds = ds.chunk(chunks=self.chunks)
ds = ds.unify_chunks()

ds = preprocess_xarray(ds, self.dataset_config)
ds = preprocess_xarray(ds, self.dataset_config, logger_override=self.logger_override)

return ds

Expand Down Expand Up @@ -2799,6 +2813,8 @@ def to_cloud_optimised(self, s3_file_uri_list=None):
self.cluster_id = self.cluster.cluster_id
else:
self.cluster_id = self.cluster.name
elif self.schedular:
self.client = self.schedular.schedule(handler=self)
else:
self.cluster_id = "local_execution"

Expand Down
Loading