"""Base class for all data types."""
from __future__ import annotations
import copy
import os
import pickle
from abc import abstractmethod
from typing import Any
import numpy as np
from myoverse.datatypes.types import (
DeletedRepresentation,
InputRepresentationName,
LastRepresentationName,
Representation,
)
[docs]
class _Data:
"""Base class for all data types.
This class provides common functionality for handling different types of data,
including maintaining original and processed representations.
Parameters
----------
raw_data : np.ndarray
The raw data to store.
sampling_frequency : float
The sampling frequency of the data.
Attributes
----------
sampling_frequency : float
The sampling frequency of the data.
_last_processing_step : str
The last processing step applied to the data.
_data : dict[str, np.ndarray | DeletedRepresentation]
Dictionary of all data. The keys are the names of the representations and the values are
either numpy arrays or DeletedRepresentation objects (for representations that have been
deleted to save memory).
Raises
------
ValueError
If the sampling frequency is less than or equal to 0.
Notes
-----
Memory Management:
When representations are deleted with delete_data(), they are replaced with
DeletedRepresentation objects that store essential metadata (shape, dtype)
but don't consume memory for the actual data. The chunking status is determined from
the shape when needed.
Examples
--------
This is an abstract base class and should not be instantiated directly.
Instead, use one of the concrete subclasses like EMGData or KinematicsData:
>>> import numpy as np
>>> from myoverse.datatypes import EMGData
>>>
>>> # Create sample data
>>> data = np.random.randn(16, 1000)
>>> emg = EMGData(data, 2000) # 2000 Hz sampling rate
>>>
>>> # Access attributes from the base _Data class
>>> print(f"Sampling frequency: {emg.sampling_frequency} Hz")
>>> print(f"Is input data chunked: {emg.is_chunked['Input']}")
"""
[docs]
def __init__(
self,
raw_data: np.ndarray,
sampling_frequency: float,
nr_of_dimensions_when_unchunked: int,
):
self.sampling_frequency: float = sampling_frequency
self.nr_of_dimensions_when_unchunked: int = nr_of_dimensions_when_unchunked
if self.sampling_frequency <= 0:
raise ValueError("The sampling frequency should be greater than 0.")
self._data: dict[str, np.ndarray | DeletedRepresentation] = {
InputRepresentationName: raw_data,
}
self.__last_processing_step: str = InputRepresentationName
@property
def is_chunked(self) -> dict[str, bool]:
"""Returns whether the data is chunked or not.
Returns
-------
dict[str, bool]
A dictionary where the keys are the representations and the values are whether the data is chunked or not.
"""
# Create cache if it doesn't exist or if _data might have changed
if not hasattr(self, "_chunked_cache") or len(self._chunked_cache) != len(
self._data,
):
self._chunked_cache = {
key: self._check_if_chunked(value) for key, value in self._data.items()
}
return self._chunked_cache
[docs]
def _check_if_chunked(self, data: np.ndarray | DeletedRepresentation) -> bool:
"""Checks if the data is chunked or not.
Parameters
----------
data : np.ndarray | DeletedRepresentation
The data to check.
Returns
-------
bool
Whether the data is chunked or not.
"""
return len(data.shape) == self.nr_of_dimensions_when_unchunked
@property
def input_data(self) -> np.ndarray:
"""Returns the input data."""
return self._data[InputRepresentationName]
@input_data.setter
def input_data(self, value: np.ndarray):
raise RuntimeError("This property is read-only.")
@property
def processed_representations(self) -> dict[str, np.ndarray]:
"""Returns the processed representations of the data."""
return self._data
@processed_representations.setter
def processed_representations(self, value: dict[str, Representation]):
raise RuntimeError("This property is read-only.")
@property
def _last_processing_step(self) -> str:
"""Returns the last processing step applied to the data.
Returns
-------
str
The last processing step applied to the data.
"""
if self.__last_processing_step is None:
raise ValueError("No processing steps have been applied.")
return self.__last_processing_step
@_last_processing_step.setter
def _last_processing_step(self, value: str):
"""Sets the last processing step applied to the data.
Parameters
----------
value : str
The last processing step applied to the data.
"""
self.__last_processing_step = value
[docs]
@abstractmethod
def plot(self, *_: Any, **__: Any):
"""Plots the data."""
raise NotImplementedError(
"This method should be implemented in the child class.",
)
[docs]
def __repr__(self) -> str:
# Get input data shape directly from _data dictionary to avoid copying
input_shape = self._data[InputRepresentationName].shape
# Build a structured string representation
lines = []
lines.append(f"{self.__class__.__name__}")
lines.append(f"Sampling frequency: {self.sampling_frequency} Hz")
lines.append(f"(0) Input {input_shape}")
# Show other representations if they exist
other_reps = [k for k in self._data.keys() if k != InputRepresentationName]
if other_reps:
lines.append("")
lines.append("Representations:")
for idx, rep_name in enumerate(other_reps, 1):
rep_data = self._data[rep_name]
# Both np.ndarray and DeletedRepresentation have .shape attribute
lines.append(f"({idx}) {rep_name} {rep_data.shape}")
# Join all parts with newlines
return "\n".join(lines)
[docs]
def __str__(self) -> str:
return (
"--\n"
+ self.__repr__()
.replace("; ", "\n")
.replace("Filter(s): ", "\nFilter(s):\n")
+ "\n--"
)
[docs]
def __getitem__(self, key: str) -> np.ndarray:
if key == InputRepresentationName:
# Use array.view() for more efficient copying when possible
data = self.input_data
return data.view() if data.flags.writeable else data.copy()
if key == LastRepresentationName:
return self[self._last_processing_step]
if key not in self._data:
raise KeyError(f'The representation "{key}" does not exist.')
data_to_return = self._data[key]
if isinstance(data_to_return, DeletedRepresentation):
raise RuntimeError(
f'The representation "{key}" was deleted and cannot be automatically '
f"recomputed. Use the new Transform API for preprocessing.",
)
# Use view when possible for more efficient memory usage
return (
data_to_return.view()
if data_to_return.flags.writeable
else data_to_return.copy()
)
[docs]
def __setitem__(self, key: str, value: np.ndarray) -> None:
raise RuntimeError(
"Direct assignment is not supported. Use the Transform API for preprocessing.",
)
[docs]
def delete_data(self, representation_to_delete: str):
"""Delete data from a representation while keeping its metadata.
This replaces the actual numpy array with a DeletedRepresentation object
that contains metadata about the array, saving memory while allowing
regeneration when needed.
Parameters
----------
representation_to_delete : str
The representation to delete the data from.
"""
if representation_to_delete == InputRepresentationName:
return
if representation_to_delete == LastRepresentationName:
self.delete_data(self._last_processing_step)
return
if representation_to_delete not in self._data:
raise KeyError(
f'The representation "{representation_to_delete}" does not exist.',
)
data = self._data[representation_to_delete]
if isinstance(data, np.ndarray):
self._data[representation_to_delete] = DeletedRepresentation(
shape=data.shape,
dtype=data.dtype,
)
[docs]
def __copy__(self) -> _Data:
"""Create a shallow copy of the instance.
Returns
-------
_Data
A shallow copy of the instance.
"""
# Create a new instance with the basic initialization
new_instance = self.__class__(
self._data[InputRepresentationName].copy(),
self.sampling_frequency,
)
# Deep copy the data dictionary to preserve all representations
new_instance._data = copy.deepcopy(self._data)
# Copy the last processing step
new_instance._Data__last_processing_step = self._Data__last_processing_step
# Copy any additional instance attributes from subclasses
# (excluding private attributes, methods, and already-copied attributes)
skip_attrs = {
"_data",
"_Data__last_processing_step",
"sampling_frequency",
"nr_of_dimensions_when_unchunked",
"_chunked_cache",
}
for name, value in vars(self).items():
if name not in skip_attrs:
setattr(new_instance, name, copy.copy(value))
return new_instance
[docs]
def save(self, filename: str):
"""Save the data to a file.
Parameters
----------
filename : str
The name of the file to save the data to.
"""
# Make sure directory exists
os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
with open(filename, "wb") as f:
pickle.dump(self, f)
[docs]
@classmethod
def load(cls, filename: str) -> _Data:
"""Load data from a file.
Parameters
----------
filename : str
The name of the file to load the data from.
Returns
-------
_Data
The loaded data.
"""
with open(filename, "rb") as f:
return pickle.load(f)
[docs]
def memory_usage(self) -> dict[str, tuple[str, int]]:
"""Calculate memory usage of each representation.
Returns
-------
dict[str, tuple[str, int]]
Dictionary with representation names as keys and tuples containing
shape as string and memory usage in bytes as values.
"""
memory_usage = {}
for key, value in self._data.items():
if isinstance(value, np.ndarray):
memory_usage[key] = (str(value.shape), value.nbytes)
elif isinstance(value, DeletedRepresentation):
memory_usage[key] = (
str(value.shape),
0, # DeletedRepresentation objects use negligible memory
)
return memory_usage