Models¶

Collection of (generative) models for ultrasound imaging.

zea contains a collection of models for various tasks, all located in the zea.models package.

See the following dropdown for a list of available models:

Available models

Presets for these models can be found in zea.models.presets.

To use these models, you can import them directly from the zea.models module and load the pretrained weights using the from_preset() method. For example:

>>> from zea.models.unet import UNet

>>> model = UNet.from_preset("unet-echonet-inpainter")

You can list all available presets using the presets attribute:

>>> from zea.models.unet import UNet
>>> presets = list(UNet.presets.keys())
>>> print(f"Available built-in zea presets for UNet: {presets}")
Available built-in zea presets for UNet: ['unet-echonet-inpainter']

zea generative models¶

In addition to models, zea provides both classical and deep generative models for tasks such as image generation, inpainting, and denoising. These models inherit from zea.models.generative.GenerativeModel or zea.models.deepgenerative.DeepGenerativeModel. Typically, these models have some additional methods, such as:

  • fit() for training the model on data

  • sample() for generating new samples from the learned distribution

  • posterior_sample() for drawing samples from the posterior given measurements

  • log_density() for computing the log-probability of data under the model

See the following dropdown for a list of available generative models:

Available models

An example of how to use the zea.models.diffusion.DiffusionModel is shown below:

>>> from zea.models.diffusion import DiffusionModel

>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
>>> samples = model.sample(n_samples=4)

Contributing and adding new models¶

Please follow the guidelines in the Contributing page if you would like to contribute a new model to zea.

The following steps are recommended when adding a new model:

  1. Create a new module in the zea.models package for your model: zea.models.mymodel.

  2. Add a model class that inherits from zea.models.base.Model. For generative models, inherit from zea.models.generative.GenerativeModel or zea.models.deepgenerative.DeepGenerativeModel as appropriate. Make sure you implement the call() method.

  3. Upload the pretrained model weights to our Hugging Face. Should be a config.json and a model.weights.h5 file. See Keras documentation how those can be saved from your model. Simply drag and drop the files to the Hugging Face website to upload them.

    Tip

    It is recommended to use the mentioned saving procedure. However, alternate saving methods are also possible, see the zea.models.echonet.EchoNet module for an example. You do now have to implement a custom_load_weights() method in your model class.

  4. Add a preset for the model in zea.models.presets. This basically allows you to have multiple weights presets for a given model architecture.

  5. Make sure to register the presets in your model module by importing the presets module and calling register_presets with the model class as an argument.

  6. Lastly, add the model to the zea.models package by importing it in the __init__.py file.

Adding non-Keras (custom) models¶

The recommended approach for any model is to implement it as a native Keras 3 model. This gives you backend-agnostic execution (JAX, TensorFlow, PyTorch) and the full preset/weight-loading infrastructure for free.

For models originally trained in PyTorch, the typical workflow is:

  1. Vendor the architecture — copy the PyTorch network code into your module (e.g. inside a _build_torch_classes() helper that imports torch lazily so that PyTorch is only required for weight conversion, not inference).

  2. Implement the Keras architecture — write keras.layers.Layer subclasses that replicate each block. Key API differences to handle:

    • Padding for stride-2 Conv2D: Keras padding='same' is asymmetric for stride > 1; use ZeroPadding2D(1) + Conv2D(padding='valid') to match PyTorch’s symmetric padding=1.

    • ConvTranspose: Keras Conv2DTranspose(padding='valid') gives the full output; crop x[:, 1:, 1:, :] (NHWC) to reproduce PyTorch’s ConvTranspose2d(padding=1, output_padding=1) alignment.

    • InstanceNorm: use GroupNormalization(groups=C, scale=False, center=False, epsilon=1e-5) for InstanceNorm2d(affine=False).

    • Weight axes: Conv2D — (2,3,1,0); Conv2DTranspose — (2,3,1,0) (same permutation, different semantics).

    • Input format: Keras defaults to channels-last (NHWC); transpose NCHW → NHWC in call() and back before returning.

  3. Write a weight-loading helper — a function that maps PyTorch state-dict keys to the Keras layer tree and calls layer.set_weights([...]).

  4. Add from_pth(path) classmethod — wraps the weight loader for convenient local testing.

  5. Optionally add an ONNX fallback — for environments that have onnxruntime but not torch, you can keep a from_onnx(path) classmethod and an _onnx_sess attribute; override call() to dispatch to the ONNX path when the session is set.

  6. Follow steps 3-6 from the standard guide above for HF upload, presets, and registration.

See zea.models.speckle2self for a complete worked example of this pattern (native Keras + PyTorch weight loading + optional ONNX fallback).