zea.models.taesd
Tiny Autoencoder (TAESD) model.
>>> from zea.models.taesd import TinyAutoencoder
>>> model = TinyAutoencoder.from_preset("taesdxl")
Important
This is a zea implementation of the model.
For the original code, see here.
Classes
TinyAutoencoder(*args, **kwargs)
|
Tiny Autoencoder model. |
TinyBase(*args, **kwargs)
|
Base class for TAESD encoder and decoder. |
TinyDecoder(*args, **kwargs)
|
Decoder from TAESD model. |
TinyEncoder(*args, **kwargs)
|
Encoder from TAESD model. |
-
class zea.models.taesd.TinyAutoencoder(*args, **kwargs)[source]
Bases: BaseModel
Tiny Autoencoder model.
Note
This model currently only supports TensorFlow and Jax backends.
Initializes the TAESD model with the given parameters.
- Parameters:
**kwargs – Additional keyword arguments to pass to the superclass initializer.
-
call(inputs)[source]
Applies the full autoencoder to the input.
-
custom_load_weights(preset, **kwargs)[source]
Load the weights for the encoder and decoder.
-
decode(inputs)[source]
Decode the encoded images.
- Parameters:
inputs (tensor) – Input images of shape (batch_size, height, width, 4).
-
encode(inputs)[source]
Encode the input images.
- Parameters:
inputs (tensor) – Input images of shape (batch_size, height, width, channels).
-
class zea.models.taesd.TinyBase(*args, **kwargs)[source]
Bases: BaseModel
Base class for TAESD encoder and decoder.
-
build(input_shape)[source]
Builds the network.
-
call(inputs)[source]
Applies the network to the input.
-
custom_load_weights(preset, **kwargs)[source]
Load the weights for the encoder or decoder.
-
maybe_convert_to_jax(input_shape)[source]
Converts the network to Jax if backend is Jax.
-
class zea.models.taesd.TinyDecoder(*args, **kwargs)[source]
Bases: TinyBase
Decoder from TAESD model.
Initializes the TAESD decoder.
- Parameters:
**kwargs – Additional keyword arguments passed to the superclass initializer.
-
class zea.models.taesd.TinyEncoder(*args, **kwargs)[source]
Bases: TinyBase
Encoder from TAESD model.
Initializes the TAESD encoder.
- Parameters:
**kwargs – Additional keyword arguments passed to the superclass initializer.