zea.models.diffusionΒΆ
Diffusion models for ultrasound image generation and posterior sampling.
To try this model, simply load one of the available presets:
>>> from zea.models.diffusion import DiffusionModel
>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
See also
A tutorial notebook where this model is used: Diffusion models for ultrasound image generation.
Classes
|
Decomposed Diffusion Sampling guidance. |
|
Diffusion Posterior Sampling guidance. |
|
Base class for diffusion guidance methods. |
|
Implementation of a diffusion generative model. |
|
Nuclear Diffusion posterior sampling guidance. |
- class zea.models.diffusion.DDS(diffusion_model, operator, disable_jit=False)[source]ΒΆ
Bases:
DiffusionGuidanceDecomposed Diffusion Sampling guidance.
Reference paper: https://arxiv.org/pdf/2303.05754
Initialize the diffusion guidance.
- Parameters:
diffusion_model (
DiffusionModel) β The diffusion model to use for guidance.operator (
Operator) β The forward operator \(A\) that maps clean images to the measurement space.disable_jit (
bool) β Whether to disable JIT compilation of the guidance function.
- __call__(noisy_images, measurements, noise_rates, signal_rates, n_inner=5, eps=1e-05, verbose=False, **op_kwargs)[source]ΒΆ
Run one DDS guidance step (public entry point).
Delegates to
call(), which may be JIT-compiled depending ondisable_jit.- Parameters:
noisy_images β Noisy images
x_tof shape(n_images, *input_shape).measurements β Target observations
y.noise_rates β Noise rates at the current diffusion time.
signal_rates β Signal rates at the current diffusion time.
n_inner (
int) β Number of conjugate gradient iterations. Default:5.eps (
float) β Convergence tolerance for the CG solver. Default:1e-5.verbose (
bool) β WhenTrue, compute and return the measurement error for logging. Default:False.**op_kwargs β Additional keyword arguments forwarded to the operator (e.g.
mask).
- Returns:
A
(gradients, (measurement_error, (pred_noises, pred_images)))tuple (seecall()).
- call(noisy_images, measurements, noise_rates, signal_rates, n_inner, eps, verbose, **op_kwargs)[source]ΒΆ
Run one DDS guidance step via conjugate gradient.
Denoises
x_tto obtain an initialx_0estimate, then refines it by solving the normal equations \(A^\top A\, x = A^\top y\) withn_innerconjugate gradient iterations.- Parameters:
noisy_images β Noisy images
x_tof shape(n_images, *input_shape).measurements β Target observations
y.noise_rates β Noise rates at the current diffusion time.
signal_rates β Signal rates at the current diffusion time.
n_inner (
int) β Number of conjugate gradient iterations.eps (
float) β Convergence tolerance; CG stops early when the residual norm falls below this threshold.verbose (
bool) β WhenTrue, compute and return the measurement errorβy - A(xΜ_0)β. WhenFalse, the error is returned as0.0to avoid extra computation.**op_kwargs β Additional keyword arguments forwarded to the operator (e.g.
mask).
- Returns:
A
(gradients, (measurement_error, (pred_noises, pred_images)))tuple.gradientsis a zero tensor because DDS performs guidance entirely inside the CG loop; the caller subtracts it as a no-op.
- conjugate_gradient_inner_loop(i, loop_state, eps=1e-05)[source]ΒΆ
A single iteration of the conjugate gradient method. This involves minimizing the error of x along the current search vector p, and then choosing the next search vector.
Reference code from: https://github.com/svi-diffusion/
- class zea.models.diffusion.DPS(diffusion_model, operator, disable_jit=False)[source]ΒΆ
Bases:
DiffusionGuidanceDiffusion Posterior Sampling guidance.
Initialize the diffusion guidance.
- Parameters:
diffusion_model (
DiffusionModel) β The diffusion model to use for guidance.operator (
Operator) β The forward operator \(A\) that maps clean images to the measurement space.disable_jit (
bool) β Whether to disable JIT compilation of the guidance function.
- __call__(noisy_images, **kwargs)[source]ΒΆ
Compute DPS gradients and denoiser outputs.
Calls the JIT-compiled gradient function obtained from
setup().- Parameters:
noisy_images β Noisy images
x_tof shape(n_images, *input_shape).**kwargs β Keyword arguments forwarded to
compute_error()(measurements,noise_rates,signal_rates,omega, and any operator kwargs such asmask).
- Returns:
A
(gradients, (measurement_error, (pred_noises, pred_images)))tuple.gradientsis the gradient of the measurement error w.r.t.noisy_imagesand can be subtracted directly from the reverse-diffusion update.
- compute_error(noisy_images, measurements, noise_rates, signal_rates, omega, **kwargs)[source]ΒΆ
Compute the DPS measurement error for gradient computation.
Following the DPS implementation, the loss is a standard L2 norm.
- Parameters:
noisy_images β Noisy images
x_tof shape(n_images, *input_shape).measurements β Target observations
y.noise_rates β Noise rates at the current diffusion time.
signal_rates β Signal rates at the current diffusion time.
omega (
float) β Scalar step-size weight for the measurement gradient.**kwargs β Additional keyword arguments forwarded to the operatorβs
forwardmethod (e.g.mask).
- Returns:
A
(measurement_error, (pred_noises, pred_images))tuple wheremeasurement_erroris the scalar loss andpred_noises/pred_imagesare the denoiser outputs.
- class zea.models.diffusion.DiffusionGuidance(diffusion_model, operator, disable_jit=False)[source]ΒΆ
Bases:
ABC,ObjectBase class for diffusion guidance methods.
Initialize the diffusion guidance.
- Parameters:
diffusion_model (
DiffusionModel) β The diffusion model to use for guidance.operator (
Operator) β The forward operator \(A\) that maps clean images to the measurement space.disable_jit (
bool) β Whether to disable JIT compilation of the guidance function.
- class zea.models.diffusion.DiffusionModel(*args, **kwargs)[source]ΒΆ
Bases:
DeepGenerativeModelImplementation of a diffusion generative model. Heavily inspired from https://keras.io/examples/generative/ddim/
Initialize a diffusion model.
- Parameters:
input_shape β Shape of the input data. Typically of the form (height, width, channels) for images.
input_range β Range of the input data.
min_signal_rate β Minimum signal rate for the diffusion schedule.
max_signal_rate β Maximum signal rate for the diffusion schedule.
network_name β Name of the network architecture to use. Options are βunet_time_conditionalβ or βdense_time_conditionalβ.
network_kwargs β Additional keyword arguments for the network.
name β Name of the model.
guidance β Guidance method to use. Can be a string, or dict with βnameβ and βparamsβ keys. Additionally, can be a DiffusionGuidance object.
operator β Operator to use. Can be a string, or dict with βnameβ and βparamsβ keys. Additionally, can be a Operator object.
ema_val β Exponential moving average value for the network weights.
min_t β Minimum diffusion time for sampling during training.
max_t β Maximum diffusion time for sampling during training.
**kwargs β Additional arguments.
- call(inputs, training=False, network=None, **kwargs)[source]ΒΆ
Call the score network.
- Parameters:
inputs β A list
[noisy_images, noise_rates_squared]as expected by the underlying time-conditional network.training (
bool) β Whether to run in training mode. WhenFalseandnetworkisNone, the EMA network is used.network β Explicit network to call. If
None, the EMA network is used during inference and the online network during training.**kwargs β Extra keyword arguments forwarded to the network.
- Returns:
Predicted noise tensor of the same shape as the input images.
- denoise(noisy_images, noise_rates, signal_rates, training, network=None)[source]ΒΆ
Predict the noise component and derive the clean-image estimate.
Uses the score network to predict the noise
Ξ΅inx_t, then computes the Tweedie estimate ofx_0.- Parameters:
noisy_images β Noisy images
x_tof shape(n_images, *input_shape).noise_rates β Noise rates at the current diffusion time, broadcastable to
noisy_images.signal_rates β Signal rates at the current diffusion time, broadcastable to
noisy_images.training (
bool) β Whether to call the network in training mode.network β Explicit network to use. If
None, chosen based ontraining(seecall()).
- Returns:
A
(pred_noises, pred_images)tuple wherepred_noisesis the predicted noiseΞ΅andpred_imagesis the Tweedie estimate ofx_0.
- diffusion_schedule(diffusion_times)[source]ΒΆ
Cosine diffusion schedule.
Implements the cosine schedule from Nichol & Dhariwal (2021).
The noisy image at time
tis defined as:- Parameters:
diffusion_times β Tensor of diffusion times in
[min_t, max_t].- Returns:
A
(noise_rates, signal_rates)tuple of tensors with the same shape asdiffusion_times.
- get_config()[source]ΒΆ
Returns the config of the object.
An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.
- log_likelihood(data, **kwargs)[source]ΒΆ
Approximate log-likelihood of the data under the model.
- Parameters:
data β Data to compute log-likelihood for.
**kwargs β Additional arguments.
- Returns:
Approximate log-likelihood.
- property metricsΒΆ
Metrics for training.
- posterior_sample(measurements, n_samples=1, n_steps=20, initial_step=0, initial_samples=None, seed=None, **kwargs)[source]ΒΆ
Sample from the posterior distribution given measurements.
- Parameters:
measurements β Input measurements. Typically of shape
(batch_size, *input_shape).n_samples (
int) β Number of posterior samples to generate. Will generaten_samplessamples for each measurement in themeasurementsbatch.n_steps (
int) β Number of diffusion steps.initial_step (
int) β Step at which to begin the reverse diffusion loop.0runs alldiffusion_stepssteps from maximum noise. Higher values skip early (high-noise) steps and requireinitial_samplesto be provided. Number of effective steps will bediffusion_steps - initial_step.initial_samples β Optional initial samples to warm-start the diffusion process. The diffusion process now starts from a noised version of these samples. This can be used to speed up the diffusion process. When
initial_step == 0, samples are noised at the maximum noise level (max_t). Wheninitial_step > 0, samples are noised at the noise level corresponding toinitial_step. Theseinitial_samplescan be initial guesses such as solutions of previous frames (for sequences), see for instance SeqDiff. Must be of shape(batch_size, n_samples, *input_shape).seed β Random seed generator.
**kwargs β Additional arguments passed to
reverse_conditional_diffusion().
- Returns:
Posterior samples
p(x|y)of shape(batch_size, n_samples, *input_shape).
- prepare_diffusion(diffusion_steps, initial_step, verbose, disable_jit=False)[source]ΒΆ
Prepare the diffusion process.
Validates
initial_step, computes the step size, and optionally creates a Keras progress bar.- Parameters:
diffusion_steps (
int) β Total number of diffusion steps.initial_step (
int) β Step index at which reverse diffusion begins. Must satisfy0 <= initial_step < diffusion_steps.verbose (
bool) β Whether to create a KerasProgbar.disable_jit (
bool) β WhenTrue, skip theinitial_steprange assertions (required when values are runtime tensors).
- Returns:
A
(step_size, progbar)tuple wherestep_sizeis the uniform time increment per step andprogbaris aProgbarinstance orNone.
- prepare_schedule(base_diffusion_times, initial_noise, initial_samples, initial_step, step_size)[source]ΒΆ
Prepare the starting noisy images for the reverse diffusion loop.
Constructs the initial
x_ttensor that is fed into the first diffusion step. Three cases are handled:initial_samplesprovided andinitial_step > 0: samples are mixed with noise at the noise level that corresponds toinitial_step, skipping the highest-noise diffusion steps.initial_samplesprovided andinitial_step == 0: samples are mixed with noise at the maximum noise level (max_t), running the full diffusion process from a noised version of the samples.xinitial_samples is Noneandinitial_step == 0: the starting point is pure noise (initial_noise).
- Parameters:
base_diffusion_times β Tensor of shape
(n_images, *[1]*n_dims)filled withmax_t.initial_noise β Pure noise tensor of shape
(n_images, *input_shape).initial_samples β Optional samples of shape
(n_images, *input_shape). The diffusion process always starts from a noised version of these samples.initial_step (
int) β Step index at which reverse diffusion begins.step_size (
float) β Uniform time increment per diffusion step.
- Returns:
Noisy images tensor of shape
(n_images, *input_shape)to use as the starting pointx_tfor the diffusion loop.
- reverse_conditional_diffusion(measurements, initial_noise, diffusion_steps, initial_samples=None, initial_step=0, stochastic_sampling=False, seed=None, verbose=False, track_progress_type='x_0', disable_jit=False, **kwargs)[source]ΒΆ
Reverse diffusion process conditioned on some measurement.
Effectively performs diffusion posterior sampling
p(x_0 | y)by interleaving reverse diffusion steps with gradient-based guidance (e.g. DPS or DDS).- Parameters:
measurements β Conditioning observations of shape
(n_images, *measurement_shape).initial_noise β Initial noise tensor of shape
(n_images, *input_shape).diffusion_steps (
int) β Total number of diffusion steps.initial_samples β Optional initial samples to warm-start the diffusion process. The diffusion process now starts from a noised version of these samples. When
initial_step == 0, samples are noised at the maximum noise level (max_t). Wheninitial_step > 0, samples are noised at the noise level corresponding toinitial_step.initial_step (
int) β Step at which to begin the reverse diffusion loop.0runs alldiffusion_stepssteps from maximum noise. Higher values skip early (high-noise) steps and requireinitial_samplesto be provided. Number of effective steps will bediffusion_steps - initial_step.stochastic_sampling (
bool) β Whether to use stochastic DDPM sampling instead of deterministic DDIM sampling.seed β Random seed generator.
verbose (
bool) β Whether to show a Keras progress bar with the guidance error at each step.track_progress_type (
Literal[None,'x_0','x_t']) β Intermediate output to store at each step."x_0"stores the Tweedie-denoised estimate;"x_t"stores the noisy intermediate image;Nonedisables tracking.disable_jit β Whether to disable JIT compilation.
**kwargs β Additional keyword arguments forwarded to the guidance function and operator (e.g.
omega,mask).
- Returns:
Generated images of shape
(n_images, *input_shape).
- reverse_diffusion(initial_noise, diffusion_steps, initial_samples=None, initial_step=0, stochastic_sampling=False, seed=None, verbose=True, track_progress_type='x_0', disable_jit=False, training=False, network_type=None)[source]ΒΆ
Reverse diffusion process to generate images from noise.
- Parameters:
initial_noise β Initial noise tensor of shape
(n_images, *input_shape).diffusion_steps (
int) β Total number of diffusion steps.initial_samples β Optional initial samples to warm-start the diffusion process. The diffusion process now starts from a noised version of these samples. When
initial_step == 0, samples are noised at the maximum noise level (max_t). Wheninitial_step > 0, samples are noised at the noise level corresponding toinitial_step.initial_step (
int) β Step at which to begin the reverse diffusion loop.0runs alldiffusion_stepssteps from maximum noise. Higher values skip early (high-noise) steps and requireinitial_samplesto be provided. Number of effective steps will bediffusion_steps - initial_step.stochastic_sampling (
bool) β Whether to use stochastic DDPM sampling instead of deterministic DDIM sampling.seed (
SeedGenerator|None) β Random seed generator.verbose (
bool) β Whether to show a Keras progress bar.track_progress_type (
Literal[None,'x_0','x_t']) β Intermediate output to store at each step."x_0"stores the Tweedie-denoised estimate;"x_t"stores the noisy intermediate image;Nonedisables tracking.disable_jit (
bool) β Whether to disable JIT compilation.training (
bool) β Whether to call the network in training mode.network_type (
Literal[None,'main','ema']) β Which network weights to use."main"uses the online network,"ema"uses the exponential-moving-average network. IfNone, the choice is determined by thetrainingargument.
- Returns:
Generated images of shape
(n_images, *input_shape).
- reverse_diffusion_step(shape, pred_images, pred_noises, signal_rates, next_signal_rates, next_noise_rates, seed=None, stochastic_sampling=False)[source]ΒΆ
A single reverse diffusion step.
- Parameters:
shape β Shape of the input tensor.
pred_images β Predicted images.
pred_noises β Predicted noises.
signal_rates β Current signal rates.
next_signal_rates β Next signal rates.
next_noise_rates β Next noise rates.
seed β Random seed generator.
stochastic_sampling β Whether to use stochastic sampling (DDPM).
- Returns:
Noisy images after the reverse diffusion step.
- Return type:
next_noisy_images
- sample(n_samples=1, n_steps=20, seed=None, **kwargs)[source]ΒΆ
Sample from the model.
- Parameters:
n_samples β Number of samples to generate.
n_steps β Number of diffusion steps.
seed β Random seed generator.
**kwargs β Additional arguments.
- Returns:
Generated samples of shape (n_samples, *input_shape).
- start_track_progress(diffusion_steps, initial_step=0)[source]ΒΆ
Initialize progress tracking for the diffusion process.
Resets
track_progressand setstrack_progress_intervalso that at most 50 frames are stored during the diffusion trajectory (to keep memory usage bounded for large step counts).- Parameters:
diffusion_steps (
int) β Total number of diffusion steps.initial_step (
int) β Step index at which reverse diffusion begins.
- store_progress(step, track_progress_type, next_noisy_images, pred_images)[source]ΒΆ
Store an intermediate diffusion frame in
track_progress.Frames are stored every
track_progress_intervalsteps. Does nothing whentrack_progress_typeisNone.- Parameters:
step (
int) β Current diffusion step index.track_progress_type (
Literal[None,'x_0','x_t']) β Which tensor to store."x_0"stores the Tweedie-denoised estimate (predicted clean image);"x_t"stores the noisy intermediate image.next_noisy_images β Noisy images
x_tafter the current step.pred_images β Predicted clean images
x_0at the current step.
- class zea.models.diffusion.NuclearDiffusion(diffusion_model, operator, disable_jit=False)[source]ΒΆ
Bases:
DPSNuclear Diffusion posterior sampling guidance.
A hybrid framework that combines diffusion posterior sampling (DPS) with low-rank temporal modeling for video restoration. This method replaces the sparsity assumption in Robust Principal Component Analysis (RPCA) with a learned diffusion prior while maintaining a nuclear norm penalty on the background component to encourage low-rank temporal structure.
See also
dehaze_nuclear_diffusion(): The dehazing application of this methodNuclear diffusion models for ultrasound dehazing: Example notebook demonstrating the method on cardiac ultrasound dehazing
DPS: Base diffusion posterior sampling guidance
Mathematical Formulation:
Given observations \(\mathbf{Y} \in \mathbb{R}^{n \times p}\) (video frames), Nuclear Diffusion jointly samples the signal \(\mathbf{X}\) and low-rank background \(\mathbf{L}\) from the posterior:
\[\mathbf{X}, \mathbf{L} \sim p_\theta(\mathbf{X}, \mathbf{L} \mid \mathbf{Y})\]The posterior is factorized as:
\[p(\mathbf{Y}, \mathbf{L}, \mathbf{X}) = p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) \, p(\mathbf{L}) \, p_\theta(\mathbf{X})\]where:
\(p(\mathbf{Y} \mid \mathbf{L}, \mathbf{X}) = \mathcal{N}(\mathbf{Y}; \mathbf{L}+\mathbf{X}, \mu^{-1} \mathbf{I})\) is the likelihood (measurement model)
\(p(\mathbf{L}) \propto \exp(-\gamma \|\mathbf{L}\|_*)\) enforces low-rank structure via the nuclear norm \(\|\mathbf{L}\|_* = \sum_i \sigma_i(\mathbf{L})\)
\(p_\theta(\mathbf{X})\) is a learned diffusion prior capturing complex signal structure
The diffusion prior operates on individual frames \(\mathbf{x}^t \in \mathbb{R}^n\), while temporal dependencies are enforced through the nuclear norm on \(\mathbf{L}\).
This guidance method alternates between reverse diffusion and measurement-guided updates, computing gradients from both the measurement error and the nuclear norm penalty:
- Parameters:
diffusion_model (
DiffusionModel) β The diffusion model for the signal component.operator (
Operator) β Forward operator defining the measurement model.disable_jit (
bool) β Whether to disable JIT compilation.
Reference
T. Stevens, O. Nolan, J. L. Robert, and R. J. G. van Sloun, βNuclear Diffusion Models for Low-Rank Background Suppression in Videos,β IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2026. https://arxiv.org/abs/2509.20886
Initialize the diffusion guidance.
- Parameters:
diffusion_model (
DiffusionModel) β The diffusion model to use for guidance.operator (
Operator) β The forward operator \(A\) that maps clean images to the measurement space.disable_jit (
bool) β Whether to disable JIT compilation of the guidance function.
- __call__(noisy_images1, noisy_images2, measurements, noise_rates, signal_rates, omega=1.0, gamma=1.0, **kwargs)[source]ΒΆ
Compute guidance gradients for posterior sampling.
This method concatenates the noisy foreground and background images, computes the combined loss via
compute_error(), and returns separate gradients for each component.- Parameters:
noisy_images1 β Noisy foreground images \(\mathbf{x}_t\) from the diffusion model, shape
(batch, frames, H, W, C).noisy_images2 β Noisy background images \(\mathbf{L}_t\), shape
(batch, frames, H, W, C).measurements β Target measurements \(\mathbf{Y}\), shape
(batch, frames, H, W, C).noise_rates β Current noise rates from diffusion schedule.
signal_rates β Current signal rates from diffusion schedule.
omega (
float) β Weight for the measurement error term. Default is 1.0.gamma (
float) β Weight for the nuclear norm penalty term. Default is 1.0.**kwargs β Additional arguments passed to
compute_error()(e.g.,gamma,rank_weight_factor,step,total_steps).
- Returns:
gradients (tuple):
(grad_foreground, grad_background)- gradients for foreground and background.loss_info (tuple):
(loss, aux)where:loss (float): Combined loss value.
aux (tuple): Auxiliary outputs from
compute_error().
- Return type:
A tuple containing
- compute_error(combined_images, measurements, noise_rates, signal_rates, omega=1.0, gamma=1.0, rank_weight_factor=None, step=None, total_steps=None, initial_step=100, max_alpha=0.5, **kwargs)[source]ΒΆ
Compute measurement error for joint diffusion posterior sampling.
- Parameters:
combined_images β Concatenated noisy images, containing both foreground and background components, shape
(batch, frames, H, W, 2C). In the context of cardiac ultrasound dehazing, the first C channels correspond to the tissue signal (foreground), and the next C channels correspond to the haze (background) component.measurements β Target measurements \(\mathbf{Y}\), shape
(batch, frames, H, W, C).noise_rates β Current noise rates from the diffusion schedule, shape
(batch, frames, 1, 1, 1).signal_rates β Current signal rates from the diffusion schedule, shape
(batch, frames, 1, 1, 1).omega (
float) β Weight \(\omega\) for the measurement error term (L2 reconstruction loss).gamma (
float) β Weight \(\gamma\) for the nuclear norm penalty term.rank_weight_factor (
float|None) β Optional weight factor forweighted_nuclear_norm_penalty(). IfNone, uses standardnuclear_norm_penalty().step (
int|None) β Current diffusion step for progressive blending. Used to compute \(\alpha(t)\).total_steps (
int|None) β Total number of diffusion steps.initial_step (
int) β Step at which to start progressive blending.max_alpha (
float) β Maximum value for \(\alpha\) at the final step. The alpha parameter mixes foreground and background predictions, but only after the initial_step to allow the diffusion model to first focus on generating the foreground signal before blending in the background component.**kwargs β Additional arguments (unused).
- Returns:
measurement_error (float): Combined loss \(\mathcal{L}\).
aux (tuple): Auxiliary outputs:
(pred_noises_foreground, pred_images_foreground, noisy_background_images, l2_error, nuclear_penalty)
- Return type:
A tuple containing
Note
The progressive blending factor \(\alpha(t)\) linearly increases from 0 at
initial_stepand plateaus atmax_alphaonce normalized progress reachesmax_alpha, allowing the background component to gradually influence the reconstruction and then saturate for the remainder of sampling.
- static nuclear_norm_penalty(background_images)[source]ΒΆ
Compute nuclear norm penalty for low-rank enforcement.
The nuclear norm (sum of singular values) encourages low-rank structure in the background component across time. For a matrix \(\mathbf{L}\), it is defined as:
\[\|\mathbf{L}\|_* = \sum_{i=1}^{r} \sigma_i(\mathbf{L})\]where \(\sigma_i\) are the singular values and \(r\) is the rank.
- Parameters:
background_images β Background images of shape
(batch, frames, height, width, channels). Each sequence is reshaped to a matrix of shape(frames, height x width x channels)before computing the nuclear norm.- Returns:
Nuclear norm penalty summed across the batch and normalized by number of frames.
Note
The input is reshaped from
(batch, frames, H, W, C)to(batch, frames, HxWxC)before computing the singular values.
- static weighted_nuclear_norm_penalty(background_images, weight_factor=2.0)[source]ΒΆ
Compute weighted nuclear norm penalty with enhanced rank control.
This implements a WNNM-style (Weighted Nuclear Norm Minimization) penalty that penalizes smaller singular values more heavily than larger ones, suppressing the spectrum tail to enforce low-rank structure. The weighted penalty is:
\[\|\mathbf{L}\|_{w,*} = \sum_{i=1}^{r} w_i \cdot \sigma_i(\mathbf{L})\]where \(w_i = 1 + \alpha \cdot \frac{i}{r}\) increases linearly with the index \(i\), and \(\alpha\) is the
weight_factor. Sinceops.svdreturns singular values in descending order (\(\sigma_1 \geq \sigma_2 \geq \cdots\)), higher indices correspond to smaller singular values, which receive larger weights.- Parameters:
background_images β Background images of shape
(batch, frames, height, width, channels).weight_factor (
float) β Scaling factor \(\alpha\) controlling how much more to penalize smaller singular values (the spectrum tail). Default is 2.0.
- Returns:
Weighted nuclear norm penalty summed across the batch and normalized by number of frames.
Note
This is a drop-in replacement for
nuclear_norm_penalty()that provides better rank control by more aggressively penalizing the tail of the singular value spectrum (smaller singular values) rather than the leading ones.