Source code for zea.data.spec

from collections import defaultdict
from dataclasses import MISSING, dataclass, field, fields
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as _get_pkg_version
from pathlib import Path
from typing import Any, List, Tuple

import h5py
import numpy as np

from zea import log

CONSISTENCY_DIMENSIONS = {"n_frames", "n_tx", "n_ax", "n_el", "n_ch", "n_spatial_ch"}

UNITS = {
    "m/s": "meters per second",
    "m": "meters",
    "Hz": "Hertz",
    "s": "seconds",
    "-": "unitless",
    "rad": "radians",
    "dB": "decibels",
    "#": "count",
    "%": "percent",
}

DEFAULT_COMPRESSION = "lzf"

# Default unit/description for every SCHEMA leaf field.  Subclasses may
# override by defining their own FIELD_METADATA dict.
_DEFAULT_FIELD_UNIT = "-"
_DEFAULT_FIELD_DESCRIPTION = ""


[docs] def check_dtype(value: Any, expected_dtype: List[type]) -> None: """Check if the dtype of a value matches the expected dtype, allowing for compatible types. Works for numpy arrays, numpy scalars, and Python native types. """ for dt in expected_dtype: if isinstance(dt, type) and issubclass(dt, np.generic): expected_np_dtype = np.dtype(dt) if hasattr(value, "dtype"): if np.issubdtype(value.dtype, expected_np_dtype): return elif np.issubdtype(expected_np_dtype, np.character) and isinstance(value, (str, bytes)): return else: if isinstance(value, dt): return actual_type = ( f"dtype {value.dtype}" if hasattr(value, "dtype") else f"Python {type(value).__name__}" ) expected_dtypes_str = ", ".join(str(dt) for dt in expected_dtype) raise TypeError( f"Expected dtype compatible with one of ({expected_dtypes_str}), got {actual_type}. " f"Hint: wrap the value with the appropriate numpy type, " f"e.g. np.float32(...), np.str_(...), np.uint8(...)." )
[docs] def value_shape(value: Any) -> tuple: """Return the shape tuple for numpy arrays and scalar values.""" if isinstance(value, np.ndarray): return value.shape return ()
[docs] def match_shape(value: Any, expected_shape: tuple) -> bool: """Check if the shape of a value matches the expected shape specification.""" shape = value_shape(value) ellipsis_positions = [i for i, dim in enumerate(expected_shape) if dim == "..."] if len(ellipsis_positions) > 1: raise ValueError("Expected shape can contain at most one '...' wildcard") if not ellipsis_positions: if len(shape) != len(expected_shape): return False comparisons = zip(shape, expected_shape) else: ellipsis_pos = ellipsis_positions[0] prefix_expected = expected_shape[:ellipsis_pos] suffix_expected = expected_shape[ellipsis_pos + 1 :] # '...' matches any number of dimensions (including zero). min_required_dims = len(prefix_expected) + len(suffix_expected) if len(shape) < min_required_dims: return False prefix_shape = shape[: len(prefix_expected)] suffix_shape = shape[len(shape) - len(suffix_expected) :] if suffix_expected else () comparisons = zip( prefix_shape + suffix_shape, prefix_expected + suffix_expected, ) for dim_size, expected_dim in comparisons: if isinstance(expected_dim, str): continue if dim_size != expected_dim: return False return True
[docs] def find_matched_shape(value: Any, expected_shapes: List[tuple]) -> tuple | None: """Find the first expected shape specification that matches the shape of the value.""" for expected_shape in expected_shapes: if match_shape(value, expected_shape): return expected_shape return None
[docs] class Spec: """Base class for data specifications with schema validation. Subclasses should define a SCHEMA class variable that specifies the expected dtype and shape for each field. The __post_init__ method will validate that the actual fields match the schema, including checking that dimensions with the same name have consistent sizes across fields. """ SCHEMA: dict @staticmethod def _is_optional_dataclass_field(field_def: Any) -> bool: if field_def is None: return False return field_def.default is not MISSING or field_def.default_factory is not MISSING
[docs] @classmethod def required_fields(cls) -> tuple[str, ...]: """Return the names of fields that have no default value.""" return tuple(f.name for f in fields(cls) if not cls._is_optional_dataclass_field(f))
[docs] @classmethod def fields(cls) -> tuple[str, ...]: """Return the names of all fields.""" return tuple(f.name for f in fields(cls))
[docs] @classmethod def optional_fields(cls) -> tuple[str, ...]: """Return the names of fields that have a default value.""" return tuple(f.name for f in fields(cls) if cls._is_optional_dataclass_field(f))
[docs] def warn_missing_optional_fields(self): """Warn about optional fields that were not provided.""" _optional_fields = self.optional_fields() for field_name in self.SCHEMA.keys(): if field_name in _optional_fields and getattr(self, field_name) is None: if hasattr(self, "FIELD_METADATA"): meta = self.FIELD_METADATA.get(field_name, {}) description = meta.get("description", _DEFAULT_FIELD_DESCRIPTION) else: description = _DEFAULT_FIELD_DESCRIPTION log.warning( f"Optional {self.__class__.__name__} field '{field_name}' is not set. " f"Description: {description} " "Defaulted to None." )
@staticmethod def _expected_shapes(shape_spec: Any) -> tuple[tuple, ...]: if shape_spec and isinstance(shape_spec[0], tuple): return tuple(shape_spec) return (shape_spec,) @staticmethod def _merge_dimension_info( dim_to_fields: defaultdict[str, set[str]], dim_to_sizes: defaultdict[str, set[int]], nested_dim_to_fields: defaultdict[str, set[str]], nested_dim_to_sizes: defaultdict[str, set[int]], ) -> None: for dim_name, nested_fields in nested_dim_to_fields.items(): dim_to_fields[dim_name].update(nested_fields) for dim_name, nested_sizes in nested_dim_to_sizes.items(): dim_to_sizes[dim_name].update(nested_sizes) @staticmethod def _track_named_dimensions( dim_to_fields: defaultdict[str, set[str]], dim_to_sizes: defaultdict[str, set[int]], field_path: str, matched_shape: tuple, shape: tuple, ) -> None: for i, dim_name in enumerate(matched_shape): if isinstance(dim_name, str) and dim_name in CONSISTENCY_DIMENSIONS: dim_to_fields[dim_name].add(field_path) dim_to_sizes[dim_name].add(shape[i]) @staticmethod def _raise_if_shape_mismatch( field_name: str, value: Any, expected_shapes: tuple[tuple, ...] ) -> None: allowed_shapes = ", ".join(str(shape) for shape in expected_shapes) raise ValueError( f"{field_name} has shape {value_shape(value)}, expected one of: {allowed_shapes}" ) def _validate_nested_field( self, field_name: str, nested_spec: "Spec", field_value: Any ) -> "Spec": """Validate a nested spec field, recursively validating its contents.""" if isinstance(field_value, dict): field_value = nested_spec(**field_value) setattr(self, field_name, field_value) # Check that the nested spec field is now an instance of the expected Spec subclass # E.g. Segmentation if nested_spec is Map if not issubclass(type(field_value), nested_spec): raise TypeError( f"Expected field '{field_name}' to be {nested_spec}, got {type(field_value)}" ) return field_value @staticmethod def _cast_native_to_numpy(value: Any, expected_dtype: list) -> Any: """Cast values to expected numpy dtypes when possible. For fields that expect a floating dtype, all floating-point inputs are accepted and normalized to the first floating dtype in ``expected_dtype`` (typically ``np.float32``). """ expected_np_dtypes = [] for dt in expected_dtype: try: expected_np_dtypes.append(np.dtype(dt)) except TypeError: continue expected_float_dtype = next( (dt for dt in expected_np_dtypes if np.issubdtype(dt, np.floating)), None, ) # Keep native string/bytes values as-is instead of converting to numpy string scalars. if isinstance(value, (str, bytes)): return value if hasattr(value, "dtype"): value_dtype = np.dtype(value.dtype) if ( expected_float_dtype is not None and np.issubdtype(value_dtype, np.floating) and value_dtype != expected_float_dtype ): return value.astype(expected_float_dtype, copy=False) return value # If the spec expects a native Python type and the value already matches it, # keep it as-is instead of converting to a numpy scalar. for dt in expected_dtype: if isinstance(dt, type) and not issubclass(dt, np.generic) and isinstance(value, dt): return value for dt in expected_dtype: try: target_dtype = np.dtype(dt) return target_dtype.type(value) except (TypeError, ValueError, OverflowError): continue return value def _validate_and_track_primitive_field( self, field_name: str, field_info: dict, field_value: Any, dim_to_fields: defaultdict[str, set[str]], dim_to_sizes: defaultdict[str, set[int]], ) -> None: expected_dtype = field_info["dtype"] if not isinstance(expected_dtype, (list, tuple)): expected_dtype = [expected_dtype] expected_shapes = self._expected_shapes(field_info["shape"]) # Auto-cast Python native types (str, int, float) to numpy equivalents field_value = self._cast_native_to_numpy(field_value, expected_dtype) setattr(self, field_name, field_value) try: check_dtype(field_value, expected_dtype) except TypeError as e: raise TypeError(f"{type(self).__name__}: field '{field_name}' has invalid dtype: {e}") matched_shape = find_matched_shape(field_value, expected_shapes) if matched_shape is None: self._raise_if_shape_mismatch(field_name, field_value, expected_shapes) self._track_named_dimensions( dim_to_fields=dim_to_fields, dim_to_sizes=dim_to_sizes, field_path=field_name, matched_shape=matched_shape, shape=value_shape(field_value), ) @staticmethod def _raise_if_inconsistent_dimensions( dim_to_fields: defaultdict[str, set[str]], dim_to_sizes: defaultdict[str, set[int]], ) -> None: for dim_name, sizes in dim_to_sizes.items(): if len(sizes) > 1: field_names = sorted(dim_to_fields[dim_name]) raise ValueError( f"Dimension '{dim_name}' has inconsistent sizes across " f"fields {field_names}: {sorted(sizes)}" ) def _collect_dimension_info( self, prefix: str = "" ) -> tuple[defaultdict[str, set[str]], defaultdict[str, set[int]]]: """Collect named dimension usage and observed sizes for this spec subtree.""" dim_to_fields = defaultdict(set) dim_to_sizes = defaultdict(set) for field_name, field_info in self.SCHEMA.items(): field_value = getattr(self, field_name) if field_value is None: continue nested_spec = field_info.get("spec") if nested_spec is not None: nested_dim_to_fields, nested_dim_to_sizes = field_value._collect_dimension_info( prefix=f"{prefix}{field_name}." ) self._merge_dimension_info( dim_to_fields, dim_to_sizes, nested_dim_to_fields, nested_dim_to_sizes, ) continue expected_shapes = self._expected_shapes(field_info["shape"]) matched_shape = find_matched_shape(field_value, expected_shapes) if matched_shape is None: # Child specs are already validated; skip defensively if no shape can be matched. continue self._track_named_dimensions( dim_to_fields=dim_to_fields, dim_to_sizes=dim_to_sizes, field_path=f"{prefix}{field_name}", matched_shape=matched_shape, shape=value_shape(field_value), ) return dim_to_fields, dim_to_sizes def __post_init__(self): dim_to_fields = defaultdict(set) dim_to_sizes = defaultdict(set) dataclass_fields = {f.name: f for f in fields(self)} for field_name, field_info in self.SCHEMA.items(): field_value = getattr(self, field_name) field_def = dataclass_fields.get(field_name) is_optional = self._is_optional_dataclass_field(field_def) if field_value is None: if not is_optional: raise ValueError(f"Missing required field '{field_name}'") continue nested_spec = field_info.get("spec") if nested_spec is not None: try: field_value = self._validate_nested_field(field_name, nested_spec, field_value) except (TypeError, ValueError) as e: raise type(e)(f"In field '{field_name}': {e}") from e nested_dim_to_fields, nested_dim_to_sizes = field_value._collect_dimension_info( prefix=f"{field_name}." ) self._merge_dimension_info( dim_to_fields, dim_to_sizes, nested_dim_to_fields, nested_dim_to_sizes, ) continue self._validate_and_track_primitive_field( field_name=field_name, field_info=field_info, field_value=field_value, dim_to_fields=dim_to_fields, dim_to_sizes=dim_to_sizes, ) self._raise_if_inconsistent_dimensions(dim_to_fields, dim_to_sizes) @staticmethod def _is_string_value(value: Any) -> bool: """Return True for scalar/array values that should be stored as HDF5 strings.""" if isinstance(value, (str, np.str_, bytes, np.bytes_)): return True if isinstance(value, np.ndarray): return value.dtype.kind in {"U", "S", "O"} return False
[docs] @staticmethod def create_dataset( group: h5py.Group, field_name: str, value: Any, compression: str = DEFAULT_COMPRESSION, chunk_frames: bool = False, ) -> None: """Create a dataset in the given group for the specified field and value, handling string and scalar values appropriately.""" dataset_is_scalar = np.isscalar(value) or value.ndim == 0 compression = None if dataset_is_scalar else compression chunks = None if ( chunk_frames and not dataset_is_scalar and isinstance(value, np.ndarray) and value.ndim >= 2 ): chunks = (1,) + value.shape[1:] if Spec._is_string_value(value): string_dtype = h5py.string_dtype(encoding="utf-8") string_value = np.asarray(value, dtype=object) group.create_dataset( field_name, data=string_value, dtype=string_dtype, compression=compression, ) else: group.create_dataset(field_name, data=value, compression=compression, chunks=chunks)
[docs] def store_in_group( self, group: h5py.Group, compression: str = DEFAULT_COMPRESSION, chunk_frames: bool = False, ) -> None: """Store the data in the given group (e.g. hdf5 group).""" assert isinstance(group, h5py.Group), "group must be an h5py Group" # Optional fields should only warn when persisting to disk, not on load. self.warn_missing_optional_fields() field_metadata = getattr(self, "FIELD_METADATA", {}) for field_name, field_info in self.SCHEMA.items(): value = getattr(self, field_name) # We do not store fields with value None in the file. if value is None: continue nested_spec = field_info.get("spec") if nested_spec is not None: subgroup = group.create_group(field_name) value.store_in_group( subgroup, compression=compression, chunk_frames=chunk_frames, ) else: self.create_dataset( group, field_name, value, compression=compression, chunk_frames=chunk_frames, ) meta = field_metadata.get(field_name, {}) group[field_name].attrs["unit"] = meta.get("unit", _DEFAULT_FIELD_UNIT) group[field_name].attrs["description"] = meta.get( "description", _DEFAULT_FIELD_DESCRIPTION )
[docs] def to_dict(self) -> dict[str, Any]: """Return this spec as a nested dictionary based on ``SCHEMA`` fields. Nested specs are converted recursively. """ result = {} for field_name, field_info in self.SCHEMA.items(): value = getattr(self, field_name) nested_spec = field_info.get("spec") if nested_spec is not None and value is not None: if isinstance(value, Spec): result[field_name] = value.to_dict() elif isinstance(value, dict): result[field_name] = { k: v.to_dict() if isinstance(v, Spec) else v for k, v in value.items() } else: result[field_name] = value else: result[field_name] = value return result
[docs] @classmethod def get_dtype(cls, field_name) -> Tuple[type, ...] | type: """Get the dtype of a field.""" return cls.SCHEMA[field_name]["dtype"]
def __repr__(self) -> str: parts = [] for field_name, field_info in self.SCHEMA.items(): value = getattr(self, field_name, None) if value is None: continue nested_spec = field_info.get("spec") if nested_spec is not None: parts.append(f"{field_name}={value!r}") elif isinstance(value, np.ndarray): parts.append(f"{field_name}=array({value.dtype} {value.shape})") else: parts.append(f"{field_name}={value!r}") return f"{self.__class__.__name__}({', '.join(parts)})"
[docs] @dataclass class Map(Spec): """Map data with per-pixel Cartesian coordinates. A map is a function from Cartesian space to some real values: every pixel at spatial index ``[f, i, j, ...]`` is assigned a 3-D position ``coordinates[f, i, j, ..., :]`` = ``[x, y, z]`` in metres. The most flexible map spec, which can be used for any spatially aligned data product. See, for example, :func:`~zea.beamform.pixelgrid.cartesian_pixel_grid` or :func:`~zea.beamform.pixelgrid.polar_pixel_grid` to create a suitable coordinate array from your scan geometry. Args: values: The map values of shape ``(n_frames, z, x, y, n_ch)`` or ``(n_frames, z, x, y)`` or ``(n_frames, z, x, n_ch)`` or ``(n_frames, z, x)`` and type uint8, float32, int16, or complex64. coordinates: Per-pixel Cartesian positions in metres, shape ``(*spatial_dims, 3)`` where ``spatial_dims`` matches the spatial (non-channel) dimensions of ``values``. For non-channeled values the shape is ``(*values.shape, 3)``; for channeled values the shape is ``(*values.shape[:-1], 3)``. The last axis holds ``[x, y, z]``. The leading ``n_frames`` axis may be omitted to broadcast one coordinate grid across all frames. labels: The labels corresponding to the ``n_ch`` channels in the values. This is required when values have an n_ch dimension, and should be None otherwise. For IQ data, this would typically be ``["I", "Q"]``. description: An optional free-text description of the map. unit: An optional string specifying the physical unit of the map values, e.g. ``"m/s"``, ``"%"``, etc. min: The minimum value of the map. max: The maximum value of the map. """ values: np.ndarray coordinates: np.ndarray | None = None labels: np.ndarray | None = None description: str | None = None unit: str | None = None min: float | None = None max: float | None = None SCHEMA = { "values": { "dtype": (np.uint8, np.float32, np.int16, np.complex64), "shape": ( ("n_frames", "z", "x", "y", "n_spatial_ch"), ("n_frames", "z", "x", "y"), ("n_frames", "z", "x", "n_spatial_ch"), ("n_frames", "z", "x"), ), }, "coordinates": {"dtype": np.float32, "shape": ("...", 3)}, "labels": {"dtype": np.str_, "shape": ("n_spatial_ch",)}, "description": {"dtype": str, "shape": ()}, "unit": {"dtype": str, "shape": ()}, "min": {"dtype": np.float32, "shape": ()}, "max": {"dtype": np.float32, "shape": ()}, } def __post_init__(self): super().__post_init__() if self.values.ndim == 5: assert self.labels is not None, ( "labels must be provided when values have n_ch dimension" ) if self.coordinates is not None: # coordinates.shape[-1] is guaranteed == 3 by the SCHEMA check above. # Validate that the spatial axes match values (with or without a trailing channel axis). coords_spatial = self.coordinates.shape[:-1] valid_spatial_shapes = { self.values.shape, self.values.shape[:-1], } # Also accept coordinates that omit the leading frame axis and # therefore broadcast across frames. if len(self.values.shape) > 1: valid_spatial_shapes.add(self.values.shape[1:]) if len(self.values.shape[:-1]) > 1: valid_spatial_shapes.add(self.values.shape[1:-1]) if coords_spatial not in valid_spatial_shapes: raise ValueError( f"{type(self).__name__}: coordinates shape {self.coordinates.shape} is " f"incompatible with values shape {self.values.shape}. " f"coordinates.shape[:-1] must equal values.shape " f"({self.values.shape}) for non-channeled data, or " f"values.shape[:-1] ({self.values.shape[:-1]}) for channeled data, " "with optional frame-axis broadcasting (leading n_frames omitted)." ) # Sanity-check units: clinical ultrasound scan regions are at most a few tens of # centimetres across, so any finite coordinate magnitude above 1 m almost certainly # indicates the array was supplied in millimetres rather than metres. max_abs = np.max(np.abs(self.coordinates[np.isfinite(self.coordinates)]), initial=0.0) if max_abs > 1.0: log.warning( f"{type(self).__name__}: coordinates have a maximum absolute value of " f"{max_abs:.4g}, which exceeds 1 m. Ultrasound scan regions are " "typically a few centimetres across. Please verify that coordinates " "are in metres, not millimetres." ) else: log.warning( f"{type(self).__name__}: coordinates are not provided, please consider adding " "a coordinates field to ensure the map can be correctly displayed." )
[docs] @dataclass class FloatMap(Map): """Map data with float32 pixel values and per-pixel Cartesian coordinates.""" SCHEMA = { **Map.SCHEMA, "values": { **Map.SCHEMA["values"], "dtype": np.float32, }, }
[docs] @dataclass class BooleanMap(Map): """Map data with bool pixel values and per-pixel Cartesian coordinates.""" SCHEMA = { **Map.SCHEMA, "values": { **Map.SCHEMA["values"], "dtype": np.bool_, }, }
[docs] @dataclass class UnsignedIntMap(Map): """Map data with uint8 pixel values and per-pixel Cartesian coordinates.""" SCHEMA = { **Map.SCHEMA, "values": { **Map.SCHEMA["values"], "dtype": np.uint8, }, }
[docs] @dataclass class Segmentation(BooleanMap): """Segmentation data with per-pixel Cartesian coordinates. Args: values: The segmentation values of shape ``(n_frames, z, x, y, n_labels)`` for 3D (volumetric) data or ``(n_frames, z, x, n_labels)`` for 2D data, with type bool. coordinates: Per-pixel Cartesian positions in metres, shape ``(*spatial_dims, 3)`` where ``spatial_dims`` matches the spatial (non-label) dimensions of ``values``. The leading frame axis may be omitted to broadcast one coordinate grid across all frames. labels: The labels corresponding to the segmentation values, where each unique value in the values corresponds to a label in this list of shape ``(n_labels,)`` and type str. .. note:: To indicate that certain frames have no segmentation, add an explicit ``"unannotated"`` entry to ``labels`` and set ``values[..., unannotated_idx]`` to ``True`` for those frames (with all other label channels set to ``False``). This keeps the shape uniform across frames while clearly distinguishing genuinely annotated frames from frames that were not labelled. For example:: labels = np.array(["unannotated", "LV_endo", "LV_myo", "LA"]) values = np.zeros((n_frames, H, W, 4), dtype=bool) # mark all frames as unannotated by default values[:, :, :, 0] = True # for annotated frames, set unannotated channel to False # and the appropriate label channel to True values[ed_idx, :, :, 0] = False values[ed_idx, :, :, 1:] = segmentation_mask # shape (H, W, 3) """ SCHEMA = { **BooleanMap.SCHEMA, "values": { **BooleanMap.SCHEMA["values"], "shape": ( ("n_frames", "z", "x", "y", "n_spatial_ch"), ("n_frames", "z", "x", "n_spatial_ch"), ), }, } def __post_init__(self): assert self.values.ndim in (4, 5), ( "Segmentation values must have 4 or 5 dimensions: " "(n_frames, z, x, n_labels) for 2D or (n_frames, z, x, y, n_labels) for 3D, " f"got shape {self.values.shape}" ) assert self.labels is not None, "Segmentation requires labels to be provided" super().__post_init__()
[docs] @dataclass class Image(Map): """Reconstructed (log-compressed) image data with per-pixel Cartesian coordinates. Args: values: The image values of shape ``(n_frames, z, x, y)`` or ``(n_frames, z, x)`` and type uint8 or float32. For float32 values, the values should be in dB (between -inf and 0). coordinates: Per-pixel Cartesian positions in metres, shape ``(*values.shape, 3)``. The leading frame axis may be omitted to broadcast one coordinate grid across all frames. """ SCHEMA = { **Map.SCHEMA, "values": { "dtype": (np.float32, np.uint8), "shape": ( ("n_frames", "x", "z", "y"), ("n_frames", "x", "z"), ), }, } def __post_init__(self): super().__post_init__() # Check that image values are in dB scale (finite or -inf, and <= 0) if self.values.dtype == np.float32: if not np.all(np.isfinite(self.values) | np.isneginf(self.values)): raise ValueError("Image values must be finite or -inf (dB scale).") if not np.all(self.values <= 0): raise ValueError("Image values must be in dB scale <= 0 when using float32 dtype.")
[docs] @dataclass class AlignedData(Spec): """Time-of-flight corrected data. Args: values: The aligned data of shape ``(n_frames, n_tx, n_ax, n_el, n_ch)`` and type float32 or int16. n_ch is 1 for RF data or 2 for IQ data. labels: The labels for the channel dimension, e.g. ``["RF"]`` or ``["I", "Q"]``. Auto-generated from n_ch if not provided. """ values: np.ndarray labels: np.ndarray | None = None SCHEMA = { "values": { "dtype": (np.float32, np.int16), "shape": ("n_frames", "n_tx", "n_ax", "n_el", "n_ch"), }, "labels": {"dtype": np.str_, "shape": ("n_ch",)}, } def __post_init__(self): n_ch = self.values.shape[-1] if n_ch not in (1, 2): raise ValueError( f"Aligned data must have n_ch ∈ {{1, 2}} (RF or IQ), " f"got n_ch={n_ch} (shape {self.values.shape})." ) if self.labels is None: self.labels = ( np.array(["RF"], dtype=np.str_) if n_ch == 1 else np.array(["I", "Q"], dtype=np.str_) ) super().__post_init__()
[docs] @dataclass class BeamformedData(FloatMap): """Beamformed (beamsummed) data with per-pixel Cartesian coordinates. Args: values: The beamformed data of shape ``(n_frames, z, x, n_ch)`` or ``(n_frames, z, x, y, n_ch)`` and type float32. n_ch is 1 for RF data or 2 for IQ data. coordinates: Per-pixel Cartesian positions in metres, shape ``(n_frames, z, x, 3)`` or ``(n_frames, z, x, y, 3)``. The leading frame axis may be omitted to broadcast one coordinate grid across all frames. labels: The labels for the channel dimension, e.g. ``["RF"]`` or ``["I", "Q"]``. Auto-generated from n_ch if not provided. """ SCHEMA = { **FloatMap.SCHEMA, "values": { "dtype": np.float32, "shape": ( ("n_frames", "z", "x", "y", "n_ch"), ("n_frames", "z", "x", "n_ch"), ), }, "labels": {"dtype": np.str_, "shape": ("n_ch",)}, } def __post_init__(self): n_ch = self.values.shape[-1] if n_ch not in (1, 2): raise ValueError( f"Beamformed data must have n_ch ∈ {{1, 2}} (RF or IQ), " f"got n_ch={n_ch} (shape {self.values.shape})." ) if self.labels is None: self.labels = ( np.array(["RF"], dtype=np.str_) if n_ch == 1 else np.array(["I", "Q"], dtype=np.str_) ) super().__post_init__()
[docs] @dataclass class EnvelopeData(FloatMap): """Envelope-detected data with per-pixel Cartesian coordinates. Args: values: The envelope data of shape ``(n_frames, x, z)`` or ``(n_frames, z, x, y)`` and type float32. coordinates: Per-pixel Cartesian positions in metres, shape ``(*values.shape, 3)``. The leading frame axis may be omitted to broadcast one coordinate grid across all frames. """ SCHEMA = { **FloatMap.SCHEMA, "values": { "dtype": np.float32, "shape": ( ("n_frames", "z", "x", "y"), ("n_frames", "z", "x"), ), }, }
[docs] @dataclass class SosMap(FloatMap): """Speed-of-sound map data with per-pixel Cartesian coordinates. Args: values: The speed-of-sound map values in m/s of shape ``(n_frames, z, x, y)`` and type float32. coordinates: Per-pixel Cartesian positions in metres, shape ``(n_frames, z, x, 3)`` or ``(n_frames, z, x, y, 3)``. The leading frame axis may be omitted to broadcast one coordinate grid across all frames. """ def __post_init__(self): super().__post_init__() if self.unit is not None and self.unit != "m/s": raise ValueError(f"Speed-of-sound map unit should be 'm/s', got '{self.unit}'") # Check sensible values for speed of sound if np.any(self.values < 300): log.warning( "Speed-of-sound map contains values below 300 m/s, which is unusually low. " "Please verify that the speed-of-sound values are correct and in m/s." )
[docs] @dataclass class StrainPercentageMap(FloatMap): """Strain map data with per-pixel Cartesian coordinates. Args: values: The strain values in % of shape ``(n_frames, z, x, y)`` and type float32. coordinates: Per-pixel Cartesian positions in metres. """ def __post_init__(self): super().__post_init__() if self.unit is not None and self.unit != "%": raise ValueError(f"Strain map unit should be '%', got '{self.unit}'")
[docs] @dataclass class ShearWaveElastographyMap(FloatMap): """Shear-wave elastography data with per-pixel Cartesian coordinates. Args: values: The shear-wave elastography values in m/s of shape ``(n_frames, z, x, y)`` and type float32. coordinates: Per-pixel Cartesian positions in metres. """ def __post_init__(self): super().__post_init__() if self.unit is not None and self.unit != "m/s": raise ValueError(f"SWE map unit should be 'm/s', got '{self.unit}'")
[docs] @dataclass class TissueDopplerMap(FloatMap): """Tissue Doppler data with per-pixel Cartesian coordinates. Args: values: The tissue Doppler values in m/s of shape ``(n_frames, z, x, y)`` and type float32. coordinates: Per-pixel Cartesian positions in metres. """ def __post_init__(self): super().__post_init__() if self.unit is not None and self.unit != "m/s": raise ValueError(f"SWE map unit should be 'm/s', got '{self.unit}'")
[docs] @dataclass class ColorDopplerMap(FloatMap): """Color Doppler (velocity) data with per-pixel Cartesian coordinates. Args: values: The color Doppler velocity values in m/s of shape ``(n_frames, z, x, y)`` and type float32. Positive values indicate flow towards the transducer, negative values indicate flow away from the transducer. coordinates: Per-pixel Cartesian positions in metres. """ def __post_init__(self): super().__post_init__() if self.unit is not None and self.unit != "m/s": raise ValueError(f"SWE map unit should be 'm/s', got '{self.unit}'")
[docs] @dataclass(init=False) class DataSpec(Spec): """Data group containing raw channels, derived pipeline products, and optional spatial maps. Plain-array data products: raw_data: Raw channel data of shape (n_frames, n_tx, n_ax, n_el, n_ch) and type float32 or int16. Grouped data products (values + optional metadata): - aligned_data: Time-of-flight corrected data and optional labels. - beamformed_data: Beamformed (beamsummed) data and per-pixel coordinates. - envelope_data: Envelope-detected data and per-pixel coordinates. - image: Reconstructed image data and per-pixel coordinates. - segmentation: Segmentation data and per-pixel coordinates. - sos_map: Speed-of-sound map data and per-pixel coordinates. - strain_percentage_map: Strain map data and per-pixel coordinates. - shear_wave_elastography_map: Shear-wave elastography data and per-pixel coordinates. - tissue_doppler: Tissue Doppler data and per-pixel coordinates. - color_doppler: Color Doppler velocity data and per-pixel coordinates. - \\*\\*kwargs: Any other spatially aligned map data and per-pixel coordinates. At least one data field (plain-array or grouped) must be provided. """ # Plain-array data products raw_data: np.ndarray | None = None # Grouped data products aligned_data: AlignedData | dict | None = None beamformed_data: BeamformedData | dict | None = None envelope_data: EnvelopeData | dict | None = None image: Image | dict | None = None segmentation: Segmentation | dict | None = None sos_map: SosMap | dict | None = None strain_percentage_map: StrainPercentageMap | dict | None = None shear_wave_elastography_map: ShearWaveElastographyMap | dict | None = None tissue_doppler: TissueDopplerMap | dict | None = None color_doppler: ColorDopplerMap | dict | None = None SCHEMA = { # Plain-array data products "raw_data": { "dtype": (np.float32, np.int16), "shape": ("n_frames", "n_tx", "n_ax", "n_el", "n_ch"), }, # Grouped data products "aligned_data": {"spec": AlignedData}, "beamformed_data": {"spec": BeamformedData}, "envelope_data": {"spec": EnvelopeData}, "image": {"spec": Image}, "segmentation": {"spec": Segmentation}, "sos_map": {"spec": SosMap}, "strain_percentage_map": {"spec": StrainPercentageMap}, "shear_wave_elastography_map": {"spec": ShearWaveElastographyMap}, "tissue_doppler": {"spec": TissueDopplerMap}, "color_doppler": {"spec": ColorDopplerMap}, } FIELD_METADATA = { "raw_data": {"unit": "-", "description": "Raw channel data."}, } def __init__( self, raw_data: np.ndarray | None = None, aligned_data: AlignedData | dict | None = None, beamformed_data: BeamformedData | dict | None = None, envelope_data: EnvelopeData | dict | None = None, image: Image | dict | None = None, segmentation: Segmentation | dict | None = None, sos_map: SosMap | dict | None = None, strain_percentage_map: StrainPercentageMap | dict | None = None, shear_wave_elastography_map: ShearWaveElastographyMap | dict | None = None, tissue_doppler: TissueDopplerMap | dict | None = None, color_doppler: ColorDopplerMap | dict | None = None, **extra_maps, ): self.raw_data = raw_data self.aligned_data = aligned_data self.beamformed_data = beamformed_data self.envelope_data = envelope_data self.image = image self.segmentation = segmentation self.sos_map = sos_map self.strain_percentage_map = strain_percentage_map self.shear_wave_elastography_map = shear_wave_elastography_map self.tissue_doppler = tissue_doppler self.color_doppler = color_doppler reserved_keys = set(self.SCHEMA) | set(self.__dataclass_fields__) | set(dir(Spec)) for key, value in extra_maps.items(): if key in reserved_keys: raise TypeError(f"Invalid custom data key '{key}': reserved name") if isinstance(value, np.ndarray): raise TypeError( f"Custom data key '{key}' must be a spatial map " f"(a dict with at least a 'values' key), not a flat array. " f"Only 'raw_data' is accepted as a flat array. " f"Wrap your data: {{'values': array, 'coordinates': coordinates_array}}." ) setattr(self, key, value) # Add custom extra maps to the schema as generic Map specs, so they get validated. self._extra_map_keys = tuple(extra_maps.keys()) if getattr(self, "_extra_map_keys", ()): self.SCHEMA = { **self.SCHEMA, **{key: {"spec": Map} for key in self._extra_map_keys}, } self.__post_init__() def __post_init__(self): # Ensure at least one data field is present all_data_keys = [k for k in self.SCHEMA] has_any = any(getattr(self, k, None) is not None for k in all_data_keys) if not has_any: raise ValueError( "At least one data field must be provided. " f"Available fields: {', '.join(all_data_keys)}" ) super().__post_init__() # n_ch must be 1 (RF) or 2 (IQ) for raw_data (checked for aligned_data by AlignedData). arr = getattr(self, "raw_data", None) if arr is not None and isinstance(arr, np.ndarray): n_ch = arr.shape[-1] if n_ch not in (1, 2): raise ValueError( f"'raw_data' must have n_ch ∈ {{1, 2}} (RF or IQ), " f"got n_ch={n_ch} (shape {arr.shape})." )
[docs] @dataclass class ScanSpec(Spec): """Scan group with acquisition and transmit metadata. All fields are aligned with the data format specification. Args: sampling_frequency: The sampling frequency in Hz. center_frequency: The center frequency in Hz of the transmit pulse. Single scalar if all transmits share the same center frequency; otherwise an array of shape (n_tx,) with one frequency per transmit. demodulation_frequency: The frequency in Hz at which the data should be demodulated. Usually the same as center_frequency, but different when doing harmonic imaging. Single scalar if all transmits share the same center frequency; otherwise an array of shape (n_tx,) with one frequency per transmit. initial_times: The times in seconds when the A/D converter starts sampling of shape (n_tx,). This is the time between the first element firing and the first recorded sample. t0_delays: The transmit delays in seconds for each element of shape (n_tx, n_el). This is the time at which each element fires, shifted such that the first element fires at t=0. tx_apodizations: The apodization values that were applied to each element during transmit of shape (n_tx, n_el). This is a value between -1 and 1 that indicates how much each element contributed to the transmit beam, with 0 meaning no contribution and 1 meaning full contribution. Negative values indicate that the element was fired with opposite polarity. focus_distances: The transmit focus distances in meters of shape (n_tx,). This is the distance from the origin point on the transducer to where the beam comes to focus. For planewaves this is set to infinity or zero. transmit_origins: The transmit origins of the transmit beams in meters of shape (n_tx, 3). This is the (x, y, z) position from which the beam is transmitted. polar_angles: The polar angles in radians of the transmit beams of shape (n_tx,). time_to_next_transmit: The time in s between subsequent transmit events of shape (n_frames, n_tx). azimuth_angles: The azimuthal angles in radians of the transmit beams of shape (n_tx,). sound_speed: The speed of sound in meters per second. tgc_gain_curve: The time-gain-compensation that was applied to every sample in the raw_data of shape (n_ax,). Divide by this curve to undo the TGC. waveforms_one_way: One-way waveforms of shape (n_tx, .) as simulated by the Verasonics system. This is the waveform after being filtered by the transducer bandwidth once. waveforms_two_way: Two-way waveforms of shape (n_tx, .) as simulated by the Verasonics system. This is the waveform after being filtered by the transducer bandwidth twice. """ sampling_frequency: np.ndarray | float center_frequency: np.ndarray | float demodulation_frequency: np.ndarray | float initial_times: np.ndarray t0_delays: np.ndarray tx_apodizations: np.ndarray focus_distances: np.ndarray transmit_origins: np.ndarray polar_angles: np.ndarray time_to_next_transmit: np.ndarray = None azimuth_angles: np.ndarray = None sound_speed: np.ndarray | float | None = None tgc_gain_curve: np.ndarray | None = None waveforms_one_way: np.ndarray | None = None waveforms_two_way: np.ndarray | None = None SCHEMA = { "sampling_frequency": {"dtype": np.float32, "shape": ()}, "center_frequency": {"dtype": np.float32, "shape": ((), ("n_tx",))}, "demodulation_frequency": {"dtype": np.float32, "shape": ((), ("n_tx",))}, "initial_times": {"dtype": np.float32, "shape": ("n_tx",)}, "t0_delays": {"dtype": np.float32, "shape": ("n_tx", "n_el")}, "tx_apodizations": {"dtype": np.float32, "shape": ("n_tx", "n_el")}, "focus_distances": {"dtype": np.float32, "shape": ("n_tx",)}, "transmit_origins": {"dtype": np.float32, "shape": ("n_tx", 3)}, "polar_angles": {"dtype": np.float32, "shape": ("n_tx",)}, "time_to_next_transmit": {"dtype": np.float32, "shape": ("n_frames", "n_tx")}, "azimuth_angles": {"dtype": np.float32, "shape": ("n_tx",)}, "sound_speed": {"dtype": np.float32, "shape": ()}, "tgc_gain_curve": {"dtype": np.float32, "shape": ("n_ax",)}, "waveforms_one_way": { "dtype": np.float32, "shape": ("n_tx", "n_samples_one_way"), }, "waveforms_two_way": { "dtype": np.float32, "shape": ("n_tx", "n_samples_two_way"), }, } FIELD_METADATA = { "sampling_frequency": {"unit": "Hz", "description": "Sampling frequency."}, "center_frequency": { "unit": "Hz", "description": "Center frequency of the transmit pulse.", }, "demodulation_frequency": {"unit": "Hz", "description": "Demodulation frequency."}, "initial_times": {"unit": "s", "description": "A/D converter start times per transmit."}, "t0_delays": {"unit": "s", "description": "Transmit delays per element."}, "tx_apodizations": {"unit": "-", "description": "Transmit apodization per element."}, "focus_distances": {"unit": "m", "description": "Transmit focus distances."}, "transmit_origins": {"unit": "m", "description": "Transmit beam origins (x, y, z)."}, "polar_angles": {"unit": "rad", "description": "Polar angles of transmit beams."}, "time_to_next_transmit": {"unit": "s", "description": "Time between transmit events."}, "azimuth_angles": {"unit": "rad", "description": "Azimuthal angles of transmit beams."}, "sound_speed": {"unit": "m/s", "description": "Speed of sound."}, "tgc_gain_curve": {"unit": "-", "description": "Time-gain-compensation curve."}, "waveforms_one_way": {"unit": "V", "description": "One-way transmit waveforms."}, "waveforms_two_way": {"unit": "V", "description": "Two-way transmit waveforms."}, } @property def n_tx(self) -> int: """Number of transmits.""" return self.t0_delays.shape[0] @property def n_el(self) -> int: """Number of elements.""" return self.t0_delays.shape[1] def __post_init__(self): super().__post_init__() if self.sampling_frequency <= 0: raise ValueError(f"Sampling frequency must be positive, got {self.sampling_frequency}") if np.any(self.center_frequency < 0): raise ValueError(f"Center frequency cannot be negative, got {self.center_frequency}") if np.any(self.demodulation_frequency < 0): raise ValueError( f"Demodulation frequency cannot be negative, got {self.demodulation_frequency}" ) if np.any(self.t0_delays < 0): raise ValueError(f"Transmit delays cannot be negative, got {self.t0_delays}") if np.any(np.logical_and(self.focus_distances >= 1, self.focus_distances != np.inf)): log.warning( "Focus distances greater than or equal to 1 meter may be unusually large. " "Maybe you have to convert to meters?" ) if np.any(self.transmit_origins > 1.0) or np.any(self.transmit_origins < -1.0): log.warning( "Transmit origin values are unusually large, extending beyond +/- 1.0 meters. " "Please verify that the transmit origin values are correct and in meters." ) if np.any(self.polar_angles < -np.pi) or np.any(self.polar_angles > np.pi): raise ValueError( f"Polar angles should be between -pi and pi radians, got values between " f"{np.min(self.polar_angles)} and {np.max(self.polar_angles)}" ) if self.azimuth_angles is not None and ( np.any(self.azimuth_angles < -np.pi) or np.any(self.azimuth_angles > np.pi) ): raise ValueError( f"Azimuth angles should be between -pi and pi radians, got values between " f"{np.min(self.azimuth_angles)} and {np.max(self.azimuth_angles)}" ) if self.sound_speed is not None and self.sound_speed <= 0: raise ValueError(f"Sound speed must be positive, got {self.sound_speed}") if self.tgc_gain_curve is not None and np.any(self.tgc_gain_curve < 0): raise ValueError( f"TGC gain curve values must be non-negative, got values between " f"{np.min(self.tgc_gain_curve)} and {np.max(self.tgc_gain_curve)}" ) # Try to simplify the data by squeezing out any singleton dimensions, # e.g. if center_frequency is an array with all the same value if isinstance(self.center_frequency, np.ndarray) and self.center_frequency.ndim == 1: if np.all(self.center_frequency == self.center_frequency[0]): self.center_frequency = self.center_frequency[0] if ( isinstance(self.demodulation_frequency, np.ndarray) and self.demodulation_frequency.ndim == 1 ): if np.all(self.demodulation_frequency == self.demodulation_frequency[0]): self.demodulation_frequency = self.demodulation_frequency[0]
[docs] @dataclass class ProbeSpec(Spec): """Probe hardware specification. Stores static, physical characteristics of the transducer that are not captured by the per-acquisition :class:`ScanSpec`. All fields are optional so that partial information can be recorded. Args: name: Probe model identifier (e.g. ``"verasonics_l11_4v"``). type: Probe geometry type: ``"linear"``, ``"phased"``, ``"curved"``, etc. probe_center_frequency: Probe nominal centre frequency in Hz. Named distinctly from :attr:`ScanSpec.center_frequency` (the per-acquisition transmit frequency) so the two never collide when merged into a single :class:`zea.Parameters` object. probe_bandwidth_percent: Fractional bandwidth as a percentage. probe_geometry: Element positions in metres, shape (n_el, 3) with columns (x, y, z). :attr:`n_el` and :attr:`pitch` are computed automatically as read-only properties from this array. element_width: Width of a single transducer element in metres. element_height: Height (elevation aperture) of a single element in metres. lens_sound_speed: Speed of sound in the acoustic lens in m/s. lens_thickness: Thickness of the acoustic lens in metres. """ name: str | None = None type: str | None = None probe_center_frequency: np.float32 | None = None probe_bandwidth_percent: np.float32 | None = None probe_geometry: np.ndarray | None = None element_width: np.float32 | None = None element_height: np.float32 | None = None lens_sound_speed: np.float32 | None = None lens_thickness: np.float32 | None = None SCHEMA = { "name": {"dtype": str, "shape": ()}, "type": {"dtype": str, "shape": ()}, "probe_center_frequency": {"dtype": np.float32, "shape": ()}, "probe_bandwidth_percent": {"dtype": np.float32, "shape": ()}, "probe_geometry": {"dtype": np.float32, "shape": ("n_el", 3)}, "element_width": {"dtype": np.float32, "shape": ()}, "element_height": {"dtype": np.float32, "shape": ()}, "lens_sound_speed": {"dtype": np.float32, "shape": ()}, "lens_thickness": {"dtype": np.float32, "shape": ()}, } FIELD_METADATA = { "name": {"description": "Probe model name/identifier."}, "type": {"description": "Probe geometry type (linear, phased, curved, ...)."}, "probe_center_frequency": { "unit": "Hz", "description": "Probe nominal centre frequency.", }, "probe_bandwidth_percent": { "unit": "%", "description": "Fractional bandwidth as a percentage.", }, "probe_geometry": { "unit": "m", "description": "Element positions (x, y, z) per element, shape (n_el, 3).", }, "element_width": { "unit": "m", "description": "Width of a single transducer element.", }, "element_height": { "unit": "m", "description": "Height (elevation aperture) of a single transducer element.", }, "lens_sound_speed": { "unit": "m/s", "description": "Speed of sound in the acoustic lens.", }, "lens_thickness": { "unit": "m", "description": "Thickness of the acoustic lens.", }, } @property def n_el(self) -> int | None: """Number of transducer elements, derived from :attr:`probe_geometry`.""" if self.probe_geometry is not None: return int(self.probe_geometry.shape[0]) return None def __post_init__(self): super().__post_init__() if self.probe_geometry is not None: if self.probe_geometry.ndim != 2 or self.probe_geometry.shape[1] != 3: raise ValueError( f"ProbeSpec: probe_geometry must have shape (n_el, 3), " f"got {self.probe_geometry.shape}" ) if np.any(self.probe_geometry > 1.0) or np.any(self.probe_geometry < -1.0): log.warning( "ProbeSpec probe_geometry values extend beyond \u00b11.0 m. " "Please verify the values are in metres." ) if self.probe_center_frequency is not None and self.probe_center_frequency <= 0: raise ValueError( "ProbeSpec: probe_center_frequency must be positive, got " f"{self.probe_center_frequency}" ) if self.probe_bandwidth_percent is not None and self.probe_bandwidth_percent <= 0: raise ValueError( "ProbeSpec: probe_bandwidth_percent must be positive, " f"got {self.probe_bandwidth_percent}" ) if self.element_width is not None and self.element_width <= 0: raise ValueError(f"ProbeSpec: element_width must be positive, got {self.element_width}") if self.element_height is not None and self.element_height <= 0: raise ValueError( f"ProbeSpec: element_height must be positive, got {self.element_height}" ) if self.lens_sound_speed is not None and self.lens_sound_speed <= 0: raise ValueError( f"ProbeSpec: lens_sound_speed must be positive, got {self.lens_sound_speed}" ) if self.lens_thickness is not None and self.lens_thickness < 0: raise ValueError( f"ProbeSpec: lens_thickness must be non-negative, got {self.lens_thickness}" )
[docs] @dataclass class Subject(Spec): """Subject metadata associated with the study. Args: id: Subject ID. type: Subject type, e.g. human, phantom, animal. age: Subject age in years. sex: Subject sex. fat: Subject fat percentage. """ id: str | None = None type: str | None = None age: np.uint8 | None = None sex: str | None = None fat_percentage: np.float32 | None = None SCHEMA = { "id": {"dtype": str, "shape": ()}, "type": {"dtype": str, "shape": ()}, "age": {"dtype": np.uint8, "shape": ()}, "sex": {"dtype": str, "shape": ()}, "fat_percentage": {"dtype": np.float32, "shape": ()}, } FIELD_METADATA = { "id": {"description": "Subject ID. Needed for subject-wise splits."}, } def __post_init__(self): super().__post_init__() if self.id is not None and not self.id.strip(): raise ValueError("Subject ID cannot be an empty string") if self.fat_percentage is not None and ( self.fat_percentage < 0 or self.fat_percentage > 100 ): raise ValueError( f"Subject fat percentage must be between 0 and 100, got {self.fat_percentage}" )
[docs] @dataclass class Signal(Spec): """Base class for additional signals with timing and sampling-frequency metadata. Args: start_time_offset: Time offset in seconds between the first transmit event of the ultrasound acquisition and sample 0 of this data. Negative means this data starts before the first transmit event; positive means it starts after. sampling_frequency: Sampling frequency in Hz for the additional signal samples. """ start_time_offset: np.ndarray | float sampling_frequency: np.ndarray | float SCHEMA = { "start_time_offset": {"dtype": np.float32, "shape": ()}, "sampling_frequency": {"dtype": np.float32, "shape": ()}, } FIELD_METADATA = { "start_time_offset": { "unit": "s", "description": ( "Time offset between the first transmit event of the ultrasound " "acquisition and sample 0 of this data. Negative means this data " "starts before the first transmit event; positive means it starts " "after." ), }, "sampling_frequency": {"unit": "Hz", "description": "Sampling frequency."}, } def __post_init__(self): super().__post_init__() if self.sampling_frequency <= 0: raise ValueError(f"Sampling frequency must be positive, got {self.sampling_frequency}")
[docs] @dataclass class ProbePose(Signal): """Sampled probe pose metadata at the tip of the transducer. The pose uses the coordinate convention x = lateral along the transducer, y = elevation (out of plane), and z = axial (depth). Args: translation: Position of the transducer tip in meters of shape (T, 3), ordered as (x, y, z). rotation: Orientation of the transducer tip of shape (T, 3) or (T, 4), interpreted according to ``rotation_representation``. rotation_representation: Rotation parameterization. Supported values are ``"euler_xyz"``, ``"quaternion_wxyz"``, and ``"quaternion_xyzw"``. start_time_offset: Time offset in seconds between the first transmit event of the ultrasound acquisition and sample 0 of this data. sampling_frequency: Sampling frequency in Hz for probe pose samples. """ translation: np.ndarray rotation: np.ndarray rotation_representation: str SCHEMA = { "translation": {"dtype": np.float32, "shape": ("T", 3)}, "rotation": {"dtype": np.float32, "shape": (("T", 3), ("T", 4))}, "rotation_representation": {"dtype": str, "shape": ()}, **Signal.SCHEMA, } FIELD_METADATA = { "translation": { "unit": "m", "description": ( "Position of the transducer tip, ordered as (x, y, z), where x is " "lateral along the transducer, y is elevation (out of plane), and " "z is axial (depth)." ), }, "rotation": { "unit": "-", "description": ( "Orientation associated with the transducer-tip pose in the " "x-lateral, y-elevation, z-axial coordinate convention, interpreted " "according to rotation_representation." ), }, "rotation_representation": { "unit": "-", "description": ( "Rotation parameterization: one of euler_xyz, quaternion_wxyz, or quaternion_xyzw." ), }, **Signal.FIELD_METADATA, } def __post_init__(self): super().__post_init__() valid_representations = { "euler_xyz": 3, "quaternion_wxyz": 4, "quaternion_xyzw": 4, } if self.translation.shape[0] != self.rotation.shape[0]: raise ValueError( "translation and rotation must have the same number of time samples, " f"got {self.translation.shape[0]} and {self.rotation.shape[0]}" ) if self.rotation_representation not in valid_representations: valid = ", ".join(sorted(valid_representations)) raise ValueError( f"rotation_representation must be one of {{{valid}}}, " f"got {self.rotation_representation!r}" ) expected_width = valid_representations[self.rotation_representation] if self.rotation.shape[1] != expected_width: raise ValueError( "rotation shape does not match rotation_representation: " f"got {self.rotation.shape} for {self.rotation_representation!r}" )
[docs] @dataclass class Signal1D(Signal): """One-dimensional sampled signal with timing metadata. Args: samples: Signal samples of shape (T) and type uint8 or float32 or int16 or complex64. start_time_offset: Time offset in seconds between the first transmit event of the ultrasound acquisition and sample 0 of this data. sampling_frequency: Sampling frequency in Hz for signal samples. """ samples: np.ndarray SCHEMA = { "samples": {"dtype": (np.uint8, np.float32, np.int16, np.complex64), "shape": ("T",)}, **Signal.SCHEMA, } FIELD_METADATA = { "samples": {"unit": "-", "description": "Signal samples."}, **Signal.FIELD_METADATA, }
[docs] @dataclass class SignalND(Signal): """N-dimensional sampled signal with timing metadata. Args: samples: Signal samples of shape (T, ...) and type uint8 or float32 or int16 or complex64. start_time_offset: Time offset in seconds between the first transmit event of the ultrasound acquisition and sample 0 of this data. sampling_frequency: Sampling frequency in Hz for signal samples. """ samples: np.ndarray SCHEMA = { "samples": {"dtype": (np.uint8, np.float32, np.int16, np.complex64), "shape": ("T", "...")}, **Signal.SCHEMA, } FIELD_METADATA = { "samples": {"unit": "-", "description": "Signal samples."}, **Signal.FIELD_METADATA, }
[docs] @dataclass class Annotations(Spec): """Frame-level annotations, either per frame or broadcast labels. Args: anatomy: Anatomy label. view: View label of shape (n_frames,). label: Pathology or classification label of shape (n_frames,). image_quality: Image quality label, e.g. low, mid, high. """ anatomy: np.ndarray | str | None = None view: np.ndarray | None = None label: np.ndarray | None = None image_quality: np.ndarray | str | None = None SCHEMA = { "anatomy": {"dtype": np.str_, "shape": (("n_frames",), ())}, "view": {"dtype": np.str_, "shape": ("n_frames",)}, "label": {"dtype": np.str_, "shape": ("n_frames",)}, "image_quality": {"dtype": np.str_, "shape": (("n_frames",), ())}, }
[docs] @dataclass(init=False) class MetadataSpec(Spec): """Metadata group with subject, acquisition context, annotations, and extra signals.""" subject: Subject | dict = field(default_factory=Subject) credit: str | None = None probe_pose: ProbePose | dict | None = None voice_narration: Signal1D | dict | None = None ecg: Signal1D | dict | None = None text_report: str | None = None annotations: Annotations | dict | None = None SCHEMA = { "subject": {"spec": Subject}, "credit": {"dtype": str, "shape": ()}, "probe_pose": {"spec": ProbePose}, "voice_narration": {"spec": Signal1D}, "ecg": {"spec": Signal1D}, "text_report": {"dtype": str, "shape": ()}, "annotations": {"spec": Annotations}, } FIELD_METADATA = { "credit": {"unit": "-", "description": "Credit or attribution for the dataset."}, "probe_pose": {"unit": "-", "description": "Sampled probe pose at the transducer tip."}, "voice_narration": {"unit": "-", "description": "Voice narration signal."}, "ecg": {"unit": "-", "description": "Electrocardiogram signal."}, "text_report": {"unit": "-", "description": "Free-text report associated with the study."}, "annotations": {"unit": "-", "description": "Frame-level annotations."}, } def __init__( self, subject: Subject | dict | None = None, credit: str | None = None, probe_pose: ProbePose | dict | None = None, voice_narration: Signal1D | dict | None = None, ecg: Signal1D | dict | None = None, text_report: str | None = None, annotations: Annotations | dict | None = None, **extra_signals, ): self.subject = subject self.credit = credit self.probe_pose = probe_pose self.voice_narration = voice_narration self.ecg = ecg self.text_report = text_report self.annotations = annotations reserved_keys = set(self.SCHEMA) | set(self.__dataclass_fields__) | set(dir(Spec)) for key, value in extra_signals.items(): if key in reserved_keys: raise TypeError(f"Invalid custom metadata key '{key}': reserved name") if isinstance(value, np.ndarray): raise TypeError( f"Custom metadata key '{key}' must be a SignalND " f"(a dict with 'samples', 'start_time_offset', and 'sampling_frequency'), " f"not a flat array. " f"Wrap your data: {{'samples': array, 'start_time_offset': 0.0, " f"'sampling_frequency': fs}}." ) setattr(self, key, value) # Add custom extra signals to the schema as generic SignalND specs, so they get validated. self._extra_signal_keys = tuple(extra_signals.keys()) if getattr(self, "_extra_signal_keys", ()): self.SCHEMA = { **self.SCHEMA, **{key: {"spec": SignalND} for key in self._extra_signal_keys}, } self.__post_init__() def __post_init__(self): super().__post_init__()
[docs] @dataclass class MetricsSpec(Spec): """Metrics group for acquisition-level quality/performance metrics. Args: common_midpoint_phase_error: Common midpoint phase error in radians of shape (n_frames,) and type float32. coherence_factor: Coherence factor of shape (n_frames,) and type float32. """ common_midpoint_phase_error: np.ndarray | None = None coherence_factor: np.ndarray | None = None SCHEMA = { "common_midpoint_phase_error": { "dtype": np.float32, "shape": ("n_frames",), }, "coherence_factor": {"dtype": np.float32, "shape": ("n_frames",)}, }
[docs] @dataclass class TrackSpec(Spec): """A single acquisition track with its own data and scan parameters. Used inside a multi-track :class:`FileSpec` where different transmit sequences coexist in the same acquisition. The ``track_schedule`` on ``FileSpec`` specifies the global ordering of transmits across all tracks. For multi-track files a human-readable ``label`` is required on every track so that users can identify which track is which (e.g. ``"focused"`` vs ``"planewave"``). Further information can be provided in the ``description`` field of the parent :class:`FileSpec`, if necessary. Single-track files may omit the label. Args: data (DataSpec | dict): The data for this track. scan (ScanSpec | dict | None): The scan parameters for this track. Required when raw_data is present in *data*. label (str | None): Short human-readable name for this track (e.g. ``"focused"`` or ``"planewave"``). Required when the parent :class:`FileSpec` contains more than one track. """ data: DataSpec | dict scan: ScanSpec | dict | None = None label: str | None = None SCHEMA = { "data": {"spec": DataSpec}, "scan": {"spec": ScanSpec}, "label": {"dtype": str, "shape": ()}, } def __post_init__(self): super().__post_init__() data = self.data has_raw = (isinstance(data, DataSpec) and data.raw_data is not None) or ( isinstance(data, dict) and data.get("raw_data") is not None ) if has_raw and self.scan is None: raise ValueError("'scan' is required when 'raw_data' is provided in track data.") if self.label is not None and not isinstance(self.label, str): raise TypeError(f"'label' must be a str, got {type(self.label)}") if self.label is not None and not self.label.strip(): raise ValueError("'label' must not be an empty or whitespace-only string.")
[docs] def store_in_group( self, group: "h5py.Group", compression: str = DEFAULT_COMPRESSION, chunk_frames: bool = False, ) -> None: """Store data, scan, and label in the HDF5 group.""" super().store_in_group(group, compression=compression, chunk_frames=chunk_frames)
[docs] @dataclass class FileSpec(Spec): """A dataset containing all the data, scan parameters, metadata, and metrics for a single acquisition. A ``FileSpec`` always contains at least one track. When ``data`` and ``scan`` are supplied at construction time they are transparently wrapped into a single :class:`TrackSpec`, so all existing call-sites continue to work unchanged. For multi-track files pass ``tracks`` directly. Args: data: Data for a single-track acquisition (wrapped into ``tracks[0]``). scan: Scan parameters for a single-track acquisition. tracks: Explicit list of :class:`TrackSpec` objects (multi-track mode). Mutually exclusive with ``data``/``scan``. track_schedule: 1-D int32 array of length ``n_total_tx`` giving the track index for each global transmit event. metadata: Additional metadata about the acquisition. metrics: Metrics computed from the acquisition. probe: Physical probe specification (see :class:`ProbeSpec`). The probe name is stored as ``probe.name``; use :attr:`zea.File.probe_name` to read it back from an HDF5 file. us_machine: The ultrasound machine used to acquire the data. description: Free-text description. Example: .. doctest:: >>> from zea.data.spec import FileSpec >>> import numpy as np >>> dataset = FileSpec( ... data={ ... "raw_data": np.zeros((2, 4, 64, 8, 1), dtype=np.float32), ... }, ... scan={ ... "sampling_frequency": np.float32(40e6), ... "center_frequency": np.float32(5e6), ... "demodulation_frequency": np.float32(5e6), ... "initial_times": np.zeros(4, dtype=np.float32), ... "t0_delays": np.zeros((4, 8), dtype=np.float32), ... "tx_apodizations": np.ones((4, 8), dtype=np.float32), ... "focus_distances": np.full(4, np.inf, dtype=np.float32), ... "transmit_origins": np.zeros((4, 3), dtype=np.float32), ... "polar_angles": np.zeros(4, dtype=np.float32), ... }, ... ) >>> dataset.data.raw_data.shape (2, 4, 64, 8, 1) """ # NOTE: data and scan are intentionally NOT dataclass fields — they are # accepted as constructor kwargs and folded into tracks[0] at init time. # @property accessors below provide backwards-compatible single-track access. tracks: list = field(default_factory=list) track_schedule: np.ndarray | None = None metadata: MetadataSpec | dict = field(default_factory=MetadataSpec) metrics: MetricsSpec | dict = field(default_factory=MetricsSpec) probe: ProbeSpec | dict | None = None us_machine: str | None = None description: str | None = None # tells the SCHEMA ↔ fields consistency test that 'tracks' is intentionally # absent from SCHEMA (list[TrackSpec] doesn't fit the standard SCHEMA patterns) _SCHEMA_EXCLUDED_FIELDS = frozenset({"tracks"}) SCHEMA = { "track_schedule": {"dtype": np.int32, "shape": ("n_total_tx",)}, "metadata": {"spec": MetadataSpec}, "metrics": {"spec": MetricsSpec}, "probe": {"spec": ProbeSpec}, "us_machine": {"dtype": str, "shape": ()}, "description": {"dtype": str, "shape": ()}, } def __init__( self, data: "DataSpec | dict | None" = None, scan: "ScanSpec | dict | None" = None, tracks: "list | None" = None, track_schedule: "np.ndarray | None" = None, metadata: "MetadataSpec | dict | None" = None, metrics: "MetricsSpec | dict | None" = None, probe_name: "str | None" = None, probe: "ProbeSpec | dict | None" = None, us_machine: "str | None" = None, description: "str | None" = None, ): if data is not None or scan is not None: if tracks: raise ValueError( "Provide either 'data'/'scan' (single-track shorthand) " "or 'tracks' (multi-track), not both." ) _implicit_track: "dict | None" = {"data": data, "scan": scan} else: _implicit_track = None if probe_name is not None: raise TypeError( "probe_name is not a FileSpec parameter. " "Use probe={'name': ...} to specify the probe name." ) self.tracks = list(tracks) if tracks is not None else [] self.track_schedule = track_schedule self.metadata = metadata if metadata is not None else MetadataSpec() self.metrics = metrics if metrics is not None else MetricsSpec() self.probe = probe self.us_machine = us_machine self.description = description self.__post_init__(_implicit_track) # ------------------------------------------------------------------ # Backwards-compat read properties (single-track files only) # ------------------------------------------------------------------ @property def data(self) -> "DataSpec": """Return the :class:`DataSpec` of the single track. Raises :exc:`AttributeError` when the file has more than one track — use ``spec.tracks[i].data`` instead. """ if len(self.tracks) != 1: raise AttributeError( f"'data' is only available for single-track FileSpecs " f"({len(self.tracks)} tracks present). Use spec.tracks[i].data." ) return self.tracks[0].data @property def scan(self) -> "ScanSpec | None": """Return the :class:`ScanSpec` of the single track. Raises :exc:`AttributeError` when the file has more than one track — use ``spec.tracks[i].scan`` instead. """ if len(self.tracks) != 1: raise AttributeError( f"'scan' is only available for single-track FileSpecs " f"({len(self.tracks)} tracks present). Use spec.tracks[i].scan." ) return self.tracks[0].scan def __post_init__(self, _implicit_track: "dict | None" = None): # Fold implicit data/scan into a TrackSpec if provided if _implicit_track is not None: self.tracks = [TrackSpec(**_implicit_track)] if not self.tracks: raise ValueError("A FileSpec must contain at least one track.") # Create TrackSpecs from dictionaries in the tracks list, if needed, and validate all tracks track_specs = [] for i, t in enumerate(self.tracks): if isinstance(t, dict): try: t = TrackSpec(**t) except (TypeError, ValueError) as e: raise type(e)(f"In tracks[{i}]: {e}") from e elif not isinstance(t, TrackSpec): raise TypeError(f"tracks[{i}] must be a TrackSpec or dict, got {type(t)}") track_specs.append(t) self.tracks = track_specs # For multi-track files every track must have a label so users can # identify tracks by name rather than relying on numeric indices. if len(self.tracks) > 1: missing = [i for i, t in enumerate(self.tracks) if not t.label] if missing: raise ValueError( f"All tracks in a multi-track file must have a 'label'. " f"Missing label for track(s) at index: {missing}. " f"Provide a short descriptive name for each track, e.g. " f"'focused' or 'planewave', so that " f"File.get_track(label) and File.track_labels work correctly." ) # Validate track_schedule indices are in range if self.track_schedule is not None: n_tracks = len(self.tracks) if not np.all((self.track_schedule >= 0) & (self.track_schedule < n_tracks)): raise ValueError( f"All track_schedule indices must be in [0, {n_tracks - 1}], " f"got min={self.track_schedule.min()}, max={self.track_schedule.max()}" ) # Warn if multi-track frame counts differ without a schedule if len(self.tracks) > 1 and self.track_schedule is None: frame_counts = [] for track in self.tracks: rd = ( track.data.raw_data if isinstance(track.data, DataSpec) else (track.data.get("raw_data") if isinstance(track.data, dict) else None) ) if rd is not None and hasattr(rd, "shape"): frame_counts.append(rd.shape[0]) if len(set(frame_counts)) > 1: log.warning( "Tracks have different numbers of frames " f"({frame_counts}). Without a 'track_schedule' it is " "ambiguous how frames correspond across tracks. Consider " "passing 'track_schedule' to make the relationship explicit." ) # Run base SCHEMA validation (metadata, metrics, scalars, track_schedule) super().__post_init__() # Validate that dimensions which are present in both metadata and tracks # are consistent across all tracks. if isinstance(self.metadata, MetadataSpec): _, meta_dim_sizes = self.metadata._collect_dimension_info("metadata.") for i, track in enumerate(self.tracks): _, track_dim_sizes = track._collect_dimension_info(f"tracks[{i}].") for dim in CONSISTENCY_DIMENSIONS: if dim in meta_dim_sizes and dim in track_dim_sizes: all_sizes = meta_dim_sizes[dim] | track_dim_sizes[dim] if len(all_sizes) > 1: meta_fields, _ = self.metadata._collect_dimension_info("metadata.") track_fields, _ = track._collect_dimension_info(f"tracks[{i}].") raise ValueError( f"Dimension '{dim}' has inconsistent sizes across " f"fields {sorted(meta_fields[dim] | track_fields[dim])}: " f"{sorted(all_sizes)}" )
[docs] def to_dict(self) -> dict: """Return this spec as a nested dictionary. Includes all :attr:`SCHEMA` fields plus the ``tracks`` list. """ result = super().to_dict() result["tracks"] = [t.to_dict() for t in self.tracks] return result
[docs] def save( self, path: str, compression: str = DEFAULT_COMPRESSION, chunk_frames: bool = False, ) -> None: """Save the dataset to the specified path.""" # Lazy import to avoid circular dependency (spec.py is imported by file.py) from zea import File try: _zea_version = _get_pkg_version("zea") except PackageNotFoundError: _zea_version = "dev" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with File(str(path), "w") as f: f.attrs["zea_version"] = _zea_version # Write scalar/array metadata fields (metadata, metrics, probe_name, etc.) for group_name, schema in self.SCHEMA.items(): if "spec" in schema: value: Spec = getattr(self, group_name) if value is None: continue group = f.create_group(group_name) value.store_in_group(group, compression=compression) else: value = getattr(self, group_name) if value is not None: if group_name == "track_schedule": # Array field — store as dataset, not attr self.create_dataset(f, group_name, value, compression=compression) else: f.attrs[group_name] = value # Write tracks (always at least one) tracks_group = f.create_group("tracks") for i, track in enumerate(self.tracks): track_group = tracks_group.create_group(f"track_{i}") track.store_in_group( track_group, compression=compression, chunk_frames=chunk_frames, ) log.info(f"File saved to {log.yellow(path)}")
[docs] @classmethod def from_hdf5(cls, file: h5py.File) -> "FileSpec": """Load and validate a :class:`FileSpec` from an open HDF5 file. Both the new ``tracks/track_N/`` format and the normal flat ``data/`` + ``scan/`` format are supported. Extra scalar fields in legacy scan groups (``n_frames``, ``n_tx``, etc.) are ignored, and the ``probe`` root attribute is mapped to ``probe.name``. Args: file: An open ``h5py.File`` (or :class:`zea.File`). Returns: FileSpec: A fully validated spec object. """ def _load_group_as_dict(group: h5py.Group) -> dict: result = {} for key in group.keys(): item = group[key] if isinstance(item, h5py.Group): result[key] = _load_group_as_dict(item) elif isinstance(item, h5py.Dataset): if h5py.check_string_dtype(item.dtype) is not None: val = item.asstr()[()] # h5py returns object-dtype arrays for strings; # convert back to np.str_ so spec dtype checks pass. if isinstance(val, np.ndarray) and val.dtype == object: val = val.astype(np.str_) result[key] = val else: result[key] = item[()] return result kwargs: dict[str, Any] = {} # Load scalar SCHEMA fields (metadata, metrics, probe_name, us_machine, description, # track_schedule) for group_name, schema in cls.SCHEMA.items(): if "spec" in schema: if group_name in file: kwargs[group_name] = _load_group_as_dict(file[group_name]) elif group_name == "track_schedule": if group_name in file: kwargs[group_name] = file[group_name][()].astype(np.int32) else: if group_name in file.attrs: kwargs[group_name] = file.attrs[group_name] # New multi-track format: tracks/track_N/ if "tracks" in file: tracks_group = file["tracks"] scan_schema_keys = set(ScanSpec.SCHEMA.keys()) tracks = [] i = 0 while f"track_{i}" in tracks_group: track_group = tracks_group[f"track_{i}"] track_dict = _load_group_as_dict(track_group) # Filter legacy scalar fields from per-track scan dicts, matching # the same treatment applied to single-track scan groups below. if "scan" in track_dict and isinstance(track_dict["scan"], dict): track_dict["scan"] = { k: v for k, v in track_dict["scan"].items() if k in scan_schema_keys } tracks.append(track_dict) i += 1 kwargs["tracks"] = tracks # Legacy flat format: data/ + scan/ at root elif "data" in file or "scan" in file: data_dict = _load_group_as_dict(file["data"]) if "data" in file else {} scan_dict = _load_group_as_dict(file["scan"]) if "scan" in file else None kwargs["data"] = data_dict if scan_dict is not None: kwargs["scan"] = scan_dict # 1. Map legacy root 'probe_name' or 'probe' attr into probe.name so # that old files with a named probe but no probe group still round-trip. if "probe" not in kwargs: try: legacy_name = file.probe_name if legacy_name is not None: kwargs["probe"] = {"name": legacy_name} except AttributeError: pass # no probe info in file — leave probe as None # 2. Filter scan dict to only keys recognised by ScanSpec.SCHEMA so # that legacy scalar fields (n_frames, n_ax, n_el, n_tx, n_ch, …) # are silently dropped. if "scan" in kwargs: scan_schema_keys = set(ScanSpec.SCHEMA.keys()) kwargs["scan"] = {k: v for k, v in kwargs["scan"].items() if k in scan_schema_keys} # 3. Handle legacy flat `data/<key>` datasets. In old files spatial # maps (image, image_sc, envelope_data, …) were stored as plain # arrays (n_frames, z, x) rather than groups with values + # coordinates. Wrap them as {"values": array} so DataSpec accepts # them. raw_data and aligned_data are valid as flat arrays and are # left untouched. if "data" in kwargs and isinstance(kwargs["data"], dict): data_dict = kwargs["data"] for key in list(data_dict.keys()): if not isinstance(data_dict[key], np.ndarray): continue schema_entry = DataSpec.SCHEMA.get(key) # raw_data / aligned_data are plain-array fields — skip them. if schema_entry is not None and "spec" not in schema_entry: continue log.warning( "Legacy flat dataset 'data/%s' has no spatial coordinates. " "The array has been loaded as 'values'; coordinates information " "was not stored in this file and will be None.", key, ) data_dict[key] = {"values": data_dict[key]} return cls(**kwargs)