diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index dbe255748..51f023587 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -36,7 +36,8 @@ def __init__( self, datasets: Mapping[str, ShardableDataset[T]], weights: Dict[str, float], - stop_strategy: str = StopStrategy.FIRST_STOP_STRATEGY, + # stop_strategy: str = StopStrategy, + stop_strategy: str = None, key: int | PRNGKeyArray = 0, ): self.datasets = datasets @@ -44,7 +45,7 @@ def __init__( if stop_strategy not in [StopStrategy.FIRST_STOP_STRATEGY, StopStrategy.ALL_STOP_STRATEGY]: raise ValueError(f"Stop strategy {stop_strategy} is not supported.") - + print(f"=== Using stop_strategy {stop_strategy} ===") self.stop_strategy = stop_strategy if not isinstance(key, int):