Toolkit
Here you can find the toolkit functions on top of DataChain for common DS/ML operations (e.g. train/test split). Import these functions from datachain.toolkit
.
toolkit
train_test_split
train_test_split(
dc: DataChain,
weights: list[float],
seed: Optional[int] = None,
) -> list[DataChain]
Splits a DataChain into multiple subsets based on the provided weights.
This function partitions the rows or items of a DataChain into disjoint subsets, ensuring that the relative sizes of the subsets correspond to the given weights. It is particularly useful for creating training, validation, and test datasets.
Parameters:
-
dc
(DataChain
) βThe DataChain instance to split.
-
weights
(list[float]
) βA list of weights indicating the relative proportions of the splits. The weights do not need to sum to 1; they will be normalized internally. For example: -
[0.7, 0.3]
corresponds to a 70/30 split; -[2, 1, 1]
corresponds to a 50/25/25 split. -
seed
(int
, default:None
) βThe seed for the random number generator. Defaults to None.
Returns:
-
list[DataChain]
βlist[DataChain]: A list of DataChain instances, one for each weight in the weights list.
Examples:
Train-test split:
from datachain import DataChain
from datachain.toolkit import train_test_split
# Load a DataChain from a storage source (e.g., S3 bucket)
dc = DataChain.from_storage("s3://bucket/dir/")
# Perform a 70/30 train-test split
train, test = train_test_split(dc, [0.7, 0.3])
# Save the resulting splits
train.save("dataset_train")
test.save("dataset_test")
Train-test-validation split:
train, test, val = train_test_split(dc, [0.7, 0.2, 0.1])
train.save("dataset_train")
test.save("dataset_test")
val.save("dataset_val")
Note
The splits are random but deterministic, based on Dataset sys__rand
field.