zea.backend¶
Backend utilities for zea.
Note
Most tensor operations are handled by Keras 3. This module only wraps the features that Keras does not expose directly: JIT compilation, automatic differentiation, and device placement.
Public API¶
jit()Unified JIT compilation for JAX (
jax.jit) and TensorFlow (tf.function). A no-op for thetorchbackend.deviceContext manager that pins all Keras ops to a specific device. Re-exported as
zea.device().func_on_device()Run a callable with its tensor arguments moved to a target device. For
torchthis also calls.to(device)on every input tensor.AutoGradBackend-agnostic automatic differentiation wrapper.
Functions
|
Run |
|
Applies JIT compilation to the given function based on the current Keras backend. |
|
Applies default tf.function to the given function. |
Classes
|
Context manager to run operations on a specific device, regardless of backend. |
- class zea.backend.device(device)[source]¶
Bases:
objectContext manager to run operations on a specific device, regardless of backend.
Normalises device strings across JAX, TensorFlow, and PyTorch so that
'gpu:0','cuda:0'and'cpu'all work with every backend, then delegates tokeras.device()which handles the per-backend dispatch.For the
torchbackend,keras.device()sets Keras’s internal device-tracking state so that tensors created by Keras ops land on the correct device. Existing input tensors are not moved automatically — usepipeline(device=..., **inputs)orzea.backend.func_on_device()when you also need to relocate pre-existing tensors.- Parameters:
device (
str) – Device string, e.g.'cuda:0','gpu:0', or'cpu'.
Example
# All backends: tensors created by Keras ops are placed on gpu:0 with zea.device("gpu:0"): output = pipeline(data=data) # Per-call device with automatic input-tensor movement (all backends) output = pipeline(device="gpu:0", data=data)
- zea.backend.func_on_device(func, device, *args, **kwargs)[source]¶
Run
funcwith all tensor arguments placed ondevice.For the
torchbackend, every tensor argument is explicitly moved with.to(device)before the call. For JAX and TensorFlow the function is executed inside anzea.backend.devicecontext, which routes newly created tensors to the requested device.- Parameters:
func (callable) – Function to call.
device (str) – Target device, e.g.
'cpu','gpu:0','cuda:1'.*args – Positional arguments forwarded to
func.**kwargs – Keyword arguments forwarded to
func.
- Returns:
Output of
func(*args, **kwargs).
- zea.backend.jit(func=None, jax=True, tensorflow=True, **kwargs)[source]¶
Applies JIT compilation to the given function based on the current Keras backend. Can be used as a decorator or as a function.
- Parameters:
func (callable) – The function to be JIT compiled.
jax (bool) – Whether to enable JIT compilation in the JAX backend.
tensorflow (bool) – Whether to enable JIT compilation in the TensorFlow backend.
**kwargs – Keyword arguments to be passed to the JIT compiler.
- Returns:
The JIT-compiled function.
- Return type:
callable
- zea.backend.tf_function(func=None, jit_compile=False, **kwargs)[source]¶
Applies default tf.function to the given function. Only in TensorFlow backend.
Modules
Autograd wrapper for different backends. |
|
JAX utilities for zea. |
|
Simple implementation of optimizers that support multi-backend autodiff. |
|
TensorFlow utilities for zea. |
|
PyTorch utilities for zea. |