zea.models.diffusionΒΆ

Diffusion models for ultrasound image generation and posterior sampling.

To try this model, simply load one of the available presets:

>>> from zea.models.diffusion import DiffusionModel

>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic")

See also

A tutorial notebook where this model is used: Diffusion models for ultrasound image generation.

Classes

DDS(diffusion_model, operator[, disable_jit])

Decomposed Diffusion Sampling guidance.

DPS(diffusion_model, operator[, disable_jit])

Diffusion Posterior Sampling guidance.

DiffusionGuidance(diffusion_model, operator)

Base class for diffusion guidance methods.

DiffusionModel(*args, **kwargs)

Implementation of a diffusion generative model.

NuclearDiffusion(diffusion_model, operator)

Nuclear Diffusion posterior sampling guidance.

class zea.models.diffusion.DDS(diffusion_model, operator, disable_jit=False)[source]ΒΆ

Bases: DiffusionGuidance

Decomposed Diffusion Sampling guidance.

Reference paper: https://arxiv.org/pdf/2303.05754

Initialize the diffusion guidance.

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model to use for guidance.

  • operator (Operator) – The forward operator \(A\) that maps clean images to the measurement space.

  • disable_jit (bool) – Whether to disable JIT compilation of the guidance function.

Acg(x, **op_kwargs)[source]ΒΆ
__call__(noisy_images, measurements, noise_rates, signal_rates, n_inner=5, eps=1e-05, verbose=False, **op_kwargs)[source]ΒΆ

Run one DDS guidance step (public entry point).

Delegates to call(), which may be JIT-compiled depending on disable_jit.

Parameters:
  • noisy_images – Noisy images x_t of shape (n_images, *input_shape).

  • measurements – Target observations y.

  • noise_rates – Noise rates at the current diffusion time.

  • signal_rates – Signal rates at the current diffusion time.

  • n_inner (int) – Number of conjugate gradient iterations. Default: 5.

  • eps (float) – Convergence tolerance for the CG solver. Default: 1e-5.

  • verbose (bool) – When True, compute and return the measurement error for logging. Default: False.

  • **op_kwargs – Additional keyword arguments forwarded to the operator (e.g. mask).

Returns:

A (gradients, (measurement_error, (pred_noises, pred_images))) tuple (see call()).

call(noisy_images, measurements, noise_rates, signal_rates, n_inner, eps, verbose, **op_kwargs)[source]ΒΆ

Run one DDS guidance step via conjugate gradient.

Denoises x_t to obtain an initial x_0 estimate, then refines it by solving the normal equations \(A^\top A\, x = A^\top y\) with n_inner conjugate gradient iterations.

Parameters:
  • noisy_images – Noisy images x_t of shape (n_images, *input_shape).

  • measurements – Target observations y.

  • noise_rates – Noise rates at the current diffusion time.

  • signal_rates – Signal rates at the current diffusion time.

  • n_inner (int) – Number of conjugate gradient iterations.

  • eps (float) – Convergence tolerance; CG stops early when the residual norm falls below this threshold.

  • verbose (bool) – When True, compute and return the measurement error β€–y - A(xΜ‚_0)β€–. When False, the error is returned as 0.0 to avoid extra computation.

  • **op_kwargs – Additional keyword arguments forwarded to the operator (e.g. mask).

Returns:

A (gradients, (measurement_error, (pred_noises, pred_images))) tuple. gradients is a zero tensor because DDS performs guidance entirely inside the CG loop; the caller subtracts it as a no-op.

conjugate_gradient_inner_loop(i, loop_state, eps=1e-05)[source]ΒΆ

A single iteration of the conjugate gradient method. This involves minimizing the error of x along the current search vector p, and then choosing the next search vector.

Reference code from: https://github.com/svi-diffusion/

setup()[source]ΒΆ

Setup DDS guidance function.

class zea.models.diffusion.DPS(diffusion_model, operator, disable_jit=False)[source]ΒΆ

Bases: DiffusionGuidance

Diffusion Posterior Sampling guidance.

Initialize the diffusion guidance.

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model to use for guidance.

  • operator (Operator) – The forward operator \(A\) that maps clean images to the measurement space.

  • disable_jit (bool) – Whether to disable JIT compilation of the guidance function.

__call__(noisy_images, **kwargs)[source]ΒΆ

Compute DPS gradients and denoiser outputs.

Calls the JIT-compiled gradient function obtained from setup().

Parameters:
  • noisy_images – Noisy images x_t of shape (n_images, *input_shape).

  • **kwargs – Keyword arguments forwarded to compute_error() (measurements, noise_rates, signal_rates, omega, and any operator kwargs such as mask).

Returns:

A (gradients, (measurement_error, (pred_noises, pred_images))) tuple. gradients is the gradient of the measurement error w.r.t. noisy_images and can be subtracted directly from the reverse-diffusion update.

compute_error(noisy_images, measurements, noise_rates, signal_rates, omega, **kwargs)[source]ΒΆ

Compute the DPS measurement error for gradient computation.

Following the DPS implementation, the loss is a standard L2 norm.

Parameters:
  • noisy_images – Noisy images x_t of shape (n_images, *input_shape).

  • measurements – Target observations y.

  • noise_rates – Noise rates at the current diffusion time.

  • signal_rates – Signal rates at the current diffusion time.

  • omega (float) – Scalar step-size weight for the measurement gradient.

  • **kwargs – Additional keyword arguments forwarded to the operator’s forward method (e.g. mask).

Returns:

A (measurement_error, (pred_noises, pred_images)) tuple where measurement_error is the scalar loss and pred_noises / pred_images are the denoiser outputs.

setup()[source]ΒΆ

Setup the autograd function for DPS.

class zea.models.diffusion.DiffusionGuidance(diffusion_model, operator, disable_jit=False)[source]ΒΆ

Bases: ABC, Object

Base class for diffusion guidance methods.

Initialize the diffusion guidance.

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model to use for guidance.

  • operator (Operator) – The forward operator \(A\) that maps clean images to the measurement space.

  • disable_jit (bool) – Whether to disable JIT compilation of the guidance function.

abstractmethod __call__(*args, **kwargs)[source]ΒΆ

Call the guidance function.

abstractmethod setup()[source]ΒΆ

Setup the guidance function. Should be implemented by subclasses.

class zea.models.diffusion.DiffusionModel(*args, **kwargs)[source]ΒΆ

Bases: DeepGenerativeModel

Implementation of a diffusion generative model. Heavily inspired from https://keras.io/examples/generative/ddim/

Initialize a diffusion model.

Parameters:
  • input_shape – Shape of the input data. Typically of the form (height, width, channels) for images.

  • input_range – Range of the input data.

  • min_signal_rate – Minimum signal rate for the diffusion schedule.

  • max_signal_rate – Maximum signal rate for the diffusion schedule.

  • network_name – Name of the network architecture to use. Options are β€œunet_time_conditional” or β€œdense_time_conditional”.

  • network_kwargs – Additional keyword arguments for the network.

  • name – Name of the model.

  • guidance – Guidance method to use. Can be a string, or dict with β€œname” and β€œparams” keys. Additionally, can be a DiffusionGuidance object.

  • operator – Operator to use. Can be a string, or dict with β€œname” and β€œparams” keys. Additionally, can be a Operator object.

  • ema_val – Exponential moving average value for the network weights.

  • min_t – Minimum diffusion time for sampling during training.

  • max_t – Maximum diffusion time for sampling during training.

  • **kwargs – Additional arguments.

call(inputs, training=False, network=None, **kwargs)[source]ΒΆ

Call the score network.

Parameters:
  • inputs – A list [noisy_images, noise_rates_squared] as expected by the underlying time-conditional network.

  • training (bool) – Whether to run in training mode. When False and network is None, the EMA network is used.

  • network – Explicit network to call. If None, the EMA network is used during inference and the online network during training.

  • **kwargs – Extra keyword arguments forwarded to the network.

Returns:

Predicted noise tensor of the same shape as the input images.

denoise(noisy_images, noise_rates, signal_rates, training, network=None)[source]ΒΆ

Predict the noise component and derive the clean-image estimate.

Uses the score network to predict the noise Ξ΅ in x_t, then computes the Tweedie estimate of x_0.

Parameters:
  • noisy_images – Noisy images x_t of shape (n_images, *input_shape).

  • noise_rates – Noise rates at the current diffusion time, broadcastable to noisy_images.

  • signal_rates – Signal rates at the current diffusion time, broadcastable to noisy_images.

  • training (bool) – Whether to call the network in training mode.

  • network – Explicit network to use. If None, chosen based on training (see call()).

Returns:

A (pred_noises, pred_images) tuple where pred_noises is the predicted noise Ξ΅ and pred_images is the Tweedie estimate of x_0.

diffusion_schedule(diffusion_times)[source]ΒΆ

Cosine diffusion schedule.

Implements the cosine schedule from Nichol & Dhariwal (2021).

The noisy image at time t is defined as:

Parameters:

diffusion_times – Tensor of diffusion times in [min_t, max_t].

Returns:

A (noise_rates, signal_rates) tuple of tensors with the same shape as diffusion_times.

get_config()[source]ΒΆ

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

linear_diffusion_schedule(diffusion_times)[source]ΒΆ

Create a linear diffusion schedule

log_likelihood(data, **kwargs)[source]ΒΆ

Approximate log-likelihood of the data under the model.

Parameters:
  • data – Data to compute log-likelihood for.

  • **kwargs – Additional arguments.

Returns:

Approximate log-likelihood.

property metricsΒΆ

Metrics for training.

posterior_sample(measurements, n_samples=1, n_steps=20, initial_step=0, initial_samples=None, seed=None, **kwargs)[source]ΒΆ

Sample from the posterior distribution given measurements.

Parameters:
  • measurements – Input measurements. Typically of shape (batch_size, *input_shape).

  • n_samples (int) – Number of posterior samples to generate. Will generate n_samples samples for each measurement in the measurements batch.

  • n_steps (int) – Number of diffusion steps.

  • initial_step (int) – Step at which to begin the reverse diffusion loop. 0 runs all diffusion_steps steps from maximum noise. Higher values skip early (high-noise) steps and require initial_samples to be provided. Number of effective steps will be diffusion_steps - initial_step.

  • initial_samples – Optional initial samples to warm-start the diffusion process. The diffusion process now starts from a noised version of these samples. This can be used to speed up the diffusion process. When initial_step == 0, samples are noised at the maximum noise level (max_t). When initial_step > 0, samples are noised at the noise level corresponding to initial_step. These initial_samples can be initial guesses such as solutions of previous frames (for sequences), see for instance SeqDiff. Must be of shape (batch_size, n_samples, *input_shape).

  • seed – Random seed generator.

  • **kwargs – Additional arguments passed to reverse_conditional_diffusion().

Returns:

Posterior samples p(x|y) of shape (batch_size, n_samples, *input_shape).

prepare_diffusion(diffusion_steps, initial_step, verbose, disable_jit=False)[source]ΒΆ

Prepare the diffusion process.

Validates initial_step, computes the step size, and optionally creates a Keras progress bar.

Parameters:
  • diffusion_steps (int) – Total number of diffusion steps.

  • initial_step (int) – Step index at which reverse diffusion begins. Must satisfy 0 <= initial_step < diffusion_steps.

  • verbose (bool) – Whether to create a Keras Progbar.

  • disable_jit (bool) – When True, skip the initial_step range assertions (required when values are runtime tensors).

Returns:

A (step_size, progbar) tuple where step_size is the uniform time increment per step and progbar is a Progbar instance or None.

prepare_schedule(base_diffusion_times, initial_noise, initial_samples, initial_step, step_size)[source]ΒΆ

Prepare the starting noisy images for the reverse diffusion loop.

Constructs the initial x_t tensor that is fed into the first diffusion step. Three cases are handled:

  • initial_samples provided and initial_step > 0: samples are mixed with noise at the noise level that corresponds to initial_step, skipping the highest-noise diffusion steps.

  • initial_samples provided and initial_step == 0: samples are mixed with noise at the maximum noise level (max_t), running the full diffusion process from a noised version of the samples.x

  • initial_samples is None and initial_step == 0: the starting point is pure noise (initial_noise).

Parameters:
  • base_diffusion_times – Tensor of shape (n_images, *[1]*n_dims) filled with max_t.

  • initial_noise – Pure noise tensor of shape (n_images, *input_shape).

  • initial_samples – Optional samples of shape (n_images, *input_shape). The diffusion process always starts from a noised version of these samples.

  • initial_step (int) – Step index at which reverse diffusion begins.

  • step_size (float) – Uniform time increment per diffusion step.

Returns:

Noisy images tensor of shape (n_images, *input_shape) to use as the starting point x_t for the diffusion loop.

reverse_conditional_diffusion(measurements, initial_noise, diffusion_steps, initial_samples=None, initial_step=0, stochastic_sampling=False, seed=None, verbose=False, track_progress_type='x_0', disable_jit=False, **kwargs)[source]ΒΆ

Reverse diffusion process conditioned on some measurement.

Effectively performs diffusion posterior sampling p(x_0 | y) by interleaving reverse diffusion steps with gradient-based guidance (e.g. DPS or DDS).

Parameters:
  • measurements – Conditioning observations of shape (n_images, *measurement_shape).

  • initial_noise – Initial noise tensor of shape (n_images, *input_shape).

  • diffusion_steps (int) – Total number of diffusion steps.

  • initial_samples – Optional initial samples to warm-start the diffusion process. The diffusion process now starts from a noised version of these samples. When initial_step == 0, samples are noised at the maximum noise level (max_t). When initial_step > 0, samples are noised at the noise level corresponding to initial_step.

  • initial_step (int) – Step at which to begin the reverse diffusion loop. 0 runs all diffusion_steps steps from maximum noise. Higher values skip early (high-noise) steps and require initial_samples to be provided. Number of effective steps will be diffusion_steps - initial_step.

  • stochastic_sampling (bool) – Whether to use stochastic DDPM sampling instead of deterministic DDIM sampling.

  • seed – Random seed generator.

  • verbose (bool) – Whether to show a Keras progress bar with the guidance error at each step.

  • track_progress_type (Literal[None, 'x_0', 'x_t']) – Intermediate output to store at each step. "x_0" stores the Tweedie-denoised estimate; "x_t" stores the noisy intermediate image; None disables tracking.

  • disable_jit – Whether to disable JIT compilation.

  • **kwargs – Additional keyword arguments forwarded to the guidance function and operator (e.g. omega, mask).

Returns:

Generated images of shape (n_images, *input_shape).

reverse_diffusion(initial_noise, diffusion_steps, initial_samples=None, initial_step=0, stochastic_sampling=False, seed=None, verbose=True, track_progress_type='x_0', disable_jit=False, training=False, network_type=None)[source]ΒΆ

Reverse diffusion process to generate images from noise.

Parameters:
  • initial_noise – Initial noise tensor of shape (n_images, *input_shape).

  • diffusion_steps (int) – Total number of diffusion steps.

  • initial_samples – Optional initial samples to warm-start the diffusion process. The diffusion process now starts from a noised version of these samples. When initial_step == 0, samples are noised at the maximum noise level (max_t). When initial_step > 0, samples are noised at the noise level corresponding to initial_step.

  • initial_step (int) – Step at which to begin the reverse diffusion loop. 0 runs all diffusion_steps steps from maximum noise. Higher values skip early (high-noise) steps and require initial_samples to be provided. Number of effective steps will be diffusion_steps - initial_step.

  • stochastic_sampling (bool) – Whether to use stochastic DDPM sampling instead of deterministic DDIM sampling.

  • seed (SeedGenerator | None) – Random seed generator.

  • verbose (bool) – Whether to show a Keras progress bar.

  • track_progress_type (Literal[None, 'x_0', 'x_t']) – Intermediate output to store at each step. "x_0" stores the Tweedie-denoised estimate; "x_t" stores the noisy intermediate image; None disables tracking.

  • disable_jit (bool) – Whether to disable JIT compilation.

  • training (bool) – Whether to call the network in training mode.

  • network_type (Literal[None, 'main', 'ema']) – Which network weights to use. "main" uses the online network, "ema" uses the exponential-moving-average network. If None, the choice is determined by the training argument.

Returns:

Generated images of shape (n_images, *input_shape).

reverse_diffusion_step(shape, pred_images, pred_noises, signal_rates, next_signal_rates, next_noise_rates, seed=None, stochastic_sampling=False)[source]ΒΆ

A single reverse diffusion step.

Parameters:
  • shape – Shape of the input tensor.

  • pred_images – Predicted images.

  • pred_noises – Predicted noises.

  • signal_rates – Current signal rates.

  • next_signal_rates – Next signal rates.

  • next_noise_rates – Next noise rates.

  • seed – Random seed generator.

  • stochastic_sampling – Whether to use stochastic sampling (DDPM).

Returns:

Noisy images after the reverse diffusion step.

Return type:

next_noisy_images

sample(n_samples=1, n_steps=20, seed=None, **kwargs)[source]ΒΆ

Sample from the model.

Parameters:
  • n_samples – Number of samples to generate.

  • n_steps – Number of diffusion steps.

  • seed – Random seed generator.

  • **kwargs – Additional arguments.

Returns:

Generated samples of shape (n_samples, *input_shape).

start_track_progress(diffusion_steps, initial_step=0)[source]ΒΆ

Initialize progress tracking for the diffusion process.

Resets track_progress and sets track_progress_interval so that at most 50 frames are stored during the diffusion trajectory (to keep memory usage bounded for large step counts).

Parameters:
  • diffusion_steps (int) – Total number of diffusion steps.

  • initial_step (int) – Step index at which reverse diffusion begins.

store_progress(step, track_progress_type, next_noisy_images, pred_images)[source]ΒΆ

Store an intermediate diffusion frame in track_progress.

Frames are stored every track_progress_interval steps. Does nothing when track_progress_type is None.

Parameters:
  • step (int) – Current diffusion step index.

  • track_progress_type (Literal[None, 'x_0', 'x_t']) – Which tensor to store. "x_0" stores the Tweedie-denoised estimate (predicted clean image); "x_t" stores the noisy intermediate image.

  • next_noisy_images – Noisy images x_t after the current step.

  • pred_images – Predicted clean images x_0 at the current step.

test_step(data)[source]ΒΆ

Custom test step so we can call model.fit() on the diffusion model.

train_step(data)[source]ΒΆ

Custom train step so we can call model.fit() on the diffusion model. .. note:: - Only implemented for the TensorFlow backend.

class zea.models.diffusion.NuclearDiffusion(diffusion_model, operator, disable_jit=False)[source]ΒΆ

Bases: DPS

Nuclear Diffusion posterior sampling guidance.

A hybrid framework that combines diffusion posterior sampling (DPS) with low-rank temporal modeling for video restoration. This method replaces the sparsity assumption in Robust Principal Component Analysis (RPCA) with a learned diffusion prior while maintaining a nuclear norm penalty on the background component to encourage low-rank temporal structure.

See also

Mathematical Formulation:

Given observations \(\mathbf{Y} \in \mathbb{R}^{n \times p}\) (video frames), Nuclear Diffusion jointly samples the signal \(\mathbf{X}\) and low-rank background \(\mathbf{L}\) from the posterior:

\[\mathbf{X}, \mathbf{L} \sim p_\theta(\mathbf{X}, \mathbf{L} \mid \mathbf{Y})\]

The posterior is factorized as:

\[p(\mathbf{Y}, \mathbf{L}, \mathbf{X}) = p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) \, p(\mathbf{L}) \, p_\theta(\mathbf{X})\]

where:

  • \(p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) = \mathcal{N}(\mathbf{Y}; \mathbf{L}+\mathbf{X}, \mu^{-1} \mathbf{I})\) is the likelihood (measurement model)

  • \(p(\mathbf{L}) \propto \exp(-\gamma \|\mathbf{L}\|_*)\) enforces low-rank structure via the nuclear norm \(\|\mathbf{L}\|_* = \sum_i \sigma_i(\mathbf{L})\)

  • \(p_\theta(\mathbf{X})\) is a learned diffusion prior capturing complex signal structure

The diffusion prior operates on individual frames \(\mathbf{x}^t \in \mathbb{R}^n\), while temporal dependencies are enforced through the nuclear norm on \(\mathbf{L}\).

This guidance method alternates between reverse diffusion and measurement-guided updates, computing gradients from both the measurement error and the nuclear norm penalty:

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model for the signal component.

  • operator (Operator) – Forward operator defining the measurement model.

  • disable_jit (bool) – Whether to disable JIT compilation.

Reference

T. Stevens, O. Nolan, J. L. Robert, and R. J. G. van Sloun, β€œNuclear Diffusion Models for Low-Rank Background Suppression in Videos,” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2026. https://arxiv.org/abs/2509.20886

Initialize the diffusion guidance.

Parameters:
  • diffusion_model (DiffusionModel) – The diffusion model to use for guidance.

  • operator (Operator) – The forward operator \(A\) that maps clean images to the measurement space.

  • disable_jit (bool) – Whether to disable JIT compilation of the guidance function.

__call__(noisy_images1, noisy_images2, measurements, noise_rates, signal_rates, omega=1.0, gamma=1.0, **kwargs)[source]ΒΆ

Compute guidance gradients for posterior sampling.

This method concatenates the noisy foreground and background images, computes the combined loss via compute_error(), and returns separate gradients for each component.

Parameters:
  • noisy_images1 – Noisy foreground images \(\mathbf{x}_t\) from the diffusion model, shape (batch, frames, H, W, C).

  • noisy_images2 – Noisy background images \(\mathbf{L}_t\), shape (batch, frames, H, W, C).

  • measurements – Target measurements \(\mathbf{Y}\), shape (batch, frames, H, W, C).

  • noise_rates – Current noise rates from diffusion schedule.

  • signal_rates – Current signal rates from diffusion schedule.

  • omega (float) – Weight for the measurement error term. Default is 1.0.

  • gamma (float) – Weight for the nuclear norm penalty term. Default is 1.0.

  • **kwargs – Additional arguments passed to compute_error() (e.g., gamma, rank_weight_factor, step, total_steps).

Returns:

  • gradients (tuple): (grad_foreground, grad_background) - gradients for foreground and background.

  • loss_info (tuple): (loss, aux) where:

    • loss (float): Combined loss value.

    • aux (tuple): Auxiliary outputs from compute_error().

Return type:

A tuple containing

compute_error(combined_images, measurements, noise_rates, signal_rates, omega=1.0, gamma=1.0, rank_weight_factor=None, step=None, total_steps=None, initial_step=100, max_alpha=0.5, **kwargs)[source]ΒΆ

Compute measurement error for joint diffusion posterior sampling.

Parameters:
  • combined_images – Concatenated noisy images, containing both foreground and background components, shape (batch, frames, H, W, 2C). In the context of cardiac ultrasound dehazing, the first C channels correspond to the tissue signal (foreground), and the next C channels correspond to the haze (background) component.

  • measurements – Target measurements \(\mathbf{Y}\), shape (batch, frames, H, W, C).

  • noise_rates – Current noise rates from the diffusion schedule, shape (batch, frames, 1, 1, 1).

  • signal_rates – Current signal rates from the diffusion schedule, shape (batch, frames, 1, 1, 1).

  • omega (float) – Weight \(\omega\) for the measurement error term (L2 reconstruction loss).

  • gamma (float) – Weight \(\gamma\) for the nuclear norm penalty term.

  • rank_weight_factor (float | None) – Optional weight factor for weighted_nuclear_norm_penalty(). If None, uses standard nuclear_norm_penalty().

  • step (int | None) – Current diffusion step for progressive blending. Used to compute \(\alpha(t)\).

  • total_steps (int | None) – Total number of diffusion steps.

  • initial_step (int) – Step at which to start progressive blending.

  • max_alpha (float) – Maximum value for \(\alpha\) at the final step. The alpha parameter mixes foreground and background predictions, but only after the initial_step to allow the diffusion model to first focus on generating the foreground signal before blending in the background component.

  • **kwargs – Additional arguments (unused).

Returns:

  • measurement_error (float): Combined loss \(\mathcal{L}\).

  • aux (tuple): Auxiliary outputs: (pred_noises_foreground, pred_images_foreground, noisy_background_images, l2_error, nuclear_penalty)

Return type:

A tuple containing

Note

The progressive blending factor \(\alpha(t)\) linearly increases from 0 at initial_step and plateaus at max_alpha once normalized progress reaches max_alpha, allowing the background component to gradually influence the reconstruction and then saturate for the remainder of sampling.

static nuclear_norm_penalty(background_images)[source]ΒΆ

Compute nuclear norm penalty for low-rank enforcement.

The nuclear norm (sum of singular values) encourages low-rank structure in the background component across time. For a matrix \(\mathbf{L}\), it is defined as:

\[\|\mathbf{L}\|_* = \sum_{i=1}^{r} \sigma_i(\mathbf{L})\]

where \(\sigma_i\) are the singular values and \(r\) is the rank.

Parameters:

background_images – Background images of shape (batch, frames, height, width, channels). Each sequence is reshaped to a matrix of shape (frames, height x width x channels) before computing the nuclear norm.

Returns:

Nuclear norm penalty summed across the batch and normalized by number of frames.

Note

The input is reshaped from (batch, frames, H, W, C) to (batch, frames, HxWxC) before computing the singular values.

static weighted_nuclear_norm_penalty(background_images, weight_factor=2.0)[source]ΒΆ

Compute weighted nuclear norm penalty with enhanced rank control.

This implements a WNNM-style (Weighted Nuclear Norm Minimization) penalty that penalizes smaller singular values more heavily than larger ones, suppressing the spectrum tail to enforce low-rank structure. The weighted penalty is:

\[\|\mathbf{L}\|_{w,*} = \sum_{i=1}^{r} w_i \cdot \sigma_i(\mathbf{L})\]

where \(w_i = 1 + \alpha \cdot \frac{i}{r}\) increases linearly with the index \(i\), and \(\alpha\) is the weight_factor. Since ops.svd returns singular values in descending order (\(\sigma_1 \geq \sigma_2 \geq \cdots\)), higher indices correspond to smaller singular values, which receive larger weights.

Parameters:
  • background_images – Background images of shape (batch, frames, height, width, channels).

  • weight_factor (float) – Scaling factor \(\alpha\) controlling how much more to penalize smaller singular values (the spectrum tail). Default is 2.0.

Returns:

Weighted nuclear norm penalty summed across the batch and normalized by number of frames.

Note

This is a drop-in replacement for nuclear_norm_penalty() that provides better rank control by more aggressively penalizing the tail of the singular value spectrum (smaller singular values) rather than the leading ones.