from collections import defaultdict
from dataclasses import MISSING, dataclass, field, fields
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as _get_pkg_version
from pathlib import Path
from typing import Any, List, Tuple
import h5py
import numpy as np
from zea import log
CONSISTENCY_DIMENSIONS = {"n_frames", "n_tx", "n_ax", "n_el", "n_ch", "n_spatial_ch"}
UNITS = {
"m/s": "meters per second",
"m": "meters",
"Hz": "Hertz",
"s": "seconds",
"-": "unitless",
"rad": "radians",
"dB": "decibels",
"#": "count",
"%": "percent",
}
DEFAULT_COMPRESSION = "lzf"
# Default unit/description for every SCHEMA leaf field. Subclasses may
# override by defining their own FIELD_METADATA dict.
_DEFAULT_FIELD_UNIT = "-"
_DEFAULT_FIELD_DESCRIPTION = ""
[docs]
def check_dtype(value: Any, expected_dtype: List[type]) -> None:
"""Check if the dtype of a value matches the expected dtype,
allowing for compatible types.
Works for numpy arrays, numpy scalars, and Python native types.
"""
for dt in expected_dtype:
if isinstance(dt, type) and issubclass(dt, np.generic):
expected_np_dtype = np.dtype(dt)
if hasattr(value, "dtype"):
if np.issubdtype(value.dtype, expected_np_dtype):
return
elif np.issubdtype(expected_np_dtype, np.character) and isinstance(value, (str, bytes)):
return
else:
if isinstance(value, dt):
return
actual_type = (
f"dtype {value.dtype}" if hasattr(value, "dtype") else f"Python {type(value).__name__}"
)
expected_dtypes_str = ", ".join(str(dt) for dt in expected_dtype)
raise TypeError(
f"Expected dtype compatible with one of ({expected_dtypes_str}), got {actual_type}. "
f"Hint: wrap the value with the appropriate numpy type, "
f"e.g. np.float32(...), np.str_(...), np.uint8(...)."
)
[docs]
def value_shape(value: Any) -> tuple:
"""Return the shape tuple for numpy arrays and scalar values."""
if isinstance(value, np.ndarray):
return value.shape
return ()
[docs]
def match_shape(value: Any, expected_shape: tuple) -> bool:
"""Check if the shape of a value matches the expected shape specification."""
shape = value_shape(value)
ellipsis_positions = [i for i, dim in enumerate(expected_shape) if dim == "..."]
if len(ellipsis_positions) > 1:
raise ValueError("Expected shape can contain at most one '...' wildcard")
if not ellipsis_positions:
if len(shape) != len(expected_shape):
return False
comparisons = zip(shape, expected_shape)
else:
ellipsis_pos = ellipsis_positions[0]
prefix_expected = expected_shape[:ellipsis_pos]
suffix_expected = expected_shape[ellipsis_pos + 1 :]
# '...' matches any number of dimensions (including zero).
min_required_dims = len(prefix_expected) + len(suffix_expected)
if len(shape) < min_required_dims:
return False
prefix_shape = shape[: len(prefix_expected)]
suffix_shape = shape[len(shape) - len(suffix_expected) :] if suffix_expected else ()
comparisons = zip(
prefix_shape + suffix_shape,
prefix_expected + suffix_expected,
)
for dim_size, expected_dim in comparisons:
if isinstance(expected_dim, str):
continue
if dim_size != expected_dim:
return False
return True
[docs]
def find_matched_shape(value: Any, expected_shapes: List[tuple]) -> tuple | None:
"""Find the first expected shape specification that matches the shape of the value."""
for expected_shape in expected_shapes:
if match_shape(value, expected_shape):
return expected_shape
return None
[docs]
class Spec:
"""Base class for data specifications with schema validation.
Subclasses should define a SCHEMA class variable that specifies the expected dtype and shape
for each field. The __post_init__ method will validate that the actual fields match the schema,
including checking that dimensions with the same name have consistent sizes across fields.
"""
SCHEMA: dict
@staticmethod
def _is_optional_dataclass_field(field_def: Any) -> bool:
if field_def is None:
return False
return field_def.default is not MISSING or field_def.default_factory is not MISSING
[docs]
@classmethod
def required_fields(cls) -> tuple[str, ...]:
"""Return the names of fields that have no default value."""
return tuple(f.name for f in fields(cls) if not cls._is_optional_dataclass_field(f))
[docs]
@classmethod
def fields(cls) -> tuple[str, ...]:
"""Return the names of all fields."""
return tuple(f.name for f in fields(cls))
[docs]
@classmethod
def optional_fields(cls) -> tuple[str, ...]:
"""Return the names of fields that have a default value."""
return tuple(f.name for f in fields(cls) if cls._is_optional_dataclass_field(f))
[docs]
def warn_missing_optional_fields(self):
"""Warn about optional fields that were not provided."""
_optional_fields = self.optional_fields()
for field_name in self.SCHEMA.keys():
if field_name in _optional_fields and getattr(self, field_name) is None:
if hasattr(self, "FIELD_METADATA"):
meta = self.FIELD_METADATA.get(field_name, {})
description = meta.get("description", _DEFAULT_FIELD_DESCRIPTION)
else:
description = _DEFAULT_FIELD_DESCRIPTION
log.warning(
f"Optional {self.__class__.__name__} field '{field_name}' is not set. "
f"Description: {description} "
"Defaulted to None."
)
@staticmethod
def _expected_shapes(shape_spec: Any) -> tuple[tuple, ...]:
if shape_spec and isinstance(shape_spec[0], tuple):
return tuple(shape_spec)
return (shape_spec,)
@staticmethod
def _merge_dimension_info(
dim_to_fields: defaultdict[str, set[str]],
dim_to_sizes: defaultdict[str, set[int]],
nested_dim_to_fields: defaultdict[str, set[str]],
nested_dim_to_sizes: defaultdict[str, set[int]],
) -> None:
for dim_name, nested_fields in nested_dim_to_fields.items():
dim_to_fields[dim_name].update(nested_fields)
for dim_name, nested_sizes in nested_dim_to_sizes.items():
dim_to_sizes[dim_name].update(nested_sizes)
@staticmethod
def _track_named_dimensions(
dim_to_fields: defaultdict[str, set[str]],
dim_to_sizes: defaultdict[str, set[int]],
field_path: str,
matched_shape: tuple,
shape: tuple,
) -> None:
for i, dim_name in enumerate(matched_shape):
if isinstance(dim_name, str) and dim_name in CONSISTENCY_DIMENSIONS:
dim_to_fields[dim_name].add(field_path)
dim_to_sizes[dim_name].add(shape[i])
@staticmethod
def _raise_if_shape_mismatch(
field_name: str, value: Any, expected_shapes: tuple[tuple, ...]
) -> None:
allowed_shapes = ", ".join(str(shape) for shape in expected_shapes)
raise ValueError(
f"{field_name} has shape {value_shape(value)}, expected one of: {allowed_shapes}"
)
def _validate_nested_field(
self, field_name: str, nested_spec: "Spec", field_value: Any
) -> "Spec":
"""Validate a nested spec field, recursively validating its contents."""
if isinstance(field_value, dict):
field_value = nested_spec(**field_value)
setattr(self, field_name, field_value)
# Check that the nested spec field is now an instance of the expected Spec subclass
# E.g. Segmentation if nested_spec is Map
if not issubclass(type(field_value), nested_spec):
raise TypeError(
f"Expected field '{field_name}' to be {nested_spec}, got {type(field_value)}"
)
return field_value
@staticmethod
def _cast_native_to_numpy(value: Any, expected_dtype: list) -> Any:
"""Cast values to expected numpy dtypes when possible.
For fields that expect a floating dtype, all floating-point inputs are
accepted and normalized to the first floating dtype in ``expected_dtype``
(typically ``np.float32``).
"""
expected_np_dtypes = []
for dt in expected_dtype:
try:
expected_np_dtypes.append(np.dtype(dt))
except TypeError:
continue
expected_float_dtype = next(
(dt for dt in expected_np_dtypes if np.issubdtype(dt, np.floating)),
None,
)
# Keep native string/bytes values as-is instead of converting to numpy string scalars.
if isinstance(value, (str, bytes)):
return value
if hasattr(value, "dtype"):
value_dtype = np.dtype(value.dtype)
if (
expected_float_dtype is not None
and np.issubdtype(value_dtype, np.floating)
and value_dtype != expected_float_dtype
):
return value.astype(expected_float_dtype, copy=False)
return value
# If the spec expects a native Python type and the value already matches it,
# keep it as-is instead of converting to a numpy scalar.
for dt in expected_dtype:
if isinstance(dt, type) and not issubclass(dt, np.generic) and isinstance(value, dt):
return value
for dt in expected_dtype:
try:
target_dtype = np.dtype(dt)
return target_dtype.type(value)
except (TypeError, ValueError, OverflowError):
continue
return value
def _validate_and_track_primitive_field(
self,
field_name: str,
field_info: dict,
field_value: Any,
dim_to_fields: defaultdict[str, set[str]],
dim_to_sizes: defaultdict[str, set[int]],
) -> None:
expected_dtype = field_info["dtype"]
if not isinstance(expected_dtype, (list, tuple)):
expected_dtype = [expected_dtype]
expected_shapes = self._expected_shapes(field_info["shape"])
# Auto-cast Python native types (str, int, float) to numpy equivalents
field_value = self._cast_native_to_numpy(field_value, expected_dtype)
setattr(self, field_name, field_value)
try:
check_dtype(field_value, expected_dtype)
except TypeError as e:
raise TypeError(f"{type(self).__name__}: field '{field_name}' has invalid dtype: {e}")
matched_shape = find_matched_shape(field_value, expected_shapes)
if matched_shape is None:
self._raise_if_shape_mismatch(field_name, field_value, expected_shapes)
self._track_named_dimensions(
dim_to_fields=dim_to_fields,
dim_to_sizes=dim_to_sizes,
field_path=field_name,
matched_shape=matched_shape,
shape=value_shape(field_value),
)
@staticmethod
def _raise_if_inconsistent_dimensions(
dim_to_fields: defaultdict[str, set[str]],
dim_to_sizes: defaultdict[str, set[int]],
) -> None:
for dim_name, sizes in dim_to_sizes.items():
if len(sizes) > 1:
field_names = sorted(dim_to_fields[dim_name])
raise ValueError(
f"Dimension '{dim_name}' has inconsistent sizes across "
f"fields {field_names}: {sorted(sizes)}"
)
def _collect_dimension_info(
self, prefix: str = ""
) -> tuple[defaultdict[str, set[str]], defaultdict[str, set[int]]]:
"""Collect named dimension usage and observed sizes for this spec subtree."""
dim_to_fields = defaultdict(set)
dim_to_sizes = defaultdict(set)
for field_name, field_info in self.SCHEMA.items():
field_value = getattr(self, field_name)
if field_value is None:
continue
nested_spec = field_info.get("spec")
if nested_spec is not None:
nested_dim_to_fields, nested_dim_to_sizes = field_value._collect_dimension_info(
prefix=f"{prefix}{field_name}."
)
self._merge_dimension_info(
dim_to_fields,
dim_to_sizes,
nested_dim_to_fields,
nested_dim_to_sizes,
)
continue
expected_shapes = self._expected_shapes(field_info["shape"])
matched_shape = find_matched_shape(field_value, expected_shapes)
if matched_shape is None:
# Child specs are already validated; skip defensively if no shape can be matched.
continue
self._track_named_dimensions(
dim_to_fields=dim_to_fields,
dim_to_sizes=dim_to_sizes,
field_path=f"{prefix}{field_name}",
matched_shape=matched_shape,
shape=value_shape(field_value),
)
return dim_to_fields, dim_to_sizes
def __post_init__(self):
dim_to_fields = defaultdict(set)
dim_to_sizes = defaultdict(set)
dataclass_fields = {f.name: f for f in fields(self)}
for field_name, field_info in self.SCHEMA.items():
field_value = getattr(self, field_name)
field_def = dataclass_fields.get(field_name)
is_optional = self._is_optional_dataclass_field(field_def)
if field_value is None:
if not is_optional:
raise ValueError(f"Missing required field '{field_name}'")
continue
nested_spec = field_info.get("spec")
if nested_spec is not None:
try:
field_value = self._validate_nested_field(field_name, nested_spec, field_value)
except (TypeError, ValueError) as e:
raise type(e)(f"In field '{field_name}': {e}") from e
nested_dim_to_fields, nested_dim_to_sizes = field_value._collect_dimension_info(
prefix=f"{field_name}."
)
self._merge_dimension_info(
dim_to_fields,
dim_to_sizes,
nested_dim_to_fields,
nested_dim_to_sizes,
)
continue
self._validate_and_track_primitive_field(
field_name=field_name,
field_info=field_info,
field_value=field_value,
dim_to_fields=dim_to_fields,
dim_to_sizes=dim_to_sizes,
)
self._raise_if_inconsistent_dimensions(dim_to_fields, dim_to_sizes)
@staticmethod
def _is_string_value(value: Any) -> bool:
"""Return True for scalar/array values that should be stored as HDF5 strings."""
if isinstance(value, (str, np.str_, bytes, np.bytes_)):
return True
if isinstance(value, np.ndarray):
return value.dtype.kind in {"U", "S", "O"}
return False
[docs]
@staticmethod
def create_dataset(
group: h5py.Group,
field_name: str,
value: Any,
compression: str = DEFAULT_COMPRESSION,
chunk_frames: bool = False,
) -> None:
"""Create a dataset in the given group for the specified field and value,
handling string and scalar values appropriately."""
dataset_is_scalar = np.isscalar(value) or value.ndim == 0
compression = None if dataset_is_scalar else compression
chunks = None
if (
chunk_frames
and not dataset_is_scalar
and isinstance(value, np.ndarray)
and value.ndim >= 2
):
chunks = (1,) + value.shape[1:]
if Spec._is_string_value(value):
string_dtype = h5py.string_dtype(encoding="utf-8")
string_value = np.asarray(value, dtype=object)
group.create_dataset(
field_name,
data=string_value,
dtype=string_dtype,
compression=compression,
)
else:
group.create_dataset(field_name, data=value, compression=compression, chunks=chunks)
[docs]
def store_in_group(
self,
group: h5py.Group,
compression: str = DEFAULT_COMPRESSION,
chunk_frames: bool = False,
) -> None:
"""Store the data in the given group (e.g. hdf5 group)."""
assert isinstance(group, h5py.Group), "group must be an h5py Group"
# Optional fields should only warn when persisting to disk, not on load.
self.warn_missing_optional_fields()
field_metadata = getattr(self, "FIELD_METADATA", {})
for field_name, field_info in self.SCHEMA.items():
value = getattr(self, field_name)
# We do not store fields with value None in the file.
if value is None:
continue
nested_spec = field_info.get("spec")
if nested_spec is not None:
subgroup = group.create_group(field_name)
value.store_in_group(
subgroup,
compression=compression,
chunk_frames=chunk_frames,
)
else:
self.create_dataset(
group,
field_name,
value,
compression=compression,
chunk_frames=chunk_frames,
)
meta = field_metadata.get(field_name, {})
group[field_name].attrs["unit"] = meta.get("unit", _DEFAULT_FIELD_UNIT)
group[field_name].attrs["description"] = meta.get(
"description", _DEFAULT_FIELD_DESCRIPTION
)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Return this spec as a nested dictionary based on ``SCHEMA`` fields.
Nested specs are converted recursively.
"""
result = {}
for field_name, field_info in self.SCHEMA.items():
value = getattr(self, field_name)
nested_spec = field_info.get("spec")
if nested_spec is not None and value is not None:
if isinstance(value, Spec):
result[field_name] = value.to_dict()
elif isinstance(value, dict):
result[field_name] = {
k: v.to_dict() if isinstance(v, Spec) else v for k, v in value.items()
}
else:
result[field_name] = value
else:
result[field_name] = value
return result
[docs]
@classmethod
def get_dtype(cls, field_name) -> Tuple[type, ...] | type:
"""Get the dtype of a field."""
return cls.SCHEMA[field_name]["dtype"]
def __repr__(self) -> str:
parts = []
for field_name, field_info in self.SCHEMA.items():
value = getattr(self, field_name, None)
if value is None:
continue
nested_spec = field_info.get("spec")
if nested_spec is not None:
parts.append(f"{field_name}={value!r}")
elif isinstance(value, np.ndarray):
parts.append(f"{field_name}=array({value.dtype} {value.shape})")
else:
parts.append(f"{field_name}={value!r}")
return f"{self.__class__.__name__}({', '.join(parts)})"
[docs]
@dataclass
class Map(Spec):
"""Map data with per-pixel Cartesian coordinates.
A map is a function from Cartesian space to some real values: every pixel at
spatial index ``[f, i, j, ...]`` is assigned a 3-D position ``coordinates[f, i, j, ..., :]``
= ``[x, y, z]`` in metres.
The most flexible map spec, which can be used for any spatially aligned data product.
See, for example, :func:`~zea.beamform.pixelgrid.cartesian_pixel_grid` or
:func:`~zea.beamform.pixelgrid.polar_pixel_grid` to create a suitable coordinate array
from your scan geometry.
Args:
values: The map values of shape ``(n_frames, z, x, y, n_ch)`` or ``(n_frames, z, x, y)``
or ``(n_frames, z, x, n_ch)`` or ``(n_frames, z, x)`` and type uint8, float32,
int16, or complex64.
coordinates: Per-pixel Cartesian positions in metres, shape ``(*spatial_dims, 3)``
where ``spatial_dims`` matches the spatial (non-channel) dimensions of ``values``.
For non-channeled values the shape is ``(*values.shape, 3)``; for channeled values
the shape is ``(*values.shape[:-1], 3)``. The last axis holds ``[x, y, z]``.
The leading ``n_frames`` axis may be omitted to broadcast one coordinate grid
across all frames.
labels: The labels corresponding to the ``n_ch`` channels in the values.
This is required when values have an n_ch dimension, and should be None otherwise.
For IQ data, this would typically be ``["I", "Q"]``.
description: An optional free-text description of the map.
unit: An optional string specifying the physical unit of the map values,
e.g. ``"m/s"``, ``"%"``, etc.
min: The minimum value of the map.
max: The maximum value of the map.
"""
values: np.ndarray
coordinates: np.ndarray | None = None
labels: np.ndarray | None = None
description: str | None = None
unit: str | None = None
min: float | None = None
max: float | None = None
SCHEMA = {
"values": {
"dtype": (np.uint8, np.float32, np.int16, np.complex64),
"shape": (
("n_frames", "z", "x", "y", "n_spatial_ch"),
("n_frames", "z", "x", "y"),
("n_frames", "z", "x", "n_spatial_ch"),
("n_frames", "z", "x"),
),
},
"coordinates": {"dtype": np.float32, "shape": ("...", 3)},
"labels": {"dtype": np.str_, "shape": ("n_spatial_ch",)},
"description": {"dtype": str, "shape": ()},
"unit": {"dtype": str, "shape": ()},
"min": {"dtype": np.float32, "shape": ()},
"max": {"dtype": np.float32, "shape": ()},
}
def __post_init__(self):
super().__post_init__()
if self.values.ndim == 5:
assert self.labels is not None, (
"labels must be provided when values have n_ch dimension"
)
if self.coordinates is not None:
# coordinates.shape[-1] is guaranteed == 3 by the SCHEMA check above.
# Validate that the spatial axes match values (with or without a trailing channel axis).
coords_spatial = self.coordinates.shape[:-1]
valid_spatial_shapes = {
self.values.shape,
self.values.shape[:-1],
}
# Also accept coordinates that omit the leading frame axis and
# therefore broadcast across frames.
if len(self.values.shape) > 1:
valid_spatial_shapes.add(self.values.shape[1:])
if len(self.values.shape[:-1]) > 1:
valid_spatial_shapes.add(self.values.shape[1:-1])
if coords_spatial not in valid_spatial_shapes:
raise ValueError(
f"{type(self).__name__}: coordinates shape {self.coordinates.shape} is "
f"incompatible with values shape {self.values.shape}. "
f"coordinates.shape[:-1] must equal values.shape "
f"({self.values.shape}) for non-channeled data, or "
f"values.shape[:-1] ({self.values.shape[:-1]}) for channeled data, "
"with optional frame-axis broadcasting (leading n_frames omitted)."
)
# Sanity-check units: clinical ultrasound scan regions are at most a few tens of
# centimetres across, so any finite coordinate magnitude above 1 m almost certainly
# indicates the array was supplied in millimetres rather than metres.
max_abs = np.max(np.abs(self.coordinates[np.isfinite(self.coordinates)]), initial=0.0)
if max_abs > 1.0:
log.warning(
f"{type(self).__name__}: coordinates have a maximum absolute value of "
f"{max_abs:.4g}, which exceeds 1 m. Ultrasound scan regions are "
"typically a few centimetres across. Please verify that coordinates "
"are in metres, not millimetres."
)
else:
log.warning(
f"{type(self).__name__}: coordinates are not provided, please consider adding "
"a coordinates field to ensure the map can be correctly displayed."
)
[docs]
@dataclass
class FloatMap(Map):
"""Map data with float32 pixel values and per-pixel Cartesian coordinates."""
SCHEMA = {
**Map.SCHEMA,
"values": {
**Map.SCHEMA["values"],
"dtype": np.float32,
},
}
[docs]
@dataclass
class BooleanMap(Map):
"""Map data with bool pixel values and per-pixel Cartesian coordinates."""
SCHEMA = {
**Map.SCHEMA,
"values": {
**Map.SCHEMA["values"],
"dtype": np.bool_,
},
}
[docs]
@dataclass
class UnsignedIntMap(Map):
"""Map data with uint8 pixel values and per-pixel Cartesian coordinates."""
SCHEMA = {
**Map.SCHEMA,
"values": {
**Map.SCHEMA["values"],
"dtype": np.uint8,
},
}
[docs]
@dataclass
class Segmentation(BooleanMap):
"""Segmentation data with per-pixel Cartesian coordinates.
Args:
values: The segmentation values of shape ``(n_frames, z, x, y, n_labels)`` for 3D
(volumetric) data or ``(n_frames, z, x, n_labels)`` for 2D data, with type bool.
coordinates: Per-pixel Cartesian positions in metres, shape ``(*spatial_dims, 3)``
where ``spatial_dims`` matches the spatial (non-label) dimensions of ``values``.
The leading frame axis may be omitted to broadcast one coordinate grid
across all frames.
labels: The labels corresponding to the segmentation values, where each unique value
in the values corresponds to a label in this list of shape ``(n_labels,)`` and type str.
.. note::
To indicate that certain frames have no segmentation, add an explicit
``"unannotated"`` entry to ``labels`` and set ``values[..., unannotated_idx]`` to
``True`` for those frames (with all other label channels set to ``False``). This
keeps the shape uniform across frames while clearly distinguishing genuinely
annotated frames from frames that were not labelled. For example::
labels = np.array(["unannotated", "LV_endo", "LV_myo", "LA"])
values = np.zeros((n_frames, H, W, 4), dtype=bool)
# mark all frames as unannotated by default
values[:, :, :, 0] = True
# for annotated frames, set unannotated channel to False
# and the appropriate label channel to True
values[ed_idx, :, :, 0] = False
values[ed_idx, :, :, 1:] = segmentation_mask # shape (H, W, 3)
"""
SCHEMA = {
**BooleanMap.SCHEMA,
"values": {
**BooleanMap.SCHEMA["values"],
"shape": (
("n_frames", "z", "x", "y", "n_spatial_ch"),
("n_frames", "z", "x", "n_spatial_ch"),
),
},
}
def __post_init__(self):
assert self.values.ndim in (4, 5), (
"Segmentation values must have 4 or 5 dimensions: "
"(n_frames, z, x, n_labels) for 2D or (n_frames, z, x, y, n_labels) for 3D, "
f"got shape {self.values.shape}"
)
assert self.labels is not None, "Segmentation requires labels to be provided"
super().__post_init__()
[docs]
@dataclass
class Image(Map):
"""Reconstructed (log-compressed) image data with per-pixel Cartesian coordinates.
Args:
values: The image values of shape ``(n_frames, z, x, y)`` or ``(n_frames, z, x)``
and type uint8 or float32. For float32 values, the values should be in dB
(between -inf and 0).
coordinates: Per-pixel Cartesian positions in metres, shape ``(*values.shape, 3)``.
The leading frame axis may be omitted to broadcast one coordinate grid
across all frames.
"""
SCHEMA = {
**Map.SCHEMA,
"values": {
"dtype": (np.float32, np.uint8),
"shape": (
("n_frames", "x", "z", "y"),
("n_frames", "x", "z"),
),
},
}
def __post_init__(self):
super().__post_init__()
# Check that image values are in dB scale (finite or -inf, and <= 0)
if self.values.dtype == np.float32:
if not np.all(np.isfinite(self.values) | np.isneginf(self.values)):
raise ValueError("Image values must be finite or -inf (dB scale).")
if not np.all(self.values <= 0):
raise ValueError("Image values must be in dB scale <= 0 when using float32 dtype.")
[docs]
@dataclass
class AlignedData(Spec):
"""Time-of-flight corrected data.
Args:
values: The aligned data of shape ``(n_frames, n_tx, n_ax, n_el, n_ch)``
and type float32 or int16. n_ch is 1 for RF data or 2 for IQ data.
labels: The labels for the channel dimension, e.g. ``["RF"]`` or ``["I", "Q"]``.
Auto-generated from n_ch if not provided.
"""
values: np.ndarray
labels: np.ndarray | None = None
SCHEMA = {
"values": {
"dtype": (np.float32, np.int16),
"shape": ("n_frames", "n_tx", "n_ax", "n_el", "n_ch"),
},
"labels": {"dtype": np.str_, "shape": ("n_ch",)},
}
def __post_init__(self):
n_ch = self.values.shape[-1]
if n_ch not in (1, 2):
raise ValueError(
f"Aligned data must have n_ch ∈ {{1, 2}} (RF or IQ), "
f"got n_ch={n_ch} (shape {self.values.shape})."
)
if self.labels is None:
self.labels = (
np.array(["RF"], dtype=np.str_)
if n_ch == 1
else np.array(["I", "Q"], dtype=np.str_)
)
super().__post_init__()
[docs]
@dataclass
class EnvelopeData(FloatMap):
"""Envelope-detected data with per-pixel Cartesian coordinates.
Args:
values: The envelope data of shape ``(n_frames, x, z)`` or
``(n_frames, z, x, y)`` and type float32.
coordinates: Per-pixel Cartesian positions in metres, shape ``(*values.shape, 3)``.
The leading frame axis may be omitted to broadcast one coordinate grid
across all frames.
"""
SCHEMA = {
**FloatMap.SCHEMA,
"values": {
"dtype": np.float32,
"shape": (
("n_frames", "z", "x", "y"),
("n_frames", "z", "x"),
),
},
}
[docs]
@dataclass
class SosMap(FloatMap):
"""Speed-of-sound map data with per-pixel Cartesian coordinates.
Args:
values: The speed-of-sound map values in m/s of shape ``(n_frames, z, x, y)``
and type float32.
coordinates: Per-pixel Cartesian positions in metres, shape
``(n_frames, z, x, 3)`` or ``(n_frames, z, x, y, 3)``.
The leading frame axis may be omitted to broadcast one coordinate grid
across all frames.
"""
def __post_init__(self):
super().__post_init__()
if self.unit is not None and self.unit != "m/s":
raise ValueError(f"Speed-of-sound map unit should be 'm/s', got '{self.unit}'")
# Check sensible values for speed of sound
if np.any(self.values < 300):
log.warning(
"Speed-of-sound map contains values below 300 m/s, which is unusually low. "
"Please verify that the speed-of-sound values are correct and in m/s."
)
[docs]
@dataclass
class StrainPercentageMap(FloatMap):
"""Strain map data with per-pixel Cartesian coordinates.
Args:
values: The strain values in % of shape ``(n_frames, z, x, y)`` and type float32.
coordinates: Per-pixel Cartesian positions in metres.
"""
def __post_init__(self):
super().__post_init__()
if self.unit is not None and self.unit != "%":
raise ValueError(f"Strain map unit should be '%', got '{self.unit}'")
[docs]
@dataclass
class ShearWaveElastographyMap(FloatMap):
"""Shear-wave elastography data with per-pixel Cartesian coordinates.
Args:
values: The shear-wave elastography values in m/s of shape
``(n_frames, z, x, y)`` and type float32.
coordinates: Per-pixel Cartesian positions in metres.
"""
def __post_init__(self):
super().__post_init__()
if self.unit is not None and self.unit != "m/s":
raise ValueError(f"SWE map unit should be 'm/s', got '{self.unit}'")
[docs]
@dataclass
class TissueDopplerMap(FloatMap):
"""Tissue Doppler data with per-pixel Cartesian coordinates.
Args:
values: The tissue Doppler values in m/s of shape ``(n_frames, z, x, y)``
and type float32.
coordinates: Per-pixel Cartesian positions in metres.
"""
def __post_init__(self):
super().__post_init__()
if self.unit is not None and self.unit != "m/s":
raise ValueError(f"SWE map unit should be 'm/s', got '{self.unit}'")
[docs]
@dataclass
class ColorDopplerMap(FloatMap):
"""Color Doppler (velocity) data with per-pixel Cartesian coordinates.
Args:
values: The color Doppler velocity values in m/s of shape
``(n_frames, z, x, y)`` and type float32. Positive values
indicate flow towards the transducer, negative values
indicate flow away from the transducer.
coordinates: Per-pixel Cartesian positions in metres.
"""
def __post_init__(self):
super().__post_init__()
if self.unit is not None and self.unit != "m/s":
raise ValueError(f"SWE map unit should be 'm/s', got '{self.unit}'")
[docs]
@dataclass(init=False)
class DataSpec(Spec):
"""Data group containing raw channels, derived pipeline products, and optional spatial maps.
Plain-array data products:
raw_data: Raw channel data of shape (n_frames, n_tx, n_ax, n_el, n_ch)
and type float32 or int16.
Grouped data products (values + optional metadata):
- aligned_data: Time-of-flight corrected data and optional labels.
- beamformed_data: Beamformed (beamsummed) data and per-pixel coordinates.
- envelope_data: Envelope-detected data and per-pixel coordinates.
- image: Reconstructed image data and per-pixel coordinates.
- segmentation: Segmentation data and per-pixel coordinates.
- sos_map: Speed-of-sound map data and per-pixel coordinates.
- strain_percentage_map: Strain map data and per-pixel coordinates.
- shear_wave_elastography_map: Shear-wave elastography data and per-pixel coordinates.
- tissue_doppler: Tissue Doppler data and per-pixel coordinates.
- color_doppler: Color Doppler velocity data and per-pixel coordinates.
- \\*\\*kwargs: Any other spatially aligned map data and per-pixel coordinates.
At least one data field (plain-array or grouped) must be provided.
"""
# Plain-array data products
raw_data: np.ndarray | None = None
# Grouped data products
aligned_data: AlignedData | dict | None = None
beamformed_data: BeamformedData | dict | None = None
envelope_data: EnvelopeData | dict | None = None
image: Image | dict | None = None
segmentation: Segmentation | dict | None = None
sos_map: SosMap | dict | None = None
strain_percentage_map: StrainPercentageMap | dict | None = None
shear_wave_elastography_map: ShearWaveElastographyMap | dict | None = None
tissue_doppler: TissueDopplerMap | dict | None = None
color_doppler: ColorDopplerMap | dict | None = None
SCHEMA = {
# Plain-array data products
"raw_data": {
"dtype": (np.float32, np.int16),
"shape": ("n_frames", "n_tx", "n_ax", "n_el", "n_ch"),
},
# Grouped data products
"aligned_data": {"spec": AlignedData},
"beamformed_data": {"spec": BeamformedData},
"envelope_data": {"spec": EnvelopeData},
"image": {"spec": Image},
"segmentation": {"spec": Segmentation},
"sos_map": {"spec": SosMap},
"strain_percentage_map": {"spec": StrainPercentageMap},
"shear_wave_elastography_map": {"spec": ShearWaveElastographyMap},
"tissue_doppler": {"spec": TissueDopplerMap},
"color_doppler": {"spec": ColorDopplerMap},
}
FIELD_METADATA = {
"raw_data": {"unit": "-", "description": "Raw channel data."},
}
def __init__(
self,
raw_data: np.ndarray | None = None,
aligned_data: AlignedData | dict | None = None,
beamformed_data: BeamformedData | dict | None = None,
envelope_data: EnvelopeData | dict | None = None,
image: Image | dict | None = None,
segmentation: Segmentation | dict | None = None,
sos_map: SosMap | dict | None = None,
strain_percentage_map: StrainPercentageMap | dict | None = None,
shear_wave_elastography_map: ShearWaveElastographyMap | dict | None = None,
tissue_doppler: TissueDopplerMap | dict | None = None,
color_doppler: ColorDopplerMap | dict | None = None,
**extra_maps,
):
self.raw_data = raw_data
self.aligned_data = aligned_data
self.beamformed_data = beamformed_data
self.envelope_data = envelope_data
self.image = image
self.segmentation = segmentation
self.sos_map = sos_map
self.strain_percentage_map = strain_percentage_map
self.shear_wave_elastography_map = shear_wave_elastography_map
self.tissue_doppler = tissue_doppler
self.color_doppler = color_doppler
reserved_keys = set(self.SCHEMA) | set(self.__dataclass_fields__) | set(dir(Spec))
for key, value in extra_maps.items():
if key in reserved_keys:
raise TypeError(f"Invalid custom data key '{key}': reserved name")
if isinstance(value, np.ndarray):
raise TypeError(
f"Custom data key '{key}' must be a spatial map "
f"(a dict with at least a 'values' key), not a flat array. "
f"Only 'raw_data' is accepted as a flat array. "
f"Wrap your data: {{'values': array, 'coordinates': coordinates_array}}."
)
setattr(self, key, value)
# Add custom extra maps to the schema as generic Map specs, so they get validated.
self._extra_map_keys = tuple(extra_maps.keys())
if getattr(self, "_extra_map_keys", ()):
self.SCHEMA = {
**self.SCHEMA,
**{key: {"spec": Map} for key in self._extra_map_keys},
}
self.__post_init__()
def __post_init__(self):
# Ensure at least one data field is present
all_data_keys = [k for k in self.SCHEMA]
has_any = any(getattr(self, k, None) is not None for k in all_data_keys)
if not has_any:
raise ValueError(
"At least one data field must be provided. "
f"Available fields: {', '.join(all_data_keys)}"
)
super().__post_init__()
# n_ch must be 1 (RF) or 2 (IQ) for raw_data (checked for aligned_data by AlignedData).
arr = getattr(self, "raw_data", None)
if arr is not None and isinstance(arr, np.ndarray):
n_ch = arr.shape[-1]
if n_ch not in (1, 2):
raise ValueError(
f"'raw_data' must have n_ch ∈ {{1, 2}} (RF or IQ), "
f"got n_ch={n_ch} (shape {arr.shape})."
)
[docs]
@dataclass
class ScanSpec(Spec):
"""Scan group with acquisition and transmit metadata.
All fields are aligned with the data format specification.
Args:
sampling_frequency: The sampling frequency in Hz.
center_frequency: The center frequency in Hz of the transmit pulse.
Single scalar if all transmits share the same center frequency;
otherwise an array of shape (n_tx,) with one frequency per transmit.
demodulation_frequency: The frequency in Hz at which the data should
be demodulated. Usually the same as center_frequency, but different
when doing harmonic imaging. Single scalar if all transmits share
the same center frequency; otherwise an array of shape (n_tx,) with
one frequency per transmit.
initial_times: The times in seconds when the A/D converter starts sampling
of shape (n_tx,). This is the time between the first element firing
and the first recorded sample.
t0_delays: The transmit delays in seconds for each element of shape
(n_tx, n_el). This is the time at which each element fires, shifted
such that the first element fires at t=0.
tx_apodizations: The apodization values that were applied to each
element during transmit of shape (n_tx, n_el). This is a value
between -1 and 1 that indicates how much each element contributed
to the transmit beam, with 0 meaning no contribution and 1 meaning
full contribution. Negative values indicate that the element was
fired with opposite polarity.
focus_distances: The transmit focus distances in meters of shape (n_tx,).
This is the distance from the origin point on the transducer to
where the beam comes to focus. For planewaves this is set to
infinity or zero.
transmit_origins: The transmit origins of the transmit beams in meters of
shape (n_tx, 3). This is the (x, y, z) position from which the beam
is transmitted.
polar_angles: The polar angles in radians of the transmit beams of shape (n_tx,).
time_to_next_transmit: The time in s between subsequent transmit events
of shape (n_frames, n_tx).
azimuth_angles: The azimuthal angles in radians of the transmit beams of
shape (n_tx,).
sound_speed: The speed of sound in meters per second.
tgc_gain_curve: The time-gain-compensation that was applied to every
sample in the raw_data of shape (n_ax,). Divide by this curve to
undo the TGC.
waveforms_one_way: One-way waveforms of shape (n_tx, .) as simulated
by the Verasonics system. This is the waveform after being filtered
by the transducer bandwidth once.
waveforms_two_way: Two-way waveforms of shape (n_tx, .) as simulated
by the Verasonics system. This is the waveform after being filtered
by the transducer bandwidth twice.
"""
sampling_frequency: np.ndarray | float
center_frequency: np.ndarray | float
demodulation_frequency: np.ndarray | float
initial_times: np.ndarray
t0_delays: np.ndarray
tx_apodizations: np.ndarray
focus_distances: np.ndarray
transmit_origins: np.ndarray
polar_angles: np.ndarray
time_to_next_transmit: np.ndarray = None
azimuth_angles: np.ndarray = None
sound_speed: np.ndarray | float | None = None
tgc_gain_curve: np.ndarray | None = None
waveforms_one_way: np.ndarray | None = None
waveforms_two_way: np.ndarray | None = None
SCHEMA = {
"sampling_frequency": {"dtype": np.float32, "shape": ()},
"center_frequency": {"dtype": np.float32, "shape": ((), ("n_tx",))},
"demodulation_frequency": {"dtype": np.float32, "shape": ((), ("n_tx",))},
"initial_times": {"dtype": np.float32, "shape": ("n_tx",)},
"t0_delays": {"dtype": np.float32, "shape": ("n_tx", "n_el")},
"tx_apodizations": {"dtype": np.float32, "shape": ("n_tx", "n_el")},
"focus_distances": {"dtype": np.float32, "shape": ("n_tx",)},
"transmit_origins": {"dtype": np.float32, "shape": ("n_tx", 3)},
"polar_angles": {"dtype": np.float32, "shape": ("n_tx",)},
"time_to_next_transmit": {"dtype": np.float32, "shape": ("n_frames", "n_tx")},
"azimuth_angles": {"dtype": np.float32, "shape": ("n_tx",)},
"sound_speed": {"dtype": np.float32, "shape": ()},
"tgc_gain_curve": {"dtype": np.float32, "shape": ("n_ax",)},
"waveforms_one_way": {
"dtype": np.float32,
"shape": ("n_tx", "n_samples_one_way"),
},
"waveforms_two_way": {
"dtype": np.float32,
"shape": ("n_tx", "n_samples_two_way"),
},
}
FIELD_METADATA = {
"sampling_frequency": {"unit": "Hz", "description": "Sampling frequency."},
"center_frequency": {
"unit": "Hz",
"description": "Center frequency of the transmit pulse.",
},
"demodulation_frequency": {"unit": "Hz", "description": "Demodulation frequency."},
"initial_times": {"unit": "s", "description": "A/D converter start times per transmit."},
"t0_delays": {"unit": "s", "description": "Transmit delays per element."},
"tx_apodizations": {"unit": "-", "description": "Transmit apodization per element."},
"focus_distances": {"unit": "m", "description": "Transmit focus distances."},
"transmit_origins": {"unit": "m", "description": "Transmit beam origins (x, y, z)."},
"polar_angles": {"unit": "rad", "description": "Polar angles of transmit beams."},
"time_to_next_transmit": {"unit": "s", "description": "Time between transmit events."},
"azimuth_angles": {"unit": "rad", "description": "Azimuthal angles of transmit beams."},
"sound_speed": {"unit": "m/s", "description": "Speed of sound."},
"tgc_gain_curve": {"unit": "-", "description": "Time-gain-compensation curve."},
"waveforms_one_way": {"unit": "V", "description": "One-way transmit waveforms."},
"waveforms_two_way": {"unit": "V", "description": "Two-way transmit waveforms."},
}
@property
def n_tx(self) -> int:
"""Number of transmits."""
return self.t0_delays.shape[0]
@property
def n_el(self) -> int:
"""Number of elements."""
return self.t0_delays.shape[1]
def __post_init__(self):
super().__post_init__()
if self.sampling_frequency <= 0:
raise ValueError(f"Sampling frequency must be positive, got {self.sampling_frequency}")
if np.any(self.center_frequency < 0):
raise ValueError(f"Center frequency cannot be negative, got {self.center_frequency}")
if np.any(self.demodulation_frequency < 0):
raise ValueError(
f"Demodulation frequency cannot be negative, got {self.demodulation_frequency}"
)
if np.any(self.t0_delays < 0):
raise ValueError(f"Transmit delays cannot be negative, got {self.t0_delays}")
if np.any(np.logical_and(self.focus_distances >= 1, self.focus_distances != np.inf)):
log.warning(
"Focus distances greater than or equal to 1 meter may be unusually large. "
"Maybe you have to convert to meters?"
)
if np.any(self.transmit_origins > 1.0) or np.any(self.transmit_origins < -1.0):
log.warning(
"Transmit origin values are unusually large, extending beyond +/- 1.0 meters. "
"Please verify that the transmit origin values are correct and in meters."
)
if np.any(self.polar_angles < -np.pi) or np.any(self.polar_angles > np.pi):
raise ValueError(
f"Polar angles should be between -pi and pi radians, got values between "
f"{np.min(self.polar_angles)} and {np.max(self.polar_angles)}"
)
if self.azimuth_angles is not None and (
np.any(self.azimuth_angles < -np.pi) or np.any(self.azimuth_angles > np.pi)
):
raise ValueError(
f"Azimuth angles should be between -pi and pi radians, got values between "
f"{np.min(self.azimuth_angles)} and {np.max(self.azimuth_angles)}"
)
if self.sound_speed is not None and self.sound_speed <= 0:
raise ValueError(f"Sound speed must be positive, got {self.sound_speed}")
if self.tgc_gain_curve is not None and np.any(self.tgc_gain_curve < 0):
raise ValueError(
f"TGC gain curve values must be non-negative, got values between "
f"{np.min(self.tgc_gain_curve)} and {np.max(self.tgc_gain_curve)}"
)
# Try to simplify the data by squeezing out any singleton dimensions,
# e.g. if center_frequency is an array with all the same value
if isinstance(self.center_frequency, np.ndarray) and self.center_frequency.ndim == 1:
if np.all(self.center_frequency == self.center_frequency[0]):
self.center_frequency = self.center_frequency[0]
if (
isinstance(self.demodulation_frequency, np.ndarray)
and self.demodulation_frequency.ndim == 1
):
if np.all(self.demodulation_frequency == self.demodulation_frequency[0]):
self.demodulation_frequency = self.demodulation_frequency[0]
[docs]
@dataclass
class ProbeSpec(Spec):
"""Probe hardware specification.
Stores static, physical characteristics of the transducer that are not
captured by the per-acquisition :class:`ScanSpec`. All fields are
optional so that partial information can be recorded.
Args:
name: Probe model identifier (e.g. ``"verasonics_l11_4v"``).
type: Probe geometry type: ``"linear"``, ``"phased"``, ``"curved"``, etc.
probe_center_frequency: Probe nominal centre frequency in Hz. Named
distinctly from :attr:`ScanSpec.center_frequency` (the per-acquisition
transmit frequency) so the two never collide when merged into a single
:class:`zea.Parameters` object.
probe_bandwidth_percent: Fractional bandwidth as a percentage.
probe_geometry: Element positions in metres, shape (n_el, 3) with columns
(x, y, z). :attr:`n_el` and :attr:`pitch` are computed
automatically as read-only properties from this array.
element_width: Width of a single transducer element in metres.
element_height: Height (elevation aperture) of a single element in metres.
lens_sound_speed: Speed of sound in the acoustic lens in m/s.
lens_thickness: Thickness of the acoustic lens in metres.
"""
name: str | None = None
type: str | None = None
probe_center_frequency: np.float32 | None = None
probe_bandwidth_percent: np.float32 | None = None
probe_geometry: np.ndarray | None = None
element_width: np.float32 | None = None
element_height: np.float32 | None = None
lens_sound_speed: np.float32 | None = None
lens_thickness: np.float32 | None = None
SCHEMA = {
"name": {"dtype": str, "shape": ()},
"type": {"dtype": str, "shape": ()},
"probe_center_frequency": {"dtype": np.float32, "shape": ()},
"probe_bandwidth_percent": {"dtype": np.float32, "shape": ()},
"probe_geometry": {"dtype": np.float32, "shape": ("n_el", 3)},
"element_width": {"dtype": np.float32, "shape": ()},
"element_height": {"dtype": np.float32, "shape": ()},
"lens_sound_speed": {"dtype": np.float32, "shape": ()},
"lens_thickness": {"dtype": np.float32, "shape": ()},
}
FIELD_METADATA = {
"name": {"description": "Probe model name/identifier."},
"type": {"description": "Probe geometry type (linear, phased, curved, ...)."},
"probe_center_frequency": {
"unit": "Hz",
"description": "Probe nominal centre frequency.",
},
"probe_bandwidth_percent": {
"unit": "%",
"description": "Fractional bandwidth as a percentage.",
},
"probe_geometry": {
"unit": "m",
"description": "Element positions (x, y, z) per element, shape (n_el, 3).",
},
"element_width": {
"unit": "m",
"description": "Width of a single transducer element.",
},
"element_height": {
"unit": "m",
"description": "Height (elevation aperture) of a single transducer element.",
},
"lens_sound_speed": {
"unit": "m/s",
"description": "Speed of sound in the acoustic lens.",
},
"lens_thickness": {
"unit": "m",
"description": "Thickness of the acoustic lens.",
},
}
@property
def n_el(self) -> int | None:
"""Number of transducer elements, derived from :attr:`probe_geometry`."""
if self.probe_geometry is not None:
return int(self.probe_geometry.shape[0])
return None
def __post_init__(self):
super().__post_init__()
if self.probe_geometry is not None:
if self.probe_geometry.ndim != 2 or self.probe_geometry.shape[1] != 3:
raise ValueError(
f"ProbeSpec: probe_geometry must have shape (n_el, 3), "
f"got {self.probe_geometry.shape}"
)
if np.any(self.probe_geometry > 1.0) or np.any(self.probe_geometry < -1.0):
log.warning(
"ProbeSpec probe_geometry values extend beyond \u00b11.0 m. "
"Please verify the values are in metres."
)
if self.probe_center_frequency is not None and self.probe_center_frequency <= 0:
raise ValueError(
"ProbeSpec: probe_center_frequency must be positive, got "
f"{self.probe_center_frequency}"
)
if self.probe_bandwidth_percent is not None and self.probe_bandwidth_percent <= 0:
raise ValueError(
"ProbeSpec: probe_bandwidth_percent must be positive, "
f"got {self.probe_bandwidth_percent}"
)
if self.element_width is not None and self.element_width <= 0:
raise ValueError(f"ProbeSpec: element_width must be positive, got {self.element_width}")
if self.element_height is not None and self.element_height <= 0:
raise ValueError(
f"ProbeSpec: element_height must be positive, got {self.element_height}"
)
if self.lens_sound_speed is not None and self.lens_sound_speed <= 0:
raise ValueError(
f"ProbeSpec: lens_sound_speed must be positive, got {self.lens_sound_speed}"
)
if self.lens_thickness is not None and self.lens_thickness < 0:
raise ValueError(
f"ProbeSpec: lens_thickness must be non-negative, got {self.lens_thickness}"
)
[docs]
@dataclass
class Subject(Spec):
"""Subject metadata associated with the study.
Args:
id: Subject ID.
type: Subject type, e.g. human, phantom, animal.
age: Subject age in years.
sex: Subject sex.
fat: Subject fat percentage.
"""
id: str | None = None
type: str | None = None
age: np.uint8 | None = None
sex: str | None = None
fat_percentage: np.float32 | None = None
SCHEMA = {
"id": {"dtype": str, "shape": ()},
"type": {"dtype": str, "shape": ()},
"age": {"dtype": np.uint8, "shape": ()},
"sex": {"dtype": str, "shape": ()},
"fat_percentage": {"dtype": np.float32, "shape": ()},
}
FIELD_METADATA = {
"id": {"description": "Subject ID. Needed for subject-wise splits."},
}
def __post_init__(self):
super().__post_init__()
if self.id is not None and not self.id.strip():
raise ValueError("Subject ID cannot be an empty string")
if self.fat_percentage is not None and (
self.fat_percentage < 0 or self.fat_percentage > 100
):
raise ValueError(
f"Subject fat percentage must be between 0 and 100, got {self.fat_percentage}"
)
[docs]
@dataclass
class Signal(Spec):
"""Base class for additional signals with timing and sampling-frequency metadata.
Args:
start_time_offset: Time offset in seconds between the first transmit event
of the ultrasound acquisition and sample 0 of this data. Negative
means this data starts before the first transmit event; positive
means it starts after.
sampling_frequency: Sampling frequency in Hz for the additional signal samples.
"""
start_time_offset: np.ndarray | float
sampling_frequency: np.ndarray | float
SCHEMA = {
"start_time_offset": {"dtype": np.float32, "shape": ()},
"sampling_frequency": {"dtype": np.float32, "shape": ()},
}
FIELD_METADATA = {
"start_time_offset": {
"unit": "s",
"description": (
"Time offset between the first transmit event of the ultrasound "
"acquisition and sample 0 of this data. Negative means this data "
"starts before the first transmit event; positive means it starts "
"after."
),
},
"sampling_frequency": {"unit": "Hz", "description": "Sampling frequency."},
}
def __post_init__(self):
super().__post_init__()
if self.sampling_frequency <= 0:
raise ValueError(f"Sampling frequency must be positive, got {self.sampling_frequency}")
[docs]
@dataclass
class ProbePose(Signal):
"""Sampled probe pose metadata at the tip of the transducer.
The pose uses the coordinate convention x = lateral along the transducer,
y = elevation (out of plane), and z = axial (depth).
Args:
translation: Position of the transducer tip in meters of shape (T, 3),
ordered as (x, y, z).
rotation: Orientation of the transducer tip of shape (T, 3) or (T, 4),
interpreted according to ``rotation_representation``.
rotation_representation: Rotation parameterization. Supported values are
``"euler_xyz"``, ``"quaternion_wxyz"``, and ``"quaternion_xyzw"``.
start_time_offset: Time offset in seconds between the first transmit event
of the ultrasound acquisition and sample 0 of this data.
sampling_frequency: Sampling frequency in Hz for probe pose samples.
"""
translation: np.ndarray
rotation: np.ndarray
rotation_representation: str
SCHEMA = {
"translation": {"dtype": np.float32, "shape": ("T", 3)},
"rotation": {"dtype": np.float32, "shape": (("T", 3), ("T", 4))},
"rotation_representation": {"dtype": str, "shape": ()},
**Signal.SCHEMA,
}
FIELD_METADATA = {
"translation": {
"unit": "m",
"description": (
"Position of the transducer tip, ordered as (x, y, z), where x is "
"lateral along the transducer, y is elevation (out of plane), and "
"z is axial (depth)."
),
},
"rotation": {
"unit": "-",
"description": (
"Orientation associated with the transducer-tip pose in the "
"x-lateral, y-elevation, z-axial coordinate convention, interpreted "
"according to rotation_representation."
),
},
"rotation_representation": {
"unit": "-",
"description": (
"Rotation parameterization: one of euler_xyz, quaternion_wxyz, or quaternion_xyzw."
),
},
**Signal.FIELD_METADATA,
}
def __post_init__(self):
super().__post_init__()
valid_representations = {
"euler_xyz": 3,
"quaternion_wxyz": 4,
"quaternion_xyzw": 4,
}
if self.translation.shape[0] != self.rotation.shape[0]:
raise ValueError(
"translation and rotation must have the same number of time samples, "
f"got {self.translation.shape[0]} and {self.rotation.shape[0]}"
)
if self.rotation_representation not in valid_representations:
valid = ", ".join(sorted(valid_representations))
raise ValueError(
f"rotation_representation must be one of {{{valid}}}, "
f"got {self.rotation_representation!r}"
)
expected_width = valid_representations[self.rotation_representation]
if self.rotation.shape[1] != expected_width:
raise ValueError(
"rotation shape does not match rotation_representation: "
f"got {self.rotation.shape} for {self.rotation_representation!r}"
)
[docs]
@dataclass
class Signal1D(Signal):
"""One-dimensional sampled signal with timing metadata.
Args:
samples: Signal samples of shape (T) and type uint8 or float32 or int16 or complex64.
start_time_offset: Time offset in seconds between the first transmit event
of the ultrasound acquisition and sample 0 of this data.
sampling_frequency: Sampling frequency in Hz for signal samples.
"""
samples: np.ndarray
SCHEMA = {
"samples": {"dtype": (np.uint8, np.float32, np.int16, np.complex64), "shape": ("T",)},
**Signal.SCHEMA,
}
FIELD_METADATA = {
"samples": {"unit": "-", "description": "Signal samples."},
**Signal.FIELD_METADATA,
}
[docs]
@dataclass
class SignalND(Signal):
"""N-dimensional sampled signal with timing metadata.
Args:
samples: Signal samples of shape (T, ...) and type uint8 or float32 or int16 or complex64.
start_time_offset: Time offset in seconds between the first transmit event
of the ultrasound acquisition and sample 0 of this data.
sampling_frequency: Sampling frequency in Hz for signal samples.
"""
samples: np.ndarray
SCHEMA = {
"samples": {"dtype": (np.uint8, np.float32, np.int16, np.complex64), "shape": ("T", "...")},
**Signal.SCHEMA,
}
FIELD_METADATA = {
"samples": {"unit": "-", "description": "Signal samples."},
**Signal.FIELD_METADATA,
}
[docs]
@dataclass
class Annotations(Spec):
"""Frame-level annotations, either per frame or broadcast labels.
Args:
anatomy: Anatomy label.
view: View label of shape (n_frames,).
label: Pathology or classification label of shape (n_frames,).
image_quality: Image quality label, e.g. low, mid, high.
"""
anatomy: np.ndarray | str | None = None
view: np.ndarray | None = None
label: np.ndarray | None = None
image_quality: np.ndarray | str | None = None
SCHEMA = {
"anatomy": {"dtype": np.str_, "shape": (("n_frames",), ())},
"view": {"dtype": np.str_, "shape": ("n_frames",)},
"label": {"dtype": np.str_, "shape": ("n_frames",)},
"image_quality": {"dtype": np.str_, "shape": (("n_frames",), ())},
}
[docs]
@dataclass
class MetricsSpec(Spec):
"""Metrics group for acquisition-level quality/performance metrics.
Args:
common_midpoint_phase_error: Common midpoint phase error in radians of
shape (n_frames,) and type float32.
coherence_factor: Coherence factor of shape (n_frames,) and type float32.
"""
common_midpoint_phase_error: np.ndarray | None = None
coherence_factor: np.ndarray | None = None
SCHEMA = {
"common_midpoint_phase_error": {
"dtype": np.float32,
"shape": ("n_frames",),
},
"coherence_factor": {"dtype": np.float32, "shape": ("n_frames",)},
}
[docs]
@dataclass
class TrackSpec(Spec):
"""A single acquisition track with its own data and scan parameters.
Used inside a multi-track :class:`FileSpec` where different transmit
sequences coexist in the same acquisition. The ``track_schedule`` on
``FileSpec`` specifies the global ordering of transmits across all tracks.
For multi-track files a human-readable ``label`` is required on every
track so that users can identify which track is which (e.g. ``"focused"``
vs ``"planewave"``). Further information can be provided in the ``description``
field of the parent :class:`FileSpec`, if necessary.
Single-track files may omit the label.
Args:
data (DataSpec | dict): The data for this track.
scan (ScanSpec | dict | None): The scan parameters for this track. Required when raw_data is
present in *data*.
label (str | None): Short human-readable name for this track (e.g. ``"focused"``
or ``"planewave"``). Required when the parent :class:`FileSpec`
contains more than one track.
"""
data: DataSpec | dict
scan: ScanSpec | dict | None = None
label: str | None = None
SCHEMA = {
"data": {"spec": DataSpec},
"scan": {"spec": ScanSpec},
"label": {"dtype": str, "shape": ()},
}
def __post_init__(self):
super().__post_init__()
data = self.data
has_raw = (isinstance(data, DataSpec) and data.raw_data is not None) or (
isinstance(data, dict) and data.get("raw_data") is not None
)
if has_raw and self.scan is None:
raise ValueError("'scan' is required when 'raw_data' is provided in track data.")
if self.label is not None and not isinstance(self.label, str):
raise TypeError(f"'label' must be a str, got {type(self.label)}")
if self.label is not None and not self.label.strip():
raise ValueError("'label' must not be an empty or whitespace-only string.")
[docs]
def store_in_group(
self,
group: "h5py.Group",
compression: str = DEFAULT_COMPRESSION,
chunk_frames: bool = False,
) -> None:
"""Store data, scan, and label in the HDF5 group."""
super().store_in_group(group, compression=compression, chunk_frames=chunk_frames)
[docs]
@dataclass
class FileSpec(Spec):
"""A dataset containing all the data, scan parameters, metadata,
and metrics for a single acquisition.
A ``FileSpec`` always contains at least one track. When ``data`` and
``scan`` are supplied at construction time they are transparently wrapped
into a single :class:`TrackSpec`, so all existing call-sites continue to
work unchanged. For multi-track files pass ``tracks`` directly.
Args:
data: Data for a single-track acquisition (wrapped into ``tracks[0]``).
scan: Scan parameters for a single-track acquisition.
tracks: Explicit list of :class:`TrackSpec` objects (multi-track mode).
Mutually exclusive with ``data``/``scan``.
track_schedule: 1-D int32 array of length ``n_total_tx`` giving the
track index for each global transmit event.
metadata: Additional metadata about the acquisition.
metrics: Metrics computed from the acquisition.
probe: Physical probe specification (see :class:`ProbeSpec`). The probe
name is stored as ``probe.name``; use :attr:`zea.File.probe_name`
to read it back from an HDF5 file.
us_machine: The ultrasound machine used to acquire the data.
description: Free-text description.
Example:
.. doctest::
>>> from zea.data.spec import FileSpec
>>> import numpy as np
>>> dataset = FileSpec(
... data={
... "raw_data": np.zeros((2, 4, 64, 8, 1), dtype=np.float32),
... },
... scan={
... "sampling_frequency": np.float32(40e6),
... "center_frequency": np.float32(5e6),
... "demodulation_frequency": np.float32(5e6),
... "initial_times": np.zeros(4, dtype=np.float32),
... "t0_delays": np.zeros((4, 8), dtype=np.float32),
... "tx_apodizations": np.ones((4, 8), dtype=np.float32),
... "focus_distances": np.full(4, np.inf, dtype=np.float32),
... "transmit_origins": np.zeros((4, 3), dtype=np.float32),
... "polar_angles": np.zeros(4, dtype=np.float32),
... },
... )
>>> dataset.data.raw_data.shape
(2, 4, 64, 8, 1)
"""
# NOTE: data and scan are intentionally NOT dataclass fields — they are
# accepted as constructor kwargs and folded into tracks[0] at init time.
# @property accessors below provide backwards-compatible single-track access.
tracks: list = field(default_factory=list)
track_schedule: np.ndarray | None = None
metadata: MetadataSpec | dict = field(default_factory=MetadataSpec)
metrics: MetricsSpec | dict = field(default_factory=MetricsSpec)
probe: ProbeSpec | dict | None = None
us_machine: str | None = None
description: str | None = None
# tells the SCHEMA ↔ fields consistency test that 'tracks' is intentionally
# absent from SCHEMA (list[TrackSpec] doesn't fit the standard SCHEMA patterns)
_SCHEMA_EXCLUDED_FIELDS = frozenset({"tracks"})
SCHEMA = {
"track_schedule": {"dtype": np.int32, "shape": ("n_total_tx",)},
"metadata": {"spec": MetadataSpec},
"metrics": {"spec": MetricsSpec},
"probe": {"spec": ProbeSpec},
"us_machine": {"dtype": str, "shape": ()},
"description": {"dtype": str, "shape": ()},
}
def __init__(
self,
data: "DataSpec | dict | None" = None,
scan: "ScanSpec | dict | None" = None,
tracks: "list | None" = None,
track_schedule: "np.ndarray | None" = None,
metadata: "MetadataSpec | dict | None" = None,
metrics: "MetricsSpec | dict | None" = None,
probe_name: "str | None" = None,
probe: "ProbeSpec | dict | None" = None,
us_machine: "str | None" = None,
description: "str | None" = None,
):
if data is not None or scan is not None:
if tracks:
raise ValueError(
"Provide either 'data'/'scan' (single-track shorthand) "
"or 'tracks' (multi-track), not both."
)
_implicit_track: "dict | None" = {"data": data, "scan": scan}
else:
_implicit_track = None
if probe_name is not None:
raise TypeError(
"probe_name is not a FileSpec parameter. "
"Use probe={'name': ...} to specify the probe name."
)
self.tracks = list(tracks) if tracks is not None else []
self.track_schedule = track_schedule
self.metadata = metadata if metadata is not None else MetadataSpec()
self.metrics = metrics if metrics is not None else MetricsSpec()
self.probe = probe
self.us_machine = us_machine
self.description = description
self.__post_init__(_implicit_track)
# ------------------------------------------------------------------
# Backwards-compat read properties (single-track files only)
# ------------------------------------------------------------------
@property
def data(self) -> "DataSpec":
"""Return the :class:`DataSpec` of the single track.
Raises :exc:`AttributeError` when the file has more than one track —
use ``spec.tracks[i].data`` instead.
"""
if len(self.tracks) != 1:
raise AttributeError(
f"'data' is only available for single-track FileSpecs "
f"({len(self.tracks)} tracks present). Use spec.tracks[i].data."
)
return self.tracks[0].data
@property
def scan(self) -> "ScanSpec | None":
"""Return the :class:`ScanSpec` of the single track.
Raises :exc:`AttributeError` when the file has more than one track —
use ``spec.tracks[i].scan`` instead.
"""
if len(self.tracks) != 1:
raise AttributeError(
f"'scan' is only available for single-track FileSpecs "
f"({len(self.tracks)} tracks present). Use spec.tracks[i].scan."
)
return self.tracks[0].scan
def __post_init__(self, _implicit_track: "dict | None" = None):
# Fold implicit data/scan into a TrackSpec if provided
if _implicit_track is not None:
self.tracks = [TrackSpec(**_implicit_track)]
if not self.tracks:
raise ValueError("A FileSpec must contain at least one track.")
# Create TrackSpecs from dictionaries in the tracks list, if needed, and validate all tracks
track_specs = []
for i, t in enumerate(self.tracks):
if isinstance(t, dict):
try:
t = TrackSpec(**t)
except (TypeError, ValueError) as e:
raise type(e)(f"In tracks[{i}]: {e}") from e
elif not isinstance(t, TrackSpec):
raise TypeError(f"tracks[{i}] must be a TrackSpec or dict, got {type(t)}")
track_specs.append(t)
self.tracks = track_specs
# For multi-track files every track must have a label so users can
# identify tracks by name rather than relying on numeric indices.
if len(self.tracks) > 1:
missing = [i for i, t in enumerate(self.tracks) if not t.label]
if missing:
raise ValueError(
f"All tracks in a multi-track file must have a 'label'. "
f"Missing label for track(s) at index: {missing}. "
f"Provide a short descriptive name for each track, e.g. "
f"'focused' or 'planewave', so that "
f"File.get_track(label) and File.track_labels work correctly."
)
# Validate track_schedule indices are in range
if self.track_schedule is not None:
n_tracks = len(self.tracks)
if not np.all((self.track_schedule >= 0) & (self.track_schedule < n_tracks)):
raise ValueError(
f"All track_schedule indices must be in [0, {n_tracks - 1}], "
f"got min={self.track_schedule.min()}, max={self.track_schedule.max()}"
)
# Warn if multi-track frame counts differ without a schedule
if len(self.tracks) > 1 and self.track_schedule is None:
frame_counts = []
for track in self.tracks:
rd = (
track.data.raw_data
if isinstance(track.data, DataSpec)
else (track.data.get("raw_data") if isinstance(track.data, dict) else None)
)
if rd is not None and hasattr(rd, "shape"):
frame_counts.append(rd.shape[0])
if len(set(frame_counts)) > 1:
log.warning(
"Tracks have different numbers of frames "
f"({frame_counts}). Without a 'track_schedule' it is "
"ambiguous how frames correspond across tracks. Consider "
"passing 'track_schedule' to make the relationship explicit."
)
# Run base SCHEMA validation (metadata, metrics, scalars, track_schedule)
super().__post_init__()
# Validate that dimensions which are present in both metadata and tracks
# are consistent across all tracks.
if isinstance(self.metadata, MetadataSpec):
_, meta_dim_sizes = self.metadata._collect_dimension_info("metadata.")
for i, track in enumerate(self.tracks):
_, track_dim_sizes = track._collect_dimension_info(f"tracks[{i}].")
for dim in CONSISTENCY_DIMENSIONS:
if dim in meta_dim_sizes and dim in track_dim_sizes:
all_sizes = meta_dim_sizes[dim] | track_dim_sizes[dim]
if len(all_sizes) > 1:
meta_fields, _ = self.metadata._collect_dimension_info("metadata.")
track_fields, _ = track._collect_dimension_info(f"tracks[{i}].")
raise ValueError(
f"Dimension '{dim}' has inconsistent sizes across "
f"fields {sorted(meta_fields[dim] | track_fields[dim])}: "
f"{sorted(all_sizes)}"
)
[docs]
def to_dict(self) -> dict:
"""Return this spec as a nested dictionary.
Includes all :attr:`SCHEMA` fields plus the ``tracks`` list.
"""
result = super().to_dict()
result["tracks"] = [t.to_dict() for t in self.tracks]
return result
[docs]
def save(
self,
path: str,
compression: str = DEFAULT_COMPRESSION,
chunk_frames: bool = False,
) -> None:
"""Save the dataset to the specified path."""
# Lazy import to avoid circular dependency (spec.py is imported by file.py)
from zea import File
try:
_zea_version = _get_pkg_version("zea")
except PackageNotFoundError:
_zea_version = "dev"
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with File(str(path), "w") as f:
f.attrs["zea_version"] = _zea_version
# Write scalar/array metadata fields (metadata, metrics, probe_name, etc.)
for group_name, schema in self.SCHEMA.items():
if "spec" in schema:
value: Spec = getattr(self, group_name)
if value is None:
continue
group = f.create_group(group_name)
value.store_in_group(group, compression=compression)
else:
value = getattr(self, group_name)
if value is not None:
if group_name == "track_schedule":
# Array field — store as dataset, not attr
self.create_dataset(f, group_name, value, compression=compression)
else:
f.attrs[group_name] = value
# Write tracks (always at least one)
tracks_group = f.create_group("tracks")
for i, track in enumerate(self.tracks):
track_group = tracks_group.create_group(f"track_{i}")
track.store_in_group(
track_group,
compression=compression,
chunk_frames=chunk_frames,
)
log.info(f"File saved to {log.yellow(path)}")
[docs]
@classmethod
def from_hdf5(cls, file: h5py.File) -> "FileSpec":
"""Load and validate a :class:`FileSpec` from an open HDF5 file.
Both the new ``tracks/track_N/`` format and the normal flat
``data/`` + ``scan/`` format are supported. Extra scalar fields in
legacy scan groups (``n_frames``, ``n_tx``, etc.) are ignored,
and the ``probe`` root attribute is mapped to ``probe.name``.
Args:
file: An open ``h5py.File`` (or :class:`zea.File`).
Returns:
FileSpec: A fully validated spec object.
"""
def _load_group_as_dict(group: h5py.Group) -> dict:
result = {}
for key in group.keys():
item = group[key]
if isinstance(item, h5py.Group):
result[key] = _load_group_as_dict(item)
elif isinstance(item, h5py.Dataset):
if h5py.check_string_dtype(item.dtype) is not None:
val = item.asstr()[()]
# h5py returns object-dtype arrays for strings;
# convert back to np.str_ so spec dtype checks pass.
if isinstance(val, np.ndarray) and val.dtype == object:
val = val.astype(np.str_)
result[key] = val
else:
result[key] = item[()]
return result
kwargs: dict[str, Any] = {}
# Load scalar SCHEMA fields (metadata, metrics, probe_name, us_machine, description,
# track_schedule)
for group_name, schema in cls.SCHEMA.items():
if "spec" in schema:
if group_name in file:
kwargs[group_name] = _load_group_as_dict(file[group_name])
elif group_name == "track_schedule":
if group_name in file:
kwargs[group_name] = file[group_name][()].astype(np.int32)
else:
if group_name in file.attrs:
kwargs[group_name] = file.attrs[group_name]
# New multi-track format: tracks/track_N/
if "tracks" in file:
tracks_group = file["tracks"]
scan_schema_keys = set(ScanSpec.SCHEMA.keys())
tracks = []
i = 0
while f"track_{i}" in tracks_group:
track_group = tracks_group[f"track_{i}"]
track_dict = _load_group_as_dict(track_group)
# Filter legacy scalar fields from per-track scan dicts, matching
# the same treatment applied to single-track scan groups below.
if "scan" in track_dict and isinstance(track_dict["scan"], dict):
track_dict["scan"] = {
k: v for k, v in track_dict["scan"].items() if k in scan_schema_keys
}
tracks.append(track_dict)
i += 1
kwargs["tracks"] = tracks
# Legacy flat format: data/ + scan/ at root
elif "data" in file or "scan" in file:
data_dict = _load_group_as_dict(file["data"]) if "data" in file else {}
scan_dict = _load_group_as_dict(file["scan"]) if "scan" in file else None
kwargs["data"] = data_dict
if scan_dict is not None:
kwargs["scan"] = scan_dict
# 1. Map legacy root 'probe_name' or 'probe' attr into probe.name so
# that old files with a named probe but no probe group still round-trip.
if "probe" not in kwargs:
try:
legacy_name = file.probe_name
if legacy_name is not None:
kwargs["probe"] = {"name": legacy_name}
except AttributeError:
pass # no probe info in file — leave probe as None
# 2. Filter scan dict to only keys recognised by ScanSpec.SCHEMA so
# that legacy scalar fields (n_frames, n_ax, n_el, n_tx, n_ch, …)
# are silently dropped.
if "scan" in kwargs:
scan_schema_keys = set(ScanSpec.SCHEMA.keys())
kwargs["scan"] = {k: v for k, v in kwargs["scan"].items() if k in scan_schema_keys}
# 3. Handle legacy flat `data/<key>` datasets. In old files spatial
# maps (image, image_sc, envelope_data, …) were stored as plain
# arrays (n_frames, z, x) rather than groups with values +
# coordinates. Wrap them as {"values": array} so DataSpec accepts
# them. raw_data and aligned_data are valid as flat arrays and are
# left untouched.
if "data" in kwargs and isinstance(kwargs["data"], dict):
data_dict = kwargs["data"]
for key in list(data_dict.keys()):
if not isinstance(data_dict[key], np.ndarray):
continue
schema_entry = DataSpec.SCHEMA.get(key)
# raw_data / aligned_data are plain-array fields — skip them.
if schema_entry is not None and "spec" not in schema_entry:
continue
log.warning(
"Legacy flat dataset 'data/%s' has no spatial coordinates. "
"The array has been loaded as 'values'; coordinates information "
"was not stored in this file and will be None.",
key,
)
data_dict[key] = {"values": data_dict[key]}
return cls(**kwargs)