Source code for zea.models.flow_matching

"""
Flow matching generative model for ultrasound image generation and posterior sampling.

Replaces the cosine diffusion schedule and noise-prediction objective of
:class:`~zea.models.diffusion.DiffusionModel` with a linear flow-matching
schedule and a velocity-field prediction objective.

.. seealso::

    - :class:`~zea.models.diffusion.DiffusionModel`: DDIM-based counterpart.
    - Liu et al., *Flow Straight and Fast*, 2022. https://arxiv.org/abs/2209.03003
    - Lipman et al., *Flow Matching for Generative Modeling*, 2022. https://arxiv.org/abs/2210.02747
    - Esser et al., *Scaling Rectified Flow Transformers for High-Resolution Image Synthesis*, 2024.
      https://arxiv.org/abs/2403.03206

"""

from __future__ import annotations

import keras
from keras import ops

from zea.backend import _import_tf
from zea.internal.registry import model_registry
from zea.models.diffusion import DiffusionModel
from zea.models.preset_utils import register_presets
from zea.models.presets import flow_matching_presets
from zea.models.utils import LossTrackerWrapper

tf = _import_tf()


[docs] @model_registry(name="flow_matching") class FlowMatchingModel(DiffusionModel): """Flow matching generative model. Implements conditional flow matching (CFM) with straight-line (linear) interpolation paths between data and noise. The forward process is: .. math:: x_t = (1 - t)\\, x_0 + t\\, \\varepsilon, \\qquad \\varepsilon \\sim \\mathcal{N}(0, I) The network is trained to predict the **velocity field** .. math:: v_\\theta(x_t, t) \\approx v = \\varepsilon - x_0 from which the clean-image estimate follows as .. math:: \\hat{x}_0 = x_t - t\\, v_\\theta(x_t, t) At inference, images are generated by integrating the probability flow ODE .. math:: \\frac{dx}{dt} = v_\\theta(x_t, t) backwards from :math:`t = 1` (pure noise) to :math:`t = 0` (clean data) using a simple Euler discretisation (identical to the DDIM update rule under this linear schedule). Noise samples are drawn independently from :math:`\\mathcal{N}(0, I)` and paired with data samples via **independent (random) coupling**, i.e. vanilla CFM / Rectified Flow (Liu et al. 2022). Minibatch Optimal Transport coupling (OT-CFM, Tong et al. 2023) is not currently implemented. All sampling, guidance (DPS/DDS), and posterior-sampling machinery from :class:`~zea.models.diffusion.DiffusionModel` is inherited unchanged. """ def __init__( self, input_shape, input_range=(0, 1), network_name="unet_time_conditional", network_kwargs=None, name="flow_matching_model", guidance="dps", operator="inpainting", ema_val=0.999, min_t=0.0, max_t=1.0, solver="heun", **kwargs, ): """Initialize a flow matching model. Args: input_shape: Shape of the input data, typically ``(height, width, channels)`` for images. input_range: Range of the input data. Default ``(0, 1)``. network_name: Network architecture. One of ``"unet_time_conditional"``, ``"dense_time_conditional"``, or ``"dit_time_conditional"`` (Diffusion Transformer). network_kwargs: Extra keyword arguments forwarded to the network constructor. name: Model name. Default ``"flow_matching_model"``. guidance: Guidance method. Can be a string (e.g. ``"dps"``), a dict with ``"name"`` and optional ``"params"`` keys, or a :class:`~zea.models.diffusion.DiffusionGuidance` instance. operator: Forward operator. Same format as ``guidance``. ema_val: Exponential moving average coefficient for the inference network weights. Default ``0.999``. min_t: Lower bound of the flow time interval. Default ``0.0``. max_t: Upper bound of the flow time interval. Default ``1.0``. solver: ODE solver used for (unconditional) sampling. One of ``"heun"`` (second-order Euler–Heun, the default) or ``"euler"`` (first-order). Heun evaluates the velocity field twice per step for higher accuracy and is purely an inference-time choice (no retraining needed). See :meth:`solver_step`. **kwargs: Additional arguments forwarded to :class:`~zea.models.diffusion.DiffusionModel`. """ if solver not in ("euler", "heun"): raise ValueError(f"solver must be one of 'euler' or 'heun', got {solver!r}.") # min_signal_rate / max_signal_rate are cosine-schedule parameters # that are not used by FlowMatchingModel; pass neutral values so the # parent __init__ does not fail. super().__init__( input_shape=input_shape, input_range=input_range, min_signal_rate=0.0, max_signal_rate=1.0, network_name=network_name, network_kwargs=network_kwargs, name=name, guidance=guidance, operator=operator, ema_val=ema_val, min_t=min_t, max_t=max_t, **kwargs, ) self.solver = solver # Replace the noise-loss tracker with a velocity-loss tracker. self.velocity_loss_tracker = LossTrackerWrapper("v_loss")
[docs] def get_config(self): config = super().get_config() # min/max_signal_rate are meaningless for this model config.pop("min_signal_rate", None) config.pop("max_signal_rate", None) config["solver"] = self.solver return config
[docs] def solver_step( self, noisy_images, noise_rates, signal_rates, next_noise_rates, next_signal_rates, shape, network=None, training: bool = False, seed=None, stochastic_sampling: bool = False, ): r"""Single ODE solver step for the probability-flow ODE. Integrates :math:`\frac{dx}{dt} = v_\theta(x_t, t)` backwards by one step using either a first-order Euler update (``solver="euler"``) or a second-order Euler–Heun (improved-Euler) update (``solver="heun"``, the default). The Heun update evaluates the velocity field **twice** per step: .. math:: \tilde{x}_{t-\Delta t} &= x_t - \Delta t\, v_\theta(x_t, t) \qquad\text{(Euler predictor)} \\ x_{t-\Delta t} &= x_t - \tfrac{\Delta t}{2}\big( v_\theta(x_t, t) + v_\theta(\tilde{x}_{t-\Delta t},\, t-\Delta t) \big) \qquad\text{(trapezoidal corrector)} where :math:`\Delta t = t - (t-\Delta t)` equals ``noise_rates - next_noise_rates`` (positive during reverse sampling). Heun's method only changes inference; it reuses the same trained velocity network and requires no retraining. Stochastic sampling falls back to the first-order Euler–Maruyama update inherited from :class:`~zea.models.diffusion.DiffusionModel`, since the deterministic Heun corrector does not apply to the Langevin SDE. Args: noisy_images: Current noisy images ``x_t``. noise_rates: Flow times ``t`` at the current step. signal_rates: ``1 - t`` at the current step. next_noise_rates: Flow times ``t - Δt`` at the next step. next_signal_rates: ``1 - (t - Δt)`` at the next step. shape: Shape of the image tensor. network: Explicit network to use (``None`` selects based on ``training``). training: Whether to call the network in training mode. seed: Random seed generator (for stochastic sampling). stochastic_sampling: Whether to use stochastic (Langevin) sampling. Returns: A ``(next_noisy_images, pred_images)`` tuple where ``next_noisy_images`` is ``x_{t-Δt}`` and ``pred_images`` is the clean-image estimate ``x̂₀`` at the current step. """ pred_noises, pred_images = self.denoise( noisy_images, noise_rates, signal_rates, training=training, network=network ) # Euler predictor (also the full update for the first-order solver). next_noisy_images = self.reverse_diffusion_step( shape=shape, pred_images=pred_images, pred_noises=pred_noises, signal_rates=signal_rates, next_signal_rates=next_signal_rates, next_noise_rates=next_noise_rates, seed=seed, stochastic_sampling=stochastic_sampling, ) if self.solver == "euler" or stochastic_sampling: return next_noisy_images, pred_images # --- Heun corrector (second velocity evaluation) --- # Velocity at the current point: v = ε̂ − x̂₀ velocity = pred_noises - pred_images # Velocity at the Euler-predicted point and the next (lower) time. pred_noises_next, pred_images_next = self.denoise( next_noisy_images, next_noise_rates, next_signal_rates, training=training, network=network, ) velocity_next = pred_noises_next - pred_images_next # Δt = t − (t − Δt) = noise_rates − next_noise_rates (≥ 0 in reverse) dt = noise_rates - next_noise_rates next_noisy_images = noisy_images - 0.5 * dt * (velocity + velocity_next) return next_noisy_images, pred_images
[docs] def diffusion_schedule(self, diffusion_times): """Linear flow-matching schedule. .. math:: \\text{noise\\_rates} = t, \\qquad \\text{signal\\_rates} = 1 - t Args: diffusion_times: Tensor of flow times in ``[min_t, max_t]``. Returns: A ``(noise_rates, signal_rates)`` tuple with the same shape as ``diffusion_times``. """ noise_rates = ops.cast(diffusion_times, "float32") signal_rates = 1.0 - noise_rates return noise_rates, signal_rates
[docs] def denoise( self, noisy_images, noise_rates, signal_rates, training: bool, network=None, ): """Predict the velocity field and derive the clean-image estimate. The network predicts the velocity :math:`v_\\theta(x_t, t)`. The clean-image estimate follows as .. math:: \\hat{x}_0 = x_t - t\\, v_\\theta(x_t, t) To keep full compatibility with the parent's sampling and guidance machinery (which expects a ``(pred_noises, pred_images)`` return value), the method also returns the corresponding noise estimate .. math:: \\hat{\\varepsilon} = \\hat{x}_0 + v_\\theta = x_t + (1 - t)\\, v_\\theta under the name ``pred_noises``. The parent's :meth:`~DiffusionModel.reverse_diffusion_step` formula .. math:: x_{t - \\Delta t} = \\alpha_{t-\\Delta t}\\,\\hat{x}_0 + \\sigma_{t-\\Delta t}\\,\\hat{\\varepsilon} is algebraically equivalent to the Euler step :math:`x_{t-\\Delta t} = x_t - \\Delta t\\, v_\\theta` under the linear schedule, so no changes to the sampling loop are needed. Args: noisy_images: Noisy images ``x_t`` of shape ``(n_images, *input_shape)``. noise_rates: Flow times ``t``, broadcastable to ``noisy_images``. signal_rates: ``1 - t``, broadcastable to ``noisy_images``. training: Whether to call the network in training mode. network: Explicit network to use. If ``None``, chosen based on ``training`` (see :meth:`~DiffusionModel.call`). Returns: A ``(pred_noises_est, pred_images)`` tuple where ``pred_noises_est`` is :math:`\\hat{\\varepsilon}` and ``pred_images`` is :math:`\\hat{x}_0`. """ pred_velocities = self([noisy_images, noise_rates], training=training, network=network) # x̂₀ = x_t - t · v pred_images = noisy_images - noise_rates * pred_velocities # ε̂ = x̂₀ + v (since v = ε − x₀ ⟹ ε = x₀ + v) pred_noises_est = pred_images + pred_velocities return pred_noises_est, pred_images
[docs] def reverse_diffusion_step( self, shape, pred_images, pred_noises, signal_rates, next_signal_rates, next_noise_rates, seed=None, stochastic_sampling=False, ): """A single reverse flow-matching step. The deterministic (ODE) step is inherited unchanged from the parent. The stochastic step adds isotropic Langevin noise on top of the Euler update, turning the probability-flow ODE into a Langevin SDE: .. math:: x_{t - \\Delta t} = x_t - \\Delta t\\, v_\\theta(x_t, t) + \\sqrt{2\\,\\Delta t}\\; \\mathbf{z}, \\qquad \\mathbf{z} \\sim \\mathcal{N}(0, I) Under the linear schedule :math:`\\alpha_t = 1 - t`, the time step is recovered as :math:`\\Delta t = \\alpha_{t-\\Delta t} - \\alpha_t` (i.e. ``next_signal_rates - signal_rates``), which is always positive during reverse sampling. Args: shape: Shape of the output tensor. pred_images: Clean-image estimate :math:`\\hat{x}_0`. pred_noises: Noise estimate :math:`\\hat{\\varepsilon}` (equal to :math:`\\hat{x}_0 + v_\\theta`). signal_rates: Current signal rates :math:`\\alpha_t = 1 - t`. next_signal_rates: Next signal rates :math:`\\alpha_{t - \\Delta t} = 1 - (t - \\Delta t)`. next_noise_rates: Next noise rates :math:`t - \\Delta t`. seed: Random seed generator. stochastic_sampling: Whether to add Langevin noise. Default ``False`` (deterministic Euler step). Returns: Updated noisy images :math:`x_{t - \\Delta t}`. """ # Deterministic Euler step (same formula as parent under linear schedule) next_noisy_images = next_signal_rates * pred_images + next_noise_rates * pred_noises if not stochastic_sampling: return next_noisy_images # Δt = α_{t−Δt} − α_t (positive; signal_rate = 1−t increases as t decreases) dt = next_signal_rates - signal_rates z = keras.random.normal(shape=shape, seed=seed) return next_noisy_images + ops.sqrt(2.0 * dt) * z
@property def metrics(self): """Metrics for training.""" return [*self.velocity_loss_tracker, *self.image_loss_tracker] def _sample_diffusion_times(self, batch_size, n_dims): """Sample flow times for training using logit-normal sampling. Returns a tensor of shape ``(batch_size, 1, ..., 1)`` (``n_dims`` trailing singleton axes) with values in ``[min_t, max_t]``. Times are drawn as :math:`t = \\sigma(z)` where :math:`z \\sim \\mathcal{N}(0, 1)` (Esser et al., *Scaling Rectified Flow Transformers*, 2024), then linearly rescaled to ``[min_t, max_t]``. This concentrates training mass on the hard intermediate timesteps near :math:`t = 0.5`. """ shape = ops.stack([batch_size, *([1] * n_dims)]) z = keras.random.normal(shape=shape) t01 = ops.sigmoid(z) # maps ℝ → (0, 1) return self.min_t + t01 * (self.max_t - self.min_t)
[docs] def train_step(self, data): """Custom train step for Rectified Flow (independent coupling). Trains the network to predict the velocity field :math:`v = \\varepsilon - x_0` from noisy observations :math:`x_t = (1 - t)\\,x_0 + t\\,\\varepsilon`, where :math:`\\varepsilon \\sim \\mathcal{N}(0, I)` is sampled independently of :math:`x_0`. Note: Only implemented for the TensorFlow backend. """ if tf is None: raise NotImplementedError( "FlowMatchingModel.train_step is only implemented for the TensorFlow backend." ) batch_size, *input_shape = ops.shape(data) n_dims = len(input_shape) noises = keras.random.normal(shape=ops.shape(data)) diffusion_times = self._sample_diffusion_times(batch_size, n_dims) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) # x_t = (1 − t) · x₀ + t · ε noisy_data = signal_rates * data + noise_rates * noises # Target velocity: v = ε − x₀ target_velocities = noises - data with tf.GradientTape() as tape: pred_velocities = self([noisy_data, noise_rates], training=True) pred_images = noisy_data - noise_rates * pred_velocities velocity_loss = self.loss(target_velocities, pred_velocities) image_loss = self.loss(data, pred_images) gradients = tape.gradient(velocity_loss, self.network.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) self.velocity_loss_tracker.update_state(velocity_loss) self.image_loss_tracker.update_state(image_loss) # Update EMA weights for weight, ema_weight in zip(self.network.weights, self.ema_network.weights): ema_weight.assign(self.ema_val * ema_weight + (1 - self.ema_val) * weight) return {m.name: m.result() for m in self.metrics}
[docs] def test_step(self, data): """Custom test step for Rectified Flow (independent coupling).""" batch_size, *input_shape = ops.shape(data) n_dims = len(input_shape) noises = keras.random.normal(shape=ops.shape(data)) diffusion_times = self._sample_diffusion_times(batch_size, n_dims) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) noisy_data = signal_rates * data + noise_rates * noises target_velocities = noises - data pred_velocities = self([noisy_data, noise_rates], training=False) pred_images = noisy_data - noise_rates * pred_velocities velocity_loss = self.loss(target_velocities, pred_velocities) image_loss = self.loss(data, pred_images) self.velocity_loss_tracker.update_state(velocity_loss) self.image_loss_tracker.update_state(image_loss) return {m.name: m.result() for m in self.metrics}
register_presets(flow_matching_presets, FlowMatchingModel)