Source code for zea.models.flow_matching

"""
Flow matching generative model for ultrasound image generation and posterior sampling.

Replaces the cosine diffusion schedule and noise-prediction objective of
:class:`~zea.models.diffusion.DiffusionModel` with a linear flow-matching
schedule and a velocity-field prediction objective.

.. seealso::

    - :class:`~zea.models.diffusion.DiffusionModel`: DDIM-based counterpart.
    - Liu et al., *Flow Straight and Fast*, 2022. https://arxiv.org/abs/2209.03003
    - Lipman et al., *Flow Matching for Generative Modeling*, 2022. https://arxiv.org/abs/2210.02747
    - Esser et al., *Scaling Rectified Flow Transformers for High-Resolution Image Synthesis*, 2024.
      https://arxiv.org/abs/2403.03206

"""

from __future__ import annotations

import keras
from keras import ops

from zea.backend import _import_tf
from zea.internal.registry import model_registry
from zea.models.diffusion import DiffusionModel
from zea.models.preset_utils import register_presets
from zea.models.presets import flow_matching_presets
from zea.models.utils import LossTrackerWrapper

tf = _import_tf()


[docs] @model_registry(name="flow_matching") class FlowMatchingModel(DiffusionModel): """Flow matching generative model. Implements conditional flow matching (CFM) with straight-line (linear) interpolation paths between data and noise. The forward process is: .. math:: x_t = (1 - t)\\, x_0 + t\\, \\varepsilon, \\qquad \\varepsilon \\sim \\mathcal{N}(0, I) The network is trained to predict the **velocity field** .. math:: v_\\theta(x_t, t) \\approx v = \\varepsilon - x_0 from which the clean-image estimate follows as .. math:: \\hat{x}_0 = x_t - t\\, v_\\theta(x_t, t) At inference, images are generated by integrating the probability flow ODE .. math:: \\frac{dx}{dt} = v_\\theta(x_t, t) backwards from :math:`t = 1` (pure noise) to :math:`t = 0` (clean data) using a simple Euler discretisation (identical to the DDIM update rule under this linear schedule). Noise samples are drawn independently from :math:`\\mathcal{N}(0, I)` and paired with data samples via **independent (random) coupling**, i.e. vanilla CFM / Rectified Flow (Liu et al. 2022). Minibatch Optimal Transport coupling (OT-CFM, Tong et al. 2023) is not currently implemented. All sampling, guidance (DPS/DDS), and posterior-sampling machinery from :class:`~zea.models.diffusion.DiffusionModel` is inherited unchanged. """ def __init__( self, input_shape, input_range=(0, 1), network_name="unet_time_conditional", network_kwargs=None, name="flow_matching_model", guidance="dps", operator="inpainting", ema_val=0.999, min_t=0.0, max_t=1.0, **kwargs, ): """Initialize a flow matching model. Args: input_shape: Shape of the input data, typically ``(height, width, channels)`` for images. input_range: Range of the input data. Default ``(0, 1)``. network_name: Network architecture. One of ``"unet_time_conditional"`` or ``"dense_time_conditional"``. network_kwargs: Extra keyword arguments forwarded to the network constructor. name: Model name. Default ``"flow_matching_model"``. guidance: Guidance method. Can be a string (e.g. ``"dps"``), a dict with ``"name"`` and optional ``"params"`` keys, or a :class:`~zea.models.diffusion.DiffusionGuidance` instance. operator: Forward operator. Same format as ``guidance``. ema_val: Exponential moving average coefficient for the inference network weights. Default ``0.999``. min_t: Lower bound of the flow time interval. Default ``0.0``. max_t: Upper bound of the flow time interval. Default ``1.0``. **kwargs: Additional arguments forwarded to :class:`~zea.models.diffusion.DiffusionModel`. """ # min_signal_rate / max_signal_rate are cosine-schedule parameters # that are not used by FlowMatchingModel; pass neutral values so the # parent __init__ does not fail. super().__init__( input_shape=input_shape, input_range=input_range, min_signal_rate=0.0, max_signal_rate=1.0, network_name=network_name, network_kwargs=network_kwargs, name=name, guidance=guidance, operator=operator, ema_val=ema_val, min_t=min_t, max_t=max_t, **kwargs, ) # Replace the noise-loss tracker with a velocity-loss tracker. self.velocity_loss_tracker = LossTrackerWrapper("v_loss")
[docs] def get_config(self): config = super().get_config() # min/max_signal_rate are meaningless for this model config.pop("min_signal_rate", None) config.pop("max_signal_rate", None) return config
[docs] def diffusion_schedule(self, diffusion_times): """Linear flow-matching schedule. .. math:: \\text{noise\\_rates} = t, \\qquad \\text{signal\\_rates} = 1 - t Args: diffusion_times: Tensor of flow times in ``[min_t, max_t]``. Returns: A ``(noise_rates, signal_rates)`` tuple with the same shape as ``diffusion_times``. """ noise_rates = ops.cast(diffusion_times, "float32") signal_rates = 1.0 - noise_rates return noise_rates, signal_rates
[docs] def denoise( self, noisy_images, noise_rates, signal_rates, training: bool, network=None, ): """Predict the velocity field and derive the clean-image estimate. The network predicts the velocity :math:`v_\\theta(x_t, t)`. The clean-image estimate follows as .. math:: \\hat{x}_0 = x_t - t\\, v_\\theta(x_t, t) To keep full compatibility with the parent's sampling and guidance machinery (which expects a ``(pred_noises, pred_images)`` return value), the method also returns the corresponding noise estimate .. math:: \\hat{\\varepsilon} = \\hat{x}_0 + v_\\theta = x_t + (1 - t)\\, v_\\theta under the name ``pred_noises``. The parent's :meth:`~DiffusionModel.reverse_diffusion_step` formula .. math:: x_{t - \\Delta t} = \\alpha_{t-\\Delta t}\\,\\hat{x}_0 + \\sigma_{t-\\Delta t}\\,\\hat{\\varepsilon} is algebraically equivalent to the Euler step :math:`x_{t-\\Delta t} = x_t - \\Delta t\\, v_\\theta` under the linear schedule, so no changes to the sampling loop are needed. Args: noisy_images: Noisy images ``x_t`` of shape ``(n_images, *input_shape)``. noise_rates: Flow times ``t``, broadcastable to ``noisy_images``. signal_rates: ``1 - t``, broadcastable to ``noisy_images``. training: Whether to call the network in training mode. network: Explicit network to use. If ``None``, chosen based on ``training`` (see :meth:`~DiffusionModel.call`). Returns: A ``(pred_noises_est, pred_images)`` tuple where ``pred_noises_est`` is :math:`\\hat{\\varepsilon}` and ``pred_images`` is :math:`\\hat{x}_0`. """ pred_velocities = self([noisy_images, noise_rates], training=training, network=network) # x̂₀ = x_t - t · v pred_images = noisy_images - noise_rates * pred_velocities # ε̂ = x̂₀ + v (since v = ε − x₀ ⟹ ε = x₀ + v) pred_noises_est = pred_images + pred_velocities return pred_noises_est, pred_images
[docs] def reverse_diffusion_step( self, shape, pred_images, pred_noises, signal_rates, next_signal_rates, next_noise_rates, seed=None, stochastic_sampling=False, ): """A single reverse flow-matching step. The deterministic (ODE) step is inherited unchanged from the parent. The stochastic step adds isotropic Langevin noise on top of the Euler update, turning the probability-flow ODE into a Langevin SDE: .. math:: x_{t - \\Delta t} = x_t - \\Delta t\\, v_\\theta(x_t, t) + \\sqrt{2\\,\\Delta t}\\; \\mathbf{z}, \\qquad \\mathbf{z} \\sim \\mathcal{N}(0, I) Under the linear schedule :math:`\\alpha_t = 1 - t`, the time step is recovered as :math:`\\Delta t = \\alpha_{t-\\Delta t} - \\alpha_t` (i.e. ``next_signal_rates - signal_rates``), which is always positive during reverse sampling. Args: shape: Shape of the output tensor. pred_images: Clean-image estimate :math:`\\hat{x}_0`. pred_noises: Noise estimate :math:`\\hat{\\varepsilon}` (equal to :math:`\\hat{x}_0 + v_\\theta`). signal_rates: Current signal rates :math:`\\alpha_t = 1 - t`. next_signal_rates: Next signal rates :math:`\\alpha_{t - \\Delta t} = 1 - (t - \\Delta t)`. next_noise_rates: Next noise rates :math:`t - \\Delta t`. seed: Random seed generator. stochastic_sampling: Whether to add Langevin noise. Default ``False`` (deterministic Euler step). Returns: Updated noisy images :math:`x_{t - \\Delta t}`. """ # Deterministic Euler step (same formula as parent under linear schedule) next_noisy_images = next_signal_rates * pred_images + next_noise_rates * pred_noises if not stochastic_sampling: return next_noisy_images # Δt = α_{t−Δt} − α_t (positive; signal_rate = 1−t increases as t decreases) dt = next_signal_rates - signal_rates z = keras.random.normal(shape=shape, seed=seed) return next_noisy_images + ops.sqrt(2.0 * dt) * z
@property def metrics(self): """Metrics for training.""" return [*self.velocity_loss_tracker, *self.image_loss_tracker] def _sample_diffusion_times(self, batch_size, n_dims): """Sample flow times for training using logit-normal sampling. Returns a tensor of shape ``(batch_size, 1, ..., 1)`` (``n_dims`` trailing singleton axes) with values in ``[min_t, max_t]``. Times are drawn as :math:`t = \\sigma(z)` where :math:`z \\sim \\mathcal{N}(0, 1)` (Esser et al., *Scaling Rectified Flow Transformers*, 2024), then linearly rescaled to ``[min_t, max_t]``. This concentrates training mass on the hard intermediate timesteps near :math:`t = 0.5`. """ shape = ops.stack([batch_size, *([1] * n_dims)]) z = keras.random.normal(shape=shape) t01 = ops.sigmoid(z) # maps ℝ → (0, 1) return self.min_t + t01 * (self.max_t - self.min_t)
[docs] def train_step(self, data): """Custom train step for Rectified Flow (independent coupling). Trains the network to predict the velocity field :math:`v = \\varepsilon - x_0` from noisy observations :math:`x_t = (1 - t)\\,x_0 + t\\,\\varepsilon`, where :math:`\\varepsilon \\sim \\mathcal{N}(0, I)` is sampled independently of :math:`x_0`. Note: Only implemented for the TensorFlow backend. """ if tf is None: raise NotImplementedError( "FlowMatchingModel.train_step is only implemented for the TensorFlow backend." ) batch_size, *input_shape = ops.shape(data) n_dims = len(input_shape) noises = keras.random.normal(shape=ops.shape(data)) diffusion_times = self._sample_diffusion_times(batch_size, n_dims) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) # x_t = (1 − t) · x₀ + t · ε noisy_data = signal_rates * data + noise_rates * noises # Target velocity: v = ε − x₀ target_velocities = noises - data with tf.GradientTape() as tape: pred_velocities = self([noisy_data, noise_rates], training=True) pred_images = noisy_data - noise_rates * pred_velocities velocity_loss = self.loss(target_velocities, pred_velocities) image_loss = self.loss(data, pred_images) gradients = tape.gradient(velocity_loss, self.network.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) self.velocity_loss_tracker.update_state(velocity_loss) self.image_loss_tracker.update_state(image_loss) # Update EMA weights for weight, ema_weight in zip(self.network.weights, self.ema_network.weights): ema_weight.assign(self.ema_val * ema_weight + (1 - self.ema_val) * weight) return {m.name: m.result() for m in self.metrics}
[docs] def test_step(self, data): """Custom test step for Rectified Flow (independent coupling).""" batch_size, *input_shape = ops.shape(data) n_dims = len(input_shape) noises = keras.random.normal(shape=ops.shape(data)) diffusion_times = self._sample_diffusion_times(batch_size, n_dims) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) noisy_data = signal_rates * data + noise_rates * noises target_velocities = noises - data pred_velocities = self([noisy_data, noise_rates], training=False) pred_images = noisy_data - noise_rates * pred_velocities velocity_loss = self.loss(target_velocities, pred_velocities) image_loss = self.loss(data, pred_images) self.velocity_loss_tracker.update_state(velocity_loss) self.image_loss_tracker.update_state(image_loss) return {m.name: m.result() for m in self.metrics}
register_presets(flow_matching_presets, FlowMatchingModel)