Skip to content

Commit 58c8e01

Browse files
committed
changes for countrate and multiple files
1 parent fdeacbf commit 58c8e01

5 files changed

Lines changed: 249 additions & 69 deletions

File tree

src/sed/loader/cfel/buffer_handler.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import time
44
from pathlib import Path
55

6+
import h5py
7+
import numpy as np
68
import dask.dataframe as dd
79
from joblib import delayed
810
from joblib import Parallel
@@ -168,7 +170,48 @@ def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None:
168170
f"Could not extract base timestamp: {e}. "
169171
"Processing files independently."
170172
)
173+
174+
# -------------------------------------------------------
175+
# Calculate index offsets
176+
# We need to read the 'index' channel (usually countId/NumOfEvents) to know the count.
177+
# This requires a quick scan of files.
178+
# -------------------------------------------------------
179+
index_offsets = {}
180+
current_offset = 0
181+
182+
index_alias = self._config.get("index", ["countId"])[0]
183+
try:
184+
channel_config = self._config["channels"][index_alias]
185+
dataset_key = channel_config["dataset_key"]
171186

187+
# Prefer serial scan for safety and simplicity, though could be parallelized
188+
# For 200 files it might take a few seconds.
189+
logger.info("Calculating index offsets...")
190+
for file_set in file_sets:
191+
try:
192+
with h5py.File(file_set["raw"], "r") as h5_file:
193+
if dataset_key in h5_file:
194+
195+
dset = h5_file[dataset_key]
196+
# sum of all events in this file
197+
# Use simple read if small enough
198+
n_events = np.sum(dset)
199+
200+
index_offsets[file_set["raw"].name] = int(current_offset)
201+
current_offset += int(n_events)
202+
else:
203+
index_offsets[file_set["raw"].name] = int(current_offset)
204+
except Exception as e:
205+
logger.warning(f"Failed to read index offset from {file_set['raw'].name}: {e}")
206+
index_offsets[file_set["raw"].name] = int(current_offset)
207+
208+
logger.debug(f"Total events calculated: {current_offset}")
209+
210+
except Exception as e:
211+
logger.warning(f"Failed to calculate index offsets: {e}. Indices may reset.")
212+
for fs in file_sets:
213+
index_offsets[fs["raw"].name] = 0
214+
172215
# -------------------------------------------------------
173216

174217
n_cores = min(len(file_sets), self.n_cores)
@@ -187,6 +230,7 @@ def is_first_file(file_set) -> bool:
187230
file_set,
188231
is_first_file(file_set),
189232
base_timestamp,
233+
index_offset=index_offsets.get(file_set["raw"].name, 0),
190234
)
191235
else:
192236
# For parallel processing, we need to be careful about the order
@@ -198,18 +242,20 @@ def is_first_file(file_set) -> bool:
198242
file_set,
199243
is_first_file(file_set),
200244
base_timestamp,
245+
index_offset=index_offsets.get(file_set["raw"].name, 0),
201246
)
202247
for file_set in file_sets
203248
)
204249

205-
def _save_buffer_file(self, file_set, is_first_file=True, base_timestamp=None):
250+
def _save_buffer_file(self, file_set, is_first_file=True, base_timestamp=None, index_offset=0):
206251
"""
207252
Saves an HDF5 file to a Parquet file using the DataFrameCreator class.
208253
209254
Args:
210255
file_set: Dictionary containing file paths
211256
is_first_file: Whether this is the first file in a multi-file run
212257
base_timestamp: Base timestamp from the first file (for subsequent files)
258+
index_offset: Offset to apply to the index
213259
"""
214260
start_time = time.time() # Add this line
215261
paths = file_set
@@ -218,9 +264,11 @@ def _save_buffer_file(self, file_set, is_first_file=True, base_timestamp=None):
218264
config_dataframe=self._config,
219265
h5_path=paths["raw"],
220266
is_first_file=is_first_file,
221-
base_timestamp=base_timestamp
267+
base_timestamp=base_timestamp,
268+
index_offset=index_offset
222269
)
223270
df = dfc.df
271+
224272
df_timed = dfc.df_timed
225273

226274
# Save electron resolved dataframe

src/sed/loader/cfel/dataframe.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class DataFrameCreator:
3030
"""
3131

3232
def __init__(self, config_dataframe: dict, h5_path: Path,
33-
is_first_file: bool = True, base_timestamp: pd.Timestamp = None) -> None:
33+
is_first_file: bool = True, base_timestamp: pd.Timestamp = None,
34+
index_offset: int = 0) -> None:
3435
"""
3536
Initializes the DataFrameCreator class.
3637
@@ -39,22 +40,20 @@ def __init__(self, config_dataframe: dict, h5_path: Path,
3940
h5_path (Path): Path to the h5 file.
4041
is_first_file (bool): Whether this is the first file in a multi-file run.
4142
base_timestamp (pd.Timestamp): Base timestamp from the first file (for subsequent files).
43+
index_offset (int): Offset to apply to the index (countId) for multi-file runs.
4244
"""
4345
self.h5_file = h5py.File(h5_path, "r")
4446
self._config = config_dataframe
4547
self.is_first_file = is_first_file
4648
self.base_timestamp = base_timestamp
49+
self.index_offset = index_offset
4750

4851
index_alias = self._config.get("index", ["countId"])[0]
49-
# # all values except the last as slow data starts from start of file
50-
# somehow written something else as this line is doing
51-
# self.index = np.cumsum([0, *self.get_dataset_array(index_alias)])
52+
5253
# get cumulative counts, but drop last because slow data only covers N-1 intervals
53-
self.index = np.cumsum([0, *self.get_dataset_array(index_alias)])[:-1]
54-
# cumulative sum starting from the first acquisition count, No artificial 0 at the start
55-
# makes identical len of TimeStamp and index, but cuts last TimeStamp
56-
# self.index = np.cumsum(self.get_dataset_array(index_alias))
57-
print(f"len of self.index: {len(self.index)}")
54+
# Add index_offset
55+
self.index = np.cumsum([0, *self.get_dataset_array(index_alias)])[:-1] + index_offset
56+
5857

5958
def get_dataset_key(self, channel: str) -> str:
6059
"""
@@ -121,7 +120,16 @@ def df_electron(self) -> pd.DataFrame:
121120
if channels == []:
122121
return pd.DataFrame()
123122

124-
series = {channel: pd.Series(self.get_dataset_array(channel)) for channel in channels}
123+
series = {
124+
channel: pd.Series(
125+
self.get_dataset_array(channel),
126+
index=pd.RangeIndex(
127+
self.index_offset,
128+
self.index_offset + len(self.get_dataset_array(channel)),
129+
),
130+
)
131+
for channel in channels
132+
}
125133
dataframe = pd.concat(series, axis=1)
126134
return dataframe.dropna()
127135

@@ -241,9 +249,6 @@ def df_timestamp(self) -> pd.DataFrame:
241249
# ------------------------------------------------------------
242250
ts_alias = self._config["columns"].get("timestamp", "timeStamp")
243251
df = pd.DataFrame({ts_alias: unix_seconds}, index=self.index)
244-
print(f"Len of TimeStamps: {len(unix_seconds)}, len of Index: {len(self.index)}")
245-
pd.set_option("display.float_format", "{:.6f}".format)
246-
print(df)
247252

248253
# # # Suppose df is your timestamp DataFrame
249254
# print("DEBUG of df")

src/sed/loader/cfel/loader.py

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,22 @@ def get_count_rate_ms(
451451
# 2) Compute point-resolved rates
452452
# -------------------------------
453453
if mode == "point":
454+
bin_size = kwds.pop("bin_size", 1)
454455
dt = np.diff(ms_concat) * 1e-3
455456
if np.any(dt <= 0):
456-
raise ValueError("Non-positive time step detected in millisecCounter")
457+
# Handle potential duplicate timestamps or jump back (should not happen with sort)
458+
dt[dt <= 0] = 1e-6 # small epsilon
457459
rates_point = counts_concat[1:] / dt
460+
461+
if bin_size > 1:
462+
# Apply rolling average for smoothing
463+
rates_point = (
464+
pd.Series(rates_point)
465+
.rolling(window=bin_size, center=True, min_periods=1)
466+
.mean()
467+
.values
468+
)
469+
458470
times_point = ms_concat[1:] * 1e-3
459471
return rates_point, times_point
460472

@@ -463,16 +475,24 @@ def get_count_rate_ms(
463475
# -------------------------------
464476
rates_file = []
465477
times_file = []
466-
prev_ms_max = 0.0 # global start
467-
468478
for idx, (ms_min, ms_max) in enumerate(file_ms_min_max):
469-
# Duration = internal file window + gap since previous file
470-
file_duration = (ms_max - ms_min) + (ms_min - prev_ms_max)
479+
# Duration = internal file window
480+
file_duration = ms_max - ms_min
471481
if file_duration <= 0:
472-
raise ValueError(f"Non-positive duration for file {fids_resolved[idx]}")
473-
474-
print(f"Total counts: {file_counts_total[idx]}")
475-
print(f"File duration: {file_duration}")
482+
# If single point or overlapping min/max, fallback or raise?
483+
# For single point (duration 0), rate is undefined (inf).
484+
# Start/End timestamps usually imply a range.
485+
# If strictly 0, we can't calculate rate.
486+
logger.warning(
487+
f"[get_count_rate_ms] File {fids_resolved[idx]} has duration <= 0 ({file_duration}). "
488+
"Skipping rate calculation for this file (set to NaN).",
489+
)
490+
rates_file.append(np.nan)
491+
times_file.append((ms_min + ms_max) / 2 * 1e-3)
492+
continue
493+
494+
# print(f"Total counts: {file_counts_total[idx]}")
495+
# print(f"File duration: {file_duration}")
476496
rate = file_counts_total[idx] / (file_duration * 1e-3)
477497
rates_file.append(rate)
478498
# times_file.append(ms_max * 1e-3) # last time in file
@@ -484,8 +504,6 @@ def get_count_rate_ms(
484504
f"counts={file_counts_total[idx]}, duration={file_duration} ms, rate={rate:.2f} Hz"
485505
)
486506

487-
prev_ms_max = ms_max
488-
489507
return np.array(rates_file), np.array(times_file)
490508

491509

@@ -529,58 +547,28 @@ def get_count_rate(
529547
self,
530548
fids: Sequence[int] | None = None,
531549
runs: Sequence[int] | None = None,
550+
**kwds,
532551
) -> tuple[np.ndarray, np.ndarray]:
533552
"""
534-
Returns the count rate per file using the total number of detected events
535-
and the file acquisition duration.
536-
537-
This method computes:
538-
- one count-rate value per file (Hz)
539-
- one global time value per file, given by the midpoint of the file
540-
acquisition window, measured in seconds since the scan start
541-
542-
The calculation is based on metadata produced by `read_dataframe`
543-
and therefore does not require loading raw event data.
544-
This makes the method fast but limited to file-level resolution.
553+
Returns the count rate. By default, returns high-resolution
554+
point-resolved rates using the millisecond counter.
545555
546556
Args:
547557
fids (Sequence[int], optional):
548558
File IDs to include. Defaults to all files.
549559
runs (Sequence[int], optional):
550560
Run IDs to include. If provided, overrides `fids`.
561+
**kwds:
562+
Additional arguments passed to `get_count_rate_ms`.
563+
- mode: "point" (default) or "file".
551564
552565
Returns:
553566
tuple[np.ndarray, np.ndarray]:
554-
- count_rate : array of count rates in Hz (one per file)
567+
- count_rate : array of count rates in Hz
555568
- time : array of global times in seconds since scan start
556-
(file midpoint)
557-
558-
Raises:
559-
KeyError:
560-
If required file statistics are missing. Call `read_dataframe` first.
561569
"""
562-
563-
fids_resolved = self._resolve_fids(fids=fids, runs=runs)
564-
565-
ts_alias = self._config["dataframe"]["columns"].get("timestamp", "timeStamp")
566-
t0 = self.metadata["file_statistics"]["timed"]["0"]["columns"][ts_alias]["min"]
567-
568-
rates = []
569-
times = []
570-
571-
for fid in fids_resolved:
572-
counts = self.metadata["file_statistics"]["electron"][str(fid)]["num_rows"]
573-
ts = self.metadata["file_statistics"]["timed"][str(fid)]["columns"][ts_alias]
574-
575-
dt = ts["max"] - ts["min"]
576-
print(f"File duration: {dt} seconds")
577-
if dt <= 0:
578-
raise ValueError(f"Non-positive elapsed time for file {fid}")
579-
580-
rates.append(counts / dt)
581-
times.append(0.5 * (ts["min"] + ts["max"]) - t0)
582-
583-
return np.asarray(rates), np.asarray(times)
570+
mode = kwds.pop("mode", "point")
571+
return self.get_count_rate_ms(fids=fids, mode=mode, runs=runs, **kwds)
584572

585573
# -------------------------------
586574
# Time-resolved count rate (binned)

src/sed/loader/flash/loader.py

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,85 @@ def parse_metadata(self, token: str = None) -> dict:
223223

224224
def get_count_rate(
225225
self,
226-
fids: Sequence[int] = None, # noqa: ARG002
227-
**kwds, # noqa: ARG002
228-
):
229-
return None, None
226+
fids: Sequence[int] = None,
227+
**kwds,
228+
) -> tuple[np.ndarray, np.ndarray]:
229+
"""
230+
Calculates the count rate for the specified files.
231+
Returns high-resolution (per-train) rates by counting electrons per trainId.
232+
233+
Args:
234+
fids (Sequence[int]): A sequence of file IDs. Defaults to all files.
235+
**kwds: Keyword arguments.
236+
237+
Returns:
238+
tuple[np.ndarray, np.ndarray]: The count rate array (Hz) and time array (seconds).
239+
"""
240+
import h5py
241+
import numpy as np
242+
import pandas as pd
243+
244+
if fids is None:
245+
fids = range(len(self.files))
246+
247+
# Get the electron channel configuration
248+
per_electron_channels = get_channels(self._config["dataframe"], "per_electron")
249+
if not per_electron_channels:
250+
return None, None
251+
252+
# We need the 'index_key' (trainId) for an electron channel
253+
first_channel = per_electron_channels[0]
254+
channel_config = self._config["dataframe"]["channels"][first_channel]
255+
index_key = channel_config["index_key"]
256+
257+
all_counts = []
258+
all_times = []
259+
260+
# FLASH repetition rate is usually 10Hz.
261+
# We try to use timestamps if available, otherwise fallback to trainId gaps.
262+
time_stamp_alias = self._config["dataframe"].get("time_stamp_alias", "timeStamp")
263+
264+
# We need a reference time (t0) from the first selected file
265+
with h5py.File(self.files[fids[0]], "r") as h5:
266+
# Try to find a global start time if any, otherwise use relative
267+
t0 = 0
268+
if time_stamp_alias in h5:
269+
# This depends on how timestamps are stored in FLASH files
270+
# For now, we use a simple relative time if not easily found.
271+
pass
272+
273+
for fid in fids:
274+
with h5py.File(self.files[fid], "r") as h5:
275+
# Read trainIds of all electron events
276+
train_ids = np.asarray(h5[index_key])
277+
278+
if len(train_ids) == 0:
279+
continue
280+
281+
# Count electrons per train
282+
df_counts = pd.Series(train_ids).value_counts().sort_index()
283+
counts = df_counts.values
284+
u_train_ids = df_counts.index.values
285+
286+
# Convert trainIds to relative seconds (assuming 10Hz)
287+
# Note: This is an approximation. A better way would be to
288+
# use the actual timestamps of the trains.
289+
if fid == fids[0]:
290+
t_start_id = u_train_ids[0]
291+
292+
times = (u_train_ids - t_start_id) * 0.1
293+
294+
# Rate per trainId interval (usually 0.1s)
295+
# If we assume exactly 10Hz, duration is 0.1s
296+
rates = counts / 0.1
297+
298+
all_counts.append(rates)
299+
all_times.append(times)
300+
301+
if not all_counts:
302+
return None, None
303+
304+
return np.concatenate(all_counts), np.concatenate(all_times)
230305

231306
def get_elapsed_time(self, fids: Sequence[int] = None, **kwds) -> float | list[float]: # type: ignore[override]
232307
"""

0 commit comments

Comments
 (0)