"""
Flow matching generative model for ultrasound image generation and posterior sampling.
Replaces the cosine diffusion schedule and noise-prediction objective of
:class:`~zea.models.diffusion.DiffusionModel` with a linear flow-matching
schedule and a velocity-field prediction objective.
.. seealso::
- :class:`~zea.models.diffusion.DiffusionModel`: DDIM-based counterpart.
- Liu et al., *Flow Straight and Fast*, 2022. https://arxiv.org/abs/2209.03003
- Lipman et al., *Flow Matching for Generative Modeling*, 2022. https://arxiv.org/abs/2210.02747
- Esser et al., *Scaling Rectified Flow Transformers for High-Resolution Image Synthesis*, 2024.
https://arxiv.org/abs/2403.03206
"""
from __future__ import annotations
import keras
from keras import ops
from zea.backend import _import_tf
from zea.internal.registry import model_registry
from zea.models.diffusion import DiffusionModel
from zea.models.preset_utils import register_presets
from zea.models.presets import flow_matching_presets
from zea.models.utils import LossTrackerWrapper
tf = _import_tf()
[docs]
@model_registry(name="flow_matching")
class FlowMatchingModel(DiffusionModel):
"""Flow matching generative model.
Implements conditional flow matching (CFM) with straight-line (linear)
interpolation paths between data and noise. The forward process is:
.. math::
x_t = (1 - t)\\, x_0 + t\\, \\varepsilon, \\qquad \\varepsilon \\sim \\mathcal{N}(0, I)
The network is trained to predict the **velocity field**
.. math::
v_\\theta(x_t, t) \\approx v = \\varepsilon - x_0
from which the clean-image estimate follows as
.. math::
\\hat{x}_0 = x_t - t\\, v_\\theta(x_t, t)
At inference, images are generated by integrating the probability flow ODE
.. math::
\\frac{dx}{dt} = v_\\theta(x_t, t)
backwards from :math:`t = 1` (pure noise) to :math:`t = 0` (clean data)
using a simple Euler discretisation (identical to the DDIM update rule under
this linear schedule).
Noise samples are drawn independently from :math:`\\mathcal{N}(0, I)` and
paired with data samples via **independent (random) coupling**, i.e. vanilla
CFM / Rectified Flow (Liu et al. 2022). Minibatch Optimal Transport
coupling (OT-CFM, Tong et al. 2023) is not currently implemented.
All sampling, guidance (DPS/DDS), and posterior-sampling machinery from
:class:`~zea.models.diffusion.DiffusionModel` is inherited unchanged.
"""
def __init__(
self,
input_shape,
input_range=(0, 1),
network_name="unet_time_conditional",
network_kwargs=None,
name="flow_matching_model",
guidance="dps",
operator="inpainting",
ema_val=0.999,
min_t=0.0,
max_t=1.0,
**kwargs,
):
"""Initialize a flow matching model.
Args:
input_shape: Shape of the input data, typically
``(height, width, channels)`` for images.
input_range: Range of the input data. Default ``(0, 1)``.
network_name: Network architecture. One of
``"unet_time_conditional"`` or ``"dense_time_conditional"``.
network_kwargs: Extra keyword arguments forwarded to the network
constructor.
name: Model name. Default ``"flow_matching_model"``.
guidance: Guidance method. Can be a string (e.g. ``"dps"``),
a dict with ``"name"`` and optional ``"params"`` keys, or a
:class:`~zea.models.diffusion.DiffusionGuidance` instance.
operator: Forward operator. Same format as ``guidance``.
ema_val: Exponential moving average coefficient for the inference
network weights. Default ``0.999``.
min_t: Lower bound of the flow time interval. Default ``0.0``.
max_t: Upper bound of the flow time interval. Default ``1.0``.
**kwargs: Additional arguments forwarded to
:class:`~zea.models.diffusion.DiffusionModel`.
"""
# min_signal_rate / max_signal_rate are cosine-schedule parameters
# that are not used by FlowMatchingModel; pass neutral values so the
# parent __init__ does not fail.
super().__init__(
input_shape=input_shape,
input_range=input_range,
min_signal_rate=0.0,
max_signal_rate=1.0,
network_name=network_name,
network_kwargs=network_kwargs,
name=name,
guidance=guidance,
operator=operator,
ema_val=ema_val,
min_t=min_t,
max_t=max_t,
**kwargs,
)
# Replace the noise-loss tracker with a velocity-loss tracker.
self.velocity_loss_tracker = LossTrackerWrapper("v_loss")
[docs]
def get_config(self):
config = super().get_config()
# min/max_signal_rate are meaningless for this model
config.pop("min_signal_rate", None)
config.pop("max_signal_rate", None)
return config
[docs]
def diffusion_schedule(self, diffusion_times):
"""Linear flow-matching schedule.
.. math::
\\text{noise\\_rates} = t, \\qquad \\text{signal\\_rates} = 1 - t
Args:
diffusion_times: Tensor of flow times in ``[min_t, max_t]``.
Returns:
A ``(noise_rates, signal_rates)`` tuple with the same shape as
``diffusion_times``.
"""
noise_rates = ops.cast(diffusion_times, "float32")
signal_rates = 1.0 - noise_rates
return noise_rates, signal_rates
[docs]
def denoise(
self,
noisy_images,
noise_rates,
signal_rates,
training: bool,
network=None,
):
"""Predict the velocity field and derive the clean-image estimate.
The network predicts the velocity :math:`v_\\theta(x_t, t)`. The
clean-image estimate follows as
.. math::
\\hat{x}_0 = x_t - t\\, v_\\theta(x_t, t)
To keep full compatibility with the parent's sampling and guidance
machinery (which expects a ``(pred_noises, pred_images)`` return
value), the method also returns the corresponding noise estimate
.. math::
\\hat{\\varepsilon} = \\hat{x}_0 + v_\\theta = x_t + (1 - t)\\, v_\\theta
under the name ``pred_noises``. The parent's
:meth:`~DiffusionModel.reverse_diffusion_step` formula
.. math::
x_{t - \\Delta t}
= \\alpha_{t-\\Delta t}\\,\\hat{x}_0
+ \\sigma_{t-\\Delta t}\\,\\hat{\\varepsilon}
is algebraically equivalent to the Euler step
:math:`x_{t-\\Delta t} = x_t - \\Delta t\\, v_\\theta` under the
linear schedule, so no changes to the sampling loop are needed.
Args:
noisy_images: Noisy images ``x_t`` of shape
``(n_images, *input_shape)``.
noise_rates: Flow times ``t``, broadcastable to ``noisy_images``.
signal_rates: ``1 - t``, broadcastable to ``noisy_images``.
training: Whether to call the network in training mode.
network: Explicit network to use. If ``None``, chosen based on
``training`` (see :meth:`~DiffusionModel.call`).
Returns:
A ``(pred_noises_est, pred_images)`` tuple where
``pred_noises_est`` is :math:`\\hat{\\varepsilon}` and
``pred_images`` is :math:`\\hat{x}_0`.
"""
pred_velocities = self([noisy_images, noise_rates], training=training, network=network)
# x̂₀ = x_t - t · v
pred_images = noisy_images - noise_rates * pred_velocities
# ε̂ = x̂₀ + v (since v = ε − x₀ ⟹ ε = x₀ + v)
pred_noises_est = pred_images + pred_velocities
return pred_noises_est, pred_images
[docs]
def reverse_diffusion_step(
self,
shape,
pred_images,
pred_noises,
signal_rates,
next_signal_rates,
next_noise_rates,
seed=None,
stochastic_sampling=False,
):
"""A single reverse flow-matching step.
The deterministic (ODE) step is inherited unchanged from the parent.
The stochastic step adds isotropic Langevin noise on top of the Euler
update, turning the probability-flow ODE into a Langevin SDE:
.. math::
x_{t - \\Delta t}
= x_t - \\Delta t\\, v_\\theta(x_t, t)
+ \\sqrt{2\\,\\Delta t}\\; \\mathbf{z},
\\qquad \\mathbf{z} \\sim \\mathcal{N}(0, I)
Under the linear schedule :math:`\\alpha_t = 1 - t`, the time step is
recovered as :math:`\\Delta t = \\alpha_{t-\\Delta t} - \\alpha_t`
(i.e. ``next_signal_rates - signal_rates``), which is always positive
during reverse sampling.
Args:
shape: Shape of the output tensor.
pred_images: Clean-image estimate :math:`\\hat{x}_0`.
pred_noises: Noise estimate :math:`\\hat{\\varepsilon}`
(equal to :math:`\\hat{x}_0 + v_\\theta`).
signal_rates: Current signal rates :math:`\\alpha_t = 1 - t`.
next_signal_rates: Next signal rates
:math:`\\alpha_{t - \\Delta t} = 1 - (t - \\Delta t)`.
next_noise_rates: Next noise rates :math:`t - \\Delta t`.
seed: Random seed generator.
stochastic_sampling: Whether to add Langevin noise. Default
``False`` (deterministic Euler step).
Returns:
Updated noisy images :math:`x_{t - \\Delta t}`.
"""
# Deterministic Euler step (same formula as parent under linear schedule)
next_noisy_images = next_signal_rates * pred_images + next_noise_rates * pred_noises
if not stochastic_sampling:
return next_noisy_images
# Δt = α_{t−Δt} − α_t (positive; signal_rate = 1−t increases as t decreases)
dt = next_signal_rates - signal_rates
z = keras.random.normal(shape=shape, seed=seed)
return next_noisy_images + ops.sqrt(2.0 * dt) * z
@property
def metrics(self):
"""Metrics for training."""
return [*self.velocity_loss_tracker, *self.image_loss_tracker]
def _sample_diffusion_times(self, batch_size, n_dims):
"""Sample flow times for training using logit-normal sampling.
Returns a tensor of shape ``(batch_size, 1, ..., 1)`` (``n_dims``
trailing singleton axes) with values in ``[min_t, max_t]``.
Times are drawn as :math:`t = \\sigma(z)` where
:math:`z \\sim \\mathcal{N}(0, 1)` (Esser et al., *Scaling Rectified
Flow Transformers*, 2024), then linearly rescaled to
``[min_t, max_t]``. This concentrates training mass on the hard
intermediate timesteps near :math:`t = 0.5`.
"""
shape = ops.stack([batch_size, *([1] * n_dims)])
z = keras.random.normal(shape=shape)
t01 = ops.sigmoid(z) # maps ℝ → (0, 1)
return self.min_t + t01 * (self.max_t - self.min_t)
[docs]
def train_step(self, data):
"""Custom train step for Rectified Flow (independent coupling).
Trains the network to predict the velocity field
:math:`v = \\varepsilon - x_0` from noisy observations
:math:`x_t = (1 - t)\\,x_0 + t\\,\\varepsilon`, where
:math:`\\varepsilon \\sim \\mathcal{N}(0, I)` is sampled independently
of :math:`x_0`.
Note:
Only implemented for the TensorFlow backend.
"""
if tf is None:
raise NotImplementedError(
"FlowMatchingModel.train_step is only implemented for the TensorFlow backend."
)
batch_size, *input_shape = ops.shape(data)
n_dims = len(input_shape)
noises = keras.random.normal(shape=ops.shape(data))
diffusion_times = self._sample_diffusion_times(batch_size, n_dims)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# x_t = (1 − t) · x₀ + t · ε
noisy_data = signal_rates * data + noise_rates * noises
# Target velocity: v = ε − x₀
target_velocities = noises - data
with tf.GradientTape() as tape:
pred_velocities = self([noisy_data, noise_rates], training=True)
pred_images = noisy_data - noise_rates * pred_velocities
velocity_loss = self.loss(target_velocities, pred_velocities)
image_loss = self.loss(data, pred_images)
gradients = tape.gradient(velocity_loss, self.network.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
self.velocity_loss_tracker.update_state(velocity_loss)
self.image_loss_tracker.update_state(image_loss)
# Update EMA weights
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
ema_weight.assign(self.ema_val * ema_weight + (1 - self.ema_val) * weight)
return {m.name: m.result() for m in self.metrics}
[docs]
def test_step(self, data):
"""Custom test step for Rectified Flow (independent coupling)."""
batch_size, *input_shape = ops.shape(data)
n_dims = len(input_shape)
noises = keras.random.normal(shape=ops.shape(data))
diffusion_times = self._sample_diffusion_times(batch_size, n_dims)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
noisy_data = signal_rates * data + noise_rates * noises
target_velocities = noises - data
pred_velocities = self([noisy_data, noise_rates], training=False)
pred_images = noisy_data - noise_rates * pred_velocities
velocity_loss = self.loss(target_velocities, pred_velocities)
image_loss = self.loss(data, pred_images)
self.velocity_loss_tracker.update_state(velocity_loss)
self.image_loss_tracker.update_state(image_loss)
return {m.name: m.result() for m in self.metrics}
register_presets(flow_matching_presets, FlowMatchingModel)