Source code for myoverse.datatypes.kinematics

"""Kinematics data type for joint position tracking."""

from __future__ import annotations

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.widgets import Slider

from myoverse.datatypes.base import _Data


[docs] class KinematicsData(_Data): """Class for storing kinematics data. Parameters ---------- input_data : np.ndarray The raw kinematics data. The shape of the array should be (n_joints, 3, n_samples) or (n_chunks, n_joints, 3, n_samples). .. important:: The class will only accept 3D or 4D arrays. There is no way to check if you actually have it in (n_chunks, n_joints, 3, n_samples) format. Please make sure to provide the correct shape of the data. sampling_frequency : float The sampling frequency of the kinematics data. Attributes ---------- input_data : np.ndarray The raw kinematics data. The shape of the array should be (n_joints, 3, n_samples) or (n_chunks, n_joints, 3, n_samples). The 3 represents the x, y, and z coordinates of the joints. sampling_frequency : float The sampling frequency of the kinematics data. processed_data : dict[str, np.ndarray] A dictionary where the keys are the names of filters applied to the kinematics data and the values are the processed kinematics data. Examples -------- >>> import numpy as np >>> from myoverse.datatypes import KinematicsData >>> >>> # Create sample kinematics data (16 joints, 3 coordinates, 1000 samples) >>> # Each joint has x, y, z coordinates >>> joint_data = np.random.randn(16, 3, 1000) >>> >>> # Create a KinematicsData object with 100 Hz sampling rate >>> kinematics = KinematicsData(joint_data, 100) >>> >>> # Access the raw data >>> raw_data = kinematics.input_data >>> print(f"Data shape: {raw_data.shape}") Data shape: (16, 3, 1000) """
[docs] def __init__(self, input_data: np.ndarray, sampling_frequency: float): if input_data.ndim not in (3, 4): raise ValueError( "The shape of the raw kinematics data should be (n_joints, 3, n_samples) " "or (n_chunks, n_joints, 3, n_samples).", ) super().__init__( input_data, sampling_frequency, nr_of_dimensions_when_unchunked=4, )
[docs] def plot( self, representation: str, nr_of_fingers: int, wrist_included: bool = True, ): """Plots the data. Parameters ---------- representation : str The representation to plot. .. important :: The representation should be a 3D tensor with shape (n_joints, 3, n_samples). nr_of_fingers : int The number of fingers to plot. wrist_included : bool, optional Whether the wrist is included in the representation. The default is True. .. note :: The wrist is always the first joint in the representation. Raises ------ KeyError If the representation does not exist. Examples -------- >>> import numpy as np >>> from myoverse.datatypes import KinematicsData >>> >>> # Create sample kinematics data for a hand with 5 fingers >>> # 16 joints: 1 wrist + 3 joints for each of the 5 fingers >>> joint_data = np.random.randn(16, 3, 100) >>> kinematics = KinematicsData(joint_data, 100) >>> >>> # Plot the kinematics data >>> kinematics.plot('Input', nr_of_fingers=5) >>> >>> # Plot without wrist >>> kinematics.plot('Input', nr_of_fingers=5, wrist_included=False) """ if representation not in self._data: raise KeyError(f'The representation "{representation}" does not exist.') kinematics = self[representation] if not wrist_included: kinematics = np.concatenate( [np.zeros((1, 3, kinematics.shape[2])), kinematics], axis=0, ) fig = plt.figure() ax = fig.add_subplot(111, projection="3d") # get biggest axis range max_range = ( np.array( [ kinematics[:, 0].max() - kinematics[:, 0].min(), kinematics[:, 1].max() - kinematics[:, 1].min(), kinematics[:, 2].max() - kinematics[:, 2].min(), ], ).max() / 2.0 ) # set axis limits ax.set_xlim( kinematics[:, 0].mean() - max_range, kinematics[:, 0].mean() + max_range, ) ax.set_ylim( kinematics[:, 1].mean() - max_range, kinematics[:, 1].mean() + max_range, ) ax.set_zlim( kinematics[:, 2].mean() - max_range, kinematics[:, 2].mean() + max_range, ) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") # create joint and finger plots (joints_plot,) = ax.plot(*kinematics[..., 0].T, "o", color="black") finger_plots = [] for finger in range(nr_of_fingers): finger_plots.append( ax.plot( *kinematics[ [0] + list(reversed(range(1 + finger * 4, 5 + finger * 4))), :, 0, ].T, color="blue", ), ) samp = plt.axes([0.25, 0.02, 0.65, 0.03]) sample_slider = Slider( samp, label="Sample (a. u.)", valmin=0, valmax=kinematics.shape[2] - 1, valstep=1, valinit=0, ) def update(val): kinematics_new_sample = kinematics[..., int(val)] joints_plot._verts3d = tuple(kinematics_new_sample.T) for finger in range(nr_of_fingers): finger_plots[finger][0]._verts3d = tuple( kinematics_new_sample[ [0] + list(reversed(range(1 + finger * 4, 5 + finger * 4))), :, ].T, ) fig.canvas.draw_idle() sample_slider.on_changed(update) plt.tight_layout() plt.show()