zea.data.dataloaderΒΆ
H5 dataloader for loading images from zea datasets.
Example
import zea
loader = zea.Dataloader(
file_paths="/path/to/dataset",
key="data/image/values",
batch_size=16,
image_range=(-60, 0),
normalization_range=(0, 1),
image_size=(256, 256),
num_threads=16,
)
for batch in loader:
# batch is a numpy array of shape (batch_size, 256, 256, 1)
...
Functions
|
Generate indices for h5 files. |
Classes
|
High-performance HDF5 dataloader built on Grain. |
|
Thread-safe random-access data source for HDF5 files. |
- class zea.data.dataloader.H5DataSource(file_paths, key='data/image', n_frames=1, frame_index_stride=1, frame_axis=-1, insert_frame_axis=True, initial_frame_axis=0, additional_axes_iter=None, sort_files=True, overlapping_blocks=False, limit_n_samples=None, limit_n_frames=None, return_filename=False, cache=False, validate=True, revision=None, **kwargs)[source]ΒΆ
Bases:
objectThread-safe random-access data source for HDF5 files.
Implements
grain.RandomAccessDataSourceprotocol (__getitem__and__len__) so it can be plugged directly into agrain.MapDatasetpipeline.Each worker thread gets its own
H5FileHandleCacheviathreading.local()soh5pyfile handles are never shared across threads.- Parameters:
file_paths (
Union[List[str],str]) β Path(s) to HDF5 directory(ies) or file(s).key (
str) β HDF5 dataset key, e.g."data/image".n_frames (
int) β Number of consecutive frames per sample.frame_index_stride (
int) β Stride between frames.frame_axis (
int) β Axis along which frames are stacked in the output.insert_frame_axis (
bool) β Whether to insert a new axis for frames.initial_frame_axis (
int) β Source axis that stores frames in the file.additional_axes_iter (
tuple|None) β Extra axes to iterate over.sort_files (
bool) β Sort files numerically.overlapping_blocks (
bool) β Allow overlapping frame blocks.limit_n_samples (
int|None) β Cap the number of samples.limit_n_frames (
int|None) β Cap frames loaded per file.return_filename (
bool) β Return filename metadata with each sample.cache (
bool) β Cache loaded samples to RAM.validate (
bool) β Validate dataset against the zea format.revision (
str|None) β HuggingFace revision (branch, tag, or commit hash) forhf://paths.
- zea.data.dataloader.generate_h5_indices(file_paths, file_shapes, n_frames, frame_index_stride, key='data/image', initial_frame_axis=0, additional_axes_iter=None, sort_files=True, overlapping_blocks=False, limit_n_frames=None)[source]ΒΆ
Generate indices for h5 files.
Generates a list of indices to extract images from hdf5 files. Length of this list is the length of the extracted dataset.
- Parameters:
file_paths (
List[str]) β List of file paths.file_shapes (
list) β List of file shapes.n_frames (
int) β Number of frames to load from each hdf5 file.frame_index_stride (
int) β Interval between frames to load.key (
str) β Key of hdf5 dataset to grab data from. Defaults to βdata/imageβ.initial_frame_axis (
int) β Axis to iterate over. Defaults to 0.additional_axes_iter (
Optional[List[int]]) β Additional axes to iterate over in the dataset. Defaults to None.sort_files (
bool) β Sort files by number. Defaults to True.overlapping_blocks (
bool) β Will take n_frames from sequence, then move by 1. Defaults to False.limit_n_frames (
int|None) β Limit the number of frames to load from each file. This means n_frames per data file will be used. These will be the first frames in the file. Defaults to None.
- Returns:
- List of tuples with indices to extract images from hdf5 files.
(file_name, key, indices) with indices being a tuple of slices.
- Return type:
list
Example
[ ( "/folder/path_to_file.hdf5", "data/image", (slice(0, 1, 1), slice(None, 256, None), slice(None, 256, None)), ), ( "/folder/path_to_file.hdf5", "data/image", (slice(1, 2, 1), slice(None, 256, None), slice(None, 256, None)), ), ..., ]