"""
Diffusion models for ultrasound image generation and posterior sampling.
To try this model, simply load one of the available presets:
.. doctest::
>>> from zea.models.diffusion import DiffusionModel
>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
.. seealso::
A tutorial notebook where this model is used:
:doc:`../notebooks/models/diffusion_model_example`.
"""
from __future__ import annotations
import abc
from typing import Literal
import keras
from keras import ops
from zea.backend import _import_tf, jit
from zea.backend.autograd import AutoGrad
from zea.func.tensor import L2, fori_loop, split_seed
from zea.internal.core import Object
from zea.internal.operators import Operator
from zea.internal.registry import diffusion_guidance_registry, model_registry, operator_registry
from zea.internal.utils import fn_requires_argument
from zea.models.dense import get_time_conditional_dense_network
from zea.models.generative import DeepGenerativeModel
from zea.models.preset_utils import register_presets
from zea.models.presets import diffusion_model_presets
from zea.models.unet import get_time_conditional_unetwork
from zea.models.utils import LossTrackerWrapper
tf = _import_tf()
[docs]
@model_registry(name="diffusion")
class DiffusionModel(DeepGenerativeModel):
"""Implementation of a diffusion generative model.
Heavily inspired from https://keras.io/examples/generative/ddim/
"""
def __init__(
self,
input_shape,
input_range=(0, 1),
min_signal_rate=0.02,
max_signal_rate=0.95,
network_name="unet_time_conditional",
network_kwargs=None,
name="diffusion_model",
guidance="dps",
operator="inpainting",
ema_val=0.999,
min_t=0.0,
max_t=1.0,
**kwargs,
):
"""Initialize a diffusion model.
Args:
input_shape: Shape of the input data. Typically of the form
`(height, width, channels)` for images.
input_range: Range of the input data.
min_signal_rate: Minimum signal rate for the diffusion schedule.
max_signal_rate: Maximum signal rate for the diffusion schedule.
network_name: Name of the network architecture to use. Options are
"unet_time_conditional" or "dense_time_conditional".
network_kwargs: Additional keyword arguments for the network.
name: Name of the model.
guidance: Guidance method to use. Can be a string, or dict with
"name" and "params" keys. Additionally, can be a `DiffusionGuidance` object.
operator: Operator to use. Can be a string, or dict with
"name" and "params" keys. Additionally, can be a `Operator` object.
ema_val: Exponential moving average value for the network weights.
min_t: Minimum diffusion time for sampling during training.
max_t: Maximum diffusion time for sampling during training.
**kwargs: Additional arguments.
"""
super().__init__(name=name, **kwargs)
self.input_shape = input_shape
self.input_range = input_range
self.min_signal_rate = min_signal_rate
self.max_signal_rate = max_signal_rate
self.network_name = network_name
self.network_kwargs = network_kwargs or {}
self.ema_val = ema_val
# reverse diffusion (i.e. sampling) goes from t = max_t to t = min_t
self.min_t = min_t
self.max_t = max_t
if network_name == "unet_time_conditional":
self.network = get_time_conditional_unetwork(
image_shape=self.input_shape,
**self.network_kwargs,
)
elif network_name == "dense_time_conditional":
assert len(input_shape) == 1, "Dense network only supports 1D input"
self.network = get_time_conditional_dense_network(
input_dim=self.input_shape[0],
**self.network_kwargs,
)
else:
raise ValueError("Invalid network name provided.")
# Also initialize the exponential moving average network
self.ema_network = keras.models.clone_model(self.network)
self.ema_network.trainable = False
self.image_loss_tracker = LossTrackerWrapper("i_loss")
self.noise_loss_tracker = LossTrackerWrapper("n_loss")
# for storing intermediate results (i.e. diffusion trajectory)
self.track_progress_interval = 1
self.track_progress = []
# for guidance / conditional sampling
self.guidance_fn = None
self.operator = None
self._init_operator_and_guidance(operator, guidance)
[docs]
def get_config(self):
config = super().get_config()
config.update(
{
"input_shape": self.input_shape,
"input_range": self.input_range,
"min_signal_rate": self.min_signal_rate,
"max_signal_rate": self.max_signal_rate,
"min_t": self.min_t,
"max_t": self.max_t,
"network_name": self.network_name,
"network_kwargs": self.network_kwargs,
"ema_val": self.ema_val,
}
)
return config
def _init_operator_and_guidance(self, operator, guidance):
if operator is not None:
if isinstance(operator, str):
operator_class = operator_registry[operator]
self.operator = operator_class()
elif isinstance(operator, Operator):
self.operator = operator
elif isinstance(operator, dict):
operator_class = operator_registry[operator["name"]]
if "params" not in operator:
operator["params"] = {}
if (
fn_requires_argument(operator_class.__init__, "image_range")
and "image_range" not in operator["params"]
):
operator["params"]["image_range"] = self.input_range
self.operator = operator_class(**operator["params"])
else:
raise ValueError(
f"Invalid operator provided, must be a string, dict or "
f"Operator object, got {operator}"
)
if guidance is not None:
assert operator is not None, "Operator must be provided for guidance"
if isinstance(guidance, str):
guidance_class = diffusion_guidance_registry[guidance]
self.guidance_fn = guidance_class(
diffusion_model=self,
operator=self.operator,
)
elif isinstance(guidance, DiffusionGuidance):
self.guidance_fn = guidance
elif isinstance(guidance, dict):
guidance_class = diffusion_guidance_registry[guidance["name"]]
self.guidance_fn = guidance_class(
diffusion_model=self, operator=self.operator, **guidance["params"]
)
else:
raise ValueError(
f"Invalid guidance provided, must be a string, dict or "
f"DiffusionGuidance object, got {guidance}"
)
[docs]
def call(self, inputs, training: bool = False, network=None, **kwargs):
"""Call the score network.
Args:
inputs: A list ``[noisy_images, noise_rates_squared]`` as
expected by the underlying time-conditional network.
training: Whether to run in training mode. When ``False`` and
``network`` is ``None``, the EMA network is used.
network: Explicit network to call. If ``None``, the EMA network
is used during inference and the online network during
training.
**kwargs: Extra keyword arguments forwarded to the network.
Returns:
Predicted noise tensor of the same shape as the input images.
"""
if network is None:
network = self.network if training else self.ema_network
return network(inputs, training=training, **kwargs)
[docs]
def sample(self, n_samples=1, n_steps=20, seed=None, **kwargs):
"""Sample from the model.
Args:
n_samples: Number of samples to generate.
n_steps: Number of diffusion steps.
seed: Random seed generator.
**kwargs: Additional arguments.
Returns:
Generated samples of shape `(n_samples, *input_shape)`.
"""
seed, seed1 = split_seed(seed, 2)
# Generate random noise
noise = keras.random.normal(
shape=(n_samples, *self.input_shape),
seed=seed1,
)
# Reverse diffusion process
return self.reverse_diffusion(
initial_noise=noise, diffusion_steps=n_steps, seed=seed, **kwargs
)
[docs]
def posterior_sample(
self,
measurements,
n_samples: int = 1,
n_steps: int = 20,
initial_step: int = 0,
initial_samples=None,
seed=None,
**kwargs,
):
"""Sample from the posterior distribution given measurements.
Args:
measurements: Input measurements. Typically of shape
``(batch_size, *input_shape)``.
n_samples: Number of posterior samples to generate.
Will generate ``n_samples`` samples for each measurement
in the ``measurements`` batch.
n_steps: Number of diffusion steps.
initial_step: Step at which to begin the reverse diffusion loop.
``0`` runs all ``diffusion_steps`` steps from maximum noise.
Higher values skip early (high-noise) steps and require
``initial_samples`` to be provided. Number of effective steps
will be ``diffusion_steps - initial_step``.
initial_samples: Optional initial samples to warm-start the
diffusion process. The diffusion process now starts from a
*noised* version of these samples. This can be used to speed
up the diffusion process.
When ``initial_step == 0``, samples are noised at the maximum
noise level (``max_t``). When ``initial_step > 0``, samples
are noised at the noise level corresponding to ``initial_step``.
These ``initial_samples`` can be initial guesses such as solutions
of previous frames (for sequences), see for instance
`SeqDiff <https://arxiv.org/abs/2409.05399>`_.
Must be of shape ``(batch_size, n_samples, *input_shape)``.
seed: Random seed generator.
**kwargs: Additional arguments passed to
:meth:`reverse_conditional_diffusion`.
Returns:
Posterior samples ``p(x|y)`` of shape
``(batch_size, n_samples, *input_shape)``.
"""
batch_size = ops.shape(measurements)[0]
shape = (batch_size, n_samples, *self.input_shape)
def _tile_with_sample_dim(tensor):
"""Tile the tensor with an additional sample dimension."""
shape = ops.shape(tensor)
tensor = ops.repeat(tensor[:, None], n_samples, axis=1) # (batch, n_samples, ...)
return ops.reshape(tensor, (-1, *shape[1:]))
measurements = _tile_with_sample_dim(measurements)
if initial_samples is not None:
initial_samples = ops.reshape(initial_samples, (-1, *self.input_shape))
if "mask" in kwargs:
kwargs["mask"] = _tile_with_sample_dim(kwargs["mask"])
seed1, seed2 = split_seed(seed, 2)
initial_noise = keras.random.normal(
shape=(batch_size * n_samples, *self.input_shape),
seed=seed1,
)
out = self.reverse_conditional_diffusion(
measurements=measurements,
initial_noise=initial_noise,
diffusion_steps=n_steps,
initial_samples=initial_samples,
initial_step=initial_step,
seed=seed2,
**kwargs,
) # ( batch_size * n_samples, *self.input_shape)
return ops.reshape(out, shape) # (batch_size, n_samples, *input_shape)
[docs]
def log_likelihood(self, data, **kwargs):
"""Approximate log-likelihood of the data under the model.
Args:
data: Data to compute log-likelihood for.
**kwargs: Additional arguments.
Returns:
Approximate log-likelihood.
"""
# This is a placeholder for likelihood estimation
raise NotImplementedError("Likelihood computation for diffusion models not implemented yet")
@property
def metrics(self):
"""Metrics for training."""
return [*self.noise_loss_tracker, *self.image_loss_tracker]
[docs]
def train_step(self, data):
"""Custom train step so we can call model.fit() on the diffusion model.
Note:
- Only implemented for the TensorFlow backend.
"""
if tf is None:
raise NotImplementedError(
"DiffusionModel.train_step is only implemented for the TensorFlow backend."
)
# Get batch size and image shape
batch_size, *input_shape = ops.shape(data)
n_dims = len(input_shape)
# Generate random noise
noises = keras.random.normal(shape=ops.shape(data))
# Sample uniform random diffusion times in [min_t, max_t]
diffusion_times = keras.random.uniform(
shape=ops.stack([batch_size, *([1] * n_dims)]),
minval=self.min_t,
maxval=self.max_t,
)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# Mix data and noises
noisy_data = signal_rates * data + noise_rates * noises
with tf.GradientTape() as tape:
pred_noises, pred_images = self.denoise(
noisy_data, noise_rates, signal_rates, training=True
)
noise_loss = self.loss(noises, pred_noises)
image_loss = self.loss(data, pred_images)
gradients = tape.gradient(noise_loss, self.network.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
# track the exponential moving averages of weights.
# ema_network is used for inference / sampling
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
ema_weight.assign(self.ema_val * ema_weight + (1 - self.ema_val) * weight)
return {m.name: m.result() for m in self.metrics}
[docs]
def test_step(self, data):
"""
Custom test step so we can call model.fit() on the diffusion model.
"""
batch_size, *input_shape = ops.shape(data)
n_dims = len(input_shape)
noises = keras.random.normal(shape=ops.shape(data))
# sample uniform random diffusion times
diffusion_times = keras.random.uniform(
shape=ops.stack([batch_size, *([1] * n_dims)]),
minval=self.min_t,
maxval=self.max_t,
)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# mix the images with noises accordingly
noisy_images = signal_rates * data + noise_rates * noises
# use the network to separate noisy images to their components
pred_noises, pred_images = self.denoise(
noisy_images, noise_rates, signal_rates, training=False
)
noise_loss = self.loss(noises, pred_noises)
image_loss = self.loss(data, pred_images)
self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
return {m.name: m.result() for m in self.metrics}
[docs]
def diffusion_schedule(self, diffusion_times):
"""Cosine diffusion schedule.
Implements the cosine schedule from `Nichol & Dhariwal (2021)
<https://arxiv.org/abs/2102.09672>`_.
The noisy image at time ``t`` is defined as:
Args:
diffusion_times: Tensor of diffusion times in ``[min_t, max_t]``.
Returns:
A ``(noise_rates, signal_rates)`` tuple of tensors with the
same shape as ``diffusion_times``.
""" # noqa: E501
# diffusion times -> angles
start_angle = ops.cast(ops.arccos(self.max_signal_rate), "float32")
end_angle = ops.cast(ops.arccos(self.min_signal_rate), "float32")
diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
# angles -> signal and noise rates
signal_rates = ops.cos(diffusion_angles)
noise_rates = ops.sin(diffusion_angles)
# note that their squared sum is always: sin^2(x) + cos^2(x) = 1
return noise_rates, signal_rates
[docs]
def linear_diffusion_schedule(self, diffusion_times):
"""Create a linear diffusion schedule"""
def _compute_alpha_t(t):
"""Compute alpha_t for linear diffusion schedule"""
return ops.prod(1 - diffusion_times[:t], axis=diffusion_times.shape[1:])
alphas = ops.vectorized_map(_compute_alpha_t, ops.arange(len(diffusion_times)))
signal_rates = ops.sqrt(alphas)
noise_rates = ops.sqrt(1 - alphas)
return signal_rates, noise_rates
[docs]
def denoise(
self,
noisy_images,
noise_rates,
signal_rates,
training: bool,
network=None,
):
"""Predict the noise component and derive the clean-image estimate.
Uses the score network to predict the noise ``ε`` in ``x_t``, then
computes the Tweedie estimate of ``x_0``.
Args:
noisy_images: Noisy images ``x_t`` of shape
``(n_images, *input_shape)``.
noise_rates: Noise rates at the current diffusion time, broadcastable
to ``noisy_images``.
signal_rates: Signal rates at the current diffusion time,
broadcastable to ``noisy_images``.
training: Whether to call the network in training mode.
network: Explicit network to use. If ``None``, chosen based on
``training`` (see :meth:`call`).
Returns:
A ``(pred_noises, pred_images)`` tuple where ``pred_noises`` is
the predicted noise ``ε`` and ``pred_images`` is the Tweedie
estimate of ``x_0``.
"""
pred_noises = self([noisy_images, noise_rates**2], training=training, network=network)
pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
return pred_noises, pred_images
[docs]
def reverse_diffusion_step(
self,
shape,
pred_images,
pred_noises,
signal_rates,
next_signal_rates,
next_noise_rates,
seed=None,
stochastic_sampling=False,
):
"""A single reverse diffusion step.
Args:
shape: Shape of the input tensor.
pred_images: Predicted images.
pred_noises: Predicted noises.
signal_rates: Current signal rates.
next_signal_rates: Next signal rates.
next_noise_rates: Next noise rates.
seed: Random seed generator.
stochastic_sampling: Whether to use stochastic sampling (DDPM).
Returns:
next_noisy_images: Noisy images after the reverse diffusion step.
"""
if not stochastic_sampling:
next_noisy_images = next_signal_rates * pred_images + next_noise_rates * pred_noises
return next_noisy_images
alpha_prev = signal_rates**2
alpha = next_signal_rates**2
sigma_t = ops.sqrt((1 - alpha) / (1 - alpha_prev)) * ops.sqrt(1 - alpha_prev / alpha)
epsilon = keras.random.normal(shape=shape, seed=seed)
next_noise_rates = ops.sqrt(1 - alpha - sigma_t**2)
next_noisy_images = (
next_signal_rates * pred_images + next_noise_rates * pred_noises + sigma_t * epsilon
)
return next_noisy_images
[docs]
def reverse_diffusion(
self,
initial_noise,
diffusion_steps: int,
initial_samples=None,
initial_step: int = 0,
stochastic_sampling: bool = False,
seed: keras.random.SeedGenerator | None = None,
verbose: bool = True,
track_progress_type: Literal[None, "x_0", "x_t"] = "x_0",
disable_jit: bool = False,
training: bool = False,
network_type: Literal[None, "main", "ema"] = None,
):
"""Reverse diffusion process to generate images from noise.
Args:
initial_noise: Initial noise tensor of shape
``(n_images, *input_shape)``.
diffusion_steps: Total number of diffusion steps.
initial_samples: Optional initial samples to warm-start the
diffusion process. The diffusion process now starts from a
*noised* version of these samples.
When ``initial_step == 0``, samples are noised at the maximum
noise level (``max_t``). When ``initial_step > 0``, samples
are noised at the noise level corresponding to ``initial_step``.
initial_step: Step at which to begin the reverse diffusion loop.
``0`` runs all ``diffusion_steps`` steps from maximum noise.
Higher values skip early (high-noise) steps and require
``initial_samples`` to be provided. Number of effective steps
will be ``diffusion_steps - initial_step``.
stochastic_sampling: Whether to use stochastic DDPM sampling
instead of deterministic DDIM sampling.
seed: Random seed generator.
verbose: Whether to show a Keras progress bar.
track_progress_type: Intermediate output to store at each step.
``"x_0"`` stores the Tweedie-denoised estimate; ``"x_t"``
stores the noisy intermediate image; ``None`` disables
tracking.
disable_jit: Whether to disable JIT compilation.
training: Whether to call the network in training mode.
network_type: Which network weights to use. ``"main"`` uses the
online network, ``"ema"`` uses the exponential-moving-average
network. If ``None``, the choice is determined by the
``training`` argument.
Returns:
Generated images of shape ``(n_images, *input_shape)``.
"""
num_images, *input_shape = ops.shape(initial_noise)
step_size, progbar = self.prepare_diffusion(diffusion_steps, initial_step, verbose)
n_dims = len(input_shape)
base_diffusion_times = ops.ones((num_images, *[1] * n_dims)) * self.max_t
next_noisy_images = self.prepare_schedule(
base_diffusion_times,
initial_noise,
initial_samples,
initial_step,
step_size,
)
def step_fn(step, loop_state):
noisy_images, pred_images, seed = loop_state
# separate the current noisy image to its components
diffusion_times = base_diffusion_times - step * step_size
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# remix the predicted components using the next signal and noise rates
next_diffusion_times = diffusion_times - step_size
next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times)
# denoise
if network_type == "ema":
network = self.ema_network
elif network_type == "main":
network = self.network
else:
network = None
pred_noises, pred_images = self.denoise(
noisy_images,
noise_rates,
signal_rates,
training=training,
network=network,
)
seed, seed1 = split_seed(seed, 2)
next_noisy_images = self.reverse_diffusion_step(
shape=(num_images, *input_shape),
pred_images=pred_images,
pred_noises=pred_noises,
signal_rates=signal_rates,
next_signal_rates=next_signal_rates,
next_noise_rates=next_noise_rates,
seed=seed1,
stochastic_sampling=stochastic_sampling,
)
# this new noisy image will be used in the next step
if progbar is not None:
progbar.update(step + 1)
self.store_progress(step, track_progress_type, next_noisy_images, pred_images)
loop_state = (next_noisy_images, pred_images, seed)
return loop_state
_, pred_images, _ = fori_loop(
initial_step,
diffusion_steps,
step_fn,
(
next_noisy_images,
ops.zeros_like(initial_noise),
seed,
),
# can't jit this with progbar or tracking intermediate values
disable_jit=verbose or track_progress_type or disable_jit,
)
return pred_images
[docs]
def reverse_conditional_diffusion(
self,
measurements,
initial_noise,
diffusion_steps: int,
initial_samples=None,
initial_step: int = 0,
stochastic_sampling: bool = False,
seed=None,
verbose: bool = False,
track_progress_type: Literal[None, "x_0", "x_t"] = "x_0",
disable_jit=False,
**kwargs,
):
"""Reverse diffusion process conditioned on some measurement.
Effectively performs diffusion posterior sampling ``p(x_0 | y)``
by interleaving reverse diffusion steps with gradient-based guidance
(e.g. DPS or DDS).
Args:
measurements: Conditioning observations of shape
``(n_images, *measurement_shape)``.
initial_noise: Initial noise tensor of shape
``(n_images, *input_shape)``.
diffusion_steps: Total number of diffusion steps.
initial_samples: Optional initial samples to warm-start the
diffusion process. The diffusion process now starts from a
*noised* version of these samples.
When ``initial_step == 0``, samples are noised at the maximum
noise level (``max_t``). When ``initial_step > 0``, samples
are noised at the noise level corresponding to ``initial_step``.
initial_step: Step at which to begin the reverse diffusion loop.
``0`` runs all ``diffusion_steps`` steps from maximum noise.
Higher values skip early (high-noise) steps and require
``initial_samples`` to be provided. Number of effective steps
will be ``diffusion_steps - initial_step``.
stochastic_sampling: Whether to use stochastic DDPM sampling
instead of deterministic DDIM sampling.
seed: Random seed generator.
verbose: Whether to show a Keras progress bar with the guidance
error at each step.
track_progress_type: Intermediate output to store at each step.
``"x_0"`` stores the Tweedie-denoised estimate; ``"x_t"``
stores the noisy intermediate image; ``None`` disables
tracking.
disable_jit: Whether to disable JIT compilation.
**kwargs: Additional keyword arguments forwarded to the guidance
function and operator (e.g. ``omega``, ``mask``).
Returns:
Generated images of shape ``(n_images, *input_shape)``.
"""
num_images, *input_shape = ops.shape(initial_noise)
step_size, progbar = self.prepare_diffusion(
diffusion_steps,
initial_step,
verbose,
)
n_dims = len(input_shape)
base_diffusion_times = ops.ones((num_images, *[1] * n_dims)) * self.max_t
next_noisy_images = self.prepare_schedule(
base_diffusion_times,
initial_noise,
initial_samples,
initial_step,
step_size,
)
def step_fn(step, loop_state):
noisy_images, pred_images, seed = loop_state
diffusion_times = base_diffusion_times - step * step_size
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# remix the predicted components using the next signal and noise rates
next_diffusion_times = diffusion_times - step_size
next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times)
gradients, (error, (pred_noises, pred_images)) = self.guidance_fn(
noisy_images,
measurements=measurements,
noise_rates=noise_rates,
signal_rates=signal_rates,
**kwargs,
)
seed, seed1 = split_seed(seed, 2)
next_noisy_images = self.reverse_diffusion_step(
shape=(num_images, *input_shape),
pred_images=pred_images,
pred_noises=pred_noises,
signal_rates=signal_rates,
next_signal_rates=next_signal_rates,
next_noise_rates=next_noise_rates,
seed=seed1,
stochastic_sampling=stochastic_sampling,
)
next_noisy_images = next_noisy_images - gradients
pred_images = pred_images - gradients
# this new noisy image will be used in the next step
if verbose:
progbar.update(step + 1, [("error", error)])
self.store_progress(step, track_progress_type, next_noisy_images, pred_images)
loop_state = (next_noisy_images, pred_images, seed)
return loop_state
_, pred_images, _ = fori_loop(
initial_step,
diffusion_steps,
step_fn,
(
next_noisy_images,
ops.zeros_like(initial_noise),
seed,
),
# can't jit this with progbar or tracking intermediate values
disable_jit=verbose or track_progress_type or disable_jit,
)
return pred_images
[docs]
def prepare_diffusion(
self,
diffusion_steps: int,
initial_step: int,
verbose: bool,
disable_jit: bool = False,
):
"""Prepare the diffusion process.
Validates ``initial_step``, computes the step size, and optionally
creates a Keras progress bar.
Args:
diffusion_steps: Total number of diffusion steps.
initial_step: Step index at which reverse diffusion begins.
Must satisfy ``0 <= initial_step < diffusion_steps``.
verbose: Whether to create a Keras :class:`~keras.utils.Progbar`.
disable_jit: When ``True``, skip the ``initial_step`` range
assertions (required when values are runtime tensors).
Returns:
A ``(step_size, progbar)`` tuple where ``step_size`` is the
uniform time increment per step and ``progbar`` is a
:class:`~keras.utils.Progbar` instance or ``None``.
"""
# Asserts
if not disable_jit:
assert initial_step >= 0, f"initial_step must be non-negative, got {initial_step}"
assert initial_step < diffusion_steps, (
f"initial_step must be less than diffusion_steps, got {initial_step}"
)
step_size = self.max_t / diffusion_steps
if verbose:
progbar = keras.utils.Progbar(diffusion_steps, verbose=verbose)
else:
progbar = None
self.start_track_progress(diffusion_steps)
return step_size, progbar
[docs]
def prepare_schedule(
self,
base_diffusion_times,
initial_noise,
initial_samples,
initial_step: int,
step_size: float,
):
"""Prepare the starting noisy images for the reverse diffusion loop.
Constructs the initial ``x_t`` tensor that is fed into the first
diffusion step. Three cases are handled:
- ``initial_samples`` provided and ``initial_step > 0``:
samples are mixed with noise at the noise level that corresponds
to ``initial_step``, skipping the highest-noise diffusion steps.
- ``initial_samples`` provided and ``initial_step == 0``:
samples are mixed with noise at the maximum noise level
(``max_t``), running the full diffusion process from a noised
version of the samples.x
- ``initial_samples is None`` and ``initial_step == 0``:
the starting point is pure noise (``initial_noise``).
Args:
base_diffusion_times: Tensor of shape
``(n_images, *[1]*n_dims)`` filled with ``max_t``.
initial_noise: Pure noise tensor of shape
``(n_images, *input_shape)``.
initial_samples: Optional samples of shape ``(n_images, *input_shape)``.
The diffusion process always starts from a *noised*
version of these samples.
initial_step: Step index at which reverse diffusion begins.
step_size: Uniform time increment per diffusion step.
Returns:
Noisy images tensor of shape ``(n_images, *input_shape)`` to
use as the starting point ``x_t`` for the diffusion loop.
"""
# We can optionally start with a set of samples that are partially noised
if initial_samples is not None and initial_step > 0:
starting_diffusion_times = base_diffusion_times - ((initial_step - 1) * step_size)
noise_rates, signal_rates = self.diffusion_schedule(starting_diffusion_times)
next_noisy_images = signal_rates * initial_samples + noise_rates * initial_noise
elif initial_samples is not None:
noise_rates, signal_rates = self.diffusion_schedule(base_diffusion_times)
next_noisy_images = signal_rates * initial_samples + noise_rates * initial_noise
elif initial_samples is None and initial_step == 0:
# important line:
# at the first sampling step, the "noisy image" is pure noise
# but its signal rate is assumed to be nonzero (min_signal_rate)
next_noisy_images = initial_noise
else:
raise ValueError(
"Why are you trying to do this? Initial samples should be provided "
"if initial_step is greater than 0 (i.e. you want to start with "
"a partially noised image)"
)
return next_noisy_images
[docs]
def start_track_progress(self, diffusion_steps: int, initial_step: int = 0):
"""Initialize progress tracking for the diffusion process.
Resets :attr:`track_progress` and sets
:attr:`track_progress_interval` so that at most 50 frames are
stored during the diffusion trajectory (to keep memory usage
bounded for large step counts).
Args:
diffusion_steps: Total number of diffusion steps.
initial_step: Step index at which reverse diffusion begins.
"""
self.track_progress = []
remaining = max(1, diffusion_steps - int(initial_step))
if remaining > 50:
self.track_progress_interval = remaining // 50
else:
self.track_progress_interval = 1
[docs]
def store_progress(
self,
step: int,
track_progress_type: Literal[None, "x_0", "x_t"],
next_noisy_images,
pred_images,
):
"""Store an intermediate diffusion frame in :attr:`track_progress`.
Frames are stored every :attr:`track_progress_interval` steps.
Does nothing when ``track_progress_type`` is ``None``.
Args:
step: Current diffusion step index.
track_progress_type: Which tensor to store. ``"x_0"`` stores
the Tweedie-denoised estimate (predicted clean image);
``"x_t"`` stores the noisy intermediate image.
next_noisy_images: Noisy images ``x_t`` after the current step.
pred_images: Predicted clean images ``x_0`` at the current step.
"""
if not track_progress_type:
return
if step % self.track_progress_interval == 0:
if track_progress_type == "x_0":
self.track_progress.append(ops.convert_to_numpy(pred_images))
elif track_progress_type == "x_t":
self.track_progress.append(ops.convert_to_numpy(next_noisy_images))
else:
raise ValueError("Invalid track_progress_type")
register_presets(diffusion_model_presets, DiffusionModel)
[docs]
class DiffusionGuidance(abc.ABC, Object):
"""Base class for diffusion guidance methods."""
def __init__(
self,
diffusion_model: DiffusionModel,
operator: Operator,
disable_jit: bool = False,
):
"""Initialize the diffusion guidance.
Args:
diffusion_model: The diffusion model to use for guidance.
operator: The forward operator :math:`A` that maps clean images
to the measurement space.
disable_jit: Whether to disable JIT compilation of the guidance
function.
"""
super().__init__()
self.diffusion_model = diffusion_model
self.operator = operator
self.disable_jit = disable_jit
self.setup()
[docs]
@abc.abstractmethod
def setup(self):
"""Setup the guidance function. Should be implemented by subclasses."""
raise NotImplementedError
[docs]
@abc.abstractmethod
def __call__(self, *args, **kwargs):
"""Call the guidance function."""
raise NotImplementedError
[docs]
@diffusion_guidance_registry(name="dps")
class DPS(DiffusionGuidance):
"""Diffusion Posterior Sampling guidance."""
[docs]
def setup(self):
"""Setup the autograd function for DPS."""
self.autograd = AutoGrad()
self.autograd.set_function(self.compute_error)
self.gradient_fn = self.autograd.get_gradient_and_value_jit_fn(
has_aux=True,
disable_jit=self.disable_jit,
)
[docs]
def compute_error(
self,
noisy_images,
measurements,
noise_rates,
signal_rates,
omega: float,
**kwargs,
):
r"""Compute the DPS measurement error for gradient computation.
Following the DPS implementation, the
loss is a standard L2 norm.
Args:
noisy_images: Noisy images ``x_t`` of shape
``(n_images, *input_shape)``.
measurements: Target observations ``y``.
noise_rates: Noise rates at the current diffusion time.
signal_rates: Signal rates at the current diffusion time.
omega: Scalar step-size weight for the measurement gradient.
**kwargs: Additional keyword arguments forwarded to the
operator's ``forward`` method (e.g. ``mask``).
Returns:
A ``(measurement_error, (pred_noises, pred_images))`` tuple
where ``measurement_error`` is the scalar loss and
``pred_noises`` / ``pred_images`` are the denoiser outputs.
"""
pred_noises, pred_images = self.diffusion_model.denoise(
noisy_images,
noise_rates,
signal_rates,
training=False,
)
# See the original DPS implementation
# https://github.com/DPS2022/diffusion-posterior-sampling/blob/effbde7325b22ce8dc3e2c06c160c021e743a12d/guided_diffusion/condition_methods.py#L31 # noqa: E501
# As well as interesting discussion on the DPS loss:
# https://github.com/DPS2022/diffusion-posterior-sampling/issues/20
measurement_error = omega * L2(measurements - self.operator.forward(pred_images, **kwargs))
return measurement_error, (pred_noises, pred_images)
[docs]
def __call__(self, noisy_images, **kwargs):
"""Compute DPS gradients and denoiser outputs.
Calls the JIT-compiled gradient function obtained from
:meth:`setup`.
Args:
noisy_images: Noisy images ``x_t`` of shape
``(n_images, *input_shape)``.
**kwargs: Keyword arguments forwarded to :meth:`compute_error`
(``measurements``, ``noise_rates``, ``signal_rates``,
``omega``, and any operator kwargs such as ``mask``).
Returns:
A ``(gradients, (measurement_error, (pred_noises, pred_images)))``
tuple. ``gradients`` is the gradient of the measurement error
w.r.t. ``noisy_images`` and can be subtracted directly from the
reverse-diffusion update.
"""
return self.gradient_fn(noisy_images, **kwargs)
[docs]
@diffusion_guidance_registry(name="dds")
class DDS(DiffusionGuidance):
"""
Decomposed Diffusion Sampling guidance.
Reference paper: https://arxiv.org/pdf/2303.05754
"""
[docs]
def setup(self):
"""Setup DDS guidance function."""
if not self.disable_jit:
self.call = jit(self.call)
[docs]
def Acg(self, x, **op_kwargs):
# we transform the operator from A(x) to A.T(A(x)) to get the normal equations,
# so that it is suitable for conjugate gradient. (symmetric, positive definite)
# Normal equations: A^T y = A^T A x
return self.operator.transpose(self.operator.forward(x, **op_kwargs), **op_kwargs)
[docs]
def conjugate_gradient_inner_loop(self, i, loop_state, eps=1e-5):
"""
A single iteration of the conjugate gradient method.
This involves minimizing the error of x along the current search
vector p, and then choosing the next search vector.
Reference code from: https://github.com/svi-diffusion/
"""
p, rs_old, r, x, eps, op_kwargs = loop_state
# compute alpha
Ap = self.Acg(p, **op_kwargs) # transform search vector p by A
a = rs_old / ops.sum(p * Ap) # minimize f along the line p
x_new = x + a * p # set new x at the minimum of f along line p
r_new = r - a * Ap # shortcut to compute next residual
# compute Gram-Schmidt coefficient beta to choose next search vector
# so that p_new is A-orthogonal to p_current.
rs_new = ops.sum(r_new * r_new)
p_new = r_new + (rs_new / rs_old) * p
# this is like a jittable 'break' -- if the residual
# is less than eps, then we just return the old
# loop state rather than the updated one.
next_loop_state = ops.cond(
ops.abs(ops.sqrt(rs_old)) < eps,
lambda: (p, rs_old, r, x, eps, op_kwargs),
lambda: (p_new, rs_new, r_new, x_new, eps, op_kwargs),
)
return next_loop_state
[docs]
def call(
self,
noisy_images,
measurements,
noise_rates,
signal_rates,
n_inner: int,
eps: float,
verbose: bool,
**op_kwargs,
):
r"""Run one DDS guidance step via conjugate gradient.
Denoises ``x_t`` to obtain an initial ``x_0`` estimate, then
refines it by solving the normal equations
:math:`A^\top A\, x = A^\top y` with ``n_inner`` conjugate
gradient iterations.
Args:
noisy_images: Noisy images ``x_t`` of shape
``(n_images, *input_shape)``.
measurements: Target observations ``y``.
noise_rates: Noise rates at the current diffusion time.
signal_rates: Signal rates at the current diffusion time.
n_inner: Number of conjugate gradient iterations.
eps: Convergence tolerance; CG stops early when the residual
norm falls below this threshold.
verbose: When ``True``, compute and return the measurement
error ``‖y - A(x̂_0)‖``. When ``False``, the error is
returned as ``0.0`` to avoid extra computation.
**op_kwargs: Additional keyword arguments forwarded to the
operator (e.g. ``mask``).
Returns:
A ``(gradients, (measurement_error, (pred_noises, pred_images)))``
tuple. ``gradients`` is a zero tensor because DDS performs
guidance entirely inside the CG loop; the caller subtracts it
as a no-op.
"""
pred_noises, pred_images = self.diffusion_model.denoise(
noisy_images,
noise_rates,
signal_rates,
training=False,
)
measurements_cg = self.operator.transpose(measurements, **op_kwargs)
r = measurements_cg - self.Acg(pred_images, **op_kwargs) # residual
p = ops.copy(r) # initial search vector = residual
rs_old = ops.sum(r * r) # residual dot product
_, _, _, pred_images_updated_cg, _, _ = fori_loop(
0,
n_inner,
self.conjugate_gradient_inner_loop,
(p, rs_old, r, pred_images, eps, op_kwargs),
)
# Not strictly necessary, just for debugging
error = ops.cond(
verbose,
lambda: L2(measurements - self.operator.forward(pred_images_updated_cg, **op_kwargs)),
lambda: 0.0,
)
pred_images = pred_images_updated_cg
# we have already performed the guidance steps in self.conjugate_gradient_method, so
# we can set these gradients to zero.
gradients = ops.zeros_like(pred_images)
return gradients, (error, (pred_noises, pred_images))
[docs]
def __call__(
self,
noisy_images,
measurements,
noise_rates,
signal_rates,
n_inner: int = 5,
eps: float = 1e-5,
verbose: bool = False,
**op_kwargs,
):
"""Run one DDS guidance step (public entry point).
Delegates to :meth:`call`, which may be JIT-compiled depending on
:attr:`disable_jit`.
Args:
noisy_images: Noisy images ``x_t`` of shape
``(n_images, *input_shape)``.
measurements: Target observations ``y``.
noise_rates: Noise rates at the current diffusion time.
signal_rates: Signal rates at the current diffusion time.
n_inner: Number of conjugate gradient iterations. Default: ``5``.
eps: Convergence tolerance for the CG solver. Default: ``1e-5``.
verbose: When ``True``, compute and return the measurement
error for logging. Default: ``False``.
**op_kwargs: Additional keyword arguments forwarded to the
operator (e.g. ``mask``).
Returns:
A ``(gradients, (measurement_error, (pred_noises, pred_images)))``
tuple (see :meth:`call`).
"""
return self.call(
noisy_images,
measurements,
noise_rates,
signal_rates,
n_inner,
eps,
verbose,
**op_kwargs,
)
[docs]
@diffusion_guidance_registry(name="nuclear-dps")
class NuclearDiffusion(DPS):
r"""Nuclear Diffusion posterior sampling guidance.
A hybrid framework that combines diffusion posterior sampling (DPS) with low-rank
temporal modeling for video restoration. This method replaces the sparsity assumption
in Robust Principal Component Analysis (RPCA) with a learned diffusion prior while
maintaining a nuclear norm penalty on the background component to encourage low-rank
temporal structure.
.. seealso::
- :func:`~zea.func.dehaze_nuclear_diffusion`: The dehazing application of this method
- :doc:`../notebooks/models/nuclear_dehazing_example`: Example notebook demonstrating
the method on cardiac ultrasound dehazing
- :class:`DPS`: Base diffusion posterior sampling guidance
**Mathematical Formulation:**
Given observations :math:`\mathbf{Y} \in \mathbb{R}^{n \times p}` (video frames),
Nuclear Diffusion jointly samples the signal :math:`\mathbf{X}` and low-rank background
:math:`\mathbf{L}` from the posterior:
.. math::
\mathbf{X}, \mathbf{L} \sim p_\theta(\mathbf{X}, \mathbf{L} \mid \mathbf{Y})
The posterior is factorized as:
.. math::
p(\mathbf{Y}, \mathbf{L}, \mathbf{X}) = p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) \, p(\mathbf{L}) \, p_\theta(\mathbf{X})
where:
- :math:`p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) = \mathcal{N}(\mathbf{Y}; \mathbf{L}+\mathbf{X}, \mu^{-1} \mathbf{I})`
is the likelihood (measurement model)
- :math:`p(\mathbf{L}) \propto \exp(-\gamma \|\mathbf{L}\|_*)` enforces low-rank structure
via the nuclear norm :math:`\|\mathbf{L}\|_* = \sum_i \sigma_i(\mathbf{L})`
- :math:`p_\theta(\mathbf{X})` is a learned diffusion prior capturing complex signal structure
The diffusion prior operates on individual frames :math:`\mathbf{x}^t \in \mathbb{R}^n`,
while temporal dependencies are enforced through the nuclear norm on :math:`\mathbf{L}`.
This guidance method alternates between reverse diffusion and measurement-guided updates,
computing gradients from both the measurement error and the nuclear norm penalty:
Args:
diffusion_model: The diffusion model for the signal component.
operator: Forward operator defining the measurement model.
disable_jit: Whether to disable JIT compilation.
.. admonition:: Reference
T. Stevens, O. Nolan, J. L. Robert, and R. J. G. van Sloun,
"Nuclear Diffusion Models for Low-Rank Background Suppression in Videos,"
*IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)*, 2026.
https://arxiv.org/abs/2509.20886
""" # noqa: E501
[docs]
@staticmethod
def nuclear_norm_penalty(background_images):
r"""Compute nuclear norm penalty for low-rank enforcement.
The nuclear norm (sum of singular values) encourages low-rank structure in
the background component across time. For a matrix :math:`\mathbf{L}`, it is
defined as:
.. math::
\|\mathbf{L}\|_* = \sum_{i=1}^{r} \sigma_i(\mathbf{L})
where :math:`\sigma_i` are the singular values and :math:`r` is the rank.
Args:
background_images: Background images of shape
``(batch, frames, height, width, channels)``.
Each sequence is reshaped to a matrix of shape ``(frames, height x width x channels)``
before computing the nuclear norm.
Returns:
Nuclear norm penalty summed across the batch and normalized by number of frames.
Note:
The input is reshaped from ``(batch, frames, H, W, C)`` to ``(batch, frames, HxWxC)``
before computing the singular values.
""" # noqa: E501
n_batch, n_frames, height, width, channels = ops.shape(background_images)
background_images_flattened = ops.reshape(
background_images, (n_batch, n_frames, height * width * channels)
)
background_nuclear_penalty = ops.norm(background_images_flattened, axis=(1, 2), ord="nuc")
# normalize nuclear penalty
background_nuclear_penalty /= n_frames
# sum across batches
return ops.sum(background_nuclear_penalty)
[docs]
@staticmethod
def weighted_nuclear_norm_penalty(background_images, weight_factor: float = 2.0):
r"""Compute weighted nuclear norm penalty with enhanced rank control.
This implements a WNNM-style (Weighted Nuclear Norm Minimization) penalty that
penalizes smaller singular values more heavily than larger ones, suppressing the
spectrum tail to enforce low-rank structure. The weighted penalty is:
.. math::
\|\mathbf{L}\|_{w,*} = \sum_{i=1}^{r} w_i \cdot \sigma_i(\mathbf{L})
where :math:`w_i = 1 + \alpha \cdot \frac{i}{r}` increases linearly with the
index :math:`i`, and :math:`\alpha` is the ``weight_factor``. Since ``ops.svd``
returns singular values in descending order (:math:`\sigma_1 \geq \sigma_2 \geq \cdots`),
higher indices correspond to smaller singular values, which receive larger weights.
Args:
background_images: Background images of shape ``(batch, frames, height, width, channels)``.
weight_factor: Scaling factor :math:`\alpha` controlling how much more to penalize
smaller singular values (the spectrum tail). Default is 2.0.
Returns:
Weighted nuclear norm penalty summed across the batch and normalized by number of frames.
Note:
This is a drop-in replacement for :meth:`nuclear_norm_penalty` that provides
better rank control by more aggressively penalizing the tail of the singular value
spectrum (smaller singular values) rather than the leading ones.
""" # noqa: E501
n_batch, n_frames, height, width, channels = ops.shape(background_images)
background_images_flattened = ops.reshape(
background_images, (n_batch, n_frames, height * width * channels)
)
def weighted_svd_penalty(matrix):
"""Compute weighted SVD penalty for a matrix"""
_, s_vals, _ = ops.svd(matrix, full_matrices=False)
n_sv = ops.shape(s_vals)[0]
weights = 1.0 + weight_factor * ops.arange(n_sv, dtype="float32") / ops.cast(
n_sv, "float32"
)
return ops.sum(weights * s_vals)
# Apply weighted penalty to each batch element
weighted_penalties = ops.vectorized_map(weighted_svd_penalty, background_images_flattened)
# normalize by number of frames
weighted_penalties /= n_frames
# sum across batches (same as original)
return ops.sum(weighted_penalties)
[docs]
def compute_error(
self,
combined_images,
measurements,
noise_rates,
signal_rates,
omega: float = 1.0,
gamma: float = 1.0,
rank_weight_factor: float | None = None,
step: int | None = None,
total_steps: int | None = None,
initial_step: int = 100,
max_alpha: float = 0.5,
**kwargs,
):
r"""Compute measurement error for joint diffusion posterior sampling.
Args:
combined_images: Concatenated noisy images, containing both foreground and background
components, shape ``(batch, frames, H, W, 2C)``. In the context of cardiac
ultrasound dehazing, the first C channels correspond to the tissue signal
(foreground), and the next C channels correspond to the haze (background) component.
measurements: Target measurements :math:`\mathbf{Y}`, shape ``(batch, frames, H, W, C)``.
noise_rates: Current noise rates from the diffusion schedule, shape ``(batch, frames, 1, 1, 1)``.
signal_rates: Current signal rates from the diffusion schedule, shape ``(batch, frames, 1, 1, 1)``.
omega: Weight :math:`\omega` for the measurement error term (L2 reconstruction loss).
gamma: Weight :math:`\gamma` for the nuclear norm penalty term.
rank_weight_factor: Optional weight factor for :meth:`weighted_nuclear_norm_penalty`.
If ``None``, uses standard :meth:`nuclear_norm_penalty`.
step: Current diffusion step for progressive blending. Used to compute :math:`\alpha(t)`.
total_steps: Total number of diffusion steps.
initial_step: Step at which to start progressive blending.
max_alpha: Maximum value for :math:`\alpha` at the final step. The alpha parameter mixes
foreground and background predictions, but only after the initial_step to allow the
diffusion model to first focus on generating the foreground signal before blending
in the background component.
**kwargs: Additional arguments (unused).
Returns:
A tuple containing:
- **measurement_error** (float): Combined loss :math:`\mathcal{L}`.
- **aux** (tuple): Auxiliary outputs:
``(pred_noises_foreground, pred_images_foreground, noisy_background_images, l2_error, nuclear_penalty)``
.. note::
The progressive blending factor :math:`\alpha(t)` linearly increases from 0
at ``initial_step`` and plateaus at ``max_alpha`` once normalized progress
reaches ``max_alpha``, allowing the background component to gradually influence
the reconstruction and then saturate for the remainder of sampling.
""" # noqa: E501
channels = ops.shape(combined_images)[-1] // 2
noisy_foreground_images = combined_images[..., :channels]
noisy_background_images = combined_images[..., channels:]
# Transpose for ops.map
noisy_tissue_seq = ops.swapaxes(noisy_foreground_images, 0, 1) # [S, B, H, W, C]
# Signal and noise rates are the same throughout the sequence, so can just
# grab the first batch and reuse that
noise_rates_s = noise_rates[:, 0, ...]
signal_rates_s = signal_rates[:, 0, ...]
def denoise_step(x_s):
pred_noises, pred_images = self.diffusion_model.denoise(
x_s, noise_rates_s, signal_rates_s, training=False
)
return {"pred_noises": pred_noises, "pred_images": pred_images}
denoised = ops.map(denoise_step, noisy_tissue_seq)
pred_noises_foreground = ops.swapaxes(denoised["pred_noises"], 0, 1) # [B, S, H, W, C]
pred_images_foreground = ops.swapaxes(denoised["pred_images"], 0, 1) # [B, S, H, W, C]
alpha = ops.clip(
(step - initial_step) / (total_steps - initial_step), 0.0, max_alpha
) # linear after initial_step
pred_measurements = (1 - alpha) * pred_images_foreground + (alpha) * noisy_background_images
l2_error = L2(measurements - pred_measurements)
# Choose penalty function for nuclear norm
if rank_weight_factor is not None:
background_nuclear_penalty = self.weighted_nuclear_norm_penalty(
noisy_background_images, rank_weight_factor
)
else:
background_nuclear_penalty = self.nuclear_norm_penalty(noisy_background_images)
# NOTE: we sum across batches for the nuclear norm here.
# the gradient of sums = sum of gradients
nuclear_penalty = ops.sum(background_nuclear_penalty)
# Combine all penalty terms
measurement_error = omega * l2_error + gamma * nuclear_penalty
return measurement_error, (
pred_noises_foreground,
pred_images_foreground,
noisy_background_images,
l2_error,
nuclear_penalty,
)
[docs]
def __call__(
self,
noisy_images1,
noisy_images2,
measurements,
noise_rates,
signal_rates,
omega: float = 1.0,
gamma: float = 1.0,
**kwargs,
):
r"""Compute guidance gradients for posterior sampling.
This method concatenates the noisy foreground and background images, computes the
combined loss via :meth:`compute_error`, and returns separate gradients
for each component.
Args:
noisy_images1: Noisy foreground images :math:`\mathbf{x}_t` from the diffusion model,
shape ``(batch, frames, H, W, C)``.
noisy_images2: Noisy background images :math:`\mathbf{L}_t`,
shape ``(batch, frames, H, W, C)``.
measurements: Target measurements :math:`\mathbf{Y}`, shape ``(batch, frames, H, W, C)``.
noise_rates: Current noise rates from diffusion schedule.
signal_rates: Current signal rates from diffusion schedule.
omega: Weight for the measurement error term. Default is 1.0.
gamma: Weight for the nuclear norm penalty term. Default is 1.0.
**kwargs: Additional arguments passed to :meth:`compute_error` (e.g., ``gamma``,
``rank_weight_factor``, ``step``, ``total_steps``).
Returns:
A tuple containing:
- **gradients** (tuple): ``(grad_foreground, grad_background)`` - gradients for foreground and background.
- **loss_info** (tuple): ``(loss, aux)`` where:
- **loss** (float): Combined loss value.
- **aux** (tuple): Auxiliary outputs from :meth:`compute_error`.
""" # noqa: E501
combined_input = ops.concatenate([noisy_images1, noisy_images2], axis=-1)
gradients, (loss, aux) = self.gradient_fn(
combined_input,
measurements=measurements,
noise_rates=noise_rates,
signal_rates=signal_rates,
omega=omega,
gamma=gamma,
**kwargs,
)
channels = ops.shape(gradients)[-1] // 2
grad1 = gradients[..., :channels]
grad2 = gradients[..., channels:]
return (grad1, grad2), (loss, aux)