Source code for zea.data.process

"""CLI for beamforming a zea dataset with a pipeline defined in a YAML config file.

Usage:
    python -m zea.data.process --dataset <path> --config <config.yaml>
"""

import argparse
import re
from concurrent.futures import ThreadPoolExecutor
from dataclasses import fields as dataclass_fields
from pathlib import Path

import keras
import numpy as np
import tyro
from keras import ops

from zea import display, io_lib, log
from zea.cli_args import ProcessArgs
from zea.config import Config
from zea.data.dataloader import Dataloader
from zea.data.datasets import Dataset
from zea.data.file import File
from zea.data.spec import ScanSpec
from zea.internal.device import init_device
from zea.ops.pipeline import Pipeline
from zea.utils import FunctionTimer

SUPPORTED_FORMATS = ["gif", "mp4", "hdf5"]

try:
    import SimpleITK as sitk

    SUPPORTED_FORMATS += ["nii.gz"]
except ImportError:
    sitk = None


[docs] def get_parser(add_help: bool = True) -> argparse.ArgumentParser: """Return an argparse parser equivalent to :class:`ProcessArgs`. Kept as a plain argparse parser for compatibility with ``sphinxcontrib-autoprogram`` and use as an argparse ``parents`` entry. """ parser = argparse.ArgumentParser( description=( "Beamform a zea dataset using a pipeline defined in a config YAML file. " "Processes frames sequentially to support temporal algorithms." ), add_help=add_help, ) parser.add_argument( "--dataset", "-d", required=True, type=str, help="Path/URI to the zea dataset (folder of HDF5 files or a single HDF5 file).", ) parser.add_argument( "--config", "-c", required=True, type=str, help="Path to config.yaml for the beamforming pipeline.", ) parser.add_argument( "--save-dir", type=Path, default=Path("output"), help="Directory where output files are written. Default: output/", ) parser.add_argument( "--key", type=str, default="data/raw_data", help="Data key to load from each file (e.g. data/raw_data, data/image/values).", ) parser.add_argument( "--n-frames", type=int, default=None, dest="n_frames", help="Maximum number of frames to process per file (all frames when omitted).", ) parser.add_argument( "--save-as", type=str, default="gif", dest="save_as", help=f"Output format. One of: {', '.join(SUPPORTED_FORMATS)}.", ) parser.add_argument( "--keep-keys", nargs="+", default=["maxval"], dest="keep_keys", help="Pipeline output keys to forward to the next frame iteration.", ) parser.add_argument( "--timings", action="store_true", help="Record dataloader and pipeline timings and save to YAML files in save_dir.", ) parser.add_argument( "--num-threads", type=int, default=16, dest="num_threads", help="Number of threads used by the dataloader. Default is 16.", ) parser.add_argument( "--revision", type=str, default=None, help=( "HuggingFace revision for the dataset (branch, tag, or commit hash). " "Only used for hf:// paths." ), ) parser.add_argument( "--config-revision", type=str, default=None, dest="config_revision", help=( "HuggingFace revision for the config (branch, tag, or commit hash). " "Defaults to --revision if omitted." ), ) parser.add_argument( "--overwrite", action="store_true", help="Overwrite existing output files. Default is False.", ) parser.add_argument( "--keep-dynamic-range", action="store_true", dest="keep_dynamic_range", help=( "Store pipeline output as-is (float32 dB) instead of converting to uint8. " "Only valid when --save-as hdf5." ), ) parser.add_argument( "--device", type=str, default="auto:1", help=( "Compute device ('cuda:0', 'cpu', 'auto:1', …). " "Only relevant when running the beamformer pipeline." ), ) return parser
def _get_config_parameters(config: Config) -> dict: """Return the config parameters dict, handling missing or empty sections.""" params = getattr(config, "parameters", None) if params is None: return {} return params.as_dict() if hasattr(params, "as_dict") else dict(params) # Keys that carry raw RF / pre-beamformed data and always require a pipeline. _PIPELINE_REQUIRED_KEYS = frozenset({"data/raw_data", "data/aligned_data/values"}) def _key_requires_pipeline(key: str) -> bool: """Return True if ``key`` holds raw RF/pre-beamformed data that needs a pipeline. Normalizes the key the same way :meth:`File.format_key` does (strip a ``tracks/track_N/`` prefix and add a leading ``data/``) so aliases like ``raw_data`` are classified the same as ``data/raw_data``. """ normalized = (key or "").strip() normalized = re.sub(r"^tracks/track_\d+/", "", normalized) if normalized and not normalized.startswith("data/"): normalized = "data/" + normalized return normalized in _PIPELINE_REQUIRED_KEYS def _build_probe_dict(probe) -> dict: """Build a minimal probe dict for File.create() from a Probe object.""" probe_dict = {} if getattr(probe, "name", None): probe_dict["name"] = probe.name if getattr(probe, "probe_geometry", None) is not None: probe_dict["probe_geometry"] = probe.probe_geometry for attr in ( "type", "probe_center_frequency", "probe_bandwidth_percent", "element_width", "element_height", "lens_sound_speed", "lens_thickness", ): val = getattr(probe, attr, None) if val is not None: probe_dict[attr] = val return probe_dict def _run_passthrough( dataset_path: str, key: str, n_frames: int | None, save_dir: Path, save_as: str, overwrite: bool, **hf_kwargs, ) -> None: """Save data frames directly without a beamforming pipeline.""" if save_as not in ("gif", "mp4", "hdf5"): raise ValueError(f"Passthrough mode only supports gif/mp4/hdf5, got {save_as!r}") save_dir.mkdir(parents=True, exist_ok=True) ds = Dataset(dataset_path, validate=False, **hf_kwargs) file_paths = list(ds.file_paths) ds.close() pbar = keras.utils.Progbar(len(file_paths)) for file_path in file_paths: with File(file_path) as f: data_key = f.format_key(key) arr = np.asarray(f[data_key][:n_frames] if n_frames is not None else f[data_key][:]) filestem = f.stem # Ensure (N, H, W) — squeeze any leading single-element dims while arr.ndim > 3 and arr.shape[0] == 1: arr = arr[0] if arr.ndim == 2: arr = arr[np.newaxis] # add frame axis if arr.dtype != np.uint8: lo, hi = float(arr.min()), float(arr.max()) arr = ( ((arr - lo) / (hi - lo) * 255).astype(np.uint8) if hi > lo else np.zeros_like(arr, dtype=np.uint8) ) save_path = save_dir / f"{filestem}.{save_as}" if save_path.exists() and not overwrite: log.warning(f"File {save_path} already exists. Use --overwrite to replace it.") else: if save_as in ("gif", "mp4"): io_lib.save_video(arr, save_path, fps=20) elif save_as == "hdf5": File.create(save_path, data={"image": {"values": arr}}, overwrite=overwrite) log.info(f"Saved {log.yellow(save_path)}") pbar.add(1)
[docs] def run_processing( dataset_path: str, config_path: str, key: str, n_frames: int | None, save_dir: Path, save_as: str = "gif", keep_keys=("maxval",), timings=False, num_threads=16, overwrite=False, keep_dynamic_range=False, revision: str | None = None, config_revision: str | None = None, ) -> None: if keep_dynamic_range and save_as != "hdf5": raise ValueError("--keep_dynamic_range is only supported with --save_as hdf5.") if save_as == "nii.gz" and sitk is None: raise ValueError("SimpleITK is not installed; cannot save as nii.gz.") if save_as not in SUPPORTED_FORMATS: raise ValueError(f"save_as must be one of {SUPPORTED_FORMATS}, got {save_as!r}") dataset_hf_kwargs = {"revision": revision} if revision is not None else {} config_hf_kwargs = ( {"revision": config_revision if config_revision is not None else revision} if (config_revision or revision) else {} ) config = Config.from_path(config_path, **config_hf_kwargs) config_params = _get_config_parameters(config) try: pipeline = Pipeline.from_path(config_path, with_batch_dim=False, **config_hf_kwargs) except (ValueError, KeyError) as exc: if _key_requires_pipeline(key): raise log.warning( f"No pipeline found in config ({exc}). " f"Key '{key}' does not require beamforming — saving data as-is." ) save_dir.mkdir(parents=True, exist_ok=True) _run_passthrough( dataset_path, key, n_frames, save_dir, save_as, overwrite, **dataset_hf_kwargs ) return save_dir.mkdir(parents=True, exist_ok=True) dataset_files = Dataset(dataset_path, validate=False, **dataset_hf_kwargs) dataloader = Dataloader( dataset_path, key=key, batch_size=None, shuffle=False, return_filename=True, limit_n_frames=n_frames, n_frames=1, num_threads=num_threads, insert_frame_axis=False, sort_files=True, **dataset_hf_kwargs, ) dataset_files.close() iterator = iter(dataloader) total_batches = len(dataloader) get_data = lambda: next(iterator) prepare_parameters = pipeline.prepare_parameters pipeline_call = pipeline.__call__ if timings: timer = FunctionTimer() get_data = timer(get_data, name="dataloader") prepare_parameters = timer(prepare_parameters, name="prepare_parameters") pipeline_call = timer(pipeline_call, name="pipeline") _DEFAULT_FPS = 20 _scan_spec_fields = {f.name for f in dataclass_fields(ScanSpec)} prev_file_path = None data_output = [] filestem = None parameters = None selected_transmits = None params = None fps = _DEFAULT_FPS def save_video_worker( video: np.ndarray, save_path: Path, src_file_path: str, fps: int, ): if save_path.exists() and not overwrite: log.warning(f"File {save_path} already exists. Use --overwrite to replace it.") return if save_as in ["mp4", "gif"]: io_lib.save_video(video, save_path, fps=fps) elif save_as == "hdf5": with File(src_file_path) as src: scan_dict = { k: v for k, v in src.get_scan_parameters().items() if k in _scan_spec_fields } probe_dict = _build_probe_dict(src.probe) File.create( save_path, data={"image": {"values": video}}, scan=scan_dict if scan_dict else None, probe=probe_dict if probe_dict else None, overwrite=overwrite, ) elif save_as == "nii.gz": sitk.WriteImage(sitk.GetImageFromArray(video), str(save_path)) log.info(f"Saved NIfTI to {log.yellow(save_path)}") pbar = keras.utils.Progbar(total_batches) with ThreadPoolExecutor(max_workers=1) as executor: save_future = None for i in range(total_batches + 1): if i < total_batches: frame, metadata = get_data() file_path = metadata["fullpath"] else: file_path = None # sentinel to flush the last file if file_path != prev_file_path: if prev_file_path is not None: video = np.stack([ops.convert_to_numpy(f) for f in data_output], axis=0) save_path = save_dir / f"{filestem}.{save_as}" if save_future is not None: save_future.result() save_future = executor.submit( save_video_worker, video, save_path, prev_file_path, fps ) data_output = [] if file_path is None: break prev_file_path = file_path with File(file_path) as f: filestem = f.stem parameters = f.load_parameters() parameters.update(config_params) selected_transmits = np.array([int(t) for t in parameters.selected_transmits]) try: fps = int(round(parameters.frames_per_second)) except (ValueError, AttributeError): fps = _DEFAULT_FPS params = prepare_parameters(parameters, **config_params) # Sentinel iteration (no more data — also covers an empty dataset # where total_batches == 0); nothing to process, so stop here. if file_path is None: break # slice to selected transmits (transmit axis = 0 when insert_frame_axis=False) frame = frame[selected_transmits] output = pipeline_call(data=frame, **params) processed_frame = output["data"] if not keep_dynamic_range: dr = getattr(parameters, "dynamic_range", None) dynamic_range = tuple(dr) if dr is not None else (-60, 0) processed_frame = display.to_8bit(processed_frame, dynamic_range, pillow=False) data_output.append(processed_frame) pbar.add(1) for key in keep_keys: if key in output: params[key] = output[key] if timings: for tname in timer.timings.keys(): timer.append_to_yaml(save_dir / f"timings_{tname}.yaml", tname) if timings: timer.print()
[docs] def main() -> None: args = tyro.cli(ProcessArgs) init_device(args.device) run_processing( args.dataset, args.config, args.key, args.n_frames, args.save_dir, args.save_as, args.keep_keys, args.timings, args.num_threads, args.overwrite, args.keep_dynamic_range, args.revision, args.config_revision, )
if __name__ == "__main__": main()