Source code for zea.data.convert.echonetlvh

"""
Script to convert the EchoNet-LVH database to zea format.

Each video is cropped so that the scan cone is centered
without padding, such that it can be converted to polar domain.

.. note::
    This cropping requires first computing scan cone parameters
    using :mod:`zea.data.convert.echonetlvh.precompute_crop`, which
    are then passed to this script.

For more information about the dataset, resort to the following links:

- The original dataset can be found at `this link <https://stanfordaimi.azurewebsites.net/datasets/5b7fcc28-579c-4285-8b72-e4238eac7bd1>`_.
"""

import csv
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import jax.numpy as jnp
import numpy as np
from jax import jit, vmap
from tqdm import tqdm

from zea import File, log
from zea.data.convert.echonet import H5Processor
from zea.data.convert.echonetlvh.precompute_crop import precompute_cone_parameters
from zea.data.convert.utils import load_avi, unzip
from zea.display import cartesian_to_polar_matrix
from zea.func.tensor import translate


[docs] def overwrite_splits(source_dir, rejection_path=None): """ Overwrite MeasurementsList.csv splits based on manual_rejections.txt or another txt file specifying which hashes to reject. Args: source_dir: Source directory containing MeasurementsList.csv and manual_rejections.txt rejection_path: Path to the rejection txt file. If None, defaults to ./manual_rejections.txt Returns: None """ current_dir = os.path.dirname(os.path.abspath(__file__)) if rejection_path is None: rejection_path = os.path.join(current_dir, "manual_rejections.txt") expected_num_rejections = 278 else: # unknown number of rejections for custom rejection file. # NOTE: this is used for testing, where we want to use a dummy rejections file expected_num_rejections = -1 try: with open(rejection_path) as f: rejected_hashes = [line.strip() for line in f] except FileNotFoundError: log.warning(f"{rejection_path} not found, skipping rejections.") return csv_path = Path(source_dir) / "MeasurementsList.csv" temp_path = Path(source_dir) / "MeasurementsList_temp.csv" try: rejection_counter = 0 with ( csv_path.open("r", newline="", encoding="utf-8") as infile, temp_path.open("w", encoding="utf-8", newline="") as outfile, ): reader = csv.DictReader(infile) writer = csv.DictWriter(outfile, fieldnames=reader.fieldnames) writer.writeheader() for row in reader: if row["HashedFileName"] in rejected_hashes: row["split"] = "rejected" rejection_counter += 1 writer.writerow(row) if expected_num_rejections != -1: assert rejection_counter == expected_num_rejections, ( f"Expected {expected_num_rejections} rejections, but applied only {rejection_counter}." ) except FileNotFoundError: log.warning(f"{csv_path} not found, skipping rejections.") return temp_path.replace(csv_path) log.info(f"Overwritten {rejection_counter}/278 rejections to {csv_path}") return
[docs] def load_splits(source_dir): """ Load splits from MeasurementsList.csv and return avi filenames Args: source_dir: Source directory containing MeasurementsList.csv Returns: Dictionary with keys 'train', 'val', 'test', 'rejected' and values as lists of avi filenames """ csv_path = Path(source_dir) / "MeasurementsList.csv" splits = {"train": [], "val": [], "test": [], "rejected": []} with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) file_split_map = {} for row in reader: filename = row["HashedFileName"] split = row["split"] file_split_map.setdefault(filename, split) for filename, split in file_split_map.items(): splits[split].append(filename + ".avi") return splits
[docs] def find_avi_file(source_dir, hashed_filename, batch=None): """ Find AVI file in the specified batch directory or any batch if not specified. Args: source_dir: Source directory containing BatchX subdirectories hashed_filename: Hashed filename (with or without .avi extension) batch: Specific batch directory to search in (e.g., "Batch2"), or None to search all batches Returns: Path to the AVI file if found, else None """ # If filename already has .avi extension, strip it if hashed_filename.endswith(".avi"): hashed_filename = hashed_filename[:-4] if batch: batch_dir = Path(source_dir) / batch avi_path = batch_dir / f"{hashed_filename}.avi" if avi_path.exists(): return avi_path return None else: for batch_dir in Path(source_dir).glob("Batch*"): avi_path = batch_dir / f"{hashed_filename}.avi" if avi_path.exists(): return avi_path return None
[docs] def load_cone_parameters(csv_path): """ Load cone parameters from CSV file into a dictionary. Args: csv_path: Path to the CSV file containing cone parameters Returns: Dictionary mapping avi_filename to cone parameters """ cone_params = {} with open(csv_path, "r", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for row in reader: if row["status"] == "success": # Convert string values to appropriate types params = {} for key, value in row.items(): if key in ("avi_filename", "status"): params[key] = value elif key == "apex_above_image": params[key] = value.lower() == "true" elif value is not None and value != "": params[key] = float(value) else: params[key] = None cone_params[row["avi_filename"]] = params return cone_params
[docs] def crop_frame_with_params(frame, cone_params): """ Crop a single frame using predetermined cone parameters. Args: frame: Input frame as numpy array cone_params: Dictionary containing cropping parameters Returns: Cropped and padded frame """ crop_left = int(cone_params["crop_left"]) crop_right = int(cone_params["crop_right"]) crop_top = int(cone_params["crop_top"]) crop_bottom = int(cone_params["crop_bottom"]) # Handle negative crop_top if crop_top < 0: cropped = frame[0:crop_bottom, crop_left:crop_right] # Add top padding top_padding = -crop_top top_pad = jnp.zeros((top_padding, cropped.shape[1]), dtype=cropped.dtype) cropped = jnp.concatenate([top_pad, cropped], axis=0) else: cropped = frame[crop_top:crop_bottom, crop_left:crop_right] # Apply horizontal centering apex_x_in_crop = cone_params["apex_x"] - crop_left cropped_height, cropped_width = cropped.shape target_center_x = cropped_width / 2 left_padding_needed = target_center_x - apex_x_in_crop left_padding = max(0, int(left_padding_needed)) right_padding = max(0, int(-left_padding_needed)) if left_padding > 0 or right_padding > 0: if left_padding > 0: left_pad = jnp.zeros((cropped_height, left_padding), dtype=cropped.dtype) cropped = jnp.concatenate([left_pad, cropped], axis=1) if right_padding > 0: right_pad = jnp.zeros((cropped_height, right_padding), dtype=cropped.dtype) cropped = jnp.concatenate([cropped, right_pad], axis=1) return cropped
[docs] def crop_sequence_with_params(sequence, cone_params): """ Apply cropping to a sequence of frames using predetermined parameters. Args: sequence: Input sequence as numpy array of shape (frames, height, width) cone_params: Dictionary containing cropping parameters Returns: Cropped and padded sequence """ crop_sequence = vmap(lambda frame: crop_frame_with_params(frame, cone_params)) return crop_sequence(sequence)
[docs] class LVHProcessor(H5Processor): """Modified H5Processor for EchoNet-LVH dataset.""" def __init__(self, *args, cone_params=None, **kwargs): super().__init__(*args, **kwargs) # Store the pre-computed cone parameters self.cart2pol_jit = jit(cartesian_to_polar_matrix) self.cart2pol_batched = vmap( (lambda matrix, angle: self.cart2pol_jit(matrix, angle=angle)), in_axes=(0, None) ) # map over sequence of images, keep the angle fixed since it's constant across a sequence self.cone_parameters = cone_params or {} self.range_to = (0, 255) # overwrite range_to to use uint8 range to save memory.
[docs] def get_split(self, avi_file: str, sequence): """ Get the split (train/val/test) for a given AVI file. Args: avi_file: Path to the AVI file sequence: Video sequence (unused) Returns: String indicating the split ('train', 'val', or 'test') """ # Extract base filename without extension filename = Path(avi_file).stem + ".avi" for split, files in self.splits.items(): if filename in files: return split raise UserWarning("Unknown split for file: " + filename)
[docs] def __call__(self, avi_file): """Takes a single avi_file and generates a zea dataset Args: avi_file: String or path to avi_file to be processed Returns: zea dataset """ avi_filename = Path(avi_file).stem + ".avi" sequence_np = load_avi(avi_file) sequence_processed = jnp.array(sequence_np) sequence_processed = translate(sequence_processed, self.range_from, self._process_range) # Get pre-computed cone parameters for this file cone_params = self.cone_parameters.get(avi_filename) if cone_params is not None: # Apply pre-computed cropping parameters sequence_processed = crop_sequence_with_params(sequence_processed, cone_params) else: raise UserWarning(f"No cone parameters for {avi_filename}") split = self.get_split(avi_file, sequence_processed) out_h5 = self.path_out_h5 / split / (Path(avi_file).stem + ".hdf5") angle = cone_params["opening_angle"] / 2 # angular field spans (-angle, +angle) polar_im_set = self.cart2pol_batched(sequence_processed, angle) sequence_processed = translate(sequence_processed, self._process_range, self.range_to) assert self.range_to == (0, 255), "Expected range_to to be (0, 255) for uint8 conversion" sequence_processed_uint8 = jnp.asarray(jnp.floor(sequence_processed + 0.5), dtype=jnp.uint8) del sequence_processed polar_im_set = translate(polar_im_set, self._process_range, self.range_to) polar_im_set_uint8 = jnp.asarray(jnp.floor(polar_im_set + 0.5), dtype=jnp.uint8) del polar_im_set if jnp.all(sequence_processed_uint8 == 0): raise ValueError(f"Processed sequence is all zeros for file {avi_file}") if jnp.all(polar_im_set_uint8 == 0): raise ValueError(f"Polar sequence is all zeros for file {avi_file}") # Convert JAX arrays to numpy for File.create / spec validation image_sc_np = np.asarray(sequence_processed_uint8) polar_np = np.asarray(polar_im_set_uint8) # Image spec requires (n_frames, x, z, y) — add y=1 dimension polar_4d = polar_np[:, :, :, np.newaxis] return File.create( out_h5, data={ "image_sc": {"values": image_sc_np}, "image": {"values": polar_4d}, }, scan={}, probe={"name": "generic"}, description="EchoNet-LVH dataset converted to zea format", )
[docs] def transform_measurement_coordinates_with_cone_params(row, cone_params): """Transform measurement coordinates using cone parameters from fit_scan_cone. Args: row: A dict containing measurement data with X1,X2,Y1,Y2 coordinates cone_params: Dictionary containing cone parameters from fit_scan_cone Returns: A new row with transformed coordinates, or None if cone_params is None """ if cone_params is None: log.warning(f"No cone parameters for file {row['HashedFileName']}") return None new_row = dict(row) # Apply cropping offset crop_left = cone_params["crop_left"] crop_top = cone_params["crop_top"] # Transform coordinates for k in ["X1", "X2", "Y1", "Y2"]: # Convert to float if not already new_row[k] = float(row[k]) - (crop_left if k.startswith("X") else crop_top) # Apply horizontal centering offset apex_x_in_crop = cone_params["apex_x"] - crop_left original_width = cone_params["crop_right"] - cone_params["crop_left"] target_center_x = original_width / 2 left_padding_needed = target_center_x - apex_x_in_crop left_padding = max(0, int(left_padding_needed)) # Adjust x coordinates for horizontal padding new_row["X1"] = new_row["X1"] + left_padding new_row["X2"] = new_row["X2"] + left_padding # Check if coordinates are within the final image bounds final_width = cone_params["new_width"] final_height = cone_params["new_height"] # Check if coordinates are out of bounds is_out_of_bounds = ( new_row["X1"] < 0 or new_row["X2"] < 0 or new_row["Y1"] < 0 or new_row["Y2"] < 0 or new_row["X1"] >= final_width or new_row["X2"] >= final_width or new_row["Y1"] >= final_height or new_row["Y2"] >= final_height ) if is_out_of_bounds: log.warning(f"Transformed coordinates out of bounds for file {row['HashedFileName']}") # Convert back to string if original was string for k in ["X1", "X2", "Y1", "Y2"]: new_row[k] = str(new_row[k]) return new_row
[docs] def convert_measurements_csv(source_csv, output_csv, cone_params_csv=None): """Convert measurements CSV file with updated coordinates using cone parameters. Args: source_csv: Path to source CSV file output_csv: Path to output CSV file cone_params_csv: Path to CSV file with cone parameters """ try: # Read the CSV file with open(source_csv, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) rows = list(reader) fieldnames = reader.fieldnames # Load cone parameters if available cone_parameters = {} if cone_params_csv and Path(cone_params_csv).exists(): cone_parameters = load_cone_parameters(cone_params_csv) else: log.warning("No cone parameters file found. Measurements will not be transformed.") # Apply coordinate transformation and track skipped rows transformed_rows = [] skipped_files = set() for row in rows: try: avi_filename = row["HashedFileName"] + ".avi" cone_params = cone_parameters.get(avi_filename, None) transformed_row = transform_measurement_coordinates_with_cone_params( row, cone_params ) if transformed_row is not None: transformed_rows.append(transformed_row) else: skipped_files.add(row["HashedFileName"]) except Exception as e: log.error(f"Error processing row for file {row['HashedFileName']}: {str(e)}") skipped_files.add(row["HashedFileName"]) # Save to new CSV file if transformed_rows: # Use keys from first row as fieldnames out_fieldnames = list(transformed_rows[0].keys()) with open(output_csv, "w", newline="", encoding="utf-8") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=out_fieldnames) writer.writeheader() writer.writerows(transformed_rows) else: # Write header only if no rows with open(output_csv, "w", newline="", encoding="utf-8") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() # Print summary log.info("Conversion Summary:") log.info(f"Total rows processed: {len(rows)}") log.info(f"Rows successfully converted: {len(transformed_rows)}") log.info(f"Rows skipped: {len(rows) - len(transformed_rows)}") if skipped_files: log.info("Skipped files:") for filename in sorted(skipped_files): log.info(f" - {filename}") log.info(f"Converted measurements saved to {output_csv}") except Exception as e: log.error(f"Error processing CSV file: {str(e)}") raise
def _process_file_worker(avi_file, dst, splits, cone_parameters, range_from, process_range): """ Function for a hyperthreading worker to process a single file. Args: avi_file: Path to the AVI file to process dst: Destination directory for output splits: Dictionary of splits cone_parameters: Dictionary of cone parameters range_from: Range from value for processing process_range: Process range value for processing Returns: Result of processing the file """ # create a fresh processor inside the worker process proc = LVHProcessor(path_out_h5=dst, splits=splits, cone_params=cone_parameters) # if LVHProcessor needs range_from/_process_range set, set them here proc.range_from = range_from proc._process_range = process_range return proc(avi_file)
[docs] def convert_echonetlvh(args): """ Conversion script for the EchoNet-LVH dataset. Unzips, overwrites splits if needed, precomputes cone parameters, and converts images and/or measurements to zea format and saves dataset. Is called with argparse arguments through zea/zea/data/convert/__main__.py Args: args (argparse.Namespace): Command-line arguments """ # Check if unzip is needed src = unzip(args.src, "echonetlvh") # Overwrite the splits if manual rejections are provided if not args.no_rejection: overwrite_splits(args.src, getattr(args, "rejection_path", None)) # Check that cone parameters exist cone_params_csv = Path(args.dst) / "cone_parameters.csv" if not cone_params_csv.exists(): precompute_cone_parameters(args) # If no specific conversion is requested, convert both if not (args.convert_measurements or args.convert_images): args.convert_measurements = True args.convert_images = True # Convert images if requested if args.convert_images: source_path = Path(src) splits = load_splits(source_path) # Load precomputed cone parameters cone_parameters = load_cone_parameters(cone_params_csv) log.info(f"Loaded cone parameters for {len(cone_parameters)} files") files_to_process = [] for split_files in splits.values(): for avi_filename in split_files: # Strip .avi if present base_filename = avi_filename[:-4] if avi_filename.endswith(".avi") else avi_filename avi_file = find_avi_file(src, base_filename, batch=args.batch) if avi_file: files_to_process.append(avi_file) else: log.warning( f"Warning: Could not find AVI file for {base_filename} in batch " f"{args.batch if args.batch else 'any'}" ) # List files that have already been processed files_done = [] for _, _, filenames in os.walk(args.dst): for filename in filenames: if filename.endswith(".hdf5"): files_done.append(filename.replace(".hdf5", "")) # Filter out already processed files files_to_process = [f for f in files_to_process if f.stem not in files_done] # Limit files if max_files is specified if args.max_files is not None: files_to_process = files_to_process[: args.max_files] log.info(f"Limited to processing {args.max_files} files due to max_files parameter") log.info(f"Files left to process: {len(files_to_process)}") # Initialize processor with splits and cone parameters processor = LVHProcessor(path_out_h5=args.dst, splits=splits, cone_params=cone_parameters) log.info("Starting the conversion process.") for file in tqdm(files_to_process): try: processor(file) except Exception as e: log.error(f"Error processing {file}: {str(e)}") log.info("All image conversion tasks are completed.") # Convert measurements if requested if args.convert_measurements: source_path = Path(src) measurements_csv = source_path / "MeasurementsList.csv" if measurements_csv.exists(): output_csv = Path(args.dst) / "MeasurementsList.csv" convert_measurements_csv(measurements_csv, output_csv, cone_params_csv) else: log.warning("MeasurementsList.csv not found in source directory") log.info("All tasks are completed.")