DataSplitter#

class myoverse.datasets.utils.DataSplitter(test_ratio=0.2, val_ratio=0.2)[source]#

Handles splitting data into training, testing, and validation sets.

The splitting strategy extracts data from the middle of the array, which is useful for time-series data where you want to avoid temporal leakage at the boundaries.

Parameters:
  • test_ratio (float) – Ratio of data for testing (0.0 to 1.0).

  • val_ratio (float) – Ratio of data for validation (0.0 to 1.0).

Examples

>>> splitter = DataSplitter(test_ratio=0.2, val_ratio=0.2)
>>> result = splitter.split(data)
>>> print(result.sizes)
(800, 100, 100)

Methods

__init__([test_ratio, val_ratio])

_split_middle(data, ratio)

Split data by extracting a portion from the middle.

_validate_ratios()

Validate split ratios.

split(data)

Split data into training, testing, and validation sets.

split_dict(data_dict)

Split multiple arrays with the same split indices.

split(data)[source]#

Split data into training, testing, and validation sets.

The split is performed by extracting from the middle of the data: 1. First, test data is extracted from the center 2. Then, validation data is extracted from the center of test data

Parameters:

data (np.ndarray) – Data to split. First dimension is assumed to be samples.

Returns:

Named tuple with training, testing, and validation arrays.

Return type:

SplitResult

split_dict(data_dict)[source]#

Split multiple arrays with the same split indices.

Parameters:

data_dict (dict[str, np.ndarray]) – Dictionary of arrays to split.

Returns:

Dictionary mapping keys to their split results.

Return type:

dict[str, SplitResult]