zea.device¶

class zea.device(device)[source]¶

Bases: object

Context manager to run operations on a specific device, regardless of backend.

Normalises device strings across JAX, TensorFlow, and PyTorch so that 'gpu:0', 'cuda:0' and 'cpu' all work with every backend, then delegates to keras.device() which handles the per-backend dispatch.

For the torch backend, keras.device() sets Keras’s internal device-tracking state so that tensors created by Keras ops land on the correct device. Existing input tensors are not moved automatically — use pipeline(device=..., **inputs) or zea.backend.func_on_device() when you also need to relocate pre-existing tensors.

Parameters:

device (str) – Device string, e.g. 'cuda:0', 'gpu:0', or 'cpu'.

Example

# All backends: tensors created by Keras ops are placed on gpu:0
with zea.device("gpu:0"):
    output = pipeline(data=data)

# Per-call device with automatic input-tensor movement (all backends)
output = pipeline(device="gpu:0", data=data)
__init__(device)[source]¶

Methods

__init__(device)