Source code for myoverse.datasets.utils.formatter

"""Rich console formatting utilities for dataset creation."""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
import zarr
from rich import box
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.tree import Tree


@dataclass
class DatasetConfig:
    """Configuration for dataset creation display."""

    emg_data_path: Path
    ground_truth_data_path: Path
    ground_truth_data_type: str
    sampling_frequency: float
    save_path: Path
    chunk_size: int
    chunk_shift: int
    test_ratio: float
    val_ratio: float
    augmentation_batch_size: int
    debug_level: int
    silence_warnings: bool


[docs] class DatasetFormatter: """Handles Rich console output for dataset creation. Extracts all formatting and display logic from EMGDataset to provide a clean separation of concerns. Parameters ---------- console : Console | None Rich console instance. If None, creates a new one. debug_level : int Debug level (0=none, 1=text, 2=text+graphs). Examples -------- >>> formatter = DatasetFormatter(debug_level=1) >>> formatter.print_header() >>> formatter.print_config(config) >>> formatter.print_summary(dataset) """
[docs] def __init__(self, console: Console | None = None, debug_level: int = 0): self.console = console or Console(color_system=None, highlight=False) self.debug_level = debug_level
[docs] def should_print(self, level: int = 1) -> bool: """Check if output should be printed at the given level.""" return self.debug_level >= level
[docs] def print_header(self, title: str = "STARTING DATASET CREATION") -> None: """Print a section header.""" if not self.should_print(): return self.console.rule(title) self.console.print()
[docs] def print_config(self, config: DatasetConfig) -> None: """Print dataset configuration table.""" if not self.should_print(): return table = Table( title="Dataset Configuration", show_header=True, box=box.ROUNDED, padding=(0, 2), ) table.add_column("Parameter", width=30) table.add_column("Value") table.add_row("EMG data path", str(config.emg_data_path)) table.add_row("Ground truth data path", str(config.ground_truth_data_path)) table.add_row("Ground truth data type", config.ground_truth_data_type) table.add_row("Sampling frequency (Hz)", str(config.sampling_frequency)) table.add_row("Save path", str(config.save_path)) table.add_row("Chunk size", str(config.chunk_size)) table.add_row("Chunk shift", str(config.chunk_shift)) table.add_row("Testing split ratio", str(config.test_ratio)) table.add_row("Validation split ratio", str(config.val_ratio)) table.add_row("Augmentation batch size", str(config.augmentation_batch_size)) table.add_row("Debug level", str(config.debug_level)) table.add_row("Silence Zarr warnings", str(config.silence_warnings)) self.console.print(table) self.console.print()
[docs] def print_tasks_info(self, tasks: list[str]) -> None: """Print information about tasks to process.""" if not self.should_print(): return self.console.print(f"Processing {len(tasks)} tasks: {', '.join(tasks)}") self.console.print()
[docs] def print_data_structure( self, emg_data: dict[str, np.ndarray], ground_truth_data: dict[str, np.ndarray], ) -> None: """Print data structure tree.""" if not self.should_print(): return tree = Tree("Dataset Structure") emg_branch = tree.add("EMG Data") for i, (k, v) in enumerate(list(emg_data.items())[:5]): emg_branch.add(f"Task {k}: Shape {v.shape}") if len(emg_data) > 5: emg_branch.add(f"... {len(emg_data) - 5} more tasks") gt_branch = tree.add("Ground Truth Data") for i, (k, v) in enumerate(list(ground_truth_data.items())[:5]): gt_branch.add(f"Task {k}: Shape {v.shape}") if len(ground_truth_data) > 5: gt_branch.add(f"... {len(ground_truth_data) - 5} more tasks") self.console.print(tree) self.console.print()
[docs] def print_data_panel(self, data: Any, title: str) -> None: """Print a data object in a styled panel.""" if not self.should_print(): return panel = Panel.fit( str(data), title=title, box=box.ROUNDED, padding=(0, 2), ) self.console.print(panel)
[docs] def print_section(self, title: str) -> None: """Print a section label.""" if not self.should_print(): return self.console.rule(title) self.console.print()
[docs] def print_action(self, action: str) -> None: """Print an action being performed.""" if not self.should_print(): return self.console.print(action)
[docs] def print_split_sizes( self, training_sizes: list[int], testing_sizes: list[int], validation_sizes: list[int], ) -> None: """Print dataset split sizes table.""" if not self.should_print(): return table = Table( title="Dataset Split Sizes", show_header=True, box=box.ROUNDED, padding=(0, 2), width=40, ) table.add_column("Split") table.add_column("Sizes") table.add_row("Training", str(training_sizes)) table.add_row("Testing", str(testing_sizes)) table.add_row("Validation", str(validation_sizes)) self.console.print(table) self.console.print()
[docs] def print_augmentation_config( self, num_pipelines: int, pipeline_names: list[str], batch_size: int, training_size: int, ) -> None: """Print augmentation configuration.""" if not self.should_print(): return table = Table( title="Augmentation Configuration", show_header=True, box=box.ROUNDED, padding=(0, 2), ) table.add_column("Parameter", width=30) table.add_column("Value") table.add_row("Total augmentation pipelines", str(num_pipelines)) table.add_row("Pipelines", "\n".join(pipeline_names)) table.add_row("Chunks to augment at once", str(batch_size)) table.add_row("Total training samples", str(training_size)) self.console.print(table) self.console.print()
[docs] def print_summary(self, dataset: zarr.Group) -> None: """Print final dataset summary.""" if not self.should_print(): return # Calculate sizes sizes = self._calculate_sizes(dataset) self.console.rule("DATASET CREATION COMPLETED") self.console.print() # Summary table table = Table( title="Dataset Summary", show_header=True, box=box.ROUNDED, padding=(0, 2), width=60, ) table.add_column("Metric", width=30) table.add_column("Value") table.add_row( "Training samples", str( dataset["training/label"].shape[0] if "label" in dataset["training"] else 0 ), ) table.add_row( "Testing samples", str( dataset["testing/label"].shape[0] if "label" in dataset["testing"] else 0 ), ) table.add_row( "Validation samples", str( dataset["validation/label"].shape[0] if "label" in dataset["validation"] else 0 ), ) table.add_row("Total dataset size", f"{sizes['total']:.2f} MB") for split, size_mb in sizes["splits"].items(): table.add_row(f"{split.capitalize()} split size", f"{size_mb:.2f} MB") self.console.print(table) self.console.print() # Structure tree self._print_structure_tree(dataset) self.console.rule("Dataset Creation Successfully Completed!")
[docs] def _calculate_sizes(self, dataset: zarr.Group) -> dict: """Calculate dataset sizes in MB.""" total_bytes = 0 split_sizes = {} for split in ["training", "testing", "validation"]: split_bytes = 0 for group in ["emg", "ground_truth"]: if group in dataset[split]: for k in dataset[split][group]: arr = dataset[split][group][k] item_size = np.dtype(arr.dtype).itemsize arr_size = np.prod(arr.shape) * item_size split_bytes += arr_size total_bytes += arr_size split_sizes[split] = split_bytes / (1024 * 1024) return { "total": total_bytes / (1024 * 1024), "splits": split_sizes, }
[docs] def _print_structure_tree(self, dataset: zarr.Group) -> None: """Print dataset structure as a tree.""" tree = Tree("Dataset Structure") for split in ["training", "testing", "validation"]: if "emg" in dataset[split]: emg_sizes = { k: dataset[f"{split}/emg"][k].shape for k in dataset[f"{split}/emg"] } if emg_sizes: split_branch = tree.add(split.capitalize()) emg_branch = split_branch.add("EMG Representations") for k, shape in emg_sizes.items(): emg_branch.add(f"{k}: {shape}") self.console.print(tree) self.console.print()