Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651908964
  • Loading branch information
chandrasekhard2 authored and tensorflower-gardener committed Jul 12, 2024
1 parent 5bad749 commit 737b9de
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 15 deletions.
1 change: 1 addition & 0 deletions official/recommendation/ranking/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class DataConfig(hyperparams.Config):
cycle_length: int = 10
sharding: bool = True
num_shards_per_host: int = 8
use_cached_data: bool = False


@dataclasses.dataclass
Expand Down
9 changes: 8 additions & 1 deletion official/recommendation/ranking/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ def __init__(self,
params: config.DataConfig,
num_dense_features: int,
vocab_sizes: List[int],
use_synthetic_data: bool = False):
use_synthetic_data: bool = False,
use_cached_data: bool = False):
self._file_pattern = file_pattern
self._params = params
self._num_dense_features = num_dense_features
self._vocab_sizes = vocab_sizes
self._use_synthetic_data = use_synthetic_data
self._use_cached_data = use_cached_data

def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
params = self._params
Expand Down Expand Up @@ -117,6 +119,8 @@ def make_dataset(shard_index):
num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
if self._use_cached_data:
dataset = dataset.take(1).cache().repeat()

return dataset

Expand Down Expand Up @@ -173,6 +177,9 @@ def _generate_synthetic_data(self, ctx: tf.distribute.InputContext,
if params.is_training:
dataset = dataset.repeat()

if self._use_cached_data:
dataset = dataset.take(1).cache().repeat()

return dataset.batch(batch_size, drop_remainder=True)


Expand Down
12 changes: 10 additions & 2 deletions official/recommendation/ranking/data/data_pipeline_multi_hot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def __init__(self,
num_dense_features: int,
vocab_sizes: List[int],
multi_hot_sizes: List[int],
use_synthetic_data: bool = False):
use_synthetic_data: bool = False,
use_cached_data: bool = False):
self._file_pattern = file_pattern
self._params = params
self._num_dense_features = num_dense_features
self._vocab_sizes = vocab_sizes
self._use_synthetic_data = use_synthetic_data
self._multi_hot_sizes = multi_hot_sizes
self._use_cached_data = use_cached_data

def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
params = self._params
Expand Down Expand Up @@ -144,6 +146,8 @@ def make_dataset(shard_index):
num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
if self._use_cached_data:
dataset = dataset.take(1).cache().repeat()

return dataset

Expand Down Expand Up @@ -215,12 +219,14 @@ def __init__(self,
params: config.DataConfig,
num_dense_features: int,
vocab_sizes: List[int],
multi_hot_sizes: List[int],):
multi_hot_sizes: List[int],
use_cached_data: bool = False):
self._file_pattern = file_pattern
self._params = params
self._num_dense_features = num_dense_features
self._vocab_sizes = vocab_sizes
self._multi_hot_sizes = multi_hot_sizes
self._use_cached_data = use_cached_data

self.label_features = 'label'
self.dense_features = ['dense-feature-%d' % x for x in range(1, 14)]
Expand Down Expand Up @@ -307,6 +313,8 @@ def make_dataset(shard_index):
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
if self._use_cached_data:
dataset = dataset.take(1).cache().repeat()

return dataset

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@

class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):

@parameterized.named_parameters(('Train', True),
('Eval', False))
def testSyntheticDataPipeline(self, is_training):
@parameterized.named_parameters(
('TrainCached', True, True),
('EvalNotCached', False, False),
('TrainNotCached', True, False),
('EvalCached', False, True),
)
def testSyntheticDataPipeline(self, is_training, use_cached_data):
task = config.Task(
model=config.ModelConfig(
embedding_dim=4,
Expand All @@ -39,8 +43,10 @@ def testSyntheticDataPipeline(self, is_training):
dcn_low_rank_dim=64,
bottom_mlp=[64, 32, 4],
top_mlp=[64, 32, 1]),
train_data=config.DataConfig(global_batch_size=16),
validation_data=config.DataConfig(global_batch_size=16),
train_data=config.DataConfig(global_batch_size=16,
use_cached_data=use_cached_data),
validation_data=config.DataConfig(global_batch_size=16,
use_cached_data=use_cached_data),
use_synthetic_data=True)

num_dense_features = task.model.num_dense_features
Expand Down
24 changes: 17 additions & 7 deletions official/recommendation/ranking/data/data_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,29 @@

class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):

@parameterized.named_parameters(('Train', True),
('Eval', False))
def testSyntheticDataPipeline(self, is_training):
@parameterized.named_parameters(
('TrainCached', True, True),
('EvalNotCached', False, False),
('TrainNotCached', True, False),
('EvalCached', False, True),
)
def testSyntheticDataPipeline(self, is_training, use_cached_data):
task = config.Task(
model=config.ModelConfig(
embedding_dim=4,
num_dense_features=8,
vocab_sizes=[40, 12, 11, 13, 2, 5],
bottom_mlp=[64, 32, 4],
top_mlp=[64, 32, 1]),
train_data=config.DataConfig(global_batch_size=16),
validation_data=config.DataConfig(global_batch_size=16),
use_synthetic_data=True)
top_mlp=[64, 32, 1],
),
train_data=config.DataConfig(
global_batch_size=16, use_cached_data=use_cached_data
),
validation_data=config.DataConfig(
global_batch_size=16, use_cached_data=use_cached_data
),
use_synthetic_data=True,
)

num_dense_features = task.model.num_dense_features
num_sparse_features = len(task.model.vocab_sizes)
Expand Down

0 comments on commit 737b9de

Please sign in to comment.