Source code for zea.ops.pipeline

import json
from typing import TYPE_CHECKING, Any, Dict, List, Union

import keras
import numpy as np
import yaml
from keras import ops

from zea import backend, log
from zea.backend import func_on_device, jit
from zea.config import Config
from zea.func.tensor import vmap
from zea.func.ultrasound import channels_to_complex, complex_to_channels
from zea.internal.core import DataTypes, ZEADecoderJSON, ZEAEncoderJSON, dict_to_tensor
from zea.internal.core import Object as ZEAObject
from zea.internal.registry import beamformer_registry, ops_registry
from zea.internal.utils import deprecated
from zea.ops.base import Operation, get_ops
from zea.ops.keras_ops import Cast
from zea.ops.tensor import Normalize
from zea.ops.ultrasound import (
    ApplyWindow,
    Demodulate,
    EnvelopeDetect,
    LogCompress,
    PfieldWeighting,
    ReshapeGrid,
    TOFCorrection,
)
from zea.utils import FunctionTimer

if TYPE_CHECKING:
    # Imported lazily at runtime (inside prepare_parameters) to avoid a circular
    # import: zea.scan imports the data specs, which can pull in this module.
    from zea.scan import Parameters


[docs] @ops_registry("pipeline") class Pipeline: """Pipeline class for processing ultrasound data through a series of operations.""" def __init__( self, operations: List[Operation], with_batch_dim: bool = True, jit_options: Union[str, None] = "ops", jit_kwargs: dict | None = None, name="pipeline", validate=True, timed: bool = False, device: Union[str, None] = None, ): """ Initialize a pipeline. Args: operations (list): A list of Operation instances representing the operations to be performed. with_batch_dim (bool, optional): Whether operations should expect a batch dimension. Defaults to True. jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None. - "pipeline": compiles the entire pipeline as a single function. This may be faster but does not preserve python control flow, such as caching. - "ops": compiles each operation separately. This preserves python control flow and caching functionality, but speeds up the operations. - None: disables JIT compilation. Defaults to "ops". jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler. name (str, optional): The name of the pipeline. Defaults to "pipeline". validate (bool, optional): Whether to validate the pipeline. Defaults to True. timed (bool, optional): Whether to time each operation. Defaults to False. device (str, optional): Default device for all pipeline calls, e.g. ``'cpu'``, ``'gpu:0'``, ``'cuda:1'``. Can be overridden per-call by passing ``device=`` to ``__call__``. Uses :func:`zea.backend.func_on_device` under the hood, which moves input tensors to the device for the ``torch`` backend and wraps the call in a device context for JAX / TensorFlow. Defaults to ``None`` (no device placement). """ self._call_pipeline = self.call self.name = name self._pipeline_layers = operations if jit_options not in ["pipeline", "ops", None]: raise ValueError("jit_options must be 'pipeline', 'ops', or None") self.with_batch_dim = with_batch_dim self._validate_flag = validate # Setup timer if jit_options == "pipeline" and timed: raise ValueError( "timed=True cannot be used with jit_options='pipeline' as the entire " "pipeline is compiled into a single function. Try setting jit_options to " "'ops' or None." ) if timed: log.warning( "Timer has been initialized for the pipeline. To get an accurate timing estimate, " "the `block_until_ready()` is used, which will slow down the execution, so " "do not use for regular processing!" ) self._callable_layers = self._get_timed_operations() else: self._callable_layers = self._pipeline_layers self._timed = timed if validate: self.validate() else: log.warning("Pipeline validation is disabled, make sure to validate manually.") if jit_kwargs is None: jit_kwargs = {} self._user_jit_kwargs = jit_kwargs.copy() if keras.backend.backend() == "jax" and self.static_params != []: existing = jit_kwargs.get("static_argnames", []) if isinstance(existing, str): existing = [existing] jit_kwargs = { **jit_kwargs, "static_argnames": list(set(existing) | set(self.static_params)), } self.jit_kwargs = jit_kwargs self.jit_options = jit_options # will handle the jit compilation self.device = device self._logged_difference_keys = False # Do not log again for nested pipelines for nested_pipeline in self._nested_pipelines: nested_pipeline._logged_difference_keys = True
[docs] def needs(self, key) -> bool: """Check if the pipeline needs a specific key at the input.""" return key in self.needs_keys
@property def _nested_pipelines(self): return [operation for operation in self.operations if isinstance(operation, Pipeline)] @property def output_keys(self) -> set: """All output keys the pipeline guarantees to produce.""" output_keys = set() for operation in self.operations: output_keys.update(operation.output_keys) return output_keys @property def valid_keys(self) -> set: """Get a set of valid keys for the pipeline. This is all keys that can be passed to the pipeline as input. """ valid_keys = set() for operation in self.operations: valid_keys.update(operation.valid_keys) return valid_keys @property def static_params(self) -> List[str]: """Get a list of static parameters for the pipeline.""" static_params = [] for operation in self.operations: static_params.extend(operation.static_params) return list(set(static_params)) @property def needs_keys(self) -> set: """Get a set of all input keys needed by the pipeline. Will keep track of keys that are already provided by previous operations. """ needs = set() has_so_far = set() previous_operation = None for operation in self.operations: if previous_operation is not None: has_so_far.update(previous_operation.output_keys) needs.update(operation.needs_keys - has_so_far) previous_operation = operation return needs
[docs] @classmethod def from_default( cls, beamformer="delay_and_sum", num_patches=100, baseband=False, enable_pfield=False, timed=False, **kwargs, ) -> "Pipeline": """Create a default pipeline. Args: beamformer (str): Type of beamformer to use. Currently supporting: - "delay_and_sum" - "delay_multiply_and_sum" - "coherence_factor" - "generalized_coherence_factor" Defaults to "delay_and_sum". num_patches (int): Number of patches for the PatchedGrid operation. Defaults to 100. If you get an out of memory error, try to increase this number. baseband (bool): If True, assume the input data is baseband (I/Q) data, which has 2 channels (last dim). Defaults to False, which assumes RF data, so input signal has a single channel dim and is still on carrier frequency. enable_pfield (bool): If True, apply PfieldWeighting. Defaults to False. This will calculate pressure field and only beamform the data to those locations. timed (bool, optional): Whether to time each operation. Defaults to False. **kwargs: Additional keyword arguments to be passed to the Pipeline constructor. """ operations = [Cast(dtype="float32")] # Add the demodulate operation if not baseband: operations += [ ApplyWindow(), Demodulate(), ] # Add beamforming ops operations.append( Beamform( beamformer=beamformer, num_patches=num_patches, enable_pfield=enable_pfield, ), ) # Add display ops operations += [ EnvelopeDetect(), Normalize(), LogCompress(), ] return cls(operations, timed=timed, **kwargs)
[docs] def copy(self) -> "Pipeline": """Create a copy of the pipeline.""" return Pipeline( self._pipeline_layers.copy(), with_batch_dim=self.with_batch_dim, jit_options=self.jit_options, jit_kwargs=self.jit_kwargs, name=self.name, validate=self._validate_flag, timed=self._timed, device=self.device, )
[docs] def reinitialize(self): """Reinitialize the pipeline in place.""" self.__init__( self._pipeline_layers, with_batch_dim=self.with_batch_dim, jit_options=self.jit_options, jit_kwargs=self.jit_kwargs, name=self.name, validate=self._validate_flag, timed=self._timed, device=self.device, )
[docs] def prepend(self, operation: Operation): """Prepend an operation to the pipeline.""" self._pipeline_layers.insert(0, operation) self.reinitialize()
[docs] def append(self, operation: Operation): """Append an operation to the pipeline.""" self._pipeline_layers.append(operation) self.reinitialize()
[docs] def insert(self, index: int, operation: Operation): """Insert an operation at a specific index in the pipeline.""" if index < 0 or index > len(self._pipeline_layers): raise IndexError("Index out of bounds for inserting operation.") self._pipeline_layers.insert(index, operation) self.reinitialize()
@property def operations(self) -> List[Union[Operation, "Pipeline"]]: """Alias for self.layers to match the zea naming convention""" return self._pipeline_layers
[docs] def reset_timer(self): """Reset the timer for timed operations.""" if self._timed: self._callable_layers = self._get_timed_operations() else: log.warning( "Timer has not been initialized. Set timed=True when initializing the pipeline." )
def _get_timed_operations(self): """Get a list of timed operations.""" self.timer = FunctionTimer() return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
[docs] def call(self, **inputs) -> Dict[str, Any]: """Process input data through the pipeline.""" for operation in self._callable_layers: try: outputs = operation(**inputs) except KeyError as exc: raise KeyError( f"[zea.Pipeline] Operation '{operation.__class__.__name__}' " f"requires input key '{exc.args[0]}', " "but it was not provided in the inputs.\n" "Check whether the objects (such as `zea.Parameters`) passed to " "`pipeline.prepare_parameters()` contain all required keys.\n" f"Current list of all passed keys: {list(inputs.keys())}\n" f"Valid keys for this pipeline: {self.valid_keys}" ) from exc except Exception as exc: raise RuntimeError( f"[zea.Pipeline] Error in operation '{operation.__class__.__name__}': {exc}" ) inputs = outputs return outputs
[docs] def __call__( self, return_numpy=False, device: Union[str, None] = None, **inputs ) -> Dict[str, Any]: """Process input data through the pipeline. Args: return_numpy (bool): If ``True``, convert output tensors to NumPy arrays before returning. device (str, optional): Device to run this call on, e.g. ``'cpu'``, ``'gpu:0'``, or ``'cuda:1'``. Overrides the pipeline-level ``device`` set at construction time for this single invocation. When ``None`` (default), the pipeline-level ``device`` attribute is used (which is also ``None`` by default, meaning no explicit device placement). **inputs: Tensor inputs forwarded to the operations. """ if any(key in inputs for key in ["probe", "scan", "config", "parameters"]) or any( isinstance(arg, ZEAObject) for arg in inputs.values() ): raise ValueError( "Parameters (and Probe/Config) objects should be first processed with " "`Pipeline.prepare_parameters` before calling the pipeline. " "e.g. inputs = pipeline.prepare_parameters(parameters, **overrides)" ) if any(isinstance(arg, str) for arg in inputs.values()): raise ValueError( "Pipeline does not support string inputs. " "Please ensure all inputs are convertible to tensors." ) if not self._logged_difference_keys: difference_keys = set(inputs.keys()) - self.valid_keys if difference_keys: log.debug( f"[zea.Pipeline] The following input keys are not used by the pipeline: " f"{difference_keys}. Make sure this is intended. " "This warning will only be shown once." ) self._logged_difference_keys = True ## PROCESSING _device = device if device is not None else self.device if _device is not None: outputs = func_on_device(self._call_pipeline, _device, **inputs) else: outputs = self._call_pipeline(**inputs) ## PREPARE OUTPUT if return_numpy: # Convert tensors to numpy arrays but preserve None values outputs = { k: ops.convert_to_numpy(v) if ops.is_tensor(v) else v for k, v in outputs.items() } return outputs
@property def jit_options(self): """Get the jit_options property of the pipeline.""" return self._jit_options
[docs] def set_jit(self, value: bool): """Set the JIT compilation for the pipeline.""" if value: self._jit() else: self._unjit()
@jit_options.setter def jit_options(self, value: Union[str, None]): """Set the jit_options property of the pipeline.""" assert value in ["pipeline", "ops", None], "jit_options must be 'pipeline', 'ops', or None" self._jit_options = value self.set_jit(value == "pipeline") for operation in self.operations: if isinstance(operation, Pipeline): operation.jit_options = value else: operation.set_jit(value == "ops") def _jit(self): """JIT compile the pipeline.""" if not self.jittable: raise ValueError( "Cannot JIT compile the pipeline because not all operations are jittable. " f"The following operations are not jittable: {self.unjitable_ops}" "Try setting jit_options to 'ops' or None." ) self._call_pipeline = jit(self.call, **self.jit_kwargs) def _unjit(self): """Un-JIT compile the pipeline.""" self._call_pipeline = self.call @property def jittable(self): """Check if all operations in the pipeline are jittable.""" return all(operation.jittable for operation in self.operations) @property def unjitable_ops(self): """Get a list of operations that are not jittable.""" return [operation for operation in self.operations if not operation.jittable] @property def with_batch_dim(self): """Get the with_batch_dim property of the pipeline.""" return self._with_batch_dim @with_batch_dim.setter def with_batch_dim(self, value): """Set the with_batch_dim property of the pipeline.""" self._with_batch_dim = value for operation in self.operations: operation.with_batch_dim = value @property def input_data_type(self): """Get the input_data_type property of the pipeline.""" return self.operations[0].input_data_type @property def output_data_type(self): """Get the output_data_type property of the pipeline.""" return self.operations[-1].output_data_type
[docs] def validate(self): """Validate the pipeline by checking the compatibility of the operations.""" operations = self.operations for i in range(len(operations) - 1): if operations[i].output_data_type is None: continue if operations[i + 1].input_data_type is None: continue if operations[i].output_data_type != operations[i + 1].input_data_type: raise ValueError( f"Operation {operations[i].__class__.__name__} output data type " f"({operations[i].output_data_type}) is not compatible " f"with the input data type ({operations[i + 1].input_data_type}) " f"of operation {operations[i + 1].__class__.__name__}" )
[docs] def set_params(self, **params): """Set parameters for the operations in the pipeline by adding them to the cache.""" for operation in self.operations: operation_params = { key: value for key, value in params.items() if key in operation.valid_keys } if operation_params: operation.set_input_cache(operation_params)
[docs] def get_params(self, per_operation: bool = False): """Get a snapshot of the current parameters of the operations in the pipeline. Args: per_operation (bool): If True, return a list of dictionaries for each operation. If False, return a single dictionary with all parameters combined. """ if per_operation: return [operation._input_cache.copy() for operation in self.operations] else: params = {} for operation in self.operations: params.update(operation._input_cache) return params
def __str__(self): """String representation of the pipeline.""" operations = [] for operation in self.operations: if isinstance(operation, Pipeline): operations.append(f"{operation.__class__.__name__}({str(operation)})") else: operations.append(operation.__class__.__name__) string = " -> ".join(operations) return string def __repr__(self): """String representation of the pipeline.""" operations = [] for operation in self.operations: if isinstance(operation, Pipeline): operations.append(repr(operation)) else: operations.append(operation.__class__.__name__) return f"Pipeline(name={self.name!r}, operations=[{', '.join(operations)}])"
[docs] @classmethod def load(cls, file_path: str, **kwargs) -> "Pipeline": """Load a pipeline from a JSON or YAML file.""" if file_path.endswith(".json"): with open(file_path, "r", encoding="utf-8") as f: json_str = f.read() return pipeline_from_json(json_str, **kwargs) elif file_path.endswith(".yaml") or file_path.endswith(".yml"): return cls.from_path(file_path, **kwargs) else: raise ValueError("File must have extension .json, .yaml, or .yml")
[docs] def get_dict(self, compact=True) -> dict: """Convert the pipeline to a dictionary. Args: compact (bool): If True (default), only include parameters that differ from their defaults. If False, include all parameters for full reproducibility. """ config = {"name": ops_registry.get_name(self)} config["operations"] = self._pipeline_to_list(self, compact=compact) if compact: params = {} if not self.with_batch_dim: params["with_batch_dim"] = self.with_batch_dim if self.jit_options != "ops": params["jit_options"] = self.jit_options if self._user_jit_kwargs: params["jit_kwargs"] = self._user_jit_kwargs if self.device is not None: params["device"] = self.device if params: config["params"] = params else: config["params"] = { "with_batch_dim": self.with_batch_dim, "jit_options": self.jit_options, "jit_kwargs": self._user_jit_kwargs, "device": self.device, } return config
@staticmethod def _pipeline_to_list(pipeline: "Pipeline", compact=True): """Convert the pipeline to a list of operations.""" ops_list = [] for op in pipeline.operations: if isinstance(op, Pipeline): ops_list.append(op.get_dict(compact=compact)) else: d = op.get_dict(compact=compact) if compact: params = d.get("params", {}) # Strip with_batch_dim when it is merely inherited from the pipeline if params.get("with_batch_dim") == pipeline.with_batch_dim: params.pop("with_batch_dim", None) # Strip jit_compile=False when it is implied by pipeline-level JIT if not params.get("jit_compile", True) and pipeline.jit_options in ( None, "pipeline", ): params.pop("jit_compile", None) if not params: d.pop("params", None) # Name-only dict → bare string shorthand if list(d.keys()) == ["name"]: d = d["name"] ops_list.append(d) return ops_list
[docs] @classmethod def from_config(cls, config: Dict, **kwargs) -> "Pipeline": """Create a pipeline from a dictionary or ``zea.Config`` object. Args: config (dict or Config): Configuration dictionary or ``zea.Config`` object. Must have a ``pipeline`` key with a subkey ``operations``. **kwargs: Additional keyword arguments to be passed to the pipeline. Example: .. doctest:: >>> from zea import Config, Pipeline >>> config = Config( ... { ... "pipeline": { ... "operations": [ ... "identity", ... ], ... } ... } ... ) >>> pipeline = Pipeline.from_config(config) """ return pipeline_from_config(Config(config), **kwargs)
[docs] @classmethod def from_path(cls, file_path: str, revision: str = None, **kwargs) -> "Pipeline": """Create a pipeline from a YAML/config file path. Args: file_path (str): Path to the config file (local or ``hf://`` URI). Must have a ``pipeline`` key with a subkey ``operations``. revision (str, optional): Revision of the config file (for Hugging Face ``hf://`` URIs). **kwargs: Additional keyword arguments to be passed to the pipeline. Example: .. doctest:: >>> from zea import Config, Pipeline >>> config = Config( ... { ... "pipeline": { ... "operations": [ ... "identity", ... ], ... } ... } ... ) >>> config.to_yaml("pipeline.yaml") >>> pipeline = Pipeline.from_path("pipeline.yaml") """ config = Config.from_path(file_path, revision=revision) return pipeline_from_config(config, **kwargs)
[docs] @classmethod @deprecated(replacement="Pipeline.from_path") def from_yaml(cls, file_path: str, **kwargs) -> "Pipeline": """Deprecated. Use :meth:`from_path` instead.""" return pipeline_from_yaml(file_path, **kwargs)
[docs] @classmethod def from_json(cls, json_string: str, **kwargs) -> "Pipeline": """Create a pipeline from a JSON string. Args: json_string (str): JSON string representing the pipeline. Must have a ``pipeline`` key with a subkey ``operations``. **kwargs: Additional keyword arguments to be passed to the pipeline. Example: ```python json_string = '{"pipeline": {"operations": ["identity"]}}' pipeline = Pipeline.from_json(json_string) ``` """ return pipeline_from_json(json_string, **kwargs)
[docs] def to_config(self, compact=True) -> Config: """Convert the pipeline to a `zea.Config` object.""" return pipeline_to_config(self, compact=compact)
[docs] def to_json(self, compact=True) -> str: """Convert the pipeline to a JSON string.""" return pipeline_to_json(self, compact=compact)
[docs] def to_yaml(self, file_path: str, compact=True) -> None: """Convert the pipeline to a YAML file.""" pipeline_to_yaml(self, file_path, compact=compact)
@property def key(self) -> str: """Input key of the pipeline.""" return self.operations[0].key @property def output_key(self) -> str: """Output key of the pipeline.""" return self.operations[-1].output_key def __eq__(self, other): """Check if two pipelines are equal.""" if not isinstance(other, Pipeline): return False # Compare the operations in both pipelines if len(self.operations) != len(other.operations): return False for op1, op2 in zip(self.operations, other.operations): if not op1 == op2: return False return True
[docs] def prepare_parameters( self, parameters: "Parameters" = None, device: Union[str, None] = None, **overrides, ): """Prepare a :class:`~zea.Parameters` object for the pipeline. Converts the (validated and derived) parameters needed by this pipeline's operations into a dictionary of tensors, then overlays any manually supplied overrides (e.g. ``config.parameters`` or ad-hoc keyword arguments). Overrides take priority over the values in ``parameters``. Args: parameters: :class:`~zea.Parameters` object. Only the keys this pipeline ``needs`` (and that are not overridden) are converted, so derivation is lazy and minimal. device: Device to place the tensors on. Defaults to the pipeline device. **overrides: Additional parameters to include in the inputs (converted to tensors). These overwrite values taken from ``parameters``. Returns: dict: Dictionary of inputs with all values as tensors. Example: .. code-block:: python inputs = pipeline.prepare_parameters(parameters, **config.parameters) outputs = pipeline(data=raw_data, **inputs) """ from zea.scan import Parameters _device = device if device is not None else self.device params_dict = {} override_keys = set(overrides.keys()) if parameters is not None: assert isinstance(parameters, Parameters), ( f"Expected an instance of `zea.Parameters`, got {type(parameters)}" ) # Only convert keys the pipeline needs and that are not overridden, # so we avoid deriving unnecessary parameters. needs_keys = self.needs_keys - override_keys with backend.device(_device): params_dict = parameters.to_tensor( include=needs_keys, keep_as_is=self.static_params ) # Convert all overrides to tensors with backend.device(_device): tensor_overrides = dict_to_tensor(overrides, keep_as_is=self.static_params) # Overrides overwrite values taken from the parameters object. return {**params_dict, **tensor_overrides}
[docs] @ops_registry("map") class Map(Pipeline): """ A pipeline that maps its operations over specified input arguments. This can be used to reduce memory usage by processing data in chunks. Notes ----- - When `chunks` and `batch_size` are both None (default), this behaves like a normal Pipeline. - Changing anything other than ``self.output_key`` in the dict will not be propagated. - Will be jitted as a single operation, not the individual operations. - This class handles the batching. For more information on how to use ``in_axes``, ``out_axes``, `see the documentation for jax.vmap <https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html>`_. Example ------- .. doctest:: >>> from zea.ops import Map, Pipeline, Demodulate, TOFCorrection >>> # apply operations in batches of 8 >>> # in this case, over the first axis of "data" >>> # or more specifically, process 8 transmits at a time >>> pipeline_mapped = Map( ... [ ... Demodulate(), ... TOFCorrection(), ... ], ... argnames="data", ... batch_size=8, ... ) >>> # you can also map a subset of the operations >>> # for example, demodulate in 4 chunks >>> # or more specifically, split the transmit axis into 4 parts >>> pipeline_mapped = Pipeline( ... [ ... Map([Demodulate()], argnames="data", chunks=4), ... TOFCorrection(), ... ], ... ) """ def __init__( self, operations: List[Operation], argnames: List[str] | str, in_axes: List[Union[int, None]] | int = 0, out_axes: List[Union[int, None]] | int = 0, chunks: int | None = None, batch_size: int | None = None, **kwargs, ): """ Args: operations (list): List of operations to be performed. argnames (str or list): List of argument names (or keys) to map over. Can also be a single string if only one argument is mapped over. in_axes (int or list): Axes to map over for each argument. If a single int is provided, it is used for all arguments. out_axes (int or list): Axes to map over for each output. If a single int is provided, it is used for all outputs. chunks (int, optional): Number of chunks to split the input data into. If None, no chunking is performed. Mutually exclusive with ``batch_size``. batch_size (int, optional): Size of batches to process at once. If None, no batching is performed. Mutually exclusive with ``chunks``. """ super().__init__(operations, **kwargs) if batch_size is not None and chunks is not None: raise ValueError( "batch_size and chunks are mutually exclusive. Please specify only one." ) if batch_size is not None and batch_size <= 0: raise ValueError("batch_size must be a positive integer.") if chunks is not None and chunks <= 0: raise ValueError("chunks must be a positive integer.") if isinstance(argnames, str): argnames = [argnames] self.argnames = argnames self.in_axes = in_axes self.out_axes = out_axes self.chunks = chunks self.batch_size = batch_size if chunks is None and batch_size is None: log.warning( "[zea.ops.Map] Both `chunks` and `batch_size` are None. " "This will behave like a normal Pipeline. " "Consider setting one of them to process data in chunks or batches." )
[docs] def call_item(self, **inputs): """Process data in patches.""" mapped_args = [] for argname in self.argnames: mapped_args.append(inputs.pop(argname, None)) def patched_call(*args): mapped_kwargs = [(k, v) for k, v in zip(self.argnames, args)] out = super(Map, self).call(**dict(mapped_kwargs), **inputs) # TODO: maybe it is possible to output everything? # e.g. prepend a empty dimension to all inputs and just map over everything? return out[self.output_key] out = vmap( patched_call, in_axes=self.in_axes, out_axes=self.out_axes, chunks=self.chunks, batch_size=self.batch_size, fn_supports_batch=True, disable_jit=not bool(self.jit_options), )(*mapped_args) return out
@property def jit_options(self): """Get the jit_options property of the pipeline.""" return self._jit_options @jit_options.setter def jit_options(self, value: Union[str, None]): """Set the jit_options property of the pipeline.""" assert value in ["pipeline", "ops", None], "jit_options must be 'pipeline', 'ops', or None" self._jit_options = value self.set_jit(value == "pipeline" or value == "ops") for operation in self.operations: if isinstance(operation, Pipeline): operation.jit_options = None else: operation.set_jit(False) def _jit(self): """JIT compile the pipeline.""" self._jittable_call = jit(self.jittable_call, **self.jit_kwargs) def _unjit(self): """Un-JIT compile the pipeline.""" self._jittable_call = self.jittable_call @property def with_batch_dim(self): """Get the with_batch_dim property of the pipeline.""" return self._with_batch_dim @with_batch_dim.setter def with_batch_dim(self, value): """Set the with_batch_dim property of the pipeline. The class handles the batching so the operations have to be set to False.""" self._with_batch_dim = value for operation in self.operations: operation.with_batch_dim = False
[docs] def jittable_call(self, **inputs): """Process input data through the pipeline.""" if self._with_batch_dim: input_data = inputs.pop(self.key) output = ops.map( lambda x: self.call_item(**{self.key: x, **inputs}), input_data, ) else: output = self.call_item(**inputs) return {self.output_key: output}
[docs] def call(self, **inputs): """Process input data through the pipeline.""" output = self._jittable_call(**inputs) inputs.update(output) return inputs
[docs] def get_dict(self, compact=True): """Get the configuration of the pipeline.""" config = super().get_dict(compact=compact) config["name"] = "map" params = config.get("params", {}) params["argnames"] = self.argnames if not compact or self.in_axes != 0: params["in_axes"] = self.in_axes if not compact or self.out_axes != 0: params["out_axes"] = self.out_axes if not compact or self.chunks is not None: params["chunks"] = self.chunks if not compact or self.batch_size is not None: params["batch_size"] = self.batch_size config["params"] = params return config
[docs] @ops_registry("patched_grid") class PatchedGrid(Map): """ A pipeline that maps its operations over `flatgrid` and `flat_pfield` keys. This can be used to reduce memory usage by processing data in chunks. For more information and flexibility, see :class:`zea.ops.Map`. """ def __init__(self, *args, num_patches=10, **kwargs): super().__init__(*args, argnames=["flatgrid", "flat_pfield"], chunks=num_patches, **kwargs) self.num_patches = num_patches
[docs] def get_dict(self, compact=True): """Get the configuration of the pipeline.""" config = super().get_dict(compact=compact) config["name"] = "patched_grid" params = config.get("params", {}) params.pop("argnames", None) params.pop("chunks", None) params["num_patches"] = self.num_patches config["params"] = params return config
[docs] @ops_registry("beamform") class Beamform(Pipeline): """Classical beamforming pipeline for ultrasound image formation. Expected input data type is `DataTypes.RF_DATA` which has shape `(n_tx, n_ax, n_el, n_ch)`. Will run the following operations in sequence: - TOFCorrection (output type `DataTypes.ALIGNED_DATA`: `(n_tx, n_ax, n_el, n_ch)`) - PfieldWeighting (optional, output type `DataTypes.ALIGNED_DATA`: `(n_tx, n_ax, n_el, n_ch)`) - Sum over channels (DAS) - Sum over transmits (Compounding) (output type `DataTypes.BEAMFORMED_DATA`: `(grid_size_z, grid_size_x, n_ch)`) - ReshapeGrid (flattened grid is also reshaped to `(grid_size_z, grid_size_x)`) """ # noqa: E501 def __init__(self, beamformer="delay_and_sum", num_patches=100, enable_pfield=False, **kwargs): """Initialize a Delay-and-Sum beamforming `zea.Pipeline`. Args: beamformer (str): Type of beamformer to use. Currently supporting: - "delay_and_sum" - "delay_multiply_and_sum" - "coherence_factor" - "generalized_coherence_factor" Defaults to "delay_and_sum". num_patches (int): Number of patches to split the grid into for patch-wise beamforming. If 1, no patching is performed. enable_pfield (bool): Whether to include pressure field weighting in the beamforming. """ self.beamformer_type = beamformer self.num_patches = num_patches self.enable_pfield = enable_pfield # for backwards compatibility name_mapping = { "das": "delay_and_sum", "dmas": "delay_multiply_and_sum", } if beamformer in name_mapping: log.deprecated( f"Beamformer name '{beamformer}' is deprecated. " f"Please use '{name_mapping[beamformer]}' instead." ) self.beamformer_type = name_mapping[beamformer] if self.beamformer_type not in beamformer_registry: raise ValueError( f"Unsupported beamformer type: '{self.beamformer_type}'. " f"Supported types are: {beamformer_registry.registered_names()}." ) # Get beamforming ops beamforming = [ TOFCorrection(), # PfieldWeighting(), # Inserted conditionally beamformer_registry[self.beamformer_type](), ] if self.enable_pfield: beamforming.insert(1, PfieldWeighting()) # Optionally add patching if self.num_patches > 1: beamforming = [ PatchedGrid( operations=beamforming, num_patches=self.num_patches, **kwargs, ) ] # Reshape the grid to image shape beamforming.append(ReshapeGrid()) # Set the output data type of the last operation # which also defines the pipeline output type beamforming[-1].output_data_type = DataTypes.BEAMFORMED_DATA super().__init__(operations=beamforming, **kwargs) def __repr__(self): """String representation of the pipeline.""" operations = [] for operation in self.operations: if isinstance(operation, Pipeline): operations.append(repr(operation)) else: operations.append(operation.__class__.__name__) return f"Beamform(name={self.name!r}, operations=[{', '.join(operations)}])"
[docs] def get_dict(self, compact=True) -> dict: """Convert the pipeline to a dictionary. Unlike Pipeline.get_dict(), this does NOT include the internal operations list, since Beamform auto-generates its operations from ``beamformer``, ``num_patches``, and ``enable_pfield``. """ config = super().get_dict(compact=compact) config.pop("operations", None) params = {} if not compact or self.beamformer_type != "delay_and_sum": params["beamformer"] = self.beamformer_type if not compact or self.num_patches != 100: params["num_patches"] = self.num_patches if not compact or self.enable_pfield: params["enable_pfield"] = self.enable_pfield # Merge in the pipeline-level params from super(). params.update(config.get("params", {})) if params: config["params"] = params else: config.pop("params", None) return config
[docs] @beamformer_registry("delay_and_sum") @ops_registry("delay_and_sum") class DelayAndSum(Operation): """Sums time-delayed signals along channels and transmits.""" def __init__(self, **kwargs): super().__init__( input_data_type=DataTypes.ALIGNED_DATA, output_data_type=DataTypes.BEAMFORMED_DATA, **kwargs, )
[docs] def call(self, **kwargs): """Performs DAS beamforming on tof-corrected input. Args: tof_corrected_data (ops.Tensor): The TOF corrected input of shape `(n_tx, prod(grid.shape), n_el, n_ch)` with optional batch dimension. Returns: dict: Dictionary containing beamformed_data of shape `(prod(grid.shape), n_ch)` with optional batch dimension. """ data = kwargs[self.key] # Sum over the channels (n_el), i.e. DAS beamformed_data = ops.sum(data, -2) # Sum over transmits (n_tx), i.e. Compounding beamformed_data = ops.sum(beamformed_data, -3) return {self.output_key: beamformed_data}
[docs] @beamformer_registry("delay_multiply_and_sum") @ops_registry("delay_multiply_and_sum") class DelayMultiplyAndSum(Operation): """Performs the operations for the Delay-Multiply-and-Sum beamformer except the delay. The delay should be performed by the TOF correction operation. """ def __init__(self, **kwargs): super().__init__( input_data_type=DataTypes.ALIGNED_DATA, output_data_type=DataTypes.BEAMFORMED_DATA, **kwargs, )
[docs] def process_image(self, data): """Performs DMAS beamforming on tof-corrected input. Args: data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)` Returns: ops.Tensor: The beamformed data of shape `(n_pix, n_ch)` """ if not data.shape[-1] == 2: raise ValueError( "MultiplyAndSum operation requires IQ data with 2 channels. " f"Got data with shape {data.shape}." ) # Compute the correlation matrix data = channels_to_complex(data) data = self._multiply(data) data = self._select_lower_triangle(data) data = ops.sum(data, axis=(0, 2, 3)) data = complex_to_channels(data) return data
def _select_lower_triangle(self, data): """Select only the lower triangle of the correlation matrix.""" n_el = data.shape[3] mask = ops.ones((n_el, n_el), dtype=data.dtype) - ops.eye(n_el, dtype=data.dtype) data = data * mask[None, None, :, :] / 2 return data def _multiply(self, data): """Apply the DMAS multiplication step.""" channel_products = data[:, :, :, None] * data[:, :, None, :] # Signed square root: sign(z) * sqrt(|z|) == z / sqrt(|z|). # Written this way to avoid ops.sign on complex data (torch.sign # does not support complex numbers; use torch.sgn instead). eps = keras.backend.epsilon() safe_sqrt = ops.cast(ops.sqrt(ops.abs(channel_products) + eps**2), channel_products.dtype) data = channel_products / safe_sqrt return data
[docs] def call(self, **kwargs): """Performs DMAS beamforming on tof-corrected input. Args: tof_corrected_data (ops.Tensor): The TOF corrected input of shape `(n_tx, prod(grid.shape), n_el, n_ch)` with optional batch dimension. Returns: dict: Dictionary containing beamformed_data of shape `(grid_size_z*grid_size_x, n_ch)` with optional batch dimension. """ data = kwargs[self.key] if not self.with_batch_dim: beamformed_data = self.process_image(data) else: # Apply process_image to each item in the batch beamformed_data = ops.map(self.process_image, data) return {self.output_key: beamformed_data}
[docs] @beamformer_registry("coherence_factor") @ops_registry("coherence_factor") class CoherenceFactor(Operation): r"""Coherence Factor (CF) Beamformer. The Coherence Factor is a pixel-dependent weight used to quantify the focus quality of the beamformed signal. It is the ratio of the coherent power to the incoherent power of the signals received across the transducer aperture. For a set of delayed signals :math:`x_i` across :math:`N` elements: .. math:: \mathrm{CF} = \frac{\left|\sum_{i=1}^{N} x_i\right|^2} {N \sum_{i=1}^{N} \left|x_i\right|^2} The CF ranges from 0 (completely incoherent) to 1 (perfectly coherent). The beamformed output is the standard DAS sum weighted by CF per transmit, then compounded across transmits. .. admonition:: Reference Hollman, K. W., Rigby, K. W., & O'Donnell, M. (1999). Coherence factor of speckle from a multi-row probe. IEEE Ultrasonics Symposium. Args: **kwargs: Additional arguments passed to the Operation base class. """ def __init__(self, **kwargs): super().__init__( input_data_type=DataTypes.ALIGNED_DATA, output_data_type=DataTypes.BEAMFORMED_DATA, **kwargs, )
[docs] def process_image(self, data): """Applies CF weighting and compounding on tof-corrected input. Args: data (ops.Tensor): TOF-corrected input of shape ``(n_tx, n_pix, n_el, n_ch)``, with optional batch dimension. Returns: ops.Tensor: Beamformed image of shape ``(n_pix, n_ch)``, with optional batch dimension. """ n_el = ops.shape(data)[-2] # DAS per transmit: sum over elements das_per_tx = ops.sum(data, axis=-2) # Coherent power: |sum_i(x_i)|^2, works for both RF (n_ch=1) and IQ (n_ch=2) coherent_power = ops.sum(ops.square(das_per_tx), axis=-1, keepdims=True) # Incoherent power: N * sum_i(|x_i|^2) incoherent_power = n_el * ops.sum( ops.sum(ops.square(data), axis=-1), axis=-1, keepdims=True ) # CF weight, clipped to [0, 1] by construction when incoherent_power > 0 cf_weight = coherent_power / (incoherent_power + keras.backend.epsilon()) # Apply weight per transmit, then compound return ops.sum(das_per_tx * cf_weight, axis=-3)
[docs] def call(self, **kwargs): """Performs CF beamforming on tof-corrected input. Args: tof_corrected_data (ops.Tensor): TOF-corrected input of shape ``(n_tx, n_pix, n_el, n_ch)``, with optional batch dimension. Returns: dict: Dictionary containing beamformed data of shape ``(n_pix, n_ch)``, with optional batch dimension. """ data = kwargs[self.key] return {self.output_key: self.process_image(data)}
[docs] @beamformer_registry("generalized_coherence_factor") @ops_registry("generalized_coherence_factor") class GeneralizedCoherenceFactor(Operation): r"""Generalized Coherence Factor (GCF) Beamformer. The GCF is a coherence-based adaptive weighting technique used to improve the quality of ultrasound images by suppressing sidelobes and clutter. It is defined as the ratio of the power within a low-frequency region of the spatial spectrum to the total power across the aperture. For a given pixel, let :math:`A(k)` be the spatial Fourier transform of the delayed channel data across :math:`N` elements. The GCF is: .. math:: \mathrm{GCF} = \frac{\sum_{k \in \mathcal{M}_0} \left|A(k)\right|^2} {\sum_{k=0}^{N-1} \left|A(k)\right|^2} where :math:`\mathcal{M}_0 = \{k : k \leq m_0\} \cup \{k : k \geq N - m_0\}` is the low spatial-frequency region controlled by :math:`m_0`. .. admonition:: Reference Li, P. C., & Li, M. L. (2003). "Adaptive imaging using the generalized coherence factor." IEEE Transactions on Ultrasonics, Ferroelectrics, and Frequency Control, 50(2), 128-141. Args: m_zero (int): Cutoff frequency index for the low-frequency spatial region. Higher values increase tolerance to phase aberrations. Defaults to ``4``. **kwargs: Additional arguments passed to the Operation base class. """ def __init__(self, m_zero=4, **kwargs): if not isinstance(m_zero, int) or m_zero < 0: raise ValueError(f"m_zero must be a non-negative integer, got {m_zero!r}.") super().__init__( input_data_type=DataTypes.ALIGNED_DATA, output_data_type=DataTypes.BEAMFORMED_DATA, **kwargs, ) self.m_zero = m_zero
[docs] def process_image(self, data, m_zero=None): """Applies GCF weighting and compounding on tof-corrected input. Args: data (ops.Tensor): TOF-corrected input of shape ``(n_tx, n_pix, n_el, n_ch)``, with optional batch dimension. m_zero (int, optional): Overrides the instance ``m_zero`` for this call. Returns: ops.Tensor: Beamformed image of shape ``(n_pix, n_ch)``, with optional batch dimension. """ if m_zero is None: m_zero = self.m_zero n_el = ops.shape(data)[-2] n_ch = data.shape[-1] # static Python int — safe for branching # Move n_el to last axis for spatial FFT: (..., n_tx, n_pix, n_ch, n_el) spatial_data = ops.moveaxis(data, -2, -1) real_in = spatial_data[..., 0, :] imag_in = spatial_data[..., 1, :] if n_ch == 2 else ops.zeros_like(real_in) # Spatial FFT power spectrum across elements real_fft, imag_fft = ops.fft((real_in, imag_in)) power_spectrum = ops.square(real_fft) + ops.square(imag_fft) # Total energy and low-frequency energy total_energy = ops.sum(power_spectrum, axis=-1, keepdims=True) indices = ops.arange(n_el) low_freq_mask = ops.logical_or( ops.less_equal(indices, m_zero), ops.greater_equal(indices, n_el - m_zero), ) low_freq_energy = ops.sum( ops.where(low_freq_mask, power_spectrum, 0.0), axis=-1, keepdims=True, ) # GCF weight gcf_weight = low_freq_energy / (total_energy + keras.backend.epsilon()) # DAS per transmit, apply weight, then compound das_per_tx = ops.sum(data, axis=-2) return ops.sum(das_per_tx * gcf_weight, axis=-3)
[docs] def call(self, m_zero=None, **kwargs): """Performs GCF beamforming on tof-corrected input. Args: m_zero (int, optional): Cutoff frequency index, overrides the instance default when provided via pipeline parameters. tof_corrected_data (ops.Tensor): TOF-corrected input of shape ``(n_tx, n_pix, n_el, n_ch)``, with optional batch dimension. Returns: dict: Dictionary containing beamformed data of shape ``(n_pix, n_ch)``, with optional batch dimension. """ data = kwargs[self.key] return {self.output_key: self.process_image(data, m_zero=m_zero)}
[docs] @ops_registry("refocus") class Refocus(Operation): r"""REFoCUS (Retrospective Encoding For Conventional Ultrasound Sequences). Decodes any transmit data into synthetic aperture (multistatic / full-matrix capture) data by inverting the transmit encoding model in the frequency domain. The transmit encoding is modelled as a matrix :math:`H` whose entry :math:`H_{t,e}` describes the complex phase shift applied to element :math:`e` during transmit event :math:`t`: .. math:: H_{t,e}(f) = a_{t,e} \exp(-j 2\pi f \tau_{t,e}) where :math:`\tau_{t,e}` is the transmit delay in samples and :math:`a_{t,e}` is the apodization. At each temporal frequency the received RF spectrum is decoded by multiplying with the pseudo-inverse :math:`H^{-1}`: .. math:: \hat{U}(f) = H^{-1}(f) \, S(f) producing a synthetic aperture dataset where each decoded channel corresponds to a virtual single-element transmission. The **input** data has shape ``(n_tx, n_ax, n_el, n_ch)`` and the **output** has shape ``(n_el, n_ax, n_el, n_ch)``, where the new first axis indexes the decoded virtual transmit elements. .. admonition:: References Bottenus, N. (2018). "Recovery of the complete data set from focused transmit beams." *IEEE Transactions on Ultrasonics, Ferroelectrics, and Frequency Control*, 65(1), 30–38. Ali, R., Dahl, J., & Bottenus, N. (2019). "Extending Retrospective Encoding for Robust Recovery of the Multistatic Dataset." *IEEE Transactions on Ultrasonics, Ferroelectrics, and Frequency Control*, 67(5), 943–956. https://github.com/nbottenus/REFoCUS Args: method (str): Inversion method. One of: - ``'adjoint'``: Adjoint (matched-filter) pseudo-inverse with an optional ramp filter in frequency. Default. - ``'tikhonov'``: Tikhonov-regularized inverse. - ``'rsvd'``: Regularized SVD-based inverse. - ``'tsvd'``: Truncated SVD-based inverse. param (float or None): Regularization / filter parameter. - ``'adjoint'``: ``None`` applies a ramp filter (multiply by :math:`f`). Set to ``0`` to disable the ramp filter. Defaults to ``None``. - ``'tikhonov'``, ``'rsvd'``, ``'tsvd'``: Relative regularization strength. Defaults to ``1e-2`` when ``None``. **kwargs: Additional arguments forwarded to :class:`~zea.ops.Operation`. """ _VALID_METHODS = ("adjoint", "tikhonov", "rsvd", "tsvd") def __init__(self, method="adjoint", param=None, **kwargs): if method not in self._VALID_METHODS: raise ValueError(f"method must be one of {self._VALID_METHODS}, got '{method}'") # SVD is not supported by TF XLA; disable JIT for SVD-based methods. if method != "adjoint": if kwargs.get("jit_compile", True): log.warning( f"Refocus method='{method}' uses SVD, which is not supported by the XLA " "JIT compiler. Setting jit_compile=False." ) kwargs["jit_compile"] = False super().__init__( input_data_type=DataTypes.RAW_DATA, output_data_type=DataTypes.RAW_DATA, **kwargs, ) self.method = method self.param = param def _get_hinv(self, delays, f_vec, apod): """Compute batched Hinv for all normalised frequencies at once. Args: delays: ``(n_tx, n_el)`` delays in samples. f_vec: ``(n_freq,)`` normalised frequencies (cycles/sample). apod: ``(n_tx, n_el)`` apodization. Returns: Hinv: ``(n_freq, n_el, n_tx)`` complex64 tensor. """ # H: (n_freq, n_tx, n_el) f_c = ops.cast(f_vec[:, None, None], "complex64") d_c = ops.cast(delays[None], "complex64") a_c = ops.cast(apod[None], "complex64") H = a_c * ops.exp(ops.cast(-1j * 2 * np.pi, "complex64") * f_c * d_c) if self.method == "adjoint": # param=None → ramp filter (multiply by f) # param=0 → no ramp (multiply by 1, plain adjoint) Hinv = ops.conj(ops.transpose(H, (0, 2, 1))) ramp_vals = f_vec if self.param is None else ops.ones_like(f_vec) ramp = ops.cast(ramp_vals, "complex64")[:, None, None] return ramp * Hinv # SVD-based methods U, s, VH = ops.svd(H, full_matrices=False) lam = self.param if self.param is not None else 1e-2 if self.method in ("tikhonov", "rsvd"): sinv = s / (s**2 + (lam * s[:, 0:1]) ** 2) else: # tsvd threshold = lam * s[:, 0:1] safe_s = ops.where(s >= threshold, s, ops.ones_like(s)) sinv = ops.where(s >= threshold, 1.0 / safe_s, ops.zeros_like(s)) VHT = ops.conj(ops.transpose(VH, (0, 2, 1))) # (n_freq, n_el, k) UT = ops.conj(ops.transpose(U, (0, 2, 1))) # (n_freq, k, n_tx) sinv_c = ops.cast(sinv, "complex64") return ops.matmul(VHT * sinv_c[:, None, :], UT) def _decode(self, data, delays_samples, apod): """REFoCUS decoding for a single (unbatched) volume. All channels and all frequency bins are processed in parallel via batched tensor operations. Args: data: ``(n_tx, n_ax, n_el, n_ch)`` float32 RF array. delays_samples: ``(n_tx, n_el)`` transmit delays in samples. apod: ``(n_tx, n_el)`` transmit apodization. Returns: decoded: ``(n_el, n_ax, n_el, n_ch)`` float32 array. """ n_tx, n_ax, n_el, n_ch = data.shape n_elements = delays_samples.shape[1] # --- FFT over all channels at once --- # data: (n_tx, n_ax, n_el, n_ch) -> (n_ch, n_el, n_tx, n_ax) rf = ops.cast(ops.transpose(data, (3, 2, 0, 1)), "float32") # (n_ch, n_el_recv, n_tx, n_freq) RF_enc_r, RF_enc_i = ops.rfft(rf) RF_enc = ops.cast(RF_enc_r, "complex64") + 1j * ops.cast(RF_enc_i, "complex64") n_freq = RF_enc.shape[-1] # Rearrange to (n_freq, n_tx, n_el_recv * n_ch) for batched matmul. # (n_ch, n_el_recv, n_tx, n_freq) -> (n_freq, n_tx, n_el_recv, n_ch) RF_enc = ops.transpose(RF_enc, (3, 2, 1, 0)) # -> (n_freq, n_tx, n_el_recv * n_ch) RF_enc = ops.reshape(RF_enc, (n_freq, n_tx, n_el * n_ch)) # --- Batched inverse encoding matrices (skip DC at index 0) --- frequency = ops.cast(ops.arange(n_freq), "float32") / n_ax freq_noDC = frequency[1:] # (n_freq - 1,) # Hinv: (n_freq - 1, n_elements, n_tx) Hinv = self._get_hinv(delays_samples, freq_noDC, apod) # --- Single batched matmul over all frequencies and channels --- # (n_freq-1, n_elements, n_tx) @ (n_freq-1, n_tx, n_el_recv * n_ch) # -> (n_freq-1, n_elements, n_el_recv * n_ch) RF_dec = ops.matmul(Hinv, RF_enc[1:]) # Prepend zeros for the DC bin: (n_freq, n_elements, n_el_recv * n_ch) dc = ops.zeros((1, n_elements, n_el * n_ch), dtype="complex64") RF_decoded = ops.concatenate([dc, RF_dec], axis=0) # --- IFFT back to time domain --- # Reshape to (n_freq, n_elements, n_el_recv, n_ch) RF_decoded = ops.reshape(RF_decoded, (n_freq, n_elements, n_el, n_ch)) # irfft acts on the last axis: move n_freq last # -> (n_elements, n_el_recv, n_ch, n_freq) RF_decoded = ops.transpose(RF_decoded, (1, 2, 3, 0)) # -> (n_elements, n_el_recv, n_ch, n_ax) rf_decoded = ops.irfft((ops.real(RF_decoded), ops.imag(RF_decoded)), fft_length=n_ax) # -> (n_elements, n_ax, n_el_recv, n_ch) rf_decoded = ops.transpose(rf_decoded, (0, 3, 1, 2)) return ops.cast(rf_decoded, "float32") # ------------------------------------------------------------------ # Operation interface # ------------------------------------------------------------------
[docs] def call( self, t0_delays, sampling_frequency, probe_geometry, initial_times, tx_apodizations=None, **kwargs, ): """Decode plane-wave / focused transmit data into multistatic data. After decoding the output is a synthetic-aperture (SA) dataset where each virtual transmit corresponds to a single element firing. The pipeline parameters that describe the transmit sequence are updated accordingly so that downstream operations (TOF correction, pfield weighting, etc.) remain consistent with the new data shape. Args: t0_delays: ``(n_tx, n_el)`` transmit delays in **seconds**. sampling_frequency: Sampling frequency in Hz. probe_geometry: ``(n_el, 3)`` element positions in metres. tx_apodizations: ``(n_tx, n_el)`` transmit apodization weights. Defaults to all-ones (uniform apodization). **kwargs: Must contain the input data tensor under ``self.key``. Returns: dict with keys: * ``self.output_key`` — decoded data ``(n_el, n_ax, n_el, n_ch)`` (or batched variant). * ``"t0_delays"`` — zeros ``(n_el, n_el)`` (SA: no extra delay). * ``"tx_apodizations"`` — identity ``(n_el, n_el)`` (one element per virtual transmit). * ``"polar_angles"`` — zeros ``(n_el,)`` (no steering). * ``"focus_distances"`` — zeros ``(n_el,)`` (no focus). * ``"transmit_origins"`` — element positions ``(n_el, 3)``. * ``"initial_times"`` — zeros ``(n_el,)``. * ``"t_peak"`` — shared transmit-waveform peak time ``(n_el,)``. * ``"flat_pfield"`` — ``None`` (resets pfield so downstream :class:`PfieldWeighting` becomes a no-op). """ data = kwargs[self.key] delays_samples = (t0_delays - initial_times[..., None]) * ops.cast( sampling_frequency, t0_delays.dtype ) if tx_apodizations is None: apod = ops.ones_like(delays_samples) else: apod = tx_apodizations if self.with_batch_dim: decoded = vmap(self._decode, in_axes=(0, None, None))(data, delays_samples, apod) else: decoded = self._decode(data, delays_samples, apod) # Number of virtual SA transmits = number of elements n_el = ops.shape(probe_geometry)[0] dtype = t0_delays.dtype sa_t0_delays = ops.zeros((n_el, n_el), dtype=dtype) sa_tx_apodizations = ops.eye(n_el, dtype=dtype) sa_polar_angles = ops.zeros((n_el,), dtype=dtype) sa_focus_distances = ops.zeros((n_el,), dtype=dtype) sa_initial_times = ops.zeros((n_el,), dtype=dtype) t_peak = kwargs.get("t_peak") if t_peak is not None: t_peak_flat = ops.reshape(ops.cast(t_peak, dtype), (-1,)) sa_t_peak = ops.broadcast_to(t_peak_flat[:1], (n_el,)) else: sa_t_peak = ops.zeros((n_el,), dtype=dtype) return { self.output_key: decoded, "t0_delays": sa_t0_delays, "tx_apodizations": sa_tx_apodizations, "polar_angles": sa_polar_angles, "focus_distances": sa_focus_distances, "transmit_origins": probe_geometry, "initial_times": sa_initial_times, "t_peak": sa_t_peak, "flat_pfield": None, }
def make_operation_chain( operation_chain: List[Union[str, Dict, Config, Operation, Pipeline]], ) -> List[Operation]: """Make an operation chain from a custom list of operations. Args: operation_chain (list): List of operations to be performed. Each operation can be: - A string: operation initialized with default parameters - A dictionary: operation initialized with parameters in the dictionary - A Config object: converted to a dictionary and initialized - An Operation/Pipeline instance: used as-is Returns: list: List of operations to be performed. Example: .. doctest:: >>> from zea.ops import make_operation_chain, LogCompress >>> SomeCustomOperation = LogCompress # just for demonstration >>> chain = make_operation_chain( ... [ ... "envelope_detect", ... {"name": "normalize", "params": {"output_range": (0, 1)}}, ... SomeCustomOperation(), ... ] ... ) """ chain = [] for operation in operation_chain: # Handle already instantiated Operation or Pipeline objects if isinstance(operation, (Operation, Pipeline)): chain.append(operation) continue assert isinstance(operation, (str, dict, Config)), ( f"Operation {operation} should be a string, dict, Config object, Operation, or Pipeline" ) if isinstance(operation, str): operation_instance = get_ops(operation)() else: if isinstance(operation, Config): operation = operation.serialize() params = operation.get("params", {}) op_name = operation.get("name") operation_cls = get_ops(op_name) # Check for nested operations at the same level as params if "operations" in operation: nested_operations = make_operation_chain(operation["operations"]) # Instantiate pipeline-type operations with nested operations if issubclass(operation_cls, Beamform): # some pipelines, such as `zea.ops.Beamformer`, are initialized # not with a list of operations but with other parameters that then # internally create a list of operations operation_instance = operation_cls(**params) elif issubclass(operation_cls, Pipeline): # in most cases we want to pass an operations list to # initialize a pipeline operation_instance = operation_cls(operations=nested_operations, **params) else: operation_instance = operation_cls(operations=nested_operations, **params) else: operation_instance = operation_cls(**params) chain.append(operation_instance) return chain def pipeline_from_config(config: Config, **kwargs) -> Pipeline: """ Create a Pipeline instance from a Config object. The config must have a top-level ``pipeline`` key containing an ``operations`` list. """ if "pipeline" not in config: top_keys = list(config.keys()) if hasattr(config, "keys") else [] raise ValueError( f"Cannot build Pipeline: missing top-level 'pipeline' key.\n" f"Expected a config with the format:\n" f" pipeline:\n" f" operations:\n" f" - <operation_name>\n" f" - ...\n" f"Found top-level keys: {top_keys}" ) # Unwrap the pipeline subsection from a full config config = Config(config["pipeline"]) if "operations" not in config: top_keys = list(config.keys()) if hasattr(config, "keys") else [] raise ValueError( f"Cannot build Pipeline: missing 'operations' key.\n" f"Expected a config with the format:\n" f" pipeline:\n" f" operations:\n" f" - <operation_name>\n" f" - ...\n" f"Found top-level keys: {top_keys}" ) if not isinstance(config.operations, (list, np.ndarray)): raise ValueError( f"Cannot build Pipeline: 'operations' must be a list, " f"got {type(config.operations).__name__}." ) operations = make_operation_chain(config.operations) # merge pipeline config without operations with kwargs pipeline_config = config.copy() pipeline_config.pop("operations") kwargs = {**pipeline_config, **kwargs} return Pipeline(operations=operations, **kwargs) def pipeline_from_json(json_string: str, **kwargs) -> Pipeline: """ Create a Pipeline instance from a JSON string. """ pipeline_config = Config(json.loads(json_string, cls=ZEADecoderJSON)) return pipeline_from_config(pipeline_config, **kwargs) @deprecated(replacement="Pipeline.from_path") def pipeline_from_yaml(yaml_path: str, **kwargs) -> Pipeline: # pragma: no cover """ Create a Pipeline instance from a YAML file. .. deprecated:: Use :meth:`Pipeline.from_path` instead. """ with open(yaml_path, "r", encoding="utf-8") as f: pipeline_config = yaml.safe_load(f) return pipeline_from_config(Config(pipeline_config), **kwargs) def _pipeline_to_serializable_dict(pipeline: Pipeline, compact=True) -> dict: """Convert a Pipeline to a dict suitable for serialization. The output format is ``{"pipeline": {"operations": [...], ...pipeline_kwargs}}`` which can be loaded back via ``pipeline_from_config``. """ pipeline_dict = { "operations": Pipeline._pipeline_to_list(pipeline, compact=compact), } if compact: if not pipeline.with_batch_dim: pipeline_dict["with_batch_dim"] = pipeline.with_batch_dim if pipeline.jit_options != "ops": pipeline_dict["jit_options"] = pipeline.jit_options if pipeline._user_jit_kwargs: pipeline_dict["jit_kwargs"] = pipeline._user_jit_kwargs if pipeline.name != "pipeline": pipeline_dict["name"] = pipeline.name else: pipeline_dict.update( { "with_batch_dim": pipeline.with_batch_dim, "jit_options": pipeline.jit_options, "jit_kwargs": pipeline._user_jit_kwargs, "name": pipeline.name, } ) return {"pipeline": pipeline_dict} def pipeline_to_config(pipeline: Pipeline, compact=True) -> Config: """ Convert a Pipeline instance into a Config object. """ return Config(_pipeline_to_serializable_dict(pipeline, compact=compact)) def pipeline_to_json(pipeline: Pipeline, compact=True) -> str: """ Convert a Pipeline instance into a JSON string. """ return json.dumps( _pipeline_to_serializable_dict(pipeline, compact=compact), cls=ZEAEncoderJSON, indent=4, ) def pipeline_to_yaml(pipeline: Pipeline, file_path: str, compact=True) -> None: """ Convert a Pipeline instance into a YAML file. """ with open(file_path, "w", encoding="utf-8") as f: yaml.safe_dump( _pipeline_to_serializable_dict(pipeline, compact=compact), f, indent=4, )