zea.backend¶

Backend utilities for zea.

Note

Most tensor operations are handled by Keras 3. This module only wraps the features that Keras does not expose directly: JIT compilation, automatic differentiation, and device placement.

Public API¶

jit()

Unified JIT compilation for JAX (jax.jit) and TensorFlow (tf.function). A no-op for the torch backend.

device

Context manager that pins all Keras ops to a specific device. Re-exported as zea.device().

func_on_device()

Run a callable with its tensor arguments moved to a target device. For torch this also calls .to(device) on every input tensor.

AutoGrad

Backend-agnostic automatic differentiation wrapper.

Functions

func_on_device(func, device, *args, **kwargs)

Run func with all tensor arguments placed on device.

jit([func, jax, tensorflow])

Applies JIT compilation to the given function based on the current Keras backend.

tf_function([func, jit_compile])

Applies default tf.function to the given function.

Classes

device(device)

Context manager to run operations on a specific device, regardless of backend.

class zea.backend.device(device)[source]¶

Bases: object

Context manager to run operations on a specific device, regardless of backend.

Normalises device strings across JAX, TensorFlow, and PyTorch so that 'gpu:0', 'cuda:0' and 'cpu' all work with every backend, then delegates to keras.device() which handles the per-backend dispatch.

For the torch backend, keras.device() sets Keras’s internal device-tracking state so that tensors created by Keras ops land on the correct device. Existing input tensors are not moved automatically — use pipeline(device=..., **inputs) or zea.backend.func_on_device() when you also need to relocate pre-existing tensors.

Parameters:

device (str) – Device string, e.g. 'cuda:0', 'gpu:0', or 'cpu'.

Example

# All backends: tensors created by Keras ops are placed on gpu:0
with zea.device("gpu:0"):
    output = pipeline(data=data)

# Per-call device with automatic input-tensor movement (all backends)
output = pipeline(device="gpu:0", data=data)
zea.backend.func_on_device(func, device, *args, **kwargs)[source]¶

Run func with all tensor arguments placed on device.

For the torch backend, every tensor argument is explicitly moved with .to(device) before the call. For JAX and TensorFlow the function is executed inside an zea.backend.device context, which routes newly created tensors to the requested device.

Parameters:
  • func (callable) – Function to call.

  • device (str) – Target device, e.g. 'cpu', 'gpu:0', 'cuda:1'.

  • *args – Positional arguments forwarded to func.

  • **kwargs – Keyword arguments forwarded to func.

Returns:

Output of func(*args, **kwargs).

zea.backend.jit(func=None, jax=True, tensorflow=True, **kwargs)[source]¶

Applies JIT compilation to the given function based on the current Keras backend. Can be used as a decorator or as a function.

Parameters:
  • func (callable) – The function to be JIT compiled.

  • jax (bool) – Whether to enable JIT compilation in the JAX backend.

  • tensorflow (bool) – Whether to enable JIT compilation in the TensorFlow backend.

  • **kwargs – Keyword arguments to be passed to the JIT compiler.

Returns:

The JIT-compiled function.

Return type:

callable

zea.backend.tf_function(func=None, jit_compile=False, **kwargs)[source]¶

Applies default tf.function to the given function. Only in TensorFlow backend.

Modules

autograd

Autograd wrapper for different backends.

jax

JAX utilities for zea.

optimizer

Simple implementation of optimizers that support multi-backend autodiff.

tensorflow

TensorFlow utilities for zea.

torch

PyTorch utilities for zea.