diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index 21cf8ba02..23f2ba37d 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -35,11 +35,12 @@ def __init__( self, datasets: Mapping[str, ShardableDataset[T]], weights: Dict[str, float], - stop_strategy: str = StopStrategy, + stop_strategy: str = StopStrategy.FIRST_STOP_STRATEGY, key: int | PRNGKeyArray = 0, ): self.datasets = datasets self.weights = MixtureDataset._normalize_weights(weights) + if stop_strategy not in [StopStrategy.FIRST_STOP_STRATEGY, StopStrategy.ALL_STOP_STRATEGY]: raise ValueError(f"Stop strategy {stop_strategy} is not supported.") print(f"=== class MixtureDataset: self.datasets.keys() = {self.datasets.keys()}, self.weights = {self.weights}, stop_strategy = {stop_strategy} ===")