Source code for myoverse.datatypes

import copy
import inspect
import os
import pickle
from abc import abstractmethod
from typing import (
    Dict,
    Optional,
    TypedDict,
    Any,
    Union,
    List,
    Tuple,
    NamedTuple,
    Final,
)

import mplcursors
import networkx
import networkx as nx
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.widgets import Slider

from myoverse.datasets.filters._template import FilterBaseClass


class DeletedRepresentation(NamedTuple):
    """Class to hold metadata about deleted representations.

    This stores the shape and dtype of the deleted array. Making it compatible with the numpy array interface.

    Attributes
    ----------
    shape : tuple
        The shape of the deleted array
    dtype : np.dtype
        The data type of the deleted array
    """

    shape: tuple
    dtype: np.dtype

    def __str__(self) -> str:
        """String representation of the deleted data."""
        return str(self.shape)


Representation = TypedDict(
    "Representation",
    {"data": np.ndarray, "filter_sequence": List[FilterBaseClass]},
)

InputRepresentationName: Final[str] = "Input"
OutputRepresentationName: Final[str] = "Output"
LastRepresentationName: Final[str] = "Last"


def create_grid_layout(
    rows: int,
    cols: int,
    n_electrodes: int = None,
    fill_pattern: str = "row",
    missing_indices: List[Tuple[int, int]] = None,
) -> np.ndarray:
    """Creates a grid layout based on specified parameters.

    Parameters
    ----------
    rows : int
        Number of rows in the grid.
    cols : int
        Number of columns in the grid.
    n_electrodes : int, optional
        Number of electrodes in the grid. If None, will be set to rows*cols minus
        the number of missing indices. Default is None.
    fill_pattern : str, optional
        Pattern to fill the grid. Options are 'row' (row-wise) or 'column' (column-wise).
        Default is 'row'.
    missing_indices : List[Tuple[int, int]], optional
        List of (row, col) indices that should be left empty (-1). Default is None.

    Returns
    -------
    np.ndarray
        2D array representing the grid layout.

    Raises
    ------
    ValueError
        If the parameters are invalid.

    Examples
    --------
    >>> import numpy as np
    >>> from myoverse.datatypes import create_grid_layout
    >>>
    >>> # Create a 4×4 grid with row-wise numbering (0-15)
    >>> grid1 = create_grid_layout(4, 4, fill_pattern='row')
    >>> print(grid1)
    [[ 0  1  2  3]
     [ 4  5  6  7]
     [ 8  9 10 11]
     [12 13 14 15]]
    >>>
    >>> # Create a 4×4 grid with column-wise numbering (0-15)
    >>> grid2 = create_grid_layout(4, 4, fill_pattern='column')
    >>> print(grid2)
    [[ 0  4  8 12]
     [ 1  5  9 13]
     [ 2  6 10 14]
     [ 3  7 11 15]]
    >>>
    >>> # Create a 3×3 grid with only 8 electrodes (missing bottom-right)
    >>> grid3 = create_grid_layout(3, 3, 8, 'row',
    ...                           missing_indices=[(2, 2)])
    >>> print(grid3)
    [[ 0  1  2]
     [ 3  4  5]
     [ 6  7 -1]]
    """
    # Initialize grid with -1 (gaps)
    grid = np.full((rows, cols), -1, dtype=int)

    # Process missing indices
    if missing_indices is None:
        missing_indices = []

    missing_positions = set(
        (r, c) for r, c in missing_indices if 0 <= r < rows and 0 <= c < cols
    )
    max_electrodes = rows * cols - len(missing_positions)

    # Validate n_electrodes
    if n_electrodes is None:
        n_electrodes = max_electrodes
    elif n_electrodes > max_electrodes:
        raise ValueError(
            f"Number of electrodes ({n_electrodes}) exceeds available positions "
            f"({max_electrodes} = {rows}×{cols} - {len(missing_positions)} missing)"
        )

    # Fill the grid based on the pattern
    electrode_idx = 0
    if fill_pattern.lower() == "row":
        for r in range(rows):
            for c in range(cols):
                if (r, c) not in missing_positions and electrode_idx < n_electrodes:
                    grid[r, c] = electrode_idx
                    electrode_idx += 1
    elif fill_pattern.lower() == "column":
        for c in range(cols):
            for r in range(rows):
                if (r, c) not in missing_positions and electrode_idx < n_electrodes:
                    grid[r, c] = electrode_idx
                    electrode_idx += 1
    else:
        raise ValueError(
            f"Invalid fill pattern: {fill_pattern}. Use 'row' or 'column'."
        )

    return grid


[docs] class _Data: """Base class for all data types. This class provides common functionality for handling different types of data, including maintaining original and processed representations, tracking filters applied, and managing data flow. Parameters ---------- raw_data : np.ndarray The raw data to store. sampling_frequency : float The sampling frequency of the data. Attributes ---------- sampling_frequency : float The sampling frequency of the data. _last_processing_step : str The last processing step applied to the data. _processed_representations : networkx.DiGraph The graph of the processed representations. _filters_used : Dict[str, FilterBaseClass] Dictionary of all filters used in the data. The keys are the names of the filters and the values are the filters themselves. _data : Dict[str, Union[np.ndarray, DeletedRepresentation]] Dictionary of all data. The keys are the names of the representations and the values are either numpy arrays or DeletedRepresentation objects (for representations that have been deleted to save memory but can be regenerated when needed). Raises ------ ValueError If the sampling frequency is less than or equal to 0. Notes ----- Memory Management: When representations are deleted with delete_data(), they are replaced with DeletedRepresentation objects that store essential metadata (shape, dtype) but don't consume memory for the actual data. These representations can be automatically recomputed when accessed. The chunking status is determined from the shape when needed. Examples -------- This is an abstract base class and should not be instantiated directly. Instead, use one of the concrete subclasses like EMGData or KinematicsData: >>> import numpy as np >>> from myoverse.datatypes import EMGData >>> >>> # Create sample data >>> data = np.random.randn(16, 1000) >>> emg = EMGData(data, 2000) # 2000 Hz sampling rate >>> >>> # Access attributes from the base _Data class >>> print(f"Sampling frequency: {emg.sampling_frequency} Hz") >>> print(f"Is input data chunked: {emg.is_chunked['Input']}") """
[docs] def __init__( self, raw_data: np.ndarray, sampling_frequency: float, nr_of_dimensions_when_unchunked: int, ): self.sampling_frequency: float = sampling_frequency self.nr_of_dimensions_when_unchunked: int = nr_of_dimensions_when_unchunked if self.sampling_frequency <= 0: raise ValueError("The sampling frequency should be greater than 0.") self._data: Dict[str, Union[np.ndarray, DeletedRepresentation]] = { InputRepresentationName: raw_data, } self._filters_used: Dict[str, FilterBaseClass] = {} self._processed_representations: networkx.DiGraph = networkx.DiGraph() self._processed_representations.add_node(InputRepresentationName) self._processed_representations.add_node(OutputRepresentationName) self.__last_processing_step: str = InputRepresentationName
@property def is_chunked(self) -> Dict[str, bool]: """Returns whether the data is chunked or not. Returns ------- Dict[str, bool] A dictionary where the keys are the representations and the values are whether the data is chunked or not. """ # Create cache if it doesn't exist or if _data might have changed if not hasattr(self, "_chunked_cache") or len(self._chunked_cache) != len( self._data ): self._chunked_cache = { key: self._check_if_chunked(value) for key, value in self._data.items() } return self._chunked_cache
[docs] def _check_if_chunked(self, data: Union[np.ndarray, DeletedRepresentation]) -> bool: """Checks if the data is chunked or not. Parameters ---------- data : Union[np.ndarray, DeletedRepresentation] The data to check. Returns ------- bool Whether the data is chunked or not. """ return len(data.shape) == self.nr_of_dimensions_when_unchunked
@property def input_data(self) -> np.ndarray: """Returns the input data.""" return self._data[InputRepresentationName] @input_data.setter def input_data(self, value: np.ndarray): raise RuntimeError("This property is read-only.") @property def processed_representations(self) -> Dict[str, np.ndarray]: """Returns the processed representations of the data.""" return self._data @processed_representations.setter def processed_representations(self, value: Dict[str, Representation]): raise RuntimeError("This property is read-only.") @property def output_representations(self) -> Dict[str, np.ndarray]: """Returns the output representations of the data.""" # Convert to set for faster lookups output_nodes = set( self._processed_representations.predecessors(OutputRepresentationName) ) return {key: value for key, value in self._data.items() if key in output_nodes} @output_representations.setter def output_representations(self, value: Dict[str, Representation]): raise RuntimeError("This property is read-only.") @property def _last_processing_step(self) -> str: """Returns the last processing step applied to the data. Returns ------- str The last processing step applied to the data. """ if self.__last_processing_step is None: raise ValueError("No processing steps have been applied.") return self.__last_processing_step @_last_processing_step.setter def _last_processing_step(self, value: str): """Sets the last processing step applied to the data. Parameters ---------- value : str The last processing step applied to the data. """ self.__last_processing_step = value
[docs] @abstractmethod def plot(self, *_: Any, **__: Any): """Plots the data.""" raise NotImplementedError( "This method should be implemented in the child class." )
[docs] def plot_graph(self, title: Optional[str] = None): """Draws the graph of the processed representations. Parameters ---------- title : Optional[str], default=None Optional title for the graph. If None, no title will be displayed. """ # Use spectral layout but with enhancements for better flow G = self._processed_representations # Initial layout using spectral positioning pos = nx.spectral_layout(G) # Always position input node on the left and output node on the right min_x = min(p[0] for p in pos.values()) max_x = max(p[0] for p in pos.values()) # Normalize x positions to ensure full range is used for node in pos: pos[node][0] = ( (pos[node][0] - min_x) / (max_x - min_x) if max_x != min_x else 0.5 ) # Force input/output node positions pos[InputRepresentationName][0] = 0.0 # Left edge pos[OutputRepresentationName][0] = 1.0 # Right edge # Use topological sort to improve node positioning try: # Get topologically sorted nodes (excluding input and output) topo_nodes = [ node for node in nx.topological_sort(G) if node not in [InputRepresentationName, OutputRepresentationName] ] # Group nodes by their topological "level" (distance from input) node_levels = {} for node in topo_nodes: # Find all paths from input to this node paths = list(nx.all_simple_paths(G, InputRepresentationName, node)) if paths: # Level is the longest path length (minus 1 for the input node) level = max(len(path) - 1 for path in paths) if level not in node_levels: node_levels[level] = [] node_levels[level].append(node) # Calculate the total number of levels max_level = max(node_levels.keys()) if node_levels else 0 # Adjust x-positions based on level - without losing the original y-positions from spectral layout for level, nodes in node_levels.items(): # Calculate new x-position (divide evenly between input and output) x_pos = level / (max_level + 1) if max_level > 0 else 0.5 # Preserve the relative y-positions from spectral layout for node in nodes: # Update only the x-position pos[node][0] = x_pos except nx.NetworkXUnfeasible: # If topological sort fails, we'll keep the original spectral layout print("Warning: Topological sort failed, using default layout") pass except Exception as e: # Catch other exceptions print(f"Warning: Error in layout algorithm: {str(e)}") pass # Identify related nodes (nodes that share the same filter parent name) # This is particularly useful for filters that return multiple outputs related_nodes = {} for node in G.nodes(): if node in [InputRepresentationName, OutputRepresentationName]: continue # Extract base filter name (part before the underscore) if "_" in node: base_name = node.split("_")[0] if base_name not in related_nodes: related_nodes[base_name] = [] related_nodes[base_name].append(node) # Adjust positions for related nodes to prevent overlap for base_name, nodes in related_nodes.items(): if len(nodes) > 1: # Find average position for this group avg_x = sum(pos[node][0] for node in nodes) / len(nodes) # Calculate better vertical spacing vertical_spacing = 0.3 / len(nodes) # Arrange nodes vertically around their average x-position for i, node in enumerate(nodes): # Keep the same x position but adjust y position pos[node][0] = avg_x # Distribute nodes vertically, centered around original y position # Start from -0.15 to +0.15 to ensure good spacing vertical_offset = -0.15 + (i * vertical_spacing) pos[node][1] = pos[node][1] + vertical_offset # Apply gentle force-directed adjustments to improve layout # without completely changing the spectral positioning for _ in range(10): # Reduced from 20 to 10 iterations # Store current positions old_pos = {n: p.copy() for n, p in pos.items()} for node in G.nodes(): if node in [InputRepresentationName, OutputRepresentationName]: continue # Skip fixed nodes # Get node neighbors neighbors = list(G.neighbors(node)) if not neighbors: continue # Calculate average position of neighbors, weighted by in/out direction pred_force = np.zeros(2) succ_force = np.zeros(2) # Predecessors pull left predecessors = list(G.predecessors(node)) if predecessors: pred_force = ( np.mean([old_pos[p] for p in predecessors], axis=0) - old_pos[node] ) # Scale down x-force to maintain left-to-right flow pred_force[0] *= 0.05 # Reduced from 0.1 to 0.05 # Successors pull right successors = list(G.successors(node)) if successors: succ_force = ( np.mean([old_pos[s] for s in successors], axis=0) - old_pos[node] ) # Scale down x-force to maintain left-to-right flow succ_force[0] *= 0.05 # Reduced from 0.1 to 0.05 # Apply force (weighted more toward maintaining x position) force = pred_force + succ_force # Reduce force magnitude to avoid disrupting the topological ordering pos[node] += 0.05 * force # Reduced from 0.1 to 0.05 # Maintain x position within 0-1 range pos[node][0] = max(0.05, min(0.95, pos[node][0])) # Final overlap prevention - ensure minimum distance between nodes min_distance = 0.1 # Minimum distance between nodes for _ in range(3): # Reduced from 5 to 3 iterations overlap_forces = {node: np.zeros(2) for node in G.nodes()} # Calculate repulsion forces between every pair of nodes node_list = list(G.nodes()) for i, node1 in enumerate(node_list): for node2 in node_list[i + 1 :]: # Skip input/output nodes if node1 in [ InputRepresentationName, OutputRepresentationName, ] or node2 in [InputRepresentationName, OutputRepresentationName]: continue # Calculate distance between nodes dist_vec = pos[node1] - pos[node2] dist = np.linalg.norm(dist_vec) # Apply repulsion if nodes are too close if dist < min_distance and dist > 0: # Normalize the vector repulsion = dist_vec / dist # Scale by how much they overlap scale = (min_distance - dist) * 0.4 # Modified from 0.5 to 0.4 # Add to both nodes' forces (in opposite directions) overlap_forces[node1] += repulsion * scale overlap_forces[node2] -= repulsion * scale # Apply forces for node, force in overlap_forces.items(): if node not in [InputRepresentationName, OutputRepresentationName]: pos[node] += force # Maintain x position closer to its original value # to preserve the topological ordering x_original = pos[node][0] # Make sure nodes stay within bounds pos[node][0] = max(0.05, min(0.95, pos[node][0])) pos[node][1] = max(-0.95, min(0.95, pos[node][1])) # Restore x position with a small adjustment pos[node][0] = 0.9 * x_original + 0.1 * pos[node][0] # Create the figure and axis with a larger size for better visualization plt.figure(figsize=(16, 12)) # Increased from (14, 10) ax = plt.gca() # Add title if provided if title is not None: plt.title(title, fontsize=16, pad=20) # Create dictionaries for node attributes node_colors = {} node_sizes = {} node_shapes = {} # Set attributes based on node type for node in G.nodes(): if node == InputRepresentationName: node_colors[node] = "crimson" node_sizes[node] = 1500 node_shapes[node] = "o" # Circle elif node == OutputRepresentationName: node_colors[node] = "forestgreen" node_sizes[node] = 1500 node_shapes[node] = "o" # Circle elif node not in self._data: # If the node is not in the data dictionary, it's a dummy node (like a filter name) node_colors[node] = "dimgray" node_sizes[node] = 1200 node_shapes[node] = "o" # Circle elif isinstance(self._data[node], DeletedRepresentation): node_colors[node] = "dimgray" node_sizes[node] = 1200 node_shapes[node] = "o" # Square for deleted representations else: node_colors[node] = "royalblue" node_sizes[node] = 1200 node_shapes[node] = "o" # Circle # Group nodes by shape for drawing node_groups = {} for shape in set(node_shapes.values()): node_groups[shape] = [node for node, s in node_shapes.items() if s == shape] # Draw each group of nodes with the correct shape drawn_nodes = {} for shape, nodes in node_groups.items(): if not nodes: continue # Create lists of node properties node_list = nodes color_list = [node_colors[node] for node in node_list] size_list = [node_sizes[node] for node in node_list] # Draw nodes with the current shape if shape == "o": # Circle drawn_nodes[shape] = nx.draw_networkx_nodes( G, pos, nodelist=node_list, node_color=color_list, node_size=size_list, alpha=0.8, ax=ax, ) elif shape == "s": # Square drawn_nodes[shape] = nx.draw_networkx_nodes( G, pos, nodelist=node_list, node_color=color_list, node_size=size_list, node_shape="s", alpha=0.8, ax=ax, ) # Set z-order for nodes if drawn_nodes[shape] is not None: drawn_nodes[shape].set_zorder(1) # Draw node labels with different colors based on node type label_objects = {} # Create custom labels: "I" for input, "O" for output, numbers for others starting from 1 node_labels = {} # Filter out input and output nodes for separate labeling intermediate_nodes = [ node for node in G.nodes if node not in [InputRepresentationName, OutputRepresentationName] ] # Add labels for input and output nodes node_labels[InputRepresentationName] = "I" node_labels[OutputRepresentationName] = "O" # For intermediate nodes, use sequential numbers (1 to n) for i, node in enumerate(intermediate_nodes, 1): node_labels[node] = str(i) label_objects["nodes"] = nx.draw_networkx_labels( G, pos, labels=node_labels, font_size=18, font_color="white", ax=ax ) # Set z-order for all labels for label_group in label_objects.values(): for text in label_group.values(): text.set_zorder(3) # Remove the grid annotations since we're now showing the grid names directly in the nodes # Add additional text annotations if needed for extra information (not grid names) # This section is kept empty as we're now using the full representation names in the nodes # Create edge styles based on connection type edge_styles = [] edge_colors = [] edge_widths = [] for u, v in G.edges(): # Define edge properties based on connection type if u == InputRepresentationName: edge_colors.append("crimson") # Input connections edge_widths.append(2.0) edge_styles.append("solid") elif v == OutputRepresentationName: edge_colors.append("forestgreen") # Output connections edge_widths.append(2.0) edge_styles.append("solid") else: edge_colors.append("dimgray") # Intermediate connections edge_widths.append(1.5) edge_styles.append("solid") # Draw all edges with the defined styles edges = nx.draw_networkx_edges( G, pos, ax=ax, edge_color=edge_colors, width=edge_widths, arrowstyle="-|>", arrowsize=20, connectionstyle="arc3,rad=0.2", # Slightly increased curve for better visibility alpha=0.8, ) # Set z-order for edges to be above nodes if isinstance(edges, list): for edge_collection in edges: edge_collection.set_zorder(2) elif edges is not None: edges.set_zorder(2) # Create annotation for hover information (initially invisible) annot = ax.annotate( "", xy=(0, 0), xytext=(20, 20), textcoords="offset points", bbox=dict(boxstyle="round,pad=0.5", fc="white", alpha=0.9), fontsize=12, fontweight="normal", color="black", zorder=5, ) annot.set_visible(False) # Add hover functionality for interactive exploration # Combine all node collections for the hover effect all_node_collections = [ collection for collection in drawn_nodes.values() if collection is not None ] if all_node_collections: # Initialize the cursor without the hover behavior first cursor = mplcursors.cursor(all_node_collections, hover=True) # Map to keep track of the nodes for each collection node_collection_map = {} for shape, collection in drawn_nodes.items(): if collection is not None: node_collection_map[collection] = node_groups[shape] def on_hover(sel): try: # Get the artist (the PathCollection) and find its shape artist = sel.artist # Get the target index - this is called 'target.index' in mplcursors if hasattr(sel, "target") and hasattr(sel.target, "index"): idx = sel.target.index else: # Fall back to other possible attribute names idx = getattr(sel, "index", 0) # Look up which nodes correspond to this artist for shape, collection in drawn_nodes.items(): if collection == artist: # Get list of nodes for this shape shape_nodes = node_groups[shape] if idx < len(shape_nodes): hovered_node_name = shape_nodes[idx] # Create the annotation text with full representation name annotation = f"Representation: {hovered_node_name}\n\n" # add whether the node needs to be recomputed if ( hovered_node_name != OutputRepresentationName and hovered_node_name in self._data ): data = self._data[hovered_node_name] if isinstance(data, DeletedRepresentation): annotation += "needs to be\nrecomputed\n\n" # add info whether the node is chunked or not annotation += "chunked: " if ( hovered_node_name != OutputRepresentationName and hovered_node_name in self.is_chunked ): annotation += str( self.is_chunked[hovered_node_name] ) else: annotation += "(see previous node(s))" # add shape information to the annotation annotation += "\n" + "shape: " if ( hovered_node_name != OutputRepresentationName and hovered_node_name in self._data ): data = self._data[hovered_node_name] if isinstance(data, np.ndarray): annotation += str(data.shape) elif isinstance(data, DeletedRepresentation): annotation += str(data.shape) else: annotation += "(see previous node(s))" sel.annotation.set_text(annotation) sel.annotation.get_bbox_patch().set( fc="white", alpha=0.9 ) # Background color sel.annotation.set_fontsize(12) # Font size sel.annotation.set_fontstyle("italic") break except Exception as e: # If any error occurs, show a simplified annotation with detailed error info error_info = f"Error in hover: {str(e)}\n" if hasattr(sel, "target"): error_info += f"Sel has target: {True}\n" if hasattr(sel.target, "index"): error_info += f"Target has index: {True}\n" error_info += f"Available attributes: {dir(sel)}" sel.annotation.set_text(error_info) cursor.connect("add", on_hover) # Improve visual appearance plt.grid(False) plt.axis("off") plt.margins(0.2) # Increased from 0.15 to give more space around nodes plt.tight_layout(pad=2.0) # Increased padding plt.show()
[docs] def apply_filter( self, filter: FilterBaseClass, representations_to_filter: list[str] | None = None, keep_representation_to_filter: bool = True, ) -> str: """Applies a filter to the data. Parameters ---------- filter : callable The filter to apply. representations_to_filter : list[str], optional A list of representations to filter. The filter is responsible for handling the appropriate number of inputs or raising an error if incompatible. If None, creates an empty list. keep_representation_to_filter : bool Whether to keep the representation(s) to filter or not. If the representation to filter is "Input", this parameter is ignored. Returns ------- str The name of the representation after applying the filter. Raises ------ ValueError If representations_to_filter is a string instead of a list TypeError If a filter returns a dictionary (no longer supported) """ representation_name = filter.name # Ensure representations_to_filter is a list, not a string if isinstance(representations_to_filter, str): raise ValueError( f"representations_to_filter must be a list, not a string. " f"Use ['{representations_to_filter}'] instead of '{representations_to_filter}'." ) # If representations_to_filter is None, create an empty list if representations_to_filter is None: representations_to_filter = [] # Check if the list is empty if len(representations_to_filter) == 0: # For all filters, check if the list is empty raise ValueError( f"The filter {filter.name} requires an input representation. " f"Please provide at least one representation to filter." ) # Replace LastRepresentationName with the actual last processing step representations_to_filter = [ self._last_processing_step if rep == LastRepresentationName else rep for rep in representations_to_filter ] # Add edges to the graph for all input representations for rep in representations_to_filter: if rep not in self._processed_representations: self._processed_representations.add_node(rep) # Add filter node and create edges from inputs to filter if representation_name not in self._processed_representations: self._processed_representations.add_node(representation_name) # Add edge from the representation to filter to the new representation if it doesn't exist yet if not self._processed_representations.has_edge(rep, representation_name): self._processed_representations.add_edge(rep, representation_name) # Get the data for each representation input_arrays = [self[rep] for rep in representations_to_filter] # Automatically extract all data object parameters to pass to the filter data_params = {} # Use inspect to get all instance attributes for attr_name, attr_value in inspect.getmembers(self): # Skip private attributes, methods, and callables if ( not attr_name.startswith("_") and not callable(attr_value) and not isinstance(attr_value, property) ): data_params[attr_name] = attr_value # Check if a standard filter is receiving multiple inputs inappropriately if len(input_arrays) > 1: raise ValueError( f"You're trying to pass multiple representations ({', '.join(representations_to_filter)}) to a " f"standard filter that only accepts a single input." ) # If there's only one input, pass it directly; otherwise pass the list # This maintains backward compatibility with existing filters if len(input_arrays) == 1: filtered_data = filter(input_arrays[0], **data_params) else: filtered_data = filter(input_arrays, **data_params) # Store the filtered data self._data[representation_name] = filtered_data # Check if the filter is going to be an output # If so, add an edge from the representation to add to the output node if filter.is_output: self._processed_representations.add_edge( representation_name, OutputRepresentationName ) # Save the used filter self._filters_used[representation_name] = filter # Set the last processing step self._last_processing_step = representation_name # Remove the representations to filter if needed if keep_representation_to_filter is False: for rep in representations_to_filter: if ( rep != InputRepresentationName ): # Never delete the raw representation self.delete_data(rep) return representation_name
[docs] def apply_filter_sequence( self, filter_sequence: List[FilterBaseClass], representations_to_filter: List[str] | None = None, keep_individual_filter_steps: bool = True, keep_representation_to_filter: bool = True, ) -> str: """Applies a sequence of filters to the data sequentially. Parameters ---------- filter_sequence : list[FilterBaseClass] The sequence of filters to apply. representations_to_filter : List[str], optional A list of representations to filter for the first filter in the sequence. Each filter is responsible for validating and handling its inputs appropriately. For subsequent filters in the sequence, the output of the previous filter is used. keep_individual_filter_steps : bool Whether to keep the results of each filter or not. keep_representation_to_filter : bool Whether to keep the representation(s) to filter or not. If the representation to filter is "Input", this parameter is ignored. Returns ------- str The name of the last representation after applying all filters. Raises ------ ValueError If filter_sequence is empty. If representations_to_filter is empty. If representations_to_filter is a string instead of a list. """ if len(filter_sequence) == 0: raise ValueError("filter_sequence cannot be empty.") # Ensure representations_to_filter is a list, not a string if isinstance(representations_to_filter, str): raise ValueError( f"representations_to_filter must be a list, not a string. " f"Use ['{representations_to_filter}'] instead of '{representations_to_filter}'." ) # If representations_to_filter is None, create an empty list if representations_to_filter is None: representations_to_filter = [] # Replace LastRepresentationName with the actual last processing step representations_to_filter = [ self._last_processing_step if rep == LastRepresentationName else rep for rep in representations_to_filter ] # Apply the first filter with the provided representations result = self.apply_filter( filter=filter_sequence[0], representations_to_filter=representations_to_filter, keep_representation_to_filter=True, # We'll handle this at the end ) # Collect intermediate results for potential cleanup later intermediate_results = [result] what_to_filter = [result] # Apply subsequent filters in sequence for i, f in enumerate(filter_sequence[1:], 1): # Apply the next filter using the previous result result = self.apply_filter( filter=f, representations_to_filter=what_to_filter, keep_representation_to_filter=True, # Always keep intermediate results until the end ) # Update what to filter for the next iteration intermediate_results.append(result) what_to_filter = [result] # Remove intermediate filter steps if needed, keeping the final result if not keep_individual_filter_steps: # Delete all intermediates except the final result for rep in intermediate_results[:-1]: # Skip the last result self.delete_data(rep) # Remove the representation to filter if needed if not keep_representation_to_filter: for rep in representations_to_filter: if ( rep != InputRepresentationName ): # Never delete the input representation self.delete_data(rep) return result
[docs] def apply_filter_pipeline( self, filter_pipeline: List[List[FilterBaseClass]], representations_to_filter: List[List[str]], keep_individual_filter_steps: bool = True, keep_representation_to_filter: bool = True, ) -> List[str]: """Applies a pipeline of filters to the data. Parameters ---------- filter_pipeline : list[list[FilterBaseClass]] The pipeline of filters to apply. Each inner list represents a branch of filters. representations_to_filter : list[list[str]] A list of input representations for each branch. Each element corresponds to a branch in the filter_pipeline and must be: - A list with a single string for standard branches that take one input - A list with multiple strings for branches starting with a multi-input filter - An empty list is not allowed unless the filter explicitly accepts no input .. note :: The length of the representations_to_filter should be the same as the length of the amount of branches in the filter_pipeline. keep_individual_filter_steps : bool Whether to keep the results of each filter or not. keep_representation_to_filter : bool Whether to keep the representation(s) to filter or not. If the representation to filter is "Input", this parameter is ignored. Returns ------- List[str] A list containing the names of the final representations from all branches. Raises ------ ValueError If the number of filter branches and representations to filter is different. If a standard filter is provided with multiple representations. If no representations are provided for a filter that requires input. If any representations_to_filter element is a string instead of a list. Notes ----- Each branch in the pipeline is processed sequentially using apply_filter_sequence. Examples -------- >>> # Example of a pipeline with multiple processing branches >>> from myoverse.datatypes import EMGData >>> from myoverse.datasets.filters.generic import ApplyFunctionFilter >>> import numpy as np >>> >>> # Create sample data >>> data = EMGData(np.random.rand(10, 8), sampling_frequency=1000) >>> >>> # Define filter branches that perform different operations on the same input >>> branch1 = [ApplyFunctionFilter(function=np.abs, name="absolute_values")] >>> branch2 = [ApplyFunctionFilter(function=lambda x: x**2, name="squared_values")] >>> >>> # Apply pipeline with two branches >>> data.apply_filter_pipeline( >>> filter_pipeline=[branch1, branch2], >>> representations_to_filter=[ >>> ["input_data"], # Process branch1 on input_data >>> ["input_data"], # Process branch2 on input_data >>> ], >>> ) >>> >>> # The results are now available as separate representations >>> abs_values = data["absolute_values"] >>> squared_values = data["squared_values"] """ if len(filter_pipeline) == 0: return [] if len(filter_pipeline) != len(representations_to_filter): raise ValueError( f"The number of filter branches ({len(filter_pipeline)}) and " f"representations to filter ({len(representations_to_filter)}) must be the same." ) # Ensure all elements in representations_to_filter are lists, not strings for branch_idx, branch_inputs in enumerate(representations_to_filter): if isinstance(branch_inputs, str): raise ValueError( f"Element {branch_idx} of representations_to_filter is a string ('{branch_inputs}'), " f"but must be a list. Use ['{branch_inputs}'] instead." ) if branch_inputs is None: raise ValueError( f"Element {branch_idx} of representations_to_filter is None, " f"but must be a list. Use an empty list [] for filters that do not require input." ) # Replace LastRepresentationName with the actual last processing step in each branch input representations_to_filter[branch_idx] = [ self._last_processing_step if rep == LastRepresentationName else rep for rep in branch_inputs ] # Collect intermediates to delete after all branches are processed intermediates_to_delete = [] all_results = [] # Process each branch without deleting intermediates for branch_idx, (filter_sequence, branch_inputs) in enumerate( zip(filter_pipeline, representations_to_filter) ): try: # Apply filter sequence and get results branch_result = self.apply_filter_sequence( filter_sequence=filter_sequence, representations_to_filter=branch_inputs, keep_individual_filter_steps=True, # Always keep during processing keep_representation_to_filter=keep_representation_to_filter, ) # Track the branch result all_results.append(branch_result) # Track intermediates that might need to be deleted if not keep_individual_filter_steps: # For each filter in the sequence (except the last), # add its name to intermediates to delete for f in filter_sequence[:-1]: if hasattr(f, "name") and f.name: intermediates_to_delete.append(f.name) except ValueError as e: # Enhance error message with branch information raise ValueError( f"Error in branch {branch_idx + 1}/{len(filter_pipeline)}: {str(e)}" ) from e # After all branches are processed, delete collected intermediates if needed if not keep_individual_filter_steps: # First, identify all final outputs from the pipeline final_outputs = set(all_results) # For each representation in the data for rep_name in list(self._data.keys()): # Skip input and final outputs if rep_name == InputRepresentationName or rep_name in final_outputs: continue # Check if this is an intermediate from any branch is_intermediate = False for base_name in intermediates_to_delete: # Either exact match or prefix match for multi-output filters if rep_name == base_name or rep_name.startswith(f"{base_name}_"): is_intermediate = True break if is_intermediate: try: self.delete_data(rep_name) except KeyError: # If already deleted or doesn't exist, just continue pass return all_results
[docs] def get_representation_history(self, representation: str) -> List[str]: """Returns the history of a representation. Parameters ---------- representation : str The representation to get the history of. Returns ------- list[str] The history of the representation. """ return list( nx.shortest_path( self._processed_representations, InputRepresentationName, representation, ) )
[docs] def __repr__(self) -> str: # Get input data shape directly from _data dictionary to avoid copying input_shape = self._data[InputRepresentationName].shape # Build a structured string representation lines = [] lines.append(f"{self.__class__.__name__}") lines.append(f"Sampling frequency: {self.sampling_frequency} Hz") lines.append(f"(0) Input {input_shape}") if len(self._processed_representations.nodes) >= 3: # Add an empty line for spacing between input and filters lines.append("") lines.append("Filter(s):") # Create mapping of representation to index only if needed if self._filters_used: representation_indices = { key: index for index, key in enumerate(self._filters_used.keys()) } # Precompute output predecessors for faster lookup output_predecessors = set( self._processed_representations.predecessors( OutputRepresentationName ) ) for filter_index, (filter_name, filter_representation) in enumerate( self._data.items() ): if filter_name == InputRepresentationName: continue # Get history and format it more efficiently history = self.get_representation_history(filter_name) history_str = " -> ".join( str(representation_indices[rep] + 1) for rep in history[1:] ) # Build filter representation string is_output = filter_name in output_predecessors shape_str = ( filter_representation.shape if not isinstance(filter_representation, str) else filter_representation ) filter_str = f"({filter_index} | {history_str}) " if is_output: filter_str += "(Output) " filter_str += f"{filter_name} {shape_str}" lines.append(filter_str) # Join all parts with newlines return "\n".join(lines)
[docs] def __str__(self) -> str: return ( "--\n" + self.__repr__() .replace("; ", "\n") .replace("Filter(s): ", "\nFilter(s):\n") + "\n--" )
[docs] def __getitem__(self, key: str) -> np.ndarray: if key == InputRepresentationName: # Use array.view() for more efficient copying when possible data = self.input_data return data.view() if data.flags.writeable else data.copy() if key == LastRepresentationName: return self[self._last_processing_step] if key not in self._processed_representations: raise KeyError(f'The representation "{key}" does not exist.') data_to_return = self._data[key] if isinstance(data_to_return, DeletedRepresentation): print(f'Recomputing representation "{key}"') history = self.get_representation_history(key) self.apply_filter_sequence( filter_sequence=[ self._filters_used[filter_name] for filter_name in history[1:] ], representations_to_filter=[history[0]], ) # Use view when possible for more efficient memory usage data = self._data[key] return data.view() if data.flags.writeable else data.copy()
[docs] def __setitem__(self, key: str, value: np.ndarray) -> None: raise RuntimeError( "This method is not supported. Run apply_filter or apply_filters instead." )
[docs] def delete_data(self, representation_to_delete: str): """Delete data from a representation while keeping its metadata. This replaces the actual numpy array with a DeletedRepresentation object that contains metadata about the array, saving memory while allowing regeneration when needed. Parameters ---------- representation_to_delete : str The representation to delete the data from. """ if representation_to_delete == InputRepresentationName: return if representation_to_delete == LastRepresentationName: self.delete_data(self._last_processing_step) return if representation_to_delete not in self._data: raise KeyError( f'The representation "{representation_to_delete}" does not exist.' ) data = self._data[representation_to_delete] if isinstance(data, np.ndarray): self._data[representation_to_delete] = DeletedRepresentation( shape=data.shape, dtype=data.dtype )
[docs] def delete_history(self, representation_to_delete: str): """Delete the processing history for a representation. Parameters ---------- representation_to_delete : str The representation to delete the history for. """ if representation_to_delete == InputRepresentationName: return if representation_to_delete == LastRepresentationName: self.delete_history(self._last_processing_step) return if representation_to_delete not in self._processed_representations.nodes: raise KeyError( f'The representation "{representation_to_delete}" does not exist.' ) self._filters_used.pop(representation_to_delete, None) self._processed_representations.remove_node(representation_to_delete)
[docs] def delete(self, representation_to_delete: str): """Delete both the data and history for a representation. Parameters ---------- representation_to_delete : str The representation to delete. """ self.delete_data(representation_to_delete) self.delete_history(representation_to_delete)
[docs] def __copy__(self) -> "_Data": """Create a shallow copy of the instance. Returns ------- _Data A shallow copy of the instance. """ # Create a new instance with the basic initialization new_instance = self.__class__( self._data[InputRepresentationName].copy(), self.sampling_frequency ) # Get all attributes of the current instance for name, value in inspect.getmembers(self): # Skip special methods, methods, and the already initialized attributes if ( ( not name.startswith("_") or name in [ "_data", "_processed_representations", "_last_processing_step", "_filters_used", ] ) and not inspect.ismethod(value) and not name == "sampling_frequency" ): # Handle different attribute types appropriately if name == "_data": # Deep copy the data dictionary setattr(new_instance, name, copy.deepcopy(value)) elif name == "_processed_representations": # Use the graph's copy method setattr(new_instance, name, value.copy()) elif name == "_filters_used": # Deep copy the filters used setattr(new_instance, name, copy.deepcopy(value)) else: # Shallow copy for other attributes setattr(new_instance, name, copy.copy(value)) return new_instance
[docs] def save(self, filename: str): """Save the data to a file. Parameters ---------- filename : str The name of the file to save the data to. """ # Make sure directory exists os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True) with open(filename, "wb") as f: pickle.dump(self, f)
[docs] @classmethod def load(cls, filename: str) -> "_Data": """Load data from a file. Parameters ---------- filename : str The name of the file to load the data from. Returns ------- _Data The loaded data. """ with open(filename, "rb") as f: return pickle.load(f)
[docs] def memory_usage(self) -> Dict[str, Tuple[str, int]]: """Calculate memory usage of each representation. Returns ------- Dict[str, Tuple[str, int]] Dictionary with representation names as keys and tuples containing shape as string and memory usage in bytes as values. """ memory_usage = {} for key, value in self._data.items(): if isinstance(value, np.ndarray): memory_usage[key] = (str(value.shape), value.nbytes) elif isinstance(value, DeletedRepresentation): memory_usage[key] = ( str(value.shape), 0, # DeletedRepresentation objects use negligible memory ) return memory_usage
[docs] class EMGData(_Data): """Class for storing EMG data. Parameters ---------- input_data : np.ndarray The raw EMG data. The shape of the array should be (n_channels, n_samples) or (n_chunks, n_channels, n_samples).# .. important:: The class will only accept 2D or 3D arrays. There is no way to check if you actually have it in (n_chunks, n_samples) or (n_chunks, n_channels, n_samples) format. Please make sure to provide the correct shape of the data. sampling_frequency : float The sampling frequency of the EMG data. grid_layouts : Optional[List[np.ndarray]], optional List of 2D arrays specifying the exact electrode arrangement for each grid. Each array element contains the electrode index (0-based). .. note:: All electrodes numbers must be unique and non-negative. The numbers must be contiguous (0 to n) spread over however many grids. Default is None. Attributes ---------- input_data : np.ndarray The raw EMG data. The shape of the array should be (n_channels, n_samples) or (n_chunks, n_channels, n_samples). sampling_frequency : float The sampling frequency of the EMG data. grid_layouts : Optional[List[np.ndarray]] List of 2D arrays specifying the exact electrode arrangement for each grid. Each array element contains the electrode index (0-based). processed_data : Dict[str, np.ndarray] A dictionary where the keys are the names of filters applied to the EMG data and the values are the processed EMG data. Raises ------ ValueError If the shape of the raw EMG data is not (n_channels, n_samples) or (n_chunks, n_channels, n_samples). If the grid layouts are not provided or are not valid. Examples -------- >>> import numpy as np >>> from myoverse.datatypes import EMGData, create_grid_layout >>> >>> # Create sample EMG data (16 channels, 1000 samples) >>> emg_data = np.random.randn(16, 1000) >>> sampling_freq = 2000 # 2000 Hz >>> >>> # Create a basic EMGData object >>> emg = EMGData(emg_data, sampling_freq) >>> >>> # Create an EMGData object with grid layouts >>> # Define a 4×4 electrode grid with row-wise numbering >>> grid = create_grid_layout(4, 4, fill_pattern='row') >>> emg_with_grid = EMGData(emg_data, sampling_freq, grid_layouts=[grid]) Working with Multiple Grid Layouts --------------------------------- Grid layouts enable precise specification of how electrodes are arranged physically. This is especially useful for visualizing and analyzing high-density EMG recordings with multiple electrode grids: >>> import numpy as np >>> import matplotlib.pyplot as plt >>> from myoverse.datatypes import EMGData, create_grid_layout >>> >>> # Create sample EMG data for 61 electrodes with 1000 samples each >>> emg_data = np.random.randn(61, 1000) >>> sampling_freq = 2048 # Hz >>> >>> # Create layouts for three different electrode grids >>> # First grid: 5×5 array with sequential numbering (0-24) >>> grid1 = create_grid_layout(5, 5, fill_pattern='row') >>> >>> # Second grid: 6×6 array with column-wise numbering >>> grid2 = create_grid_layout(6, 6, fill_pattern='column') >>> # Shift indices to start after the first grid (add 25) >>> grid2[grid2 >= 0] += 25 >>> >>> # Third grid: Irregular 3×4 array >>> grid3 = create_grid_layout(3, 4, fill_pattern='row') >>> grid3[grid3 >= 0] += 50 >>> >>> # Create EMGData with all three grids >>> emg = EMGData(emg_data, sampling_freq, grid_layouts=[grid1, grid2, grid3]) >>> >>> # Visualize the three grid layouts >>> for i in range(3): ... emg.plot_grid_layout(i) >>> >>> # Plot the raw EMG data using the grid arrangements >>> emg.plot('Input', scaling_factor=[15.0, 12.0, 20.0]) >>> >>> # Access individual grid dimensions >>> grid_dimensions = emg._get_grid_dimensions() >>> for i, (rows, cols, electrodes) in enumerate(grid_dimensions): ... print(f"Grid {i+1}: {rows}×{cols} with {electrodes} electrodes") """
[docs] def __init__( self, input_data: np.ndarray, sampling_frequency: float, grid_layouts: Optional[List[np.ndarray]] = None, ): if input_data.ndim != 2 and input_data.ndim != 3: raise ValueError( "The shape of the raw EMG data should be (n_channels, n_samples) or (n_chunks, n_channels, n_samples)." ) super().__init__( input_data, sampling_frequency, nr_of_dimensions_when_unchunked=3 ) self.grid_layouts = None # Initialize to None first # Process and validate grid layouts if provided if grid_layouts is not None: # Transform to list if it is a numpy array if isinstance(grid_layouts, np.ndarray): grid_layouts = list(grid_layouts) for i, layout in enumerate(grid_layouts): if not isinstance(layout, np.ndarray) or layout.ndim != 2: raise ValueError(f"Grid layout {i + 1} must be a 2D numpy array") # Check that not all elements are -1 if np.all(layout == -1): raise ValueError( f"Grid layout {i + 1} contains all -1 values, indicating no electrodes!" ) # Check for duplicate electrode indices valid_indices = layout[layout >= 0] if len(np.unique(valid_indices)) != len(valid_indices): raise ValueError( f"Grid layout {i + 1} contains duplicate electrode indices" ) # Store the validated grid layouts self.grid_layouts = grid_layouts
[docs] def _get_grid_dimensions(self): """Get dimensions and electrode counts for each grid. Returns ------- List[Tuple[int, int, int]] List of (rows, cols, electrodes) tuples for each grid, or empty list if no grid layouts are available. """ if self.grid_layouts is None: return [] return [ (layout.shape[0], layout.shape[1], np.sum(layout >= 0)) for layout in self.grid_layouts ]
[docs] def plot( self, representation: str, nr_of_grids: Optional[int] = None, nr_of_electrodes_per_grid: Optional[int] = None, scaling_factor: Union[float, List[float]] = 20.0, use_grid_layouts: bool = True, ): """Plots the data for a specific representation. Parameters ---------- representation : str The representation to plot. nr_of_grids : Optional[int], optional The number of electrode grids to plot. If None and grid_layouts is provided, will use the number of grids in grid_layouts. Default is None. nr_of_electrodes_per_grid : Optional[int], optional The number of electrodes per grid to plot. If None, will be determined from data shape or grid_layouts if available. Default is None. scaling_factor : Union[float, List[float]], optional The scaling factor for the data. The default is 20.0. If a list is provided, the scaling factor for each grid is used. use_grid_layouts : bool, optional Whether to use the grid_layouts for plotting. Default is True. If False, will use the nr_of_grids and nr_of_electrodes_per_grid parameters. Examples -------- >>> import numpy as np >>> from myoverse.datatypes import EMGData, create_grid_layout >>> >>> # Create sample EMG data (64 channels, 1000 samples) >>> emg_data = np.random.randn(64, 1000) >>> >>> # Create EMGData with two 4×8 grids (32 electrodes each) >>> grid1 = create_grid_layout(4, 8, 32, fill_pattern='row') >>> grid2 = create_grid_layout(4, 8, 32, fill_pattern='row') >>> >>> # Adjust indices for second grid >>> grid2[grid2 >= 0] += 32 >>> >>> emg = EMGData(emg_data, 2000, grid_layouts=[grid1, grid2]) >>> >>> # Plot the raw data using the grid layouts >>> emg.plot('Input') >>> >>> # Adjust scaling for better visualization >>> emg.plot('Input', scaling_factor=[15.0, 25.0]) >>> >>> # Plot without using grid layouts (specify manual grid configuration) >>> emg.plot('Input', nr_of_grids=2, nr_of_electrodes_per_grid=32, ... use_grid_layouts=False) """ data = self[representation] # Use grid_layouts if available and requested if self.grid_layouts is not None and use_grid_layouts: grid_dimensions = self._get_grid_dimensions() if nr_of_grids is not None and nr_of_grids != len(self.grid_layouts): print( f"Warning: nr_of_grids ({nr_of_grids}) does not match grid_layouts length " f"({len(self.grid_layouts)}). Using grid_layouts." ) nr_of_grids = len(self.grid_layouts) electrodes_per_grid = [dims[2] for dims in grid_dimensions] else: # Auto-determine nr_of_grids if not provided if nr_of_grids is None: nr_of_grids = 1 # Auto-determine nr_of_electrodes_per_grid if not provided if nr_of_electrodes_per_grid is None: if self.is_chunked[representation]: total_electrodes = data.shape[1] else: total_electrodes = data.shape[0] # Try to determine a sensible default nr_of_electrodes_per_grid = total_electrodes // nr_of_grids electrodes_per_grid = [nr_of_electrodes_per_grid] * nr_of_grids # Prepare scaling factors if isinstance(scaling_factor, float): scaling_factor = [scaling_factor] * nr_of_grids assert len(scaling_factor) == nr_of_grids, ( "The number of scaling factors should be equal to the number of grids." ) fig = plt.figure(figsize=(5 * nr_of_grids, 6)) # Calculate electrode index offset for each grid electrode_offsets = [0] for i in range(len(electrodes_per_grid) - 1): electrode_offsets.append(electrode_offsets[-1] + electrodes_per_grid[i]) # Make a subplot for each grid for grid_idx in range(nr_of_grids): ax = fig.add_subplot(1, nr_of_grids, grid_idx + 1) grid_title = f"Grid {grid_idx + 1}" if self.grid_layouts is not None and use_grid_layouts: rows, cols, _ = grid_dimensions[grid_idx] grid_title += f" ({rows}×{cols})" ax.set_title(grid_title) offset = electrode_offsets[grid_idx] n_electrodes = electrodes_per_grid[grid_idx] for electrode_idx in range(n_electrodes): data_idx = offset + electrode_idx if self.is_chunked[representation]: # Handle chunked data - plot first chunk for visualization ax.plot( data[0, data_idx] + electrode_idx * data[0].mean() * scaling_factor[grid_idx] ) else: ax.plot( data[data_idx] + electrode_idx * data.mean() * scaling_factor[grid_idx] ) ax.set_xlabel("Time (samples)") ax.set_ylabel("Electrode #") # Set the y-axis ticks to the electrode numbers beginning from 1 mean_val = ( data[0].mean() if self.is_chunked[representation] else data.mean() ) ax.set_yticks( np.arange(0, n_electrodes) * mean_val * scaling_factor[grid_idx], np.arange(1, n_electrodes + 1), ) # Only for grid 1 keep the y-axis label if grid_idx != 0: ax.set_ylabel("") plt.tight_layout() plt.show()
[docs] def plot_grid_layout( self, grid_idx: int = 0, show_indices: bool = True, cmap: Optional[plt.cm.ScalarMappable] = None, figsize: Optional[Tuple[float, float]] = None, title: Optional[str] = None, colorbar: bool = True, grid_color: str = "black", grid_alpha: float = 0.7, text_color: str = "white", text_fontsize: int = 10, text_fontweight: str = "bold", highlight_electrodes: Optional[List[int]] = None, highlight_color: str = "red", save_path: Optional[str] = None, dpi: int = 150, return_fig: bool = False, ax: Optional[plt.Axes] = None, autoshow: bool = True, ): """Plots the 2D layout of a specific electrode grid with enhanced visualization. Parameters ---------- grid_idx : int, optional The index of the grid to plot. Default is 0. show_indices : bool, optional Whether to show the electrode indices in the plot. Default is True. cmap : Optional[plt.cm.ScalarMappable], optional Custom colormap to use for visualization. If None, a default viridis colormap is used. figsize : Optional[Tuple[float, float]], optional Custom figure size as (width, height) in inches. If None, size is calculated based on grid dimensions. Ignored if an existing axes object is provided. title : Optional[str], optional Custom title for the plot. If None, a default title showing grid dimensions is used. colorbar : bool, optional Whether to show a colorbar. Default is True. grid_color : str, optional Color of the grid lines. Default is "black". grid_alpha : float, optional Transparency of grid lines (0-1). Default is 0.7. text_color : str, optional Color of the electrode indices text. Default is "white". text_fontsize : int, optional Font size for electrode indices. Default is 10. text_fontweight : str, optional Font weight for electrode indices. Default is "bold". highlight_electrodes : Optional[List[int]], optional List of electrode indices to highlight. Default is None. highlight_color : str, optional Color to use for highlighting electrodes. Default is "red". save_path : Optional[str], optional Path to save the figure. If None, figure is not saved. Default is None. dpi : int, optional DPI for saved figure. Default is 150. return_fig : bool, optional Whether to return the figure and axes. Default is False. ax : Optional[plt.Axes], optional Existing axes object to plot on. If None, a new figure and axes will be created. autoshow : bool, optional Whether to automatically show the figure. Default is True. Set to False when plotting multiple grids on the same figure. Returns ------- Optional[Tuple[plt.Figure, plt.Axes]] Figure and axes objects if return_fig is True. Raises ------ ValueError If grid_layouts is not available or the grid_idx is out of range. Examples -------- >>> import numpy as np >>> from myoverse.datatypes import EMGData, create_grid_layout >>> >>> # Create sample EMG data (64 channels, 1000 samples) >>> emg_data = np.random.randn(64, 1000) >>> >>> # Create an 8×8 grid with some missing electrodes >>> grid = create_grid_layout(8, 8, 64, fill_pattern='row', ... missing_indices=[(7, 7), (0, 0)]) >>> >>> emg = EMGData(emg_data, 2000, grid_layouts=[grid]) >>> >>> # Basic visualization >>> emg.plot_grid_layout(0) >>> >>> # Advanced visualization >>> emg.plot_grid_layout( ... 0, ... figsize=(10, 10), ... colorbar=True, ... highlight_electrodes=[10, 20, 30], ... grid_alpha=0.5 ... ) >>> >>> # Multiple grids in one figure >>> import matplotlib.pyplot as plt >>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) >>> emg.plot_grid_layout(0, title="Grid 1", ax=ax1, autoshow=False) >>> emg.plot_grid_layout(1, title="Grid 2", ax=ax2, autoshow=False) >>> plt.tight_layout() >>> plt.show() """ if self.grid_layouts is None: raise ValueError("Cannot plot grid layout: grid_layouts not provided.") if grid_idx < 0 or grid_idx >= len(self.grid_layouts): raise ValueError( f"Grid index {grid_idx} out of range (0 to {len(self.grid_layouts) - 1})." ) # Get the grid layout grid = self.grid_layouts[grid_idx] rows, cols = grid.shape # Get number of electrodes n_electrodes = np.sum(grid >= 0) # Set default title if not provided if title is None: title = f"Grid {grid_idx + 1} layout ({rows}×{cols}) with {n_electrodes} electrodes" # Create a masked array for plotting masked_grid = np.ma.masked_less(grid, 0) # Create figure and axes if not provided if ax is None: # Calculate optimal figure size if not provided if figsize is None: # Scale based on grid dimensions with minimum size width = max(6, cols * 0.75 + 2) height = max(5, rows * 0.75 + 1) if colorbar: width += 1 # Add space for colorbar figsize = (width, height) fig, ax = plt.subplots(figsize=figsize) else: # Get the figure object from the provided axes fig = ax.figure # Setup colormap if cmap is None: cmap = plt.cm.viridis cmap.set_bad("white", 1.0) # Create custom norm to ensure integer values are centered in color bands norm = plt.Normalize(vmin=-0.5, vmax=np.max(grid) + 0.5) # Plot the grid with improved visuals im = ax.imshow(masked_grid, cmap=cmap, norm=norm, interpolation="nearest") # Add colorbar if requested if colorbar: cbar = plt.colorbar(im, ax=ax, pad=0.01) cbar.set_label("Electrode Index") # Add tick labels only at integer positions cbar.set_ticks(np.arange(0, np.max(grid) + 1)) # Improve grid lines # Major ticks at electrode centers ax.set_xticks(np.arange(0, cols, 1)) ax.set_yticks(np.arange(0, rows, 1)) # Minor ticks at grid boundaries ax.set_xticks(np.arange(-0.5, cols, 1), minor=True) ax.set_yticks(np.arange(-0.5, rows, 1), minor=True) # Hide major tick labels for cleaner look ax.set_xticklabels([]) ax.set_yticklabels([]) # Apply grid styling ax.grid( which="minor", color=grid_color, linestyle="-", linewidth=1, alpha=grid_alpha, ) ax.tick_params(which="minor", bottom=False, left=False) # Add axis labels ax.set_xlabel("Columns", fontsize=text_fontsize + 1) ax.set_ylabel("Rows", fontsize=text_fontsize + 1) # Add electrode numbers with improved styling if show_indices: for i in range(rows): for j in range(cols): if grid[i, j] >= 0: # Create a dictionary for text properties text_props = { "ha": "center", "va": "center", "color": text_color, "fontsize": text_fontsize, "fontweight": text_fontweight, } # Add highlight if this electrode is in highlight list if highlight_electrodes and grid[i, j] in highlight_electrodes: # Draw a circle around highlighted electrodes circle = plt.Circle( (j, i), 0.4, fill=False, edgecolor=highlight_color, linewidth=2, alpha=0.8, ) ax.add_patch(circle) # Change text properties for highlighted electrodes text_props["fontweight"] = "extra bold" # Add the electrode index text ax.text(j, i, str(grid[i, j]), **text_props) # Add a title with improved styling ax.set_title(title, fontsize=text_fontsize + 4, pad=10) # Set aspect ratio to be equal ax.set_aspect("equal") # Save figure if path provided if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") # Show the figure if autoshow is True if autoshow: plt.tight_layout() plt.show() # Return figure and axes if requested if return_fig: return fig, ax return None
[docs] class KinematicsData(_Data): """Class for storing kinematics data. Parameters ---------- input_data : np.ndarray The raw kinematics data. The shape of the array should be (n_joints, 3, n_samples) or (n_chunks, n_joints, 3, n_samples). .. important:: The class will only accept 3D or 4D arrays. There is no way to check if you actually have it in (n_chunks, n_joints, 3, n_samples) format. Please make sure to provide the correct shape of the data. sampling_frequency : float The sampling frequency of the kinematics data. Attributes ---------- input_data : np.ndarray The raw kinematics data. The shape of the array should be (n_joints, 3, n_samples) or (n_chunks, n_joints, 3, n_samples). The 3 represents the x, y, and z coordinates of the joints. sampling_frequency : float The sampling frequency of the kinematics data. processed_data : Dict[str, np.ndarray] A dictionary where the keys are the names of filters applied to the kinematics data and the values are the processed kinematics data. Examples -------- >>> import numpy as np >>> from myoverse.datatypes import KinematicsData >>> >>> # Create sample kinematics data (16 joints, 3 coordinates, 1000 samples) >>> # Each joint has x, y, z coordinates >>> joint_data = np.random.randn(16, 3, 1000) >>> >>> # Create a KinematicsData object with 100 Hz sampling rate >>> kinematics = KinematicsData(joint_data, 100) >>> >>> # Access the raw data >>> raw_data = kinematics.input_data >>> print(f"Data shape: {raw_data.shape}") Data shape: (16, 3, 1000) """
[docs] def __init__(self, input_data: np.ndarray, sampling_frequency: float): if input_data.ndim != 3 and input_data.ndim != 4: raise ValueError( "The shape of the raw kinematics data should be (n_joints, 3, n_samples) " "or (n_chunks, n_joints, 3, n_samples)." ) super().__init__( input_data, sampling_frequency, nr_of_dimensions_when_unchunked=4 )
[docs] def plot( self, representation: str, nr_of_fingers: int, wrist_included: bool = True ): """Plots the data. Parameters ---------- representation : str The representation to plot. .. important :: The representation should be a 3D tensor with shape (n_joints, 3, n_samples). nr_of_fingers : int The number of fingers to plot. wrist_included : bool, optional Whether the wrist is included in the representation. The default is True. .. note :: The wrist is always the first joint in the representation. Raises ------ KeyError If the representation does not exist. Examples -------- >>> import numpy as np >>> from myoverse.datatypes import KinematicsData >>> >>> # Create sample kinematics data for a hand with 5 fingers >>> # 16 joints: 1 wrist + 3 joints for each of the 5 fingers >>> joint_data = np.random.randn(16, 3, 100) >>> kinematics = KinematicsData(joint_data, 100) >>> >>> # Plot the kinematics data >>> kinematics.plot('Input', nr_of_fingers=5) >>> >>> # Plot without wrist >>> kinematics.plot('Input', nr_of_fingers=5, wrist_included=False) """ if representation not in self._data: raise KeyError(f'The representation "{representation}" does not exist.') kinematics = self[representation] if not wrist_included: kinematics = np.concatenate( [np.zeros((1, 3, kinematics.shape[2])), kinematics], axis=0 ) fig = plt.figure() ax = fig.add_subplot(111, projection="3d") # get biggest axis range max_range = ( np.array( [ kinematics[:, 0].max() - kinematics[:, 0].min(), kinematics[:, 1].max() - kinematics[:, 1].min(), kinematics[:, 2].max() - kinematics[:, 2].min(), ] ).max() / 2.0 ) # set axis limits ax.set_xlim( kinematics[:, 0].mean() - max_range, kinematics[:, 0].mean() + max_range, ) ax.set_ylim( kinematics[:, 1].mean() - max_range, kinematics[:, 1].mean() + max_range, ) ax.set_zlim( kinematics[:, 2].mean() - max_range, kinematics[:, 2].mean() + max_range, ) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") # create joint and finger plots (joints_plot,) = ax.plot(*kinematics[..., 0].T, "o", color="black") finger_plots = [] for finger in range(nr_of_fingers): finger_plots.append( ax.plot( *kinematics[ [0] + list(reversed(range(1 + finger * 4, 5 + finger * 4))), :, 0, ].T, color="blue", ) ) samp = plt.axes([0.25, 0.02, 0.65, 0.03]) sample_slider = Slider( samp, label="Sample (a. u.)", valmin=0, valmax=kinematics.shape[2] - 1, valstep=1, valinit=0, ) def update(val): kinematics_new_sample = kinematics[..., int(val)] joints_plot._verts3d = tuple(kinematics_new_sample.T) for finger in range(nr_of_fingers): finger_plots[finger][0]._verts3d = tuple( kinematics_new_sample[ [0] + list(reversed(range(1 + finger * 4, 5 + finger * 4))), :, ].T ) fig.canvas.draw_idle() sample_slider.on_changed(update) plt.tight_layout() plt.show()
[docs] class VirtualHandKinematics(_Data): """Class for storing virtual hand kinematics data from MyoGestic [1]_. Parameters ---------- input_data : np.ndarray The raw kinematics data for a virtual hand. The shape of the array should be (9, n_samples) or (n_chunks, 9, n_samples). .. important:: The class will only accept 2D or 3D arrays. There is no way to check if you actually have it in (n_chunks, n_samples) or (n_chunks, 9, n_samples) format. Please make sure to provide the correct shape of the data. sampling_frequency : float The sampling frequency of the kinematics data. Attributes ---------- input_data : np.ndarray The raw kinematics data for a virtual hand. The shape of the array should be (9, n_samples) or (n_chunks, 9, n_samples). The 9 typically represents the degrees of freedom: wrist flexion/extension, wrist pronation/supination, wrist deviation, and the flexion of all 5 fingers. sampling_frequency : float The sampling frequency of the kinematics data. processed_data : Dict[str, np.ndarray] A dictionary where the keys are the names of filters applied to the kinematics data and the values are the processed kinematics data. Examples -------- >>> import numpy as np >>> from myoverse.datatypes import VirtualHandKinematics >>> >>> # Create sample virtual hand kinematics data (9 DOFs, 1000 samples) >>> joint_data = np.random.randn(9, 1000) >>> >>> # Create a VirtualHandKinematics object with 100 Hz sampling rate >>> kinematics = VirtualHandKinematics(joint_data, 100) >>> >>> # Access the raw data >>> raw_data = kinematics.input_data >>> print(f"Data shape: {raw_data.shape}") References ---------- .. [1] MyoGestic: https://github.com/NsquaredLab/MyoGestic """
[docs] def __init__(self, input_data: np.ndarray, sampling_frequency: float): if input_data.ndim != 2 and input_data.ndim != 3: raise ValueError( "The shape of the raw kinematics data should be (9, n_samples) " "or (n_chunks, 9, n_samples)." ) super().__init__( input_data, sampling_frequency, nr_of_dimensions_when_unchunked=3 )
[docs] def plot( self, representation: str, nr_of_fingers: int = 5, visualize_wrist: bool = True ): """Plots the virtual hand kinematics data. Parameters ---------- representation : str The representation to plot. The representation should be a 2D tensor with shape (9, n_samples) or a 3D tensor with shape (n_chunks, 9, n_samples). nr_of_fingers : int, optional The number of fingers to plot. Default is 5. visualize_wrist : bool, optional Whether to visualize wrist movements. Default is True. Raises ------ KeyError If the representation does not exist. """ if representation not in self._data: raise KeyError(f'The representation "{representation}" does not exist.') data = self[representation] is_chunked = self.is_chunked[representation] if is_chunked: # Use only the first chunk for visualization data = data[0] # Check if we have the expected number of DOFs if data.shape[0] != 9: raise ValueError(f"Expected 9 degrees of freedom, but got {data.shape[0]}") fig = plt.figure(figsize=(12, 8)) # Create a separate plot for each DOF wrist_ax = fig.add_subplot(2, 1, 1) fingers_ax = fig.add_subplot(2, 1, 2) # Plot wrist DOFs (first 3 channels) if visualize_wrist: wrist_ax.set_title("Wrist Kinematics") wrist_ax.plot(data[0], label="Wrist Flexion/Extension") wrist_ax.plot(data[1], label="Wrist Pronation/Supination") wrist_ax.plot(data[2], label="Wrist Deviation") wrist_ax.legend() wrist_ax.set_xlabel("Time (samples)") wrist_ax.set_ylabel("Normalized Position") wrist_ax.grid(True) # Plot finger DOFs (remaining channels) fingers_ax.set_title("Finger Kinematics") finger_names = ["Thumb", "Index", "Middle", "Ring", "Pinky"] for i in range(min(nr_of_fingers, 5)): fingers_ax.plot(data[i + 3], label=finger_names[i]) fingers_ax.legend() fingers_ax.set_xlabel("Time (samples)") fingers_ax.set_ylabel("Normalized Flexion") fingers_ax.grid(True) plt.tight_layout() plt.show()
DATA_TYPES_MAP = { "emg": EMGData, "kinematics": KinematicsData, "virtual_hand": VirtualHandKinematics, }