zea.models.dit¶

Diffusion Transformer (DiT) backend.

This module provides a Diffusion Transformer (Peebles & Xie, 2023) network that can be used as a drop-in backend for DiffusionModel and FlowMatchingModel.

The network exposes exactly the same call signature as the time-conditional UNet backend (get_time_conditional_unetwork()): it takes a list [noisy_images, time_scalar] where noisy_images has shape (batch, height, width, channels) and time_scalar has shape (batch, 1, 1, 1), and returns a tensor with the same shape as noisy_images. This makes it interchangeable with the UNet backend without any changes to the sampling, training, or guidance machinery.

The architecture follows the original DiT with adaLN-Zero conditioning:

  1. The image is split into non-overlapping patches and linearly embedded into a sequence of tokens (patch embedding via a strided convolution).

  2. Learnable positional embeddings are added to the tokens.

  3. The (scalar) diffusion time is embedded with a sinusoidal embedding followed by an MLP to produce a conditioning vector c.

  4. A stack of transformer blocks processes the tokens. Each block modulates its layer-normalised activations with shift/scale/gate parameters regressed from c (adaptive layer norm, zero-initialised so the block starts as the identity).

  5. A final adaLN-modulated linear layer projects each token back to its pixel patch, and the patches are reassembled (unpatchified) into an image.

See also

Peebles & Xie, Scalable Diffusion Models with Transformers, 2023. https://arxiv.org/abs/2212.09748

Functions

get_time_conditional_dit_network(image_shape)

Build a time-conditional Diffusion Transformer (DiT) network.

modulate(x, shift, scale)

Apply adaptive layer-norm modulation.

Classes

AddPositionEmbedding(*args, **kwargs)

Add learnable positional embeddings to a sequence of tokens.

DiTBlock(*args, **kwargs)

A single DiT transformer block with adaLN-Zero conditioning.

DiTTimeConditional(*args, **kwargs)

Diffusion Transformer with time-conditional (adaLN-Zero) embedding.

FinalLayer(*args, **kwargs)

Final adaLN-modulated projection from tokens back to pixel patches.

Unpatchify(*args, **kwargs)

Reassemble a sequence of pixel patches into an image.

class zea.models.dit.AddPositionEmbedding(*args, **kwargs)[source]¶

Bases: Layer

Add learnable positional embeddings to a sequence of tokens.

build(input_shape)[source]¶
call(x)[source]¶
class zea.models.dit.DiTBlock(*args, **kwargs)[source]¶

Bases: Layer

A single DiT transformer block with adaLN-Zero conditioning.

call(x, c)[source]¶
get_config()[source]¶

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

class zea.models.dit.DiTTimeConditional(*args, **kwargs)[source]¶

Bases: BaseModel

Diffusion Transformer with time-conditional (adaLN-Zero) embedding.

call(*args, **kwargs)[source]¶
get_config()[source]¶

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

class zea.models.dit.FinalLayer(*args, **kwargs)[source]¶

Bases: Layer

Final adaLN-modulated projection from tokens back to pixel patches.

build(input_shape)[source]¶
call(x, c)[source]¶
get_config()[source]¶

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

class zea.models.dit.Unpatchify(*args, **kwargs)[source]¶

Bases: Layer

Reassemble a sequence of pixel patches into an image.

call(x)[source]¶
get_config()[source]¶

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

zea.models.dit.get_time_conditional_dit_network(image_shape, patch_size=8, hidden_size=384, depth=12, num_heads=6, mlp_ratio=4.0, embedding_min_frequency=1.0, embedding_max_frequency=1000.0, embedding_dims=256)[source]¶

Build a time-conditional Diffusion Transformer (DiT) network.

The returned model has the same input/output contract as get_time_conditional_unetwork(), so it can be used interchangeably as a backend for diffusion / flow-matching models.

Parameters:
  • image_shape – Tuple (height, width, channels). Both height and width must be divisible by patch_size.

  • patch_size – Side length of the (square) image patches.

  • hidden_size – Token embedding dimension. Must be divisible by num_heads.

  • depth – Number of transformer blocks.

  • num_heads – Number of attention heads.

  • mlp_ratio – Hidden-dimension expansion ratio of the per-token MLP.

  • embedding_min_frequency – Minimum frequency for the sinusoidal time embedding.

  • embedding_max_frequency – Maximum frequency for the sinusoidal time embedding.

  • embedding_dims – Dimensionality of the sinusoidal time embedding (must be even).

Returns:

A functional model mapping [noisy_images, time_scalar] to a tensor with the same shape as noisy_images.

Return type:

keras.Model

zea.models.dit.modulate(x, shift, scale)[source]¶

Apply adaptive layer-norm modulation.

Parameters:
  • x – Token tensor of shape (batch, num_tokens, hidden_size).

  • shift – Shift tensor of shape (batch, hidden_size).

  • scale – Scale tensor of shape (batch, hidden_size).

Returns:

Modulated tensor x * (1 + scale) + shift of the same shape as x.