Skip to content

Commit

Permalink
add drop_remainder option as required by sok
Browse files Browse the repository at this point in the history
  • Loading branch information
chengmengli06 committed Oct 20, 2023
1 parent ae74f18 commit bbfe530
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 33 deletions.
37 changes: 25 additions & 12 deletions easy_rec/python/compat/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,27 @@ def _get_logits_with_sok(): # pylint: disable=missing-docstring
embedding_weights = shared_weights[shared_name]
else:
with ops.device('/gpu:0'):
if column.ev_params is not None:
embedding_weights = sok.DynamicVariable(name='embedding_weights',
dimension=column.dimension, initializer='random', #column.initializer,
trainable=column.trainable and trainable, dtype=dtypes.float32)
else:
embedding_weights = variable_scope.get_variable(
name='embedding_weights',
shape=embedding_shape,
dtype=dtypes.float32,
initializer=column.initializer,
trainable=column.trainable and trainable,
partitioner=None,
collections=weight_collections)
shared_weights[shared_name] = embedding_weights
else:
with ops.device('/gpu:0'):
if column.ev_params is not None:
embedding_weights = sok.DynamicVariable(name='embedding_weights',
dimension=column.dimension, initializer='random', #column.initializer,
trainable=column.trainable and trainable, dtype=dtypes.float32)
else:
embedding_weights = variable_scope.get_variable(
name='embedding_weights',
shape=embedding_shape,
Expand All @@ -283,19 +304,10 @@ def _get_logits_with_sok(): # pylint: disable=missing-docstring
trainable=column.trainable and trainable,
partitioner=None,
collections=weight_collections)
shared_weights[shared_name] = embedding_weights
else:
with ops.device('/gpu:0'):
embedding_weights = variable_scope.get_variable(
name='embedding_weights',
shape=embedding_shape,
dtype=dtypes.float32,
initializer=column.initializer,
trainable=column.trainable and trainable,
partitioner=None,
collections=weight_collections)
# required by sok
embedding_weights.target_gpu = -1
if 'DynamicVariable' not in str(type(embedding_weights)):
embedding_weights.target_gpu = -1
# SparseTensor RaggedTensor
sparse_tensors = column.categorical_column._get_sparse_tensors(
builder, weight_collections=weight_collections, trainable=trainable)
output_id = len(output_tensors)
Expand Down Expand Up @@ -332,6 +344,7 @@ def _get_logits_with_sok(): # pylint: disable=missing-docstring
elif len(lookup_output_ids_with_wgt) > 0:
outputs = sok.lookup_sparse(
lookup_embeddings_with_wgt,
# RaggedTensor .values .row_lengths
lookup_indices_with_wgt,
lookup_wgts,
combiners=lookup_combiners_with_wgt)
Expand Down
33 changes: 27 additions & 6 deletions easy_rec/python/compat/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,23 @@ def optimize_loss(loss,
raise ValueError('Unrecognized optimizer: function should return '
'subclass of Optimizer. Got %s.' % str(opt))
else:
raise ValueError('Unrecognized optimizer: should be string, '
'subclass of Optimizer, instance of '
'subclass of Optimizer or function with one argument. '
'Got %s.' % str(optimizer))
opt = optimizer
# raise ValueError('Unrecognized optimizer: should be string, '
# 'subclass of Optimizer, instance of '
# 'subclass of Optimizer or function with one argument. '
# 'Got %s.' % str(optimizer))

# All trainable variables, if specific variables are not specified.
if variables is None:
variables = vars_.trainable_variables()

# Compute gradients.
gradients = opt.compute_gradients(
if 'compute_gradients' not in dir(opt):
import tensorflow as tf
gradients = tf.gradients(loss, variables)
gradients = list(zip(gradients, variables))
else:
gradients = opt.compute_gradients(
loss,
variables,
colocate_gradients_with_ops=colocate_gradients_with_ops)
Expand Down Expand Up @@ -331,7 +337,22 @@ def optimize_loss(loss,

# Create gradient updates.
def _apply_grad():
grad_updates = opt.apply_gradients(
if 'compute_gradients' not in dir(opt):
sparse_vars = [ x for x in gradients if 'DynamicVariable' in str(type(x[1])) ]
dense_vars = [ x for x in gradients if 'DynamicVariable' not in str(type(x[1])) ]
sparse_grad_updates = opt.apply_gradients(sparse_vars)
dense_grad_updates = opt._optimizer.apply_gradients(
dense_vars,
global_step=global_step if increment_global_step else None,
name='train')
if sparse_grad_updates is not None and dense_grad_updates is not None:
grad_updates = tf.group([sparse_grad_updates, dense_grad_updates])
elif sparse_grad_updates is not None:
grad_updates = sparse_grad_updates
elif dense_grad_updates is not None:
grad_updates = dense_grad_updates
else:
grad_updates = opt.apply_gradients(
gradients,
global_step=global_step if increment_global_step else None,
name='train')
Expand Down
4 changes: 3 additions & 1 deletion easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ def __init__(self,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None):
pipeline_config=None,
**kwargs):
self._pipeline_config = pipeline_config
self._data_config = data_config
self._check_mode = check_mode
logging.info('check_mode: %s ' % self._check_mode)
# tf.estimator.ModeKeys.*, only available before
# calling self._build
self._mode = None
self._has_ev = 'ev_params' in kwargs

if self._data_config.auto_expand_input_fields:
input_fields = [x for x in self._data_config.input_fields]
Expand Down
59 changes: 50 additions & 9 deletions easy_rec/python/input/parquet_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def __init__(self,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None):
pipeline_config=None,
**kwargs):
super(ParquetInput,
self).__init__(data_config, feature_config, input_path, task_index,
task_num, check_mode, pipeline_config)
task_num, check_mode, pipeline_config, **kwargs)
if input_path is None:
return

Expand All @@ -50,13 +51,14 @@ def __init__(self,
self._proc_arr = []
for proc_id in range(num_proc):
proc = multiprocessing.Process(
target=self._parse_one_file, args=(proc_id,))
target=self._load_data_proc, args=(proc_id,))
self._proc_arr.append(proc)

def _parse_one_file(self, proc_id):
def _load_data_proc(self, proc_id):
all_fields = list(self._label_fields) + list(self._effective_fields)
logging.info('data proc %d start' % proc_id)
num_files = 0
part_data_dict = {}
while True:
try:
input_file = self._file_que.get(block=False)
Expand Down Expand Up @@ -88,17 +90,53 @@ def _parse_one_file(self, proc_id):
self._data_que.put(data_dict)
sid += self._batch_size
if res_num > 0:
logging.info('proc[%d] add final sample' % proc_id)
accum_res_num = 0
data_dict = {}
part_data_dict_n = {}
for k in self._label_fields:
data_dict[k] = np.array([x[0] for x in input_data[k][sid:]],
tmp_lbls = np.array([x[0] for x in input_data[k][sid:]],
dtype=np.float32)
if part_data_dict is not None and k in part_data_dict:
tmp_lbls = np.concatenate([part_data_dict[k], tmp_lbls], axis=0)
if len(tmp_lbls) > self._batch_size:
data_dict[k] = tmp_lbls[:self._batch_size]
part_data_dict_n[k] = tmp_lbls[self._batch_size:]
elif len(tmp_lbls) == self._batch_size:
data_dict[k] = tmp_lbls
else:
part_data_dict_n[k] = tmp_lbls
else:
part_data_dict_n[k] = tmp_lbls
for k in self._effective_fields:
val = input_data[k][sid:]
all_lens = np.array([len(x) for x in val], dtype=np.int32)
all_vals = np.concatenate(list(val))
data_dict[k] = (all_lens, all_vals)
self._data_que.put(data_dict)
if part_data_dict is not None and k in part_data_dict:
tmp_lens = np.concatenate([part_data_dict[k][0], all_lens], axis=0)
tmp_vals = np.concatenate([part_data_dict[k][1], all_vals], axis=0)
if len(tmp_lens) > self._batch_size:
tmp_res_lens = tmp_lens[self._batch_size:]
tmp_lens = tmp_lens[:self._batch_size]
tmp_num_elems = np.sum(tmp_lens)
tmp_res_vals = tmp_vals[tmp_num_elems:]
tmp_vals = tmp_vals[:tmp_num_elems]
part_data_dict_n[k] = (tmp_res_lens, tmp_res_vals)
data_dict[k] = (tmp_lens, tmp_vals)
elif len(tmp_lens) == self._batch_size:
data_dict[k] = (tmp_lens, tmp_vals)
else:
part_data_dict_n[k] = (tmp_lens, tmp_vals)
else:
part_data_dict_n[k] = (all_lens, all_vals)
if len(data_dict) > 0:
self._data_que.put(data_dict)
part_data_dict = part_data_dict_n
if len(part_data_dict) > 0:
if not self._data_config.drop_remainder:
self._data_que.put(part_data_dict)
else:
logging.warning('drop remain %d samples as drop_remainder is set' % \
len(part_data_dict[self._label_fields[0]]))
self._data_que.put(None)
logging.info('data proc %d done, file_num=%d' % (proc_id, num_files))

Expand Down Expand Up @@ -133,7 +171,10 @@ def _to_fea_dict(self, input_dict):
if fc.feature_type == fc.IdFeature or fc.feature_type == fc.TagFeature:
input_0 = fc.input_names[0]
fea_name = fc.feature_name if fc.HasField('feature_name') else input_0
tmp = input_dict[input_0][1] % fc.num_buckets
if not self._has_ev:
tmp = input_dict[input_0][1] % fc.num_buckets
else:
tmp = input_dict[input_0][1]
fea_dict[fea_name] = tf.RaggedTensor.from_row_lengths(tmp, input_dict[input_0][0])

lbl_dict = {}
Expand Down
8 changes: 8 additions & 0 deletions easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def _train_and_evaluate_impl(pipeline_config,
input_fn_kwargs = {'pipeline_config': pipeline_config}
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
if pipeline_config.model_config.HasField('ev_params'):
input_fn_kwargs['ev_params'] = pipeline_config.model_config.ev_params

# create train input
train_input_fn = _get_input_fn(
Expand Down Expand Up @@ -779,6 +781,8 @@ def export(export_dir,
input_fn_kwargs = {'pipeline_config': pipeline_config}
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path
if pipeline_config.model_config.HasField('ev_params'):
input_fn_kwargs['ev_params'] = pipeline_config.model_config.ev_params
serving_input_fn = _get_input_fn(data_config, feature_configs, None,
export_config, **input_fn_kwargs)
ckpt_path = _get_ckpt_path(pipeline_config, checkpoint_path)
Expand Down Expand Up @@ -841,6 +845,10 @@ def export_checkpoint(pipeline_config=None,
if data_config.input_type == data_config.InputType.OdpsRTPInputV2:
input_fn_kwargs['fg_json_path'] = pipeline_config.fg_json_path

if pipeline_config.model_config.HasField('ev_params'):
input_fn_kwargs['ev_params'] = pipeline_config.model_config.ev_params


# create estimator
params = {'log_device_placement': verbose}
if asset_files:
Expand Down
10 changes: 10 additions & 0 deletions easy_rec/python/model/easy_rec_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@
except Exception:
hvd = None

try:
from sparse_operation_kit import experiment as sok
except Exception:
sok = None



if tf.__version__ >= '2.0':
tf = tf.compat.v1

Expand Down Expand Up @@ -206,6 +213,9 @@ def _train_model_fn(self, features, labels, run_config):
% (len(grouped_vars), len(optimizer_config))
optimizer = MultiOptimizer(all_opts, grouped_vars)

if self.train_config.train_distribute == DistributionStrategy.SokStrategy:
optimizer = sok.OptimizerWrapper(optimizer)

hooks = []
if estimator_utils.has_hvd():
assert not self.train_config.sync_replicas, \
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -298,5 +298,5 @@ message DatasetConfig {
}
optional uint32 eval_batch_size = 1001 [default = 4096];


optional bool drop_remainder = 1002 [default = false];
}
2 changes: 1 addition & 1 deletion easy_rec/python/utils/hvd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, root_rank, device=''):
def begin(self):
bcast_vars = []
for x in tf.global_variables():
if '/embedding' not in x.name:
if '/embedding' not in x.name and 'DynamicVariable' not in str(type(x)):
bcast_vars.append(x)
logging.info('will broadcast variable: %s' % x.name)
if not self.bcast_op or self.bcast_op.graph != tf.get_default_graph():
Expand Down
6 changes: 3 additions & 3 deletions samples/model_config/multi_tower_on_taobao_sok.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ train_config {
sync_replicas: False
train_distribute: SokStrategy
num_steps: 200000
# is_profiling: true
is_profiling: true
}

eval_config {
Expand Down Expand Up @@ -136,8 +136,8 @@ data_config {
batch_size: 8192
num_epochs: 1000000
prefetch_size: 64
input_type: CSVInput
# input_type: DummyInput
# input_type: CSVInput
input_type: DummyInput
}

feature_config: {
Expand Down

0 comments on commit bbfe530

Please sign in to comment.