Skip to content

Commit 1110cc4

Browse files
Remove spawned subprocess for ref burst computation in S1.transform()
1 parent 068501f commit 1110cc4

1 file changed

Lines changed: 47 additions & 74 deletions

File tree

insardev_pygmtsar/insardev_pygmtsar/S1_transform.py

Lines changed: 47 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -848,11 +848,9 @@ def process_burst_sequential(bursts, target, debug=False):
848848
self.consolidate_metadata(target, record_id=all_dates[-1][-1])
849849

850850
def process_burst_dates_parallel(bursts, target, n_jobs_inner, scheduler_inner=None, debug=False):
851-
"""Process a single burst with dates parallelized across n_jobs_inner processes."""
852-
import multiprocessing as mp
851+
"""Process a single burst with dates parallelized across n_jobs_inner workers."""
853852
import numpy as np
854853
import xarray as xr
855-
from .PRM import PRM
856854

857855
burst_refs = bursts[0]
858856
burst_reps = bursts[1]
@@ -869,122 +867,97 @@ def process_burst_dates_parallel(bursts, target, n_jobs_inner, scheduler_inner=N
869867
print(f'NOTE: {fullBurstId} directory exists but metadata file is missing. Removing...')
870868
shutil.rmtree(outdir)
871869

872-
ref_burst_name = burst_refs[0][-1]
870+
# Phase 1: Compute transform inline (same as process_burst_sequential)
871+
prm_cache = {}
872+
for burst_ref in burst_refs:
873+
prm, _, _ = self.align_ref(burst_ref[-1], debug=debug, return_slc=False)
874+
prm_cache[burst_ref[-1]] = prm
873875

874-
# Phase 1: Compute prm_ref with doppler correction (shared with subprocess and date workers)
875-
prm_ref, _, _ = self.align_ref(ref_burst_name, debug=debug, return_slc=False)
876-
prm_ref_df = prm_ref.df
877-
prm_ref_orbit_df = prm_ref.orbit_df
876+
ref_burst_name = burst_refs[0][-1]
877+
prm_ref_main = prm_cache[ref_burst_name]
878878

879-
# Get geometry for transform subprocess
879+
from .utils_satellite import compute_transform_inverse, get_dem_wgs84ellipsoid, save_transform
880880
record = self.get_record(ref_burst_name)
881-
geometry_wkt = record.geometry.iloc[0].wkt
882-
883-
# Phase 2: Compute transform in spawned subprocess (memory released on exit)
884-
# Use spawn context for true memory isolation (not fork which shares pages)
885-
ctx = mp.get_context('spawn')
886-
result_queue = ctx.Queue()
887-
p = ctx.Process(target=_compute_transform_inverse_worker, args=(
888-
prm_ref_df, prm_ref_orbit_df, self.DEM, geometry_wkt, outdir,
889-
1/dem_vertical_accuracy, epsg, resolution, bbox, 8, debug, self.netcdf_engine_read, result_queue
890-
))
891-
p.start()
892-
result = result_queue.get() # wait for result
893-
p.join()
894-
895-
# Reconstruct topo from queue result
896-
if remove_topo_phase:
897-
topo = xr.DataArray(result['topo_values'],
898-
coords={'a': result['topo_a_coords'], 'r': result['topo_r_coords']},
899-
dims=['a', 'r'])
900-
elif reference_height != 0:
901-
topo = xr.DataArray(
902-
np.full_like(result['topo_values'], reference_height),
903-
coords={'a': result['topo_a_coords'], 'r': result['topo_r_coords']},
904-
dims=['a', 'r'])
905-
else:
906-
topo = None
907-
908-
# Reconstruct transform from queue result (without ele - not needed for geocoding)
909-
transform = xr.Dataset({
910-
'rng': xr.DataArray(result['transform_rng'],
911-
coords={'y': result['transform_y'], 'x': result['transform_x']},
912-
dims=['y', 'x']),
913-
'azi': xr.DataArray(result['transform_azi'],
914-
coords={'y': result['transform_y'], 'x': result['transform_x']},
915-
dims=['y', 'x']),
916-
}, attrs=result['transform_attrs'])
917-
918-
# Get SC_height from prm_ref (already computed before subprocess)
919-
sc_height_result = prm_ref.SAT_baseline(prm_ref)
920-
sc_height = {
921-
'SC_height': sc_height_result.get('SC_height'),
922-
'SC_height_start': sc_height_result.get('SC_height_start'),
923-
'SC_height_end': sc_height_result.get('SC_height_end')
924-
}
925-
926-
# Phase 3: Process dates in parallel using spawned subprocesses
927-
# Each worker processes one date then exits (max_tasks_per_child=1), releasing memory
928-
all_dates = burst_reps + burst_refs
881+
dem = get_dem_wgs84ellipsoid(self.DEM, record.geometry.iloc[0], netcdf_engine=self.netcdf_engine_read)
882+
topo, transform = compute_transform_inverse(prm_ref_main, dem, scale_factor=1/dem_vertical_accuracy, epsg=epsg, resolution=resolution, bbox=bbox, debug=debug)
883+
del dem
929884

930-
# prm_ref_df and prm_ref_orbit_df already serialized before subprocess
885+
save_transform(transform, outdir, scale_factor=1/dem_vertical_accuracy)
931886

932-
# Get topo_llt once (small, shared by all workers for alignment computation)
933-
topo_llt = self._get_topo_llt(ref_burst_name, degrees=alignment_spacing)
887+
if not remove_topo_phase:
888+
if reference_height != 0:
889+
topo = xr.full_like(topo, reference_height)
890+
else:
891+
topo = None
892+
transform = transform.drop_vars('ele')
893+
894+
# Pre-compute SC_height and topo_llt caches
895+
sc_height_cache = {}
896+
for burst_ref in burst_refs:
897+
burst_ref_name = burst_ref[-1]
898+
prm_ref = prm_cache[burst_ref_name]
899+
sc_height_result = prm_ref.SAT_baseline(prm_ref)
900+
sc_height_cache[burst_ref_name] = {
901+
'SC_height': sc_height_result.get('SC_height'),
902+
'SC_height_start': sc_height_result.get('SC_height_start'),
903+
'SC_height_end': sc_height_result.get('SC_height_end')
904+
}
905+
906+
topo_llt_cache = {}
907+
for burst_ref in burst_refs:
908+
topo_llt_cache[burst_ref[-1]] = self._get_topo_llt(burst_ref[-1], degrees=alignment_spacing)
909+
910+
# Phase 2: Build worker args and process dates in parallel
911+
all_dates = burst_reps + burst_refs
912+
prm_ref_df = prm_cache[ref_burst_name].df
913+
prm_ref_orbit_df = prm_cache[ref_burst_name].orbit_df
914+
topo_llt = topo_llt_cache[ref_burst_name]
934915

935-
# Build argument tuples for each date (no S1 instance needed in workers)
936916
worker_args = []
937917
for burst_item in all_dates:
938918
is_ref = burst_item in burst_refs
939919
burst_name = burst_item[-1]
920+
burst_ref = [b for b in burst_refs if b[:2] == burst_item[:2]][0]
921+
burst_ref_name = burst_ref[-1]
940922

941-
# Get file paths for this burst
942923
prefix = self.fullBurstId(burst_name)
943924
record = self.get_record(burst_name)
944925
xml_file = os.path.join(self.datadir, prefix, 'annotation', f'{burst_name}.xml')
945926
tiff_file = os.path.join(self.datadir, prefix, 'measurement', f'{burst_name}.tiff')
946927
orbit_file = os.path.join(self.datadir, record['orbit'].iloc[0])
947-
948-
# Derive calibration and noise XML paths (ASF burst format: {burst_name}.xml)
949928
calibration_xml = os.path.join(self.datadir, prefix, 'calibration', f'{burst_name}.xml')
950929
noise_xml = os.path.join(self.datadir, prefix, 'noise', f'{burst_name}.xml')
951-
# Raise error if files don't exist when corrections are requested
930+
952931
if radiometric_calibration and not os.path.exists(calibration_xml):
953932
raise FileNotFoundError(f"Calibration XML not found: {calibration_xml}")
954933
if remove_thermal_noise and not os.path.exists(noise_xml):
955934
raise FileNotFoundError(f"Noise XML not found: {noise_xml}")
956935

957-
# Build record dict for worker (use reset_index to include index values like polarization)
958936
record_dict = {}
959937
record_reset = record.reset_index()
960-
if debug and burst_item == all_dates[0]:
961-
print(f'DEBUG: record_reset.columns = {list(record_reset.columns)}')
962938
for col in record_reset.columns:
963939
val = record_reset[col].iloc[0]
964-
if hasattr(val, 'wkt'): # geometry
940+
if hasattr(val, 'wkt'):
965941
record_dict[col] = val.wkt
966942
else:
967943
record_dict[col] = val
968-
if debug and burst_item == all_dates[0]:
969-
print(f'DEBUG: record_dict keys = {list(record_dict.keys())}')
970944

971945
worker_args.append((
972946
outdir, burst_item, burst_refs, is_ref,
973947
xml_file, tiff_file, orbit_file, record_dict,
974948
topo, transform,
975-
prm_ref_df, prm_ref_orbit_df, sc_height,
949+
prm_ref_df, prm_ref_orbit_df, sc_height_cache[burst_ref_name],
976950
topo_llt, epsg, remove_tidal_phase,
977951
remove_thermal_noise, radiometric_calibration,
978952
calibration_xml, noise_xml, reference_height, debug
979953
))
980954

981-
# Process dates in parallel using joblib
982955
joblib.Parallel(n_jobs=n_jobs_inner, backend=scheduler_inner)(
983956
joblib.delayed(_process_date_worker)(args) for args in worker_args
984957
)
985958

986959
# Cleanup and consolidate
987-
del topo, transform, prm_ref
960+
del topo, transform, prm_cache
988961
self.consolidate_metadata(target, record_id=all_dates[-1][-1])
989962

990963
# Get reference and repeat bursts as groups

0 commit comments

Comments
 (0)