import json
from typing import TYPE_CHECKING, Any, Dict, List, Union
import keras
import numpy as np
import yaml
from keras import ops
from zea import backend, log
from zea.backend import func_on_device, jit
from zea.config import Config
from zea.func.tensor import vmap
from zea.func.ultrasound import channels_to_complex, complex_to_channels
from zea.internal.core import DataTypes, ZEADecoderJSON, ZEAEncoderJSON, dict_to_tensor
from zea.internal.core import Object as ZEAObject
from zea.internal.registry import beamformer_registry, ops_registry
from zea.internal.utils import deprecated
from zea.ops.base import Operation, get_ops
from zea.ops.keras_ops import Cast
from zea.ops.tensor import Normalize
from zea.ops.ultrasound import (
ApplyWindow,
Demodulate,
EnvelopeDetect,
LogCompress,
PfieldWeighting,
ReshapeGrid,
TOFCorrection,
)
from zea.utils import FunctionTimer
if TYPE_CHECKING:
# Imported lazily at runtime (inside prepare_parameters) to avoid a circular
# import: zea.scan imports the data specs, which can pull in this module.
from zea.scan import Parameters
[docs]
@ops_registry("pipeline")
class Pipeline:
"""Pipeline class for processing ultrasound data through a series of operations."""
def __init__(
self,
operations: List[Operation],
with_batch_dim: bool = True,
jit_options: Union[str, None] = "ops",
jit_kwargs: dict | None = None,
name="pipeline",
validate=True,
timed: bool = False,
device: Union[str, None] = None,
):
"""
Initialize a pipeline.
Args:
operations (list): A list of Operation instances representing the operations
to be performed.
with_batch_dim (bool, optional): Whether operations should expect a batch dimension.
Defaults to True.
jit_options (str, optional): The JIT options to use. Must be "pipeline", "ops", or None.
- "pipeline": compiles the entire pipeline as a single function.
This may be faster but does not preserve python control flow, such as caching.
- "ops": compiles each operation separately. This preserves python control flow and
caching functionality, but speeds up the operations.
- None: disables JIT compilation.
Defaults to "ops".
jit_kwargs (dict, optional): Additional keyword arguments for the JIT compiler.
name (str, optional): The name of the pipeline. Defaults to "pipeline".
validate (bool, optional): Whether to validate the pipeline. Defaults to True.
timed (bool, optional): Whether to time each operation. Defaults to False.
device (str, optional): Default device for all pipeline calls, e.g.
``'cpu'``, ``'gpu:0'``, ``'cuda:1'``. Can be overridden per-call
by passing ``device=`` to ``__call__``. Uses
:func:`zea.backend.func_on_device` under the hood, which moves
input tensors to the device for the ``torch`` backend and wraps
the call in a device context for JAX / TensorFlow. Defaults to
``None`` (no device placement).
"""
self._call_pipeline = self.call
self.name = name
self._pipeline_layers = operations
if jit_options not in ["pipeline", "ops", None]:
raise ValueError("jit_options must be 'pipeline', 'ops', or None")
self.with_batch_dim = with_batch_dim
self._validate_flag = validate
# Setup timer
if jit_options == "pipeline" and timed:
raise ValueError(
"timed=True cannot be used with jit_options='pipeline' as the entire "
"pipeline is compiled into a single function. Try setting jit_options to "
"'ops' or None."
)
if timed:
log.warning(
"Timer has been initialized for the pipeline. To get an accurate timing estimate, "
"the `block_until_ready()` is used, which will slow down the execution, so "
"do not use for regular processing!"
)
self._callable_layers = self._get_timed_operations()
else:
self._callable_layers = self._pipeline_layers
self._timed = timed
if validate:
self.validate()
else:
log.warning("Pipeline validation is disabled, make sure to validate manually.")
if jit_kwargs is None:
jit_kwargs = {}
self._user_jit_kwargs = jit_kwargs.copy()
if keras.backend.backend() == "jax" and self.static_params != []:
existing = jit_kwargs.get("static_argnames", [])
if isinstance(existing, str):
existing = [existing]
jit_kwargs = {
**jit_kwargs,
"static_argnames": list(set(existing) | set(self.static_params)),
}
self.jit_kwargs = jit_kwargs
self.jit_options = jit_options # will handle the jit compilation
self.device = device
self._logged_difference_keys = False
# Do not log again for nested pipelines
for nested_pipeline in self._nested_pipelines:
nested_pipeline._logged_difference_keys = True
[docs]
def needs(self, key) -> bool:
"""Check if the pipeline needs a specific key at the input."""
return key in self.needs_keys
@property
def _nested_pipelines(self):
return [operation for operation in self.operations if isinstance(operation, Pipeline)]
@property
def output_keys(self) -> set:
"""All output keys the pipeline guarantees to produce."""
output_keys = set()
for operation in self.operations:
output_keys.update(operation.output_keys)
return output_keys
@property
def valid_keys(self) -> set:
"""Get a set of valid keys for the pipeline.
This is all keys that can be passed to the pipeline as input.
"""
valid_keys = set()
for operation in self.operations:
valid_keys.update(operation.valid_keys)
return valid_keys
@property
def static_params(self) -> List[str]:
"""Get a list of static parameters for the pipeline."""
static_params = []
for operation in self.operations:
static_params.extend(operation.static_params)
return list(set(static_params))
@property
def needs_keys(self) -> set:
"""Get a set of all input keys needed by the pipeline.
Will keep track of keys that are already provided by previous operations.
"""
needs = set()
has_so_far = set()
previous_operation = None
for operation in self.operations:
if previous_operation is not None:
has_so_far.update(previous_operation.output_keys)
needs.update(operation.needs_keys - has_so_far)
previous_operation = operation
return needs
[docs]
@classmethod
def from_default(
cls,
beamformer="delay_and_sum",
num_patches=100,
baseband=False,
enable_pfield=False,
timed=False,
**kwargs,
) -> "Pipeline":
"""Create a default pipeline.
Args:
beamformer (str): Type of beamformer to use.
Currently supporting:
- "delay_and_sum"
- "delay_multiply_and_sum"
- "coherence_factor"
- "generalized_coherence_factor"
Defaults to "delay_and_sum".
num_patches (int): Number of patches for the PatchedGrid operation.
Defaults to 100. If you get an out of memory error, try to increase this number.
baseband (bool): If True, assume the input data is baseband (I/Q) data,
which has 2 channels (last dim). Defaults to False, which assumes RF data,
so input signal has a single channel dim and is still on carrier frequency.
enable_pfield (bool): If True, apply PfieldWeighting. Defaults to False.
This will calculate pressure field and only beamform the data to those locations.
timed (bool, optional): Whether to time each operation. Defaults to False.
**kwargs: Additional keyword arguments to be passed to the Pipeline constructor.
"""
operations = [Cast(dtype="float32")]
# Add the demodulate operation
if not baseband:
operations += [
ApplyWindow(),
Demodulate(),
]
# Add beamforming ops
operations.append(
Beamform(
beamformer=beamformer,
num_patches=num_patches,
enable_pfield=enable_pfield,
),
)
# Add display ops
operations += [
EnvelopeDetect(),
Normalize(),
LogCompress(),
]
return cls(operations, timed=timed, **kwargs)
[docs]
def copy(self) -> "Pipeline":
"""Create a copy of the pipeline."""
return Pipeline(
self._pipeline_layers.copy(),
with_batch_dim=self.with_batch_dim,
jit_options=self.jit_options,
jit_kwargs=self.jit_kwargs,
name=self.name,
validate=self._validate_flag,
timed=self._timed,
device=self.device,
)
[docs]
def reinitialize(self):
"""Reinitialize the pipeline in place."""
self.__init__(
self._pipeline_layers,
with_batch_dim=self.with_batch_dim,
jit_options=self.jit_options,
jit_kwargs=self.jit_kwargs,
name=self.name,
validate=self._validate_flag,
timed=self._timed,
device=self.device,
)
[docs]
def prepend(self, operation: Operation):
"""Prepend an operation to the pipeline."""
self._pipeline_layers.insert(0, operation)
self.reinitialize()
[docs]
def append(self, operation: Operation):
"""Append an operation to the pipeline."""
self._pipeline_layers.append(operation)
self.reinitialize()
[docs]
def insert(self, index: int, operation: Operation):
"""Insert an operation at a specific index in the pipeline."""
if index < 0 or index > len(self._pipeline_layers):
raise IndexError("Index out of bounds for inserting operation.")
self._pipeline_layers.insert(index, operation)
self.reinitialize()
@property
def operations(self) -> List[Union[Operation, "Pipeline"]]:
"""Alias for self.layers to match the zea naming convention"""
return self._pipeline_layers
[docs]
def reset_timer(self):
"""Reset the timer for timed operations."""
if self._timed:
self._callable_layers = self._get_timed_operations()
else:
log.warning(
"Timer has not been initialized. Set timed=True when initializing the pipeline."
)
def _get_timed_operations(self):
"""Get a list of timed operations."""
self.timer = FunctionTimer()
return [self.timer(op, name=op.__class__.__name__) for op in self._pipeline_layers]
[docs]
def call(self, **inputs) -> Dict[str, Any]:
"""Process input data through the pipeline."""
for operation in self._callable_layers:
try:
outputs = operation(**inputs)
except KeyError as exc:
raise KeyError(
f"[zea.Pipeline] Operation '{operation.__class__.__name__}' "
f"requires input key '{exc.args[0]}', "
"but it was not provided in the inputs.\n"
"Check whether the objects (such as `zea.Parameters`) passed to "
"`pipeline.prepare_parameters()` contain all required keys.\n"
f"Current list of all passed keys: {list(inputs.keys())}\n"
f"Valid keys for this pipeline: {self.valid_keys}"
) from exc
except Exception as exc:
raise RuntimeError(
f"[zea.Pipeline] Error in operation '{operation.__class__.__name__}': {exc}"
)
inputs = outputs
return outputs
[docs]
def __call__(
self, return_numpy=False, device: Union[str, None] = None, **inputs
) -> Dict[str, Any]:
"""Process input data through the pipeline.
Args:
return_numpy (bool): If ``True``, convert output tensors to NumPy
arrays before returning.
device (str, optional): Device to run this call on, e.g.
``'cpu'``, ``'gpu:0'``, or ``'cuda:1'``. Overrides the
pipeline-level ``device`` set at construction time for this
single invocation. When ``None`` (default), the pipeline-level
``device`` attribute is used (which is also ``None`` by
default, meaning no explicit device placement).
**inputs: Tensor inputs forwarded to the operations.
"""
if any(key in inputs for key in ["probe", "scan", "config", "parameters"]) or any(
isinstance(arg, ZEAObject) for arg in inputs.values()
):
raise ValueError(
"Parameters (and Probe/Config) objects should be first processed with "
"`Pipeline.prepare_parameters` before calling the pipeline. "
"e.g. inputs = pipeline.prepare_parameters(parameters, **overrides)"
)
if any(isinstance(arg, str) for arg in inputs.values()):
raise ValueError(
"Pipeline does not support string inputs. "
"Please ensure all inputs are convertible to tensors."
)
if not self._logged_difference_keys:
difference_keys = set(inputs.keys()) - self.valid_keys
if difference_keys:
log.debug(
f"[zea.Pipeline] The following input keys are not used by the pipeline: "
f"{difference_keys}. Make sure this is intended. "
"This warning will only be shown once."
)
self._logged_difference_keys = True
## PROCESSING
_device = device if device is not None else self.device
if _device is not None:
outputs = func_on_device(self._call_pipeline, _device, **inputs)
else:
outputs = self._call_pipeline(**inputs)
## PREPARE OUTPUT
if return_numpy:
# Convert tensors to numpy arrays but preserve None values
outputs = {
k: ops.convert_to_numpy(v) if ops.is_tensor(v) else v for k, v in outputs.items()
}
return outputs
@property
def jit_options(self):
"""Get the jit_options property of the pipeline."""
return self._jit_options
[docs]
def set_jit(self, value: bool):
"""Set the JIT compilation for the pipeline."""
if value:
self._jit()
else:
self._unjit()
@jit_options.setter
def jit_options(self, value: Union[str, None]):
"""Set the jit_options property of the pipeline."""
assert value in ["pipeline", "ops", None], "jit_options must be 'pipeline', 'ops', or None"
self._jit_options = value
self.set_jit(value == "pipeline")
for operation in self.operations:
if isinstance(operation, Pipeline):
operation.jit_options = value
else:
operation.set_jit(value == "ops")
def _jit(self):
"""JIT compile the pipeline."""
if not self.jittable:
raise ValueError(
"Cannot JIT compile the pipeline because not all operations are jittable. "
f"The following operations are not jittable: {self.unjitable_ops}"
"Try setting jit_options to 'ops' or None."
)
self._call_pipeline = jit(self.call, **self.jit_kwargs)
def _unjit(self):
"""Un-JIT compile the pipeline."""
self._call_pipeline = self.call
@property
def jittable(self):
"""Check if all operations in the pipeline are jittable."""
return all(operation.jittable for operation in self.operations)
@property
def unjitable_ops(self):
"""Get a list of operations that are not jittable."""
return [operation for operation in self.operations if not operation.jittable]
@property
def with_batch_dim(self):
"""Get the with_batch_dim property of the pipeline."""
return self._with_batch_dim
@with_batch_dim.setter
def with_batch_dim(self, value):
"""Set the with_batch_dim property of the pipeline."""
self._with_batch_dim = value
for operation in self.operations:
operation.with_batch_dim = value
@property
def input_data_type(self):
"""Get the input_data_type property of the pipeline."""
return self.operations[0].input_data_type
@property
def output_data_type(self):
"""Get the output_data_type property of the pipeline."""
return self.operations[-1].output_data_type
[docs]
def validate(self):
"""Validate the pipeline by checking the compatibility of the operations."""
operations = self.operations
for i in range(len(operations) - 1):
if operations[i].output_data_type is None:
continue
if operations[i + 1].input_data_type is None:
continue
if operations[i].output_data_type != operations[i + 1].input_data_type:
raise ValueError(
f"Operation {operations[i].__class__.__name__} output data type "
f"({operations[i].output_data_type}) is not compatible "
f"with the input data type ({operations[i + 1].input_data_type}) "
f"of operation {operations[i + 1].__class__.__name__}"
)
[docs]
def set_params(self, **params):
"""Set parameters for the operations in the pipeline by adding them to the cache."""
for operation in self.operations:
operation_params = {
key: value for key, value in params.items() if key in operation.valid_keys
}
if operation_params:
operation.set_input_cache(operation_params)
[docs]
def get_params(self, per_operation: bool = False):
"""Get a snapshot of the current parameters of the operations in the pipeline.
Args:
per_operation (bool): If True, return a list of dictionaries for each operation.
If False, return a single dictionary with all parameters combined.
"""
if per_operation:
return [operation._input_cache.copy() for operation in self.operations]
else:
params = {}
for operation in self.operations:
params.update(operation._input_cache)
return params
def __str__(self):
"""String representation of the pipeline."""
operations = []
for operation in self.operations:
if isinstance(operation, Pipeline):
operations.append(f"{operation.__class__.__name__}({str(operation)})")
else:
operations.append(operation.__class__.__name__)
string = " -> ".join(operations)
return string
def __repr__(self):
"""String representation of the pipeline."""
operations = []
for operation in self.operations:
if isinstance(operation, Pipeline):
operations.append(repr(operation))
else:
operations.append(operation.__class__.__name__)
return f"Pipeline(name={self.name!r}, operations=[{', '.join(operations)}])"
[docs]
@classmethod
def load(cls, file_path: str, **kwargs) -> "Pipeline":
"""Load a pipeline from a JSON or YAML file."""
if file_path.endswith(".json"):
with open(file_path, "r", encoding="utf-8") as f:
json_str = f.read()
return pipeline_from_json(json_str, **kwargs)
elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
return cls.from_path(file_path, **kwargs)
else:
raise ValueError("File must have extension .json, .yaml, or .yml")
[docs]
def get_dict(self, compact=True) -> dict:
"""Convert the pipeline to a dictionary.
Args:
compact (bool): If True (default), only include
parameters that differ from their defaults.
If False, include all parameters for full reproducibility.
"""
config = {"name": ops_registry.get_name(self)}
config["operations"] = self._pipeline_to_list(self, compact=compact)
if compact:
params = {}
if not self.with_batch_dim:
params["with_batch_dim"] = self.with_batch_dim
if self.jit_options != "ops":
params["jit_options"] = self.jit_options
if self._user_jit_kwargs:
params["jit_kwargs"] = self._user_jit_kwargs
if self.device is not None:
params["device"] = self.device
if params:
config["params"] = params
else:
config["params"] = {
"with_batch_dim": self.with_batch_dim,
"jit_options": self.jit_options,
"jit_kwargs": self._user_jit_kwargs,
"device": self.device,
}
return config
@staticmethod
def _pipeline_to_list(pipeline: "Pipeline", compact=True):
"""Convert the pipeline to a list of operations."""
ops_list = []
for op in pipeline.operations:
if isinstance(op, Pipeline):
ops_list.append(op.get_dict(compact=compact))
else:
d = op.get_dict(compact=compact)
if compact:
params = d.get("params", {})
# Strip with_batch_dim when it is merely inherited from the pipeline
if params.get("with_batch_dim") == pipeline.with_batch_dim:
params.pop("with_batch_dim", None)
# Strip jit_compile=False when it is implied by pipeline-level JIT
if not params.get("jit_compile", True) and pipeline.jit_options in (
None,
"pipeline",
):
params.pop("jit_compile", None)
if not params:
d.pop("params", None)
# Name-only dict → bare string shorthand
if list(d.keys()) == ["name"]:
d = d["name"]
ops_list.append(d)
return ops_list
[docs]
@classmethod
def from_config(cls, config: Dict, **kwargs) -> "Pipeline":
"""Create a pipeline from a dictionary or ``zea.Config`` object.
Args:
config (dict or Config): Configuration dictionary or ``zea.Config`` object.
Must have a ``pipeline`` key with a subkey ``operations``.
**kwargs: Additional keyword arguments to be passed to the pipeline.
Example:
.. doctest::
>>> from zea import Config, Pipeline
>>> config = Config(
... {
... "pipeline": {
... "operations": [
... "identity",
... ],
... }
... }
... )
>>> pipeline = Pipeline.from_config(config)
"""
return pipeline_from_config(Config(config), **kwargs)
[docs]
@classmethod
def from_path(cls, file_path: str, revision: str = None, **kwargs) -> "Pipeline":
"""Create a pipeline from a YAML/config file path.
Args:
file_path (str): Path to the config file (local or ``hf://`` URI).
Must have a ``pipeline`` key with a subkey ``operations``.
revision (str, optional): Revision of the config file (for Hugging Face ``hf://`` URIs).
**kwargs: Additional keyword arguments to be passed to the pipeline.
Example:
.. doctest::
>>> from zea import Config, Pipeline
>>> config = Config(
... {
... "pipeline": {
... "operations": [
... "identity",
... ],
... }
... }
... )
>>> config.to_yaml("pipeline.yaml")
>>> pipeline = Pipeline.from_path("pipeline.yaml")
"""
config = Config.from_path(file_path, revision=revision)
return pipeline_from_config(config, **kwargs)
[docs]
@classmethod
@deprecated(replacement="Pipeline.from_path")
def from_yaml(cls, file_path: str, **kwargs) -> "Pipeline":
"""Deprecated. Use :meth:`from_path` instead."""
return pipeline_from_yaml(file_path, **kwargs)
[docs]
@classmethod
def from_json(cls, json_string: str, **kwargs) -> "Pipeline":
"""Create a pipeline from a JSON string.
Args:
json_string (str): JSON string representing the pipeline.
Must have a ``pipeline`` key with a subkey ``operations``.
**kwargs: Additional keyword arguments to be passed to the pipeline.
Example:
```python
json_string = '{"pipeline": {"operations": ["identity"]}}'
pipeline = Pipeline.from_json(json_string)
```
"""
return pipeline_from_json(json_string, **kwargs)
[docs]
def to_config(self, compact=True) -> Config:
"""Convert the pipeline to a `zea.Config` object."""
return pipeline_to_config(self, compact=compact)
[docs]
def to_json(self, compact=True) -> str:
"""Convert the pipeline to a JSON string."""
return pipeline_to_json(self, compact=compact)
[docs]
def to_yaml(self, file_path: str, compact=True) -> None:
"""Convert the pipeline to a YAML file."""
pipeline_to_yaml(self, file_path, compact=compact)
@property
def key(self) -> str:
"""Input key of the pipeline."""
return self.operations[0].key
@property
def output_key(self) -> str:
"""Output key of the pipeline."""
return self.operations[-1].output_key
def __eq__(self, other):
"""Check if two pipelines are equal."""
if not isinstance(other, Pipeline):
return False
# Compare the operations in both pipelines
if len(self.operations) != len(other.operations):
return False
for op1, op2 in zip(self.operations, other.operations):
if not op1 == op2:
return False
return True
[docs]
def prepare_parameters(
self,
parameters: "Parameters" = None,
device: Union[str, None] = None,
**overrides,
):
"""Prepare a :class:`~zea.Parameters` object for the pipeline.
Converts the (validated and derived) parameters needed by this
pipeline's operations into a dictionary of tensors, then overlays any
manually supplied overrides (e.g. ``config.parameters`` or ad-hoc
keyword arguments). Overrides take priority over the values in
``parameters``.
Args:
parameters: :class:`~zea.Parameters` object. Only the keys
this pipeline ``needs`` (and that are not overridden) are
converted, so derivation is lazy and minimal.
device: Device to place the tensors on. Defaults to the pipeline
device.
**overrides: Additional parameters to include in the inputs
(converted to tensors). These overwrite values taken from
``parameters``.
Returns:
dict: Dictionary of inputs with all values as tensors.
Example:
.. code-block:: python
inputs = pipeline.prepare_parameters(parameters, **config.parameters)
outputs = pipeline(data=raw_data, **inputs)
"""
from zea.scan import Parameters
_device = device if device is not None else self.device
params_dict = {}
override_keys = set(overrides.keys())
if parameters is not None:
assert isinstance(parameters, Parameters), (
f"Expected an instance of `zea.Parameters`, got {type(parameters)}"
)
# Only convert keys the pipeline needs and that are not overridden,
# so we avoid deriving unnecessary parameters.
needs_keys = self.needs_keys - override_keys
with backend.device(_device):
params_dict = parameters.to_tensor(
include=needs_keys, keep_as_is=self.static_params
)
# Convert all overrides to tensors
with backend.device(_device):
tensor_overrides = dict_to_tensor(overrides, keep_as_is=self.static_params)
# Overrides overwrite values taken from the parameters object.
return {**params_dict, **tensor_overrides}
[docs]
@ops_registry("map")
class Map(Pipeline):
"""
A pipeline that maps its operations over specified input arguments.
This can be used to reduce memory usage by processing data in chunks.
Notes
-----
- When `chunks` and `batch_size` are both None (default), this behaves like a normal Pipeline.
- Changing anything other than ``self.output_key`` in the dict will not be propagated.
- Will be jitted as a single operation, not the individual operations.
- This class handles the batching.
For more information on how to use ``in_axes``, ``out_axes``, `see the documentation for
jax.vmap <https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html>`_.
Example
-------
.. doctest::
>>> from zea.ops import Map, Pipeline, Demodulate, TOFCorrection
>>> # apply operations in batches of 8
>>> # in this case, over the first axis of "data"
>>> # or more specifically, process 8 transmits at a time
>>> pipeline_mapped = Map(
... [
... Demodulate(),
... TOFCorrection(),
... ],
... argnames="data",
... batch_size=8,
... )
>>> # you can also map a subset of the operations
>>> # for example, demodulate in 4 chunks
>>> # or more specifically, split the transmit axis into 4 parts
>>> pipeline_mapped = Pipeline(
... [
... Map([Demodulate()], argnames="data", chunks=4),
... TOFCorrection(),
... ],
... )
"""
def __init__(
self,
operations: List[Operation],
argnames: List[str] | str,
in_axes: List[Union[int, None]] | int = 0,
out_axes: List[Union[int, None]] | int = 0,
chunks: int | None = None,
batch_size: int | None = None,
**kwargs,
):
"""
Args:
operations (list): List of operations to be performed.
argnames (str or list): List of argument names (or keys) to map over.
Can also be a single string if only one argument is mapped over.
in_axes (int or list): Axes to map over for each argument.
If a single int is provided, it is used for all arguments.
out_axes (int or list): Axes to map over for each output.
If a single int is provided, it is used for all outputs.
chunks (int, optional): Number of chunks to split the input data into.
If None, no chunking is performed. Mutually exclusive with ``batch_size``.
batch_size (int, optional): Size of batches to process at once.
If None, no batching is performed. Mutually exclusive with ``chunks``.
"""
super().__init__(operations, **kwargs)
if batch_size is not None and chunks is not None:
raise ValueError(
"batch_size and chunks are mutually exclusive. Please specify only one."
)
if batch_size is not None and batch_size <= 0:
raise ValueError("batch_size must be a positive integer.")
if chunks is not None and chunks <= 0:
raise ValueError("chunks must be a positive integer.")
if isinstance(argnames, str):
argnames = [argnames]
self.argnames = argnames
self.in_axes = in_axes
self.out_axes = out_axes
self.chunks = chunks
self.batch_size = batch_size
if chunks is None and batch_size is None:
log.warning(
"[zea.ops.Map] Both `chunks` and `batch_size` are None. "
"This will behave like a normal Pipeline. "
"Consider setting one of them to process data in chunks or batches."
)
[docs]
def call_item(self, **inputs):
"""Process data in patches."""
mapped_args = []
for argname in self.argnames:
mapped_args.append(inputs.pop(argname, None))
def patched_call(*args):
mapped_kwargs = [(k, v) for k, v in zip(self.argnames, args)]
out = super(Map, self).call(**dict(mapped_kwargs), **inputs)
# TODO: maybe it is possible to output everything?
# e.g. prepend a empty dimension to all inputs and just map over everything?
return out[self.output_key]
out = vmap(
patched_call,
in_axes=self.in_axes,
out_axes=self.out_axes,
chunks=self.chunks,
batch_size=self.batch_size,
fn_supports_batch=True,
disable_jit=not bool(self.jit_options),
)(*mapped_args)
return out
@property
def jit_options(self):
"""Get the jit_options property of the pipeline."""
return self._jit_options
@jit_options.setter
def jit_options(self, value: Union[str, None]):
"""Set the jit_options property of the pipeline."""
assert value in ["pipeline", "ops", None], "jit_options must be 'pipeline', 'ops', or None"
self._jit_options = value
self.set_jit(value == "pipeline" or value == "ops")
for operation in self.operations:
if isinstance(operation, Pipeline):
operation.jit_options = None
else:
operation.set_jit(False)
def _jit(self):
"""JIT compile the pipeline."""
self._jittable_call = jit(self.jittable_call, **self.jit_kwargs)
def _unjit(self):
"""Un-JIT compile the pipeline."""
self._jittable_call = self.jittable_call
@property
def with_batch_dim(self):
"""Get the with_batch_dim property of the pipeline."""
return self._with_batch_dim
@with_batch_dim.setter
def with_batch_dim(self, value):
"""Set the with_batch_dim property of the pipeline.
The class handles the batching so the operations have to be set to False."""
self._with_batch_dim = value
for operation in self.operations:
operation.with_batch_dim = False
[docs]
def jittable_call(self, **inputs):
"""Process input data through the pipeline."""
if self._with_batch_dim:
input_data = inputs.pop(self.key)
output = ops.map(
lambda x: self.call_item(**{self.key: x, **inputs}),
input_data,
)
else:
output = self.call_item(**inputs)
return {self.output_key: output}
[docs]
def call(self, **inputs):
"""Process input data through the pipeline."""
output = self._jittable_call(**inputs)
inputs.update(output)
return inputs
[docs]
def get_dict(self, compact=True):
"""Get the configuration of the pipeline."""
config = super().get_dict(compact=compact)
config["name"] = "map"
params = config.get("params", {})
params["argnames"] = self.argnames
if not compact or self.in_axes != 0:
params["in_axes"] = self.in_axes
if not compact or self.out_axes != 0:
params["out_axes"] = self.out_axes
if not compact or self.chunks is not None:
params["chunks"] = self.chunks
if not compact or self.batch_size is not None:
params["batch_size"] = self.batch_size
config["params"] = params
return config
[docs]
@ops_registry("patched_grid")
class PatchedGrid(Map):
"""
A pipeline that maps its operations over `flatgrid` and `flat_pfield` keys.
This can be used to reduce memory usage by processing data in chunks.
For more information and flexibility, see :class:`zea.ops.Map`.
"""
def __init__(self, *args, num_patches=10, **kwargs):
super().__init__(*args, argnames=["flatgrid", "flat_pfield"], chunks=num_patches, **kwargs)
self.num_patches = num_patches
[docs]
def get_dict(self, compact=True):
"""Get the configuration of the pipeline."""
config = super().get_dict(compact=compact)
config["name"] = "patched_grid"
params = config.get("params", {})
params.pop("argnames", None)
params.pop("chunks", None)
params["num_patches"] = self.num_patches
config["params"] = params
return config
[docs]
@beamformer_registry("delay_and_sum")
@ops_registry("delay_and_sum")
class DelayAndSum(Operation):
"""Sums time-delayed signals along channels and transmits."""
def __init__(self, **kwargs):
super().__init__(
input_data_type=DataTypes.ALIGNED_DATA,
output_data_type=DataTypes.BEAMFORMED_DATA,
**kwargs,
)
[docs]
def call(self, **kwargs):
"""Performs DAS beamforming on tof-corrected input.
Args:
tof_corrected_data (ops.Tensor): The TOF corrected input of shape
`(n_tx, prod(grid.shape), n_el, n_ch)` with optional batch dimension.
Returns:
dict: Dictionary containing beamformed_data
of shape `(prod(grid.shape), n_ch)`
with optional batch dimension.
"""
data = kwargs[self.key]
# Sum over the channels (n_el), i.e. DAS
beamformed_data = ops.sum(data, -2)
# Sum over transmits (n_tx), i.e. Compounding
beamformed_data = ops.sum(beamformed_data, -3)
return {self.output_key: beamformed_data}
[docs]
@beamformer_registry("delay_multiply_and_sum")
@ops_registry("delay_multiply_and_sum")
class DelayMultiplyAndSum(Operation):
"""Performs the operations for the Delay-Multiply-and-Sum beamformer except the delay.
The delay should be performed by the TOF correction operation.
"""
def __init__(self, **kwargs):
super().__init__(
input_data_type=DataTypes.ALIGNED_DATA,
output_data_type=DataTypes.BEAMFORMED_DATA,
**kwargs,
)
[docs]
def process_image(self, data):
"""Performs DMAS beamforming on tof-corrected input.
Args:
data (ops.Tensor): The TOF corrected input of shape `(n_tx, n_pix, n_el, n_ch)`
Returns:
ops.Tensor: The beamformed data of shape `(n_pix, n_ch)`
"""
if not data.shape[-1] == 2:
raise ValueError(
"MultiplyAndSum operation requires IQ data with 2 channels. "
f"Got data with shape {data.shape}."
)
# Compute the correlation matrix
data = channels_to_complex(data)
data = self._multiply(data)
data = self._select_lower_triangle(data)
data = ops.sum(data, axis=(0, 2, 3))
data = complex_to_channels(data)
return data
def _select_lower_triangle(self, data):
"""Select only the lower triangle of the correlation matrix."""
n_el = data.shape[3]
mask = ops.ones((n_el, n_el), dtype=data.dtype) - ops.eye(n_el, dtype=data.dtype)
data = data * mask[None, None, :, :] / 2
return data
def _multiply(self, data):
"""Apply the DMAS multiplication step."""
channel_products = data[:, :, :, None] * data[:, :, None, :]
# Signed square root: sign(z) * sqrt(|z|) == z / sqrt(|z|).
# Written this way to avoid ops.sign on complex data (torch.sign
# does not support complex numbers; use torch.sgn instead).
eps = keras.backend.epsilon()
safe_sqrt = ops.cast(ops.sqrt(ops.abs(channel_products) + eps**2), channel_products.dtype)
data = channel_products / safe_sqrt
return data
[docs]
def call(self, **kwargs):
"""Performs DMAS beamforming on tof-corrected input.
Args:
tof_corrected_data (ops.Tensor): The TOF corrected input of shape
`(n_tx, prod(grid.shape), n_el, n_ch)` with optional batch dimension.
Returns:
dict: Dictionary containing beamformed_data
of shape `(grid_size_z*grid_size_x, n_ch)`
with optional batch dimension.
"""
data = kwargs[self.key]
if not self.with_batch_dim:
beamformed_data = self.process_image(data)
else:
# Apply process_image to each item in the batch
beamformed_data = ops.map(self.process_image, data)
return {self.output_key: beamformed_data}
[docs]
@beamformer_registry("coherence_factor")
@ops_registry("coherence_factor")
class CoherenceFactor(Operation):
r"""Coherence Factor (CF) Beamformer.
The Coherence Factor is a pixel-dependent weight used to quantify the focus
quality of the beamformed signal. It is the ratio of the coherent power to
the incoherent power of the signals received across the transducer aperture.
For a set of delayed signals :math:`x_i` across :math:`N` elements:
.. math::
\mathrm{CF} = \frac{\left|\sum_{i=1}^{N} x_i\right|^2}
{N \sum_{i=1}^{N} \left|x_i\right|^2}
The CF ranges from 0 (completely incoherent) to 1 (perfectly coherent).
The beamformed output is the standard DAS sum weighted by CF per transmit,
then compounded across transmits.
.. admonition:: Reference
Hollman, K. W., Rigby, K. W., & O'Donnell, M. (1999).
Coherence factor of speckle from a multi-row probe. IEEE Ultrasonics Symposium.
Args:
**kwargs: Additional arguments passed to the Operation base class.
"""
def __init__(self, **kwargs):
super().__init__(
input_data_type=DataTypes.ALIGNED_DATA,
output_data_type=DataTypes.BEAMFORMED_DATA,
**kwargs,
)
[docs]
def process_image(self, data):
"""Applies CF weighting and compounding on tof-corrected input.
Args:
data (ops.Tensor): TOF-corrected input of shape
``(n_tx, n_pix, n_el, n_ch)``, with optional batch dimension.
Returns:
ops.Tensor: Beamformed image of shape ``(n_pix, n_ch)``,
with optional batch dimension.
"""
n_el = ops.shape(data)[-2]
# DAS per transmit: sum over elements
das_per_tx = ops.sum(data, axis=-2)
# Coherent power: |sum_i(x_i)|^2, works for both RF (n_ch=1) and IQ (n_ch=2)
coherent_power = ops.sum(ops.square(das_per_tx), axis=-1, keepdims=True)
# Incoherent power: N * sum_i(|x_i|^2)
incoherent_power = n_el * ops.sum(
ops.sum(ops.square(data), axis=-1), axis=-1, keepdims=True
)
# CF weight, clipped to [0, 1] by construction when incoherent_power > 0
cf_weight = coherent_power / (incoherent_power + keras.backend.epsilon())
# Apply weight per transmit, then compound
return ops.sum(das_per_tx * cf_weight, axis=-3)
[docs]
def call(self, **kwargs):
"""Performs CF beamforming on tof-corrected input.
Args:
tof_corrected_data (ops.Tensor): TOF-corrected input of shape
``(n_tx, n_pix, n_el, n_ch)``, with optional batch dimension.
Returns:
dict: Dictionary containing beamformed data of shape
``(n_pix, n_ch)``, with optional batch dimension.
"""
data = kwargs[self.key]
return {self.output_key: self.process_image(data)}
[docs]
@beamformer_registry("generalized_coherence_factor")
@ops_registry("generalized_coherence_factor")
class GeneralizedCoherenceFactor(Operation):
r"""Generalized Coherence Factor (GCF) Beamformer.
The GCF is a coherence-based adaptive weighting technique used to improve
the quality of ultrasound images by suppressing sidelobes and clutter.
It is defined as the ratio of the power within a low-frequency region of the
spatial spectrum to the total power across the aperture.
For a given pixel, let :math:`A(k)` be the spatial Fourier transform of the
delayed channel data across :math:`N` elements. The GCF is:
.. math::
\mathrm{GCF} = \frac{\sum_{k \in \mathcal{M}_0} \left|A(k)\right|^2}
{\sum_{k=0}^{N-1} \left|A(k)\right|^2}
where :math:`\mathcal{M}_0 = \{k : k \leq m_0\} \cup \{k : k \geq N - m_0\}`
is the low spatial-frequency region controlled by :math:`m_0`.
.. admonition:: Reference
Li, P. C., & Li, M. L. (2003).
"Adaptive imaging using the generalized coherence factor."
IEEE Transactions on Ultrasonics, Ferroelectrics, and Frequency Control,
50(2), 128-141.
Args:
m_zero (int): Cutoff frequency index for the low-frequency spatial region.
Higher values increase tolerance to phase aberrations. Defaults to ``4``.
**kwargs: Additional arguments passed to the Operation base class.
"""
def __init__(self, m_zero=4, **kwargs):
if not isinstance(m_zero, int) or m_zero < 0:
raise ValueError(f"m_zero must be a non-negative integer, got {m_zero!r}.")
super().__init__(
input_data_type=DataTypes.ALIGNED_DATA,
output_data_type=DataTypes.BEAMFORMED_DATA,
**kwargs,
)
self.m_zero = m_zero
[docs]
def process_image(self, data, m_zero=None):
"""Applies GCF weighting and compounding on tof-corrected input.
Args:
data (ops.Tensor): TOF-corrected input of shape
``(n_tx, n_pix, n_el, n_ch)``, with optional batch dimension.
m_zero (int, optional): Overrides the instance ``m_zero`` for this call.
Returns:
ops.Tensor: Beamformed image of shape ``(n_pix, n_ch)``,
with optional batch dimension.
"""
if m_zero is None:
m_zero = self.m_zero
n_el = ops.shape(data)[-2]
n_ch = data.shape[-1] # static Python int — safe for branching
# Move n_el to last axis for spatial FFT: (..., n_tx, n_pix, n_ch, n_el)
spatial_data = ops.moveaxis(data, -2, -1)
real_in = spatial_data[..., 0, :]
imag_in = spatial_data[..., 1, :] if n_ch == 2 else ops.zeros_like(real_in)
# Spatial FFT power spectrum across elements
real_fft, imag_fft = ops.fft((real_in, imag_in))
power_spectrum = ops.square(real_fft) + ops.square(imag_fft)
# Total energy and low-frequency energy
total_energy = ops.sum(power_spectrum, axis=-1, keepdims=True)
indices = ops.arange(n_el)
low_freq_mask = ops.logical_or(
ops.less_equal(indices, m_zero),
ops.greater_equal(indices, n_el - m_zero),
)
low_freq_energy = ops.sum(
ops.where(low_freq_mask, power_spectrum, 0.0),
axis=-1,
keepdims=True,
)
# GCF weight
gcf_weight = low_freq_energy / (total_energy + keras.backend.epsilon())
# DAS per transmit, apply weight, then compound
das_per_tx = ops.sum(data, axis=-2)
return ops.sum(das_per_tx * gcf_weight, axis=-3)
[docs]
def call(self, m_zero=None, **kwargs):
"""Performs GCF beamforming on tof-corrected input.
Args:
m_zero (int, optional): Cutoff frequency index, overrides the
instance default when provided via pipeline parameters.
tof_corrected_data (ops.Tensor): TOF-corrected input of shape
``(n_tx, n_pix, n_el, n_ch)``, with optional batch dimension.
Returns:
dict: Dictionary containing beamformed data of shape
``(n_pix, n_ch)``, with optional batch dimension.
"""
data = kwargs[self.key]
return {self.output_key: self.process_image(data, m_zero=m_zero)}
[docs]
@ops_registry("refocus")
class Refocus(Operation):
r"""REFoCUS (Retrospective Encoding For Conventional Ultrasound Sequences).
Decodes any transmit data into synthetic aperture
(multistatic / full-matrix capture) data by inverting the transmit
encoding model in the frequency domain.
The transmit encoding is modelled as a matrix :math:`H` whose entry
:math:`H_{t,e}` describes the complex phase shift applied to element
:math:`e` during transmit event :math:`t`:
.. math::
H_{t,e}(f) = a_{t,e} \exp(-j 2\pi f \tau_{t,e})
where :math:`\tau_{t,e}` is the transmit delay in samples and
:math:`a_{t,e}` is the apodization.
At each temporal frequency the received RF spectrum is decoded by
multiplying with the pseudo-inverse :math:`H^{-1}`:
.. math::
\hat{U}(f) = H^{-1}(f) \, S(f)
producing a synthetic aperture dataset where each decoded channel
corresponds to a virtual single-element transmission.
The **input** data has shape ``(n_tx, n_ax, n_el, n_ch)`` and the
**output** has shape ``(n_el, n_ax, n_el, n_ch)``, where the new first
axis indexes the decoded virtual transmit elements.
.. admonition:: References
Bottenus, N. (2018).
"Recovery of the complete data set from focused transmit beams."
*IEEE Transactions on Ultrasonics, Ferroelectrics, and Frequency
Control*, 65(1), 30–38.
Ali, R., Dahl, J., & Bottenus, N. (2019).
"Extending Retrospective Encoding for Robust Recovery of the Multistatic Dataset."
*IEEE Transactions on Ultrasonics, Ferroelectrics, and Frequency
Control*, 67(5), 943–956.
https://github.com/nbottenus/REFoCUS
Args:
method (str): Inversion method. One of:
- ``'adjoint'``: Adjoint (matched-filter) pseudo-inverse with
an optional ramp filter in frequency. Default.
- ``'tikhonov'``: Tikhonov-regularized inverse.
- ``'rsvd'``: Regularized SVD-based inverse.
- ``'tsvd'``: Truncated SVD-based inverse.
param (float or None): Regularization / filter parameter.
- ``'adjoint'``: ``None`` applies a ramp filter (multiply by
:math:`f`). Set to ``0`` to disable the ramp filter. Defaults to ``None``.
- ``'tikhonov'``, ``'rsvd'``, ``'tsvd'``: Relative regularization
strength. Defaults to ``1e-2`` when ``None``.
**kwargs: Additional arguments forwarded to
:class:`~zea.ops.Operation`.
"""
_VALID_METHODS = ("adjoint", "tikhonov", "rsvd", "tsvd")
def __init__(self, method="adjoint", param=None, **kwargs):
if method not in self._VALID_METHODS:
raise ValueError(f"method must be one of {self._VALID_METHODS}, got '{method}'")
# SVD is not supported by TF XLA; disable JIT for SVD-based methods.
if method != "adjoint":
if kwargs.get("jit_compile", True):
log.warning(
f"Refocus method='{method}' uses SVD, which is not supported by the XLA "
"JIT compiler. Setting jit_compile=False."
)
kwargs["jit_compile"] = False
super().__init__(
input_data_type=DataTypes.RAW_DATA,
output_data_type=DataTypes.RAW_DATA,
**kwargs,
)
self.method = method
self.param = param
def _get_hinv(self, delays, f_vec, apod):
"""Compute batched Hinv for all normalised frequencies at once.
Args:
delays: ``(n_tx, n_el)`` delays in samples.
f_vec: ``(n_freq,)`` normalised frequencies (cycles/sample).
apod: ``(n_tx, n_el)`` apodization.
Returns:
Hinv: ``(n_freq, n_el, n_tx)`` complex64 tensor.
"""
# H: (n_freq, n_tx, n_el)
f_c = ops.cast(f_vec[:, None, None], "complex64")
d_c = ops.cast(delays[None], "complex64")
a_c = ops.cast(apod[None], "complex64")
H = a_c * ops.exp(ops.cast(-1j * 2 * np.pi, "complex64") * f_c * d_c)
if self.method == "adjoint":
# param=None → ramp filter (multiply by f)
# param=0 → no ramp (multiply by 1, plain adjoint)
Hinv = ops.conj(ops.transpose(H, (0, 2, 1)))
ramp_vals = f_vec if self.param is None else ops.ones_like(f_vec)
ramp = ops.cast(ramp_vals, "complex64")[:, None, None]
return ramp * Hinv
# SVD-based methods
U, s, VH = ops.svd(H, full_matrices=False)
lam = self.param if self.param is not None else 1e-2
if self.method in ("tikhonov", "rsvd"):
sinv = s / (s**2 + (lam * s[:, 0:1]) ** 2)
else: # tsvd
threshold = lam * s[:, 0:1]
safe_s = ops.where(s >= threshold, s, ops.ones_like(s))
sinv = ops.where(s >= threshold, 1.0 / safe_s, ops.zeros_like(s))
VHT = ops.conj(ops.transpose(VH, (0, 2, 1))) # (n_freq, n_el, k)
UT = ops.conj(ops.transpose(U, (0, 2, 1))) # (n_freq, k, n_tx)
sinv_c = ops.cast(sinv, "complex64")
return ops.matmul(VHT * sinv_c[:, None, :], UT)
def _decode(self, data, delays_samples, apod):
"""REFoCUS decoding for a single (unbatched) volume.
All channels and all frequency bins are processed in parallel via
batched tensor operations.
Args:
data: ``(n_tx, n_ax, n_el, n_ch)`` float32 RF array.
delays_samples: ``(n_tx, n_el)`` transmit delays in samples.
apod: ``(n_tx, n_el)`` transmit apodization.
Returns:
decoded: ``(n_el, n_ax, n_el, n_ch)`` float32 array.
"""
n_tx, n_ax, n_el, n_ch = data.shape
n_elements = delays_samples.shape[1]
# --- FFT over all channels at once ---
# data: (n_tx, n_ax, n_el, n_ch) -> (n_ch, n_el, n_tx, n_ax)
rf = ops.cast(ops.transpose(data, (3, 2, 0, 1)), "float32")
# (n_ch, n_el_recv, n_tx, n_freq)
RF_enc_r, RF_enc_i = ops.rfft(rf)
RF_enc = ops.cast(RF_enc_r, "complex64") + 1j * ops.cast(RF_enc_i, "complex64")
n_freq = RF_enc.shape[-1]
# Rearrange to (n_freq, n_tx, n_el_recv * n_ch) for batched matmul.
# (n_ch, n_el_recv, n_tx, n_freq) -> (n_freq, n_tx, n_el_recv, n_ch)
RF_enc = ops.transpose(RF_enc, (3, 2, 1, 0))
# -> (n_freq, n_tx, n_el_recv * n_ch)
RF_enc = ops.reshape(RF_enc, (n_freq, n_tx, n_el * n_ch))
# --- Batched inverse encoding matrices (skip DC at index 0) ---
frequency = ops.cast(ops.arange(n_freq), "float32") / n_ax
freq_noDC = frequency[1:] # (n_freq - 1,)
# Hinv: (n_freq - 1, n_elements, n_tx)
Hinv = self._get_hinv(delays_samples, freq_noDC, apod)
# --- Single batched matmul over all frequencies and channels ---
# (n_freq-1, n_elements, n_tx) @ (n_freq-1, n_tx, n_el_recv * n_ch)
# -> (n_freq-1, n_elements, n_el_recv * n_ch)
RF_dec = ops.matmul(Hinv, RF_enc[1:])
# Prepend zeros for the DC bin: (n_freq, n_elements, n_el_recv * n_ch)
dc = ops.zeros((1, n_elements, n_el * n_ch), dtype="complex64")
RF_decoded = ops.concatenate([dc, RF_dec], axis=0)
# --- IFFT back to time domain ---
# Reshape to (n_freq, n_elements, n_el_recv, n_ch)
RF_decoded = ops.reshape(RF_decoded, (n_freq, n_elements, n_el, n_ch))
# irfft acts on the last axis: move n_freq last
# -> (n_elements, n_el_recv, n_ch, n_freq)
RF_decoded = ops.transpose(RF_decoded, (1, 2, 3, 0))
# -> (n_elements, n_el_recv, n_ch, n_ax)
rf_decoded = ops.irfft((ops.real(RF_decoded), ops.imag(RF_decoded)), fft_length=n_ax)
# -> (n_elements, n_ax, n_el_recv, n_ch)
rf_decoded = ops.transpose(rf_decoded, (0, 3, 1, 2))
return ops.cast(rf_decoded, "float32")
# ------------------------------------------------------------------
# Operation interface
# ------------------------------------------------------------------
[docs]
def call(
self,
t0_delays,
sampling_frequency,
probe_geometry,
initial_times,
tx_apodizations=None,
**kwargs,
):
"""Decode plane-wave / focused transmit data into multistatic data.
After decoding the output is a synthetic-aperture (SA) dataset where
each virtual transmit corresponds to a single element firing. The
pipeline parameters that describe the transmit sequence are updated
accordingly so that downstream operations (TOF correction, pfield
weighting, etc.) remain consistent with the new data shape.
Args:
t0_delays: ``(n_tx, n_el)`` transmit delays in **seconds**.
sampling_frequency: Sampling frequency in Hz.
probe_geometry: ``(n_el, 3)`` element positions in metres.
tx_apodizations: ``(n_tx, n_el)`` transmit apodization weights.
Defaults to all-ones (uniform apodization).
**kwargs: Must contain the input data tensor under ``self.key``.
Returns:
dict with keys:
* ``self.output_key`` — decoded data ``(n_el, n_ax, n_el, n_ch)``
(or batched variant).
* ``"t0_delays"`` — zeros ``(n_el, n_el)`` (SA: no extra delay).
* ``"tx_apodizations"`` — identity ``(n_el, n_el)`` (one element
per virtual transmit).
* ``"polar_angles"`` — zeros ``(n_el,)`` (no steering).
* ``"focus_distances"`` — zeros ``(n_el,)`` (no focus).
* ``"transmit_origins"`` — element positions ``(n_el, 3)``.
* ``"initial_times"`` — zeros ``(n_el,)``.
* ``"t_peak"`` — shared transmit-waveform peak time ``(n_el,)``.
* ``"flat_pfield"`` — ``None`` (resets pfield so downstream
:class:`PfieldWeighting` becomes a no-op).
"""
data = kwargs[self.key]
delays_samples = (t0_delays - initial_times[..., None]) * ops.cast(
sampling_frequency, t0_delays.dtype
)
if tx_apodizations is None:
apod = ops.ones_like(delays_samples)
else:
apod = tx_apodizations
if self.with_batch_dim:
decoded = vmap(self._decode, in_axes=(0, None, None))(data, delays_samples, apod)
else:
decoded = self._decode(data, delays_samples, apod)
# Number of virtual SA transmits = number of elements
n_el = ops.shape(probe_geometry)[0]
dtype = t0_delays.dtype
sa_t0_delays = ops.zeros((n_el, n_el), dtype=dtype)
sa_tx_apodizations = ops.eye(n_el, dtype=dtype)
sa_polar_angles = ops.zeros((n_el,), dtype=dtype)
sa_focus_distances = ops.zeros((n_el,), dtype=dtype)
sa_initial_times = ops.zeros((n_el,), dtype=dtype)
t_peak = kwargs.get("t_peak")
if t_peak is not None:
t_peak_flat = ops.reshape(ops.cast(t_peak, dtype), (-1,))
sa_t_peak = ops.broadcast_to(t_peak_flat[:1], (n_el,))
else:
sa_t_peak = ops.zeros((n_el,), dtype=dtype)
return {
self.output_key: decoded,
"t0_delays": sa_t0_delays,
"tx_apodizations": sa_tx_apodizations,
"polar_angles": sa_polar_angles,
"focus_distances": sa_focus_distances,
"transmit_origins": probe_geometry,
"initial_times": sa_initial_times,
"t_peak": sa_t_peak,
"flat_pfield": None,
}
def make_operation_chain(
operation_chain: List[Union[str, Dict, Config, Operation, Pipeline]],
) -> List[Operation]:
"""Make an operation chain from a custom list of operations.
Args:
operation_chain (list): List of operations to be performed.
Each operation can be:
- A string: operation initialized with default parameters
- A dictionary: operation initialized with parameters in the dictionary
- A Config object: converted to a dictionary and initialized
- An Operation/Pipeline instance: used as-is
Returns:
list: List of operations to be performed.
Example:
.. doctest::
>>> from zea.ops import make_operation_chain, LogCompress
>>> SomeCustomOperation = LogCompress # just for demonstration
>>> chain = make_operation_chain(
... [
... "envelope_detect",
... {"name": "normalize", "params": {"output_range": (0, 1)}},
... SomeCustomOperation(),
... ]
... )
"""
chain = []
for operation in operation_chain:
# Handle already instantiated Operation or Pipeline objects
if isinstance(operation, (Operation, Pipeline)):
chain.append(operation)
continue
assert isinstance(operation, (str, dict, Config)), (
f"Operation {operation} should be a string, dict, Config object, Operation, or Pipeline"
)
if isinstance(operation, str):
operation_instance = get_ops(operation)()
else:
if isinstance(operation, Config):
operation = operation.serialize()
params = operation.get("params", {})
op_name = operation.get("name")
operation_cls = get_ops(op_name)
# Check for nested operations at the same level as params
if "operations" in operation:
nested_operations = make_operation_chain(operation["operations"])
# Instantiate pipeline-type operations with nested operations
if issubclass(operation_cls, Beamform):
# some pipelines, such as `zea.ops.Beamformer`, are initialized
# not with a list of operations but with other parameters that then
# internally create a list of operations
operation_instance = operation_cls(**params)
elif issubclass(operation_cls, Pipeline):
# in most cases we want to pass an operations list to
# initialize a pipeline
operation_instance = operation_cls(operations=nested_operations, **params)
else:
operation_instance = operation_cls(operations=nested_operations, **params)
else:
operation_instance = operation_cls(**params)
chain.append(operation_instance)
return chain
def pipeline_from_config(config: Config, **kwargs) -> Pipeline:
"""
Create a Pipeline instance from a Config object.
The config must have a top-level ``pipeline`` key containing an ``operations`` list.
"""
if "pipeline" not in config:
top_keys = list(config.keys()) if hasattr(config, "keys") else []
raise ValueError(
f"Cannot build Pipeline: missing top-level 'pipeline' key.\n"
f"Expected a config with the format:\n"
f" pipeline:\n"
f" operations:\n"
f" - <operation_name>\n"
f" - ...\n"
f"Found top-level keys: {top_keys}"
)
# Unwrap the pipeline subsection from a full config
config = Config(config["pipeline"])
if "operations" not in config:
top_keys = list(config.keys()) if hasattr(config, "keys") else []
raise ValueError(
f"Cannot build Pipeline: missing 'operations' key.\n"
f"Expected a config with the format:\n"
f" pipeline:\n"
f" operations:\n"
f" - <operation_name>\n"
f" - ...\n"
f"Found top-level keys: {top_keys}"
)
if not isinstance(config.operations, (list, np.ndarray)):
raise ValueError(
f"Cannot build Pipeline: 'operations' must be a list, "
f"got {type(config.operations).__name__}."
)
operations = make_operation_chain(config.operations)
# merge pipeline config without operations with kwargs
pipeline_config = config.copy()
pipeline_config.pop("operations")
kwargs = {**pipeline_config, **kwargs}
return Pipeline(operations=operations, **kwargs)
def pipeline_from_json(json_string: str, **kwargs) -> Pipeline:
"""
Create a Pipeline instance from a JSON string.
"""
pipeline_config = Config(json.loads(json_string, cls=ZEADecoderJSON))
return pipeline_from_config(pipeline_config, **kwargs)
@deprecated(replacement="Pipeline.from_path")
def pipeline_from_yaml(yaml_path: str, **kwargs) -> Pipeline: # pragma: no cover
"""
Create a Pipeline instance from a YAML file.
.. deprecated::
Use :meth:`Pipeline.from_path` instead.
"""
with open(yaml_path, "r", encoding="utf-8") as f:
pipeline_config = yaml.safe_load(f)
return pipeline_from_config(Config(pipeline_config), **kwargs)
def _pipeline_to_serializable_dict(pipeline: Pipeline, compact=True) -> dict:
"""Convert a Pipeline to a dict suitable for serialization.
The output format is ``{"pipeline": {"operations": [...], ...pipeline_kwargs}}``
which can be loaded back via ``pipeline_from_config``.
"""
pipeline_dict = {
"operations": Pipeline._pipeline_to_list(pipeline, compact=compact),
}
if compact:
if not pipeline.with_batch_dim:
pipeline_dict["with_batch_dim"] = pipeline.with_batch_dim
if pipeline.jit_options != "ops":
pipeline_dict["jit_options"] = pipeline.jit_options
if pipeline._user_jit_kwargs:
pipeline_dict["jit_kwargs"] = pipeline._user_jit_kwargs
if pipeline.name != "pipeline":
pipeline_dict["name"] = pipeline.name
else:
pipeline_dict.update(
{
"with_batch_dim": pipeline.with_batch_dim,
"jit_options": pipeline.jit_options,
"jit_kwargs": pipeline._user_jit_kwargs,
"name": pipeline.name,
}
)
return {"pipeline": pipeline_dict}
def pipeline_to_config(pipeline: Pipeline, compact=True) -> Config:
"""
Convert a Pipeline instance into a Config object.
"""
return Config(_pipeline_to_serializable_dict(pipeline, compact=compact))
def pipeline_to_json(pipeline: Pipeline, compact=True) -> str:
"""
Convert a Pipeline instance into a JSON string.
"""
return json.dumps(
_pipeline_to_serializable_dict(pipeline, compact=compact),
cls=ZEAEncoderJSON,
indent=4,
)
def pipeline_to_yaml(pipeline: Pipeline, file_path: str, compact=True) -> None:
"""
Convert a Pipeline instance into a YAML file.
"""
with open(file_path, "w", encoding="utf-8") as f:
yaml.safe_dump(
_pipeline_to_serializable_dict(pipeline, compact=compact),
f,
indent=4,
)