Source code for myoverse.datasets.supervised

import pickle
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
import zarr
from rich import box
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
    Progress,
    TextColumn,
    BarColumn,
    TaskProgressColumn,
    TimeRemainingColumn,
    SpinnerColumn,
)
from rich.table import Table
from rich.tree import Tree

from myoverse.datasets.filters.emg_augmentations import EMGAugmentation
from myoverse.datasets.filters.generic import ChunkizeDataFilter, FilterBaseClass
from myoverse.datatypes import _Data, DATA_TYPES_MAP


def _split_data(data: np.ndarray, split_ratio: float) -> tuple[np.ndarray, np.ndarray]:
    split_amount = int(data.shape[0] * split_ratio / 2)
    middle_index = data.shape[0] // 2

    mask = np.ones(data.shape[0], dtype=bool)
    mask[middle_index - split_amount : middle_index + split_amount] = False

    return data[mask], data[~mask]


def _add_to_dataset(group: zarr.Group, data: Optional[np.ndarray], name: str):
    """Add data to a zarr group, handling compatibility with Zarr 3.

    Parameters
    ----------
    group : zarr.Group
        The zarr group to add data to
    data : Optional[np.ndarray]
        The data to add
    name : str
        The name of the dataset
    """
    if data is None or (isinstance(data, np.ndarray) and data.size == 0):
        return

    # Ensure data is a numpy array
    if not isinstance(data, np.ndarray):
        data = np.array(data)

    try:
        if name in group:
            # Don't append empty data
            if data.size == 0:
                return

            # Zarr 3 doesn't have append but we can use setitem to add data
            current_shape = group[name].shape
            new_shape = list(current_shape)
            new_shape[0] += data.shape[0]

            # Resize the dataset
            group[name].resize(new_shape)

            # Insert the new data
            group[name][current_shape[0] :] = data
        else:
            # Don't create empty datasets
            if data.size == 0:
                return

            # Create new dataset with appropriate chunking
            group.create_dataset(
                name, data=data, shape=data.shape, chunks=(1, *data.shape[1:])
            )
    except Exception as e:
        # Handle differences between Zarr 2 and 3
        if "append" in str(e):
            # This is Zarr 2 behavior
            if data.size > 0:  # Only append if there's data
                group[name].append(data)
        else:
            raise


[docs] class EMGDataset: """ Class for creating a dataset from EMG and ground truth data. Parameters ---------- emg_data_path : pathlib.Path Path to the EMG data file. It should be a pickle file containing a dictionary with the keys being the task number and the values being a numpy array of shape (n_channels, n_samples). emg_data : dict[str, np.ndarray] Optional dictionary containing EMG data if not loading from a file ground_truth_data_path : pathlib.Path Path to the ground truth data file. It should be a pickle file containing a dictionary with the keys being the task number and the values being a numpy array of custom shape (..., n_samples). The custom shape can be anything, but the last dimension should be the same as the EMG data. ground_truth_data : dict[str, np.ndarray] Optional dictionary containing ground truth data if not loading from a file ground_truth_data_type : str Type of ground truth data, e.g. 'kinematics' sampling_frequency : float Sampling frequency of the data in Hz tasks_to_use : Sequence[str] Sequence of strings containing the task numbers to use. If empty, all tasks will be used. save_path : pathlib.Path Path to save the dataset to. It should be a zarr file. emg_filter_pipeline_before_chunking : list[list[FilterBaseClass]] Sequence of filters to apply to the EMG data before chunking. emg_representations_to_filter_before_chunking : list[list[str]] Representations of EMG data to filter before chunking. emg_filter_pipeline_after_chunking : list[list[FilterBaseClass]] Sequence of filters to apply to the EMG data after chunking. emg_representations_to_filter_after_chunking : list[list[str]] Representations of EMG data to filter after chunking. ground_truth_filter_pipeline_before_chunking : list[list[FilterBaseClass]] Sequence of filters to apply to the ground truth data before chunking. ground_truth_representations_to_filter_before_chunking : list[list[str]] Representations of ground truth data to filter before chunking. ground_truth_filter_pipeline_after_chunking : list[list[FilterBaseClass]] Sequence of filters to apply to the ground truth data after chunking. ground_truth_representations_to_filter_after_chunking : list[list[str]] Representations of ground truth data to filter after chunking. chunk_size : int Size of the chunks to create from the data. chunk_shift : int Shift between the chunks. testing_split_ratio : float Ratio of the data to use for testing. The data will be split in the middle. The first half will be used for training and the second half will be used for testing. If 0, no data will be used for testing. validation_split_ratio : float Ratio of the data to use for validation. The data will be split in the middle. The first half will be used for training and the second half will be used for validation. If 0, no data will be used for validation. augmentation_pipelines : list[list[EMGAugmentation]] Sequence of augmentation_pipelines to apply to the training data. amount_of_chunks_to_augment_at_once : int Amount of chunks to augment at once. This is done to speed up the process. debug_level : int Debug level: - 0: No debug output (default) - 1: Full text debugging with Rich (configuration, progress, tables, data details) - 2: Level 1 plus data visualizations (graphs and plots) silence_zarr_warnings : bool Whether to silence all Zarr-related warnings, including those from zarr.codecs and zarr.core modules Methods ------- create_dataset() Creates the dataset. """
[docs] def __init__( self, emg_data_path: Path = Path("REPLACE ME"), emg_data: dict[str, np.ndarray] = {}, ground_truth_data_path: Path = Path("REPLACE ME"), ground_truth_data: dict[str, np.ndarray] = {}, ground_truth_data_type: str = "kinematics", sampling_frequency: float = 0.0, tasks_to_use: Sequence[str] = (), save_path: Path = Path("REPLACE ME"), emg_filter_pipeline_before_chunking: list[list[FilterBaseClass]] = (), emg_representations_to_filter_before_chunking: list[list[str]] = (), emg_filter_pipeline_after_chunking: list[list[FilterBaseClass]] = (), emg_representations_to_filter_after_chunking: list[list[str]] = (), ground_truth_filter_pipeline_before_chunking: list[list[FilterBaseClass]] = (), ground_truth_representations_to_filter_before_chunking: list[list[str]] = (), ground_truth_filter_pipeline_after_chunking: list[list[FilterBaseClass]] = (), ground_truth_representations_to_filter_after_chunking: list[list[str]] = (), chunk_size: int = 192, chunk_shift: int = 64, testing_split_ratio: float = 0.2, validation_split_ratio: float = 0.2, augmentation_pipelines: list[list[EMGAugmentation]] = (), amount_of_chunks_to_augment_at_once: int = 250, debug_level: int = 0, silence_zarr_warnings: bool = True, ): self.emg_data_path = emg_data_path self.emg_data = emg_data self.ground_truth_data_path = ground_truth_data_path self.ground_truth_data = ground_truth_data # check if at least one of the data sources is provided if not self.emg_data and not self.emg_data_path: raise ValueError("At least one of the EMG data sources should be provided.") if not self.ground_truth_data and not self.ground_truth_data_path: raise ValueError( "At least one of the ground truth data sources should be provided." ) self.ground_truth_data_type = ground_truth_data_type self.sampling_frequency = sampling_frequency self.tasks_to_use = tasks_to_use self.save_path = save_path self.emg_filter_pipeline_before_chunking = emg_filter_pipeline_before_chunking self.emg_representations_to_filter_before_chunking = ( emg_representations_to_filter_before_chunking ) self.ground_truth_filter_pipeline_before_chunking = ( ground_truth_filter_pipeline_before_chunking ) self.ground_truth_representations_to_filter_before_chunking = ( ground_truth_representations_to_filter_before_chunking ) self.emg_filter_pipeline_after_chunking = emg_filter_pipeline_after_chunking self.emg_representations_to_filter_after_chunking = ( emg_representations_to_filter_after_chunking ) self.ground_truth_filter_pipeline_after_chunking = ( ground_truth_filter_pipeline_after_chunking ) self.ground_truth_representations_to_filter_after_chunking = ( ground_truth_representations_to_filter_after_chunking ) self.chunk_size = chunk_size self.chunk_shift = chunk_shift self.testing_split_ratio = testing_split_ratio self.validation_split_ratio = validation_split_ratio self.augmentation_pipelines = augmentation_pipelines self.amount_of_chunks_to_augment_at_once = amount_of_chunks_to_augment_at_once self.debug_level = debug_level self.silence_zarr_warnings = silence_zarr_warnings # Initialize Rich console for all debug levels self.console = Console() self._tasks_string_length = 0
def __add_data_to_dataset( self, data: _Data, groups: list[zarr.Group] ) -> Tuple[list[int], list[int], list[int]]: """ Add data to zarr dataset groups. Parameters ---------- data : _Data The data object to add to the dataset groups : list[zarr.Group] List of zarr groups for training, testing, and validation Returns ------- tuple[list[int], list[int], list[int]] Lists of sizes for training, testing, and validation datasets """ training_data_sizes, testing_data_sizes, validation_data_sizes = [], [], [] if self.debug_level >= 1: self.console.print( f"[bold green]Adding data with keys:[/bold green] [cyan]{list(data.output_representations.keys())}[/cyan]" ) self.console.print() # Add empty line for k, v in data.output_representations.items(): validation_data_from_task = None if self.debug_level >= 1: self.console.print( f"[bold]Splitting representation:[/bold] [yellow]{k}[/yellow] [dim]with shape {v.shape}[/dim]" ) if self.testing_split_ratio > 0: training_data_from_task, testing_data_from_task = _split_data( v, self.testing_split_ratio ) if self.debug_level >= 1: self.console.print( f" [green]Training shape:[/green] {training_data_from_task.shape}" ) self.console.print( f" [yellow]Testing shape:[/yellow] {testing_data_from_task.shape}" ) if self.validation_split_ratio > 0: testing_data_from_task, validation_data_from_task = _split_data( testing_data_from_task, self.validation_split_ratio ) if self.debug_level >= 1: self.console.print( f" [yellow]After validation split - Testing shape:[/yellow] {testing_data_from_task.shape}" ) self.console.print( f" [blue]Validation shape:[/blue] {validation_data_from_task.shape}" ) else: training_data_from_task = v testing_data_from_task = None if self.debug_level >= 1: self.console.print( f" [green]No testing split, all data for training:[/green] {training_data_from_task.shape}" ) # Add a space between different splits if self.debug_level >= 1: self.console.print() # Add empty line for g, data_from_task in zip( groups, ( training_data_from_task, testing_data_from_task, validation_data_from_task, ), ): _add_to_dataset(g, data_from_task, k) training_data_sizes.append(training_data_from_task.shape[0]) testing_data_sizes.append( testing_data_from_task.shape[0] if testing_data_from_task is not None else 0 ) validation_data_sizes.append( validation_data_from_task.shape[0] if validation_data_from_task is not None else 0 ) if self.debug_level >= 1: # Create a table for the dataset sizes sizes_table = Table( title="Dataset Split Sizes", show_header=True, header_style="bold magenta", box=box.ROUNDED, padding=(0, 2), width=40, ) sizes_table.add_column("Split", style="cyan") sizes_table.add_column("Sizes", style="green") sizes_table.add_row("Training", str(training_data_sizes)) sizes_table.add_row("Testing", str(testing_data_sizes)) sizes_table.add_row("Validation", str(validation_data_sizes)) self.console.print(sizes_table) self.console.print() # Add empty line return training_data_sizes, testing_data_sizes, validation_data_sizes
[docs] def create_dataset(self): """Create a supervised dataset from EMG and ground truth data.""" # Silence zarr warnings if requested if self.silence_zarr_warnings: import warnings # Silence warnings from zarr.codecs warnings.filterwarnings( "ignore", category=UserWarning, module="zarr.codecs" ) # Silence warnings from zarr core warnings.filterwarnings("ignore", category=UserWarning, module="zarr.core") # Silence any other zarr-related warnings warnings.filterwarnings("ignore", category=UserWarning, module="zarr") # Display configuration when debugging is enabled if self.debug_level > 0: # Create header for dataset creation self.console.rule( "[bold blue]STARTING DATASET CREATION", style="blue double" ) self.console.print() # Add empty line # Create a table for configuration table = Table( title="Dataset Configuration", show_header=True, header_style="bold magenta", box=box.ROUNDED, padding=(0, 2), ) table.add_column("Parameter", style="dim", width=30) table.add_column("Value", style="green") # Add configuration parameters to table table.add_row("EMG data path", str(self.emg_data_path)) table.add_row("Ground truth data path", str(self.ground_truth_data_path)) table.add_row("Ground truth data type", self.ground_truth_data_type) table.add_row("Sampling frequency (Hz)", str(self.sampling_frequency)) table.add_row("Save path", str(self.save_path)) table.add_row("Chunk size", str(self.chunk_size)) table.add_row("Chunk shift", str(self.chunk_shift)) table.add_row("Testing split ratio", str(self.testing_split_ratio)) table.add_row("Validation split ratio", str(self.validation_split_ratio)) table.add_row( "Amount of chunks to augment at once", str(self.amount_of_chunks_to_augment_at_once), ) table.add_row("Debug level", str(self.debug_level)) table.add_row("Silence Zarr warnings", str(self.silence_zarr_warnings)) # Display the table self.console.print(table) self.console.print() # Add empty line # Load data if not provided emg_data = self.emg_data or pickle.load(self.emg_data_path.open("rb")) ground_truth_data = self.ground_truth_data or pickle.load( self.ground_truth_data_path.open("rb") ) # Use tasks_to_use if provided, otherwise use all tasks if not self.tasks_to_use: self.tasks_to_use = list(emg_data.keys()) if self.debug_level > 0: self.console.print( f"[bold cyan]Processing {len(self.tasks_to_use)} tasks:[/bold cyan] {', '.join(self.tasks_to_use)}" ) self.console.print() # Add empty line # Create a tree for data shapes data_tree = Tree("[bold yellow]Dataset Structure") emg_branch = data_tree.add("[bold green]EMG Data") for i, (k, v) in enumerate(list(emg_data.items())[:5]): emg_branch.add(f"Task {k}: Shape {v.shape}") if len(emg_data) > 5: emg_branch.add(f"... {len(emg_data) - 5} more tasks") gt_branch = data_tree.add("[bold green]Ground Truth Data") for i, (k, v) in enumerate(list(ground_truth_data.items())[:5]): gt_branch.add(f"Task {k}: Shape {v.shape}") if len(ground_truth_data) > 5: gt_branch.add(f"... {len(ground_truth_data) - 5} more tasks") self.console.print(data_tree) self.console.print() # Add empty line # Create zarr directory and open store self.save_path.mkdir(parents=True, exist_ok=True) # Open zarr dataset with specified format dataset = zarr.open(str(self.save_path), mode="w", zarr_version=2) # Create groups for training, testing, and validation training_group = dataset.create_group("training") testing_group = dataset.create_group("testing") validation_group = dataset.create_group("validation") # Set the task string length for labels self._tasks_string_length = len(max(self.tasks_to_use, key=len)) # Process each task if self.debug_level > 0: self.console.rule("[bold blue]PROCESSING TASKS", style="blue double") self.console.print() # Add empty line # Create progress bar for task processing regardless of debug level with Progress( SpinnerColumn(), TextColumn("[bold blue]{task.description}"), BarColumn(bar_width=40), TaskProgressColumn(), TextColumn("•"), TimeRemainingColumn(), console=self.console, expand=True, transient=False, ) as progress: task_progress = progress.add_task( f"[bold]Processing [cyan]{len(self.tasks_to_use)}[/cyan] tasks...", total=len(self.tasks_to_use), ) # Process each task with progress tracking for task_idx, task in enumerate(self.tasks_to_use): # Update the progress bar with current task information progress.update( task_progress, description=f"[bold]Processing task [cyan]{task}[/cyan] ([cyan]{task_idx + 1}[/cyan]/[cyan]{len(self.tasks_to_use)}[/cyan])", ) # Process the task self._process_task( task, emg_data, ground_truth_data, training_group, testing_group, validation_group, ) # Advance the progress bar progress.advance(task_progress) # Apply data augmentation if requested self._apply_augmentations(dataset, training_group) # Print dataset summary at the end if self.debug_level > 0: self._print_dataset_summary(dataset)
[docs] def _process_task( self, task: str, emg_data: dict[str, np.ndarray], ground_truth_data: dict[str, np.ndarray], training_group: zarr.Group, testing_group: zarr.Group, validation_group: zarr.Group, ): """Process a single task and add its data to the dataset.""" emg_data_from_task = emg_data[task] ground_truth_data_from_task = ground_truth_data[task] # Trim data to same length if needed min_length = min( emg_data_from_task.shape[-1], ground_truth_data_from_task.shape[-1] ) emg_data_from_task = emg_data_from_task[..., :min_length] ground_truth_data_from_task = ground_truth_data_from_task[..., :min_length] # Create appropriate data objects emg_data_from_task = DATA_TYPES_MAP["emg"]( input_data=emg_data_from_task, sampling_frequency=self.sampling_frequency, ) ground_truth_data_from_task = DATA_TYPES_MAP[self.ground_truth_data_type]( input_data=ground_truth_data_from_task, sampling_frequency=self.sampling_frequency, ) # Verify chunking status is the same if ( emg_data_from_task.is_chunked["Input"] != ground_truth_data_from_task.is_chunked["Input"] ): raise ValueError( f"The EMG and ground truth data should have the same chunking status. " f"EMG data is {'chunked' if emg_data_from_task.is_chunked else 'not chunked'} and " f"ground truth data is {'chunked' if ground_truth_data_from_task.is_chunked else 'not chunked'}." ) # Debug output based on debug level if self.debug_level >= 1: self.console.print( "[bold white on blue] Initial Data [/bold white on blue]", justify="center", ) self.console.print() # Add empty line emg_panel = Panel.fit( str(emg_data_from_task), title=f"[bold green]EMG Data Task {task}[/bold green]", border_style="green", box=box.ROUNDED, padding=(0, 2), ) gt_panel = Panel.fit( str(ground_truth_data_from_task), title=f"[bold blue]Ground Truth Data Task {task}[/bold blue]", border_style="blue", box=box.ROUNDED, padding=(0, 2), ) self.console.print(emg_panel) self.console.print(gt_panel) self.console.print() # Add empty line # Plot graphs only for debug level 2 if self.debug_level >= 2: self.console.print( "[bold yellow]Generating EMG data graph...[/bold yellow]" ) emg_data_from_task.plot_graph( title=f"EMG Data - Task: {task} (Pre-Processing)" ) self.console.print( "[bold yellow]Generating ground truth data graph...[/bold yellow]" ) ground_truth_data_from_task.plot_graph( title=f"Ground Truth Data - Task: {task} (Pre-Processing)" ) self.console.print() # Add empty line # Process unchunked data if not emg_data_from_task.is_chunked["Input"]: if self.debug_level >= 1: self.console.print( "[bold white on magenta] Pre-Chunking Processing [/bold white on magenta]", justify="center", ) self.console.print() # Add empty line # Apply filters before chunking if self.emg_filter_pipeline_before_chunking: if self.debug_level >= 1: self.console.print( "▶ [bold cyan]Applying EMG filters before chunking...[/bold cyan]" ) emg_data_from_task.apply_filter_pipeline( filter_pipeline=self.emg_filter_pipeline_before_chunking, representations_to_filter=self.emg_representations_to_filter_before_chunking, ) if self.ground_truth_filter_pipeline_before_chunking: if self.debug_level >= 1: self.console.print( "▶ [bold cyan]Applying ground truth filters before chunking...[/bold cyan]" ) ground_truth_data_from_task.apply_filter_pipeline( filter_pipeline=self.ground_truth_filter_pipeline_before_chunking, representations_to_filter=self.ground_truth_representations_to_filter_before_chunking, ) if self.debug_level >= 1: self.console.print() # Add empty line self.console.print( "[bold white on green] Chunking Process [/bold white on green]", justify="center", ) self.console.print() # Add empty line # Apply chunking filters if self.debug_level >= 1: self.console.print("▶ [bold cyan]Chunking EMG data...[/bold cyan]") emg_data_from_task.apply_filter( filter=ChunkizeDataFilter( chunk_size=self.chunk_size, chunk_shift=self.chunk_shift, is_output=len(self.emg_filter_pipeline_after_chunking) == 0, name="EMG_Chunkizer", input_is_chunked=False, ), representations_to_filter=["Last"], ) chunked_emg_data_from_task = emg_data_from_task if self.debug_level >= 1: self.console.print( "▶ [bold cyan]Chunking ground truth data...[/bold cyan]" ) ground_truth_data_from_task.apply_filter( filter=ChunkizeDataFilter( chunk_size=self.chunk_size, chunk_shift=self.chunk_shift, is_output=len(self.ground_truth_filter_pipeline_after_chunking) == 0, input_is_chunked=False, ), representations_to_filter=["Last"], ) chunked_ground_truth_data_from_task = ground_truth_data_from_task # Debug output for chunking if self.debug_level >= 1: self.console.print() # Add empty line self.console.rule("[bold green]After Chunking", style="green") self.console.print() # Add empty line emg_panel = Panel.fit( str(chunked_emg_data_from_task), title="[bold green]Chunked EMG Data[/bold green]", border_style="green", box=box.ROUNDED, padding=(0, 2), ) gt_panel = Panel.fit( str(chunked_ground_truth_data_from_task), title="[bold blue]Chunked Ground Truth Data[/bold blue]", border_style="blue", box=box.ROUNDED, padding=(0, 2), ) self.console.print(emg_panel) self.console.print(gt_panel) self.console.print() # Add empty line # Plot graphs only for debug level 2 if self.debug_level >= 2: self.console.print( "[bold yellow]Generating chunked EMG data graph...[/bold yellow]" ) chunked_emg_data_from_task.plot_graph( title=f"Chunked EMG Data - Task: {task}" ) self.console.print( "[bold yellow]Generating chunked ground truth data graph...[/bold yellow]" ) chunked_ground_truth_data_from_task.plot_graph( title=f"Chunked Ground Truth Data - Task: {task}" ) self.console.print() # Add empty line else: # Data is already chunked chunked_emg_data_from_task = emg_data_from_task # Process in batches to avoid memory issues i = 0 temp = [] while ( i + self.amount_of_chunks_to_augment_at_once <= chunked_emg_data_from_task.shape[0] ): temp.append( np.concatenate( chunked_emg_data_from_task[ i : i + self.amount_of_chunks_to_augment_at_once ], axis=-1, ) ) i += self.amount_of_chunks_to_augment_at_once chunked_emg_data_from_task = np.stack(temp, axis=1) chunked_ground_truth_data_from_task = ground_truth_data_from_task # Post-chunking processing section if self.debug_level >= 1: self.console.print( "[bold white on magenta] Post-Chunking Processing [/bold white on magenta]", justify="center", ) self.console.print() # Add empty line # Apply filters after chunking if self.emg_filter_pipeline_after_chunking: if self.debug_level >= 1: self.console.print( "▶ [bold cyan]Applying EMG filters after chunking...[/bold cyan]" ) chunked_emg_data_from_task.apply_filter_pipeline( filter_pipeline=self.emg_filter_pipeline_after_chunking, representations_to_filter=self.emg_representations_to_filter_after_chunking, ) if self.ground_truth_filter_pipeline_after_chunking: if self.debug_level >= 1: self.console.print( "▶ [bold cyan]Applying ground truth filters after chunking...[/bold cyan]" ) chunked_ground_truth_data_from_task.apply_filter_pipeline( filter_pipeline=self.ground_truth_filter_pipeline_after_chunking, representations_to_filter=self.ground_truth_representations_to_filter_after_chunking, ) # Debug output after filtering chunked data if self.debug_level >= 1: self.console.print() # Add empty line self.console.rule("[bold green]After Filtering Chunked Data", style="green") self.console.print() # Add empty line emg_panel = Panel.fit( str(chunked_emg_data_from_task), title="[bold green]Filtered Chunked EMG Data[/bold green]", border_style="green", box=box.ROUNDED, padding=(0, 2), ) gt_panel = Panel.fit( str(chunked_ground_truth_data_from_task), title="[bold blue]Filtered Chunked Ground Truth Data[/bold blue]", border_style="blue", box=box.ROUNDED, padding=(0, 2), ) self.console.print(emg_panel) self.console.print(gt_panel) self.console.print() # Add empty line # Plot graphs only for debug level 2 if self.debug_level >= 2: self.console.print( "[bold yellow]Generating filtered chunked EMG data graph...[/bold yellow]" ) chunked_emg_data_from_task.plot_graph( title=f"Filtered Chunked EMG Data - Task: {task}" ) self.console.print( "[bold yellow]Generating filtered chunked ground truth data graph...[/bold yellow]" ) chunked_ground_truth_data_from_task.plot_graph( title=f"Filtered Chunked Ground Truth Data - Task: {task}" ) self.console.print() # Add empty line # Dataset creation section if self.debug_level >= 1: self.console.print( "[bold white on blue] Dataset Creation [/bold white on blue]", justify="center", ) self.console.print( "▶ [bold cyan]Adding processed data to dataset...[/bold cyan]" ) self.console.print() # Add empty line for group_name, chunked_data_from_task in zip( ["emg", "ground_truth"], [chunked_emg_data_from_task, chunked_ground_truth_data_from_task], ): ( training_sizes, testing_sizes, validation_sizes, ) = self.__add_data_to_dataset( chunked_data_from_task, [ ( g.create_group(group_name) if group_name not in list(g.group_keys()) else g[group_name] ) for g in (training_group, testing_group, validation_group) ], ) # Verify data lengths match data_length = list(chunked_emg_data_from_task.output_representations.values())[ -1 ].shape[0] data_length_ground_truth = list( chunked_ground_truth_data_from_task.output_representations.values() )[-1].shape[0] # Validate sizes assert len(set(training_sizes)) == 1, "The training sizes are not the same." assert len(set(testing_sizes)) == 1, "The testing sizes are not the same." assert len(set(validation_sizes)) == 1, "The validation sizes are not the same." assert data_length == data_length_ground_truth, ( f"The data lengths of the EMG and ground truth data should be the same. " f"For task {task}, the EMG data has length {data_length} and the ground " f"truth data has length {data_length_ground_truth}." ) # Add labels, class indices, and one-hot encodings for g, size in zip( (training_group, testing_group, validation_group), (training_sizes[0], testing_sizes[0], validation_sizes[0]), ): # Use consistent unicode string array approach since conversion # happens in _add_to_dataset as needed label_array = np.array( [task] * size, dtype=f"<U{self._tasks_string_length}" ).reshape(-1, 1) _add_to_dataset( g, label_array, "label", ) _add_to_dataset( g, np.array([self.tasks_to_use.index(task)] * size, dtype=np.int8).reshape( -1, 1 ), "class", ) _add_to_dataset( g, np.repeat( np.array( [ np.eye(len(self.tasks_to_use), dtype=np.int8)[ self.tasks_to_use.index(task) ] ] ), size, axis=0, ), "one_hot_class", )
[docs] def _apply_augmentations(self, dataset: zarr.Group, training_group: zarr.Group): """Apply augmentations to the training data.""" # Start augmentation phase if there are augmentation pipelines if self.augmentation_pipelines and len(self.augmentation_pipelines) > 0: # Get all available samples in training group # Use the first available filter key instead of hardcoding "raw" filter_keys = list(training_group["emg"].array_keys()) if not filter_keys: self.console.print( "[bold red]No EMG filters found in training group![/bold red]" ) return training_size = training_group["emg"][filter_keys[0]].shape[0] if self.debug_level > 0: self.console.rule( "[bold blue]APPLYING AUGMENTATIONS", style="blue double" ) self.console.print() # Add empty line # Display augmentation info augmentation_info = Table( title="Augmentation Configuration", show_header=True, header_style="bold magenta", box=box.ROUNDED, padding=(0, 2), ) augmentation_info.add_column("Parameter", style="dim", width=30) augmentation_info.add_column("Value", style="green") augmentation_info.add_row( "Total augmentation pipelines", str(len(self.augmentation_pipelines)), ) pipeline_names = [] for pipeline in self.augmentation_pipelines: names = [f.name for f in pipeline] pipeline_names.append(" → ".join(names)) augmentation_info.add_row("Pipelines", "\n".join(pipeline_names)) augmentation_info.add_row( "Chunks to augment at once", str(self.amount_of_chunks_to_augment_at_once), ) augmentation_info.add_row("Total training samples", str(training_size)) self.console.print(augmentation_info) self.console.print() # Add empty line # Create progress bar for augmentation regardless of debug level with Progress( SpinnerColumn(), TextColumn("[bold green]{task.description}"), BarColumn(bar_width=40), TaskProgressColumn(), TextColumn("•"), TimeRemainingColumn(), console=self.console, expand=True, transient=False, ) as progress: # Calculate total batches for all augmentations total_batches = int( np.ceil(training_size / self.amount_of_chunks_to_augment_at_once) ) * len(self.augmentation_pipelines) # Create main progress task aug_progress = progress.add_task( f"[bold]Applying [green]{len(self.augmentation_pipelines)}[/green] augmentation pipelines...", total=total_batches, ) # Process each augmentation pipeline for aug_idx, augmentation_pipeline in enumerate( self.augmentation_pipelines ): pipeline_name = " → ".join([f.name for f in augmentation_pipeline]) # Update progress description for this pipeline progress.update( aug_progress, description=f"[bold]Pipeline [green]{aug_idx + 1}/{len(self.augmentation_pipelines)}[/green]: {pipeline_name}", ) # Apply augmentation in batches self._apply_augmentation_pipeline( aug_idx, augmentation_pipeline, dataset, training_group, progress, aug_progress, )
[docs] def _apply_augmentation_pipeline( self, aug_idx: int, augmentation_pipeline: list[EMGAugmentation], dataset: zarr.Group, training_group: zarr.Group, progress: Optional[Progress] = None, progress_task_id: Optional[int] = None, ): """Apply a single augmentation pipeline to training data in batches.""" # Get total samples to process # Use the first available filter key instead of hardcoding "raw" filter_keys = list(training_group["emg"].array_keys()) if not filter_keys: self.console.print( "[bold red]No EMG filters found in training group![/bold red]" ) return training_size = training_group["emg"][filter_keys[0]].shape[0] # Process in batches for start_idx in range( 0, training_size, self.amount_of_chunks_to_augment_at_once ): # Calculate end index for current batch end_idx = min( start_idx + self.amount_of_chunks_to_augment_at_once, training_size ) # Update progress information if progress is not None and progress_task_id is not None: batch_num = start_idx // self.amount_of_chunks_to_augment_at_once + 1 total_batches = int( np.ceil(training_size / self.amount_of_chunks_to_augment_at_once) ) progress.update( progress_task_id, description=f"[bold]Pipeline [green]{aug_idx + 1}/{len(self.augmentation_pipelines)}[/green]: Batch [green]{batch_num}/{total_batches}[/green]", ) # Initialize batch accumulators emg_to_append = {k: [] for k in dataset["training/emg"]} ground_truth_to_append = {k: [] for k in dataset["training/ground_truth"]} label_to_append = [] class_to_append = [] one_hot_class_to_append = [] # Process each item in the batch for i in range(start_idx, end_idx): # Apply augmentation to EMG data for k in dataset["training/emg"]: temp = DATA_TYPES_MAP["emg"]( input_data=dataset["training/emg"][k][i].astype(np.float32), sampling_frequency=self.sampling_frequency, ) temp.apply_filter_pipeline( filter_pipeline=[augmentation_pipeline], representations_to_filter=[["Last"]], ) emg_to_append[k].append(temp["Last"]) # Copy corresponding ground truth data for k in dataset["training/ground_truth"]: ground_truth_to_append[k].append( dataset["training/ground_truth"][k][i] ) # Copy labels and classes label_to_append.append(dataset["training/label"][i]) class_to_append.append(dataset["training/class"][i]) one_hot_class_to_append.append(dataset["training/one_hot_class"][i]) # Append the batch to the training group self._append_augmented_batch( training_group, emg_to_append, ground_truth_to_append, label_to_append, class_to_append, one_hot_class_to_append, ) # Advance progress if tracking if progress is not None and progress_task_id is not None: progress.advance(progress_task_id)
[docs] def _append_augmented_batch( self, training_group: zarr.Group, emg_to_append: Dict[str, List[np.ndarray]], ground_truth_to_append: Dict[str, List[np.ndarray]], label_to_append: List[np.ndarray], class_to_append: List[np.ndarray], one_hot_class_to_append: List[np.ndarray], ): """Append a batch of augmented data to the training group.""" # Debug shapes before appending if self.debug_level >= 2: # Only show shapes in higher debug level # Create a table for shapes shapes_table = Table( title="Augmented Batch Shapes", show_header=True, header_style="bold magenta", box=box.ROUNDED, padding=(0, 2), width=80, title_style="bold magenta", title_justify="center", ) shapes_table.add_column("Data Type", style="cyan", width=50) shapes_table.add_column("Shape", style="green") # Add EMG shapes for k, v in emg_to_append.items(): if v: shapes_table.add_row(f"EMG {k}", str(np.array(v).shape)) break # Add ground truth shapes for k, v in ground_truth_to_append.items(): if v: shapes_table.add_row(f"Ground Truth {k}", str(np.array(v).shape)) break # Add label shape if label_to_append: shapes_table.add_row("Labels", str(np.array(label_to_append).shape)) self.console.print(shapes_table) # Add EMG data for k, v in emg_to_append.items(): if v: # Check that list is not empty _add_to_dataset(training_group["emg"], np.array(v), name=k) # Add ground truth data for k, v in ground_truth_to_append.items(): if v: # Check that list is not empty _add_to_dataset(training_group["ground_truth"], np.array(v), name=k) # Add labels and classes if label_to_append: _add_to_dataset(training_group, np.array(label_to_append), name="label") if class_to_append: _add_to_dataset(training_group, np.array(class_to_append), name="class") _add_to_dataset( training_group, np.array(one_hot_class_to_append), name="one_hot_class", )
[docs] def _print_dataset_summary(self, dataset: zarr.Group): """Print a summary of the created dataset.""" # Get dataset sizes training_emg_sizes = { k: dataset["training/emg"][k].shape for k in dataset["training/emg"] } testing_emg_sizes = ( {k: dataset["testing/emg"][k].shape for k in dataset["testing/emg"]} if "emg" in dataset["testing"] else {} ) validation_emg_sizes = ( {k: dataset["validation/emg"][k].shape for k in dataset["validation/emg"]} if "emg" in dataset["validation"] else {} ) # Calculate memory usage total_size_bytes = 0 split_sizes = {} for split in ["training", "testing", "validation"]: split_size_bytes = 0 for group in ["emg", "ground_truth"]: if group in dataset[split]: for k in dataset[split][group]: arr = dataset[split][group][k] item_size = np.dtype(arr.dtype).itemsize arr_size = np.prod(arr.shape) * item_size split_size_bytes += arr_size total_size_bytes += arr_size split_sizes[split] = split_size_bytes / (1024 * 1024) # Convert to MB # Total size in MB total_size_mb = total_size_bytes / (1024 * 1024) # Create a visually appealing summary with Rich self.console.rule("[bold blue]DATASET CREATION COMPLETED", style="blue double") self.console.print() # Add empty line # Summary table summary_table = Table( title="Dataset Summary", show_header=True, header_style="bold magenta", box=box.ROUNDED, padding=(0, 2), width=60, title_style="bold magenta", title_justify="center", ) summary_table.add_column("Metric", style="dim", width=30) summary_table.add_column("Value", style="green") # Add summary metrics summary_table.add_row("Total tasks", str(len(self.tasks_to_use))) summary_table.add_row( "Training samples", str( dataset["training/label"].shape[0] if "label" in dataset["training"] else 0 ), ) summary_table.add_row( "Testing samples", str( dataset["testing/label"].shape[0] if "label" in dataset["testing"] else 0 ), ) summary_table.add_row( "Validation samples", str( dataset["validation/label"].shape[0] if "label" in dataset["validation"] else 0 ), ) summary_table.add_row("Total dataset size", f"{total_size_mb:.2f} MB") # Add split sizes for split, size_mb in split_sizes.items(): summary_table.add_row( f"{split.capitalize()} split size", f"{size_mb:.2f} MB" ) self.console.print(summary_table) self.console.print() # Add empty line # Dataset structure tree structure_tree = Tree("[bold yellow]Dataset Structure") for split, sizes in [ ("Training", training_emg_sizes), ("Testing", testing_emg_sizes), ("Validation", validation_emg_sizes), ]: if sizes: split_branch = structure_tree.add(f"[bold cyan]{split}") # EMG representations emg_branch = split_branch.add("[bold green]EMG Representations") for k, shape in sizes.items(): emg_branch.add(f"{k}: {shape}") self.console.print(structure_tree) self.console.print() # Add empty line self.console.rule( "[bold green]Dataset Creation Successfully Completed!", style="green double" )