zea.models.dit¶
Diffusion Transformer (DiT) backend.
This module provides a Diffusion Transformer (Peebles & Xie, 2023) network that can be
used as a drop-in backend for DiffusionModel and
FlowMatchingModel.
The network exposes exactly the same call signature as the time-conditional
UNet backend (get_time_conditional_unetwork()): it takes
a list [noisy_images, time_scalar] where noisy_images has shape
(batch, height, width, channels) and time_scalar has shape
(batch, 1, 1, 1), and returns a tensor with the same shape as
noisy_images. This makes it interchangeable with the UNet backend without
any changes to the sampling, training, or guidance machinery.
The architecture follows the original DiT with adaLN-Zero conditioning:
The image is split into non-overlapping patches and linearly embedded into a sequence of tokens (patch embedding via a strided convolution).
Learnable positional embeddings are added to the tokens.
The (scalar) diffusion time is embedded with a sinusoidal embedding followed by an MLP to produce a conditioning vector
c.A stack of transformer blocks processes the tokens. Each block modulates its layer-normalised activations with shift/scale/gate parameters regressed from
c(adaptive layer norm, zero-initialised so the block starts as the identity).A final adaLN-modulated linear layer projects each token back to its pixel patch, and the patches are reassembled (unpatchified) into an image.
See also
Peebles & Xie, Scalable Diffusion Models with Transformers, 2023. https://arxiv.org/abs/2212.09748
Functions
|
Build a time-conditional Diffusion Transformer (DiT) network. |
|
Apply adaptive layer-norm modulation. |
Classes
|
Add learnable positional embeddings to a sequence of tokens. |
|
A single DiT transformer block with adaLN-Zero conditioning. |
|
Diffusion Transformer with time-conditional (adaLN-Zero) embedding. |
|
Final adaLN-modulated projection from tokens back to pixel patches. |
|
Reassemble a sequence of pixel patches into an image. |
- class zea.models.dit.AddPositionEmbedding(*args, **kwargs)[source]¶
Bases:
LayerAdd learnable positional embeddings to a sequence of tokens.
- class zea.models.dit.DiTBlock(*args, **kwargs)[source]¶
Bases:
LayerA single DiT transformer block with adaLN-Zero conditioning.
- class zea.models.dit.DiTTimeConditional(*args, **kwargs)[source]¶
Bases:
BaseModelDiffusion Transformer with time-conditional (adaLN-Zero) embedding.
- class zea.models.dit.FinalLayer(*args, **kwargs)[source]¶
Bases:
LayerFinal adaLN-modulated projection from tokens back to pixel patches.
- class zea.models.dit.Unpatchify(*args, **kwargs)[source]¶
Bases:
LayerReassemble a sequence of pixel patches into an image.
- zea.models.dit.get_time_conditional_dit_network(image_shape, patch_size=8, hidden_size=384, depth=12, num_heads=6, mlp_ratio=4.0, embedding_min_frequency=1.0, embedding_max_frequency=1000.0, embedding_dims=256)[source]¶
Build a time-conditional Diffusion Transformer (DiT) network.
The returned model has the same input/output contract as
get_time_conditional_unetwork(), so it can be used interchangeably as a backend for diffusion / flow-matching models.- Parameters:
image_shape – Tuple
(height, width, channels). Bothheightandwidthmust be divisible bypatch_size.patch_size – Side length of the (square) image patches.
hidden_size – Token embedding dimension. Must be divisible by
num_heads.depth – Number of transformer blocks.
num_heads – Number of attention heads.
mlp_ratio – Hidden-dimension expansion ratio of the per-token MLP.
embedding_min_frequency – Minimum frequency for the sinusoidal time embedding.
embedding_max_frequency – Maximum frequency for the sinusoidal time embedding.
embedding_dims – Dimensionality of the sinusoidal time embedding (must be even).
- Returns:
A functional model mapping
[noisy_images, time_scalar]to a tensor with the same shape asnoisy_images.- Return type:
keras.Model
- zea.models.dit.modulate(x, shift, scale)[source]¶
Apply adaptive layer-norm modulation.
- Parameters:
x – Token tensor of shape
(batch, num_tokens, hidden_size).shift – Shift tensor of shape
(batch, hidden_size).scale – Scale tensor of shape
(batch, hidden_size).
- Returns:
Modulated tensor
x * (1 + scale) + shiftof the same shape asx.