Skip to content

Commit

Permalink
[feat]: add parquet dataset (#435)
Browse files Browse the repository at this point in the history
add parquet_input_v3
  • Loading branch information
tiankongdeguiji authored Dec 7, 2023
1 parent 1fb889d commit 75a97ec
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 5 deletions.
5 changes: 4 additions & 1 deletion easy_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

# Avoid import tensorflow which conflicts with the version used in EasyRecProcessor
if 'PROCESSOR_TEST' not in os.environ:
from tensorflow.python.platform import tf_logging
tf_logging.set_verbosity(tf_logging.INFO)

if platform.system() == 'Linux':
ops_dir = os.path.join(curr_dir, 'python/ops')
import tensorflow as tf
ops_dir = os.path.join(curr_dir, 'python/ops')
if 'PAI' in tf.__version__:
ops_dir = os.path.join(ops_dir, '1.12_pai')
elif tf.__version__.startswith('1.12'):
Expand Down
3 changes: 2 additions & 1 deletion easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def create_multi_placeholders(self, export_config):
else:
placeholder_name = 'input_%d' % fid
if input_name in export_fields_name:
tf_type = self._multi_value_types[input_name]
tf_type = self._multi_value_types[input_name] if input_name in self._multi_value_types \
else get_tf_type(self._input_field_types[fid])
logging.info('multi value input_name: %s, dtype: %s' %
(input_name, tf_type))
finput = tf.placeholder(tf_type, [None, None], name=placeholder_name)
Expand Down
197 changes: 197 additions & 0 deletions easy_rec/python/input/parquet_input_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import tensorflow as tf
from tensorflow.python.platform import gfile

from easy_rec.python.input.input import Input
from easy_rec.python.utils.input_utils import get_type_defaults

try:
from tensorflow.python.data.experimental.ops import parquet_dataset_ops
from tensorflow.python.data.experimental.ops import parquet_pybind
from tensorflow.python.data.experimental.ops import dataframe
from tensorflow.python.ops import gen_ragged_conversion_ops
from tensorflow.python.ops.work_queue import WorkQueue
except Exception:
logging.error('You should install DeepRec first.')
pass


class ParquetInputV3(Input):

def __init__(self,
data_config,
feature_config,
input_path,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None,
**kwargs):
super(ParquetInputV3,
self).__init__(data_config, feature_config, input_path, task_index,
task_num, check_mode, pipeline_config)

self._ignore_val_dict = {}
for f in data_config.input_fields:
if f.HasField('ignore_val'):
self._ignore_val_dict[f.input_name] = get_type_defaults(
f.input_type, f.ignore_val)

self._true_type_dict = {}
for fc in self._feature_configs:
if fc.feature_type in [fc.IdFeature, fc.TagFeature, fc.SequenceFeature]:
if fc.hash_bucket_size > 0:
self._true_type_dict[fc.input_names[0]] = tf.string
elif fc.num_buckets > 0:
self._true_type_dict[fc.input_names[0]] = tf.int64
if len(fc.input_names) > 1:
self._true_type_dict[fc.input_names[1]] = tf.float32
if fc.feature_type == fc.RawFeature:
self._true_type_dict[fc.input_names[0]] = tf.float32

self._reserve_fields = None
self._reserve_types = None
if 'reserve_fields' in kwargs and 'reserve_types' in kwargs:
self._reserve_fields = kwargs['reserve_fields']
self._reserve_types = kwargs['reserve_types']

# In ParquetDataset multi_value use input type
self._multi_value_types = {}

def _ignore_and_cast(self, name, value):
ignore_value = self._ignore_val_dict.get(name, None)
if ignore_value:
if isinstance(value, tf.SparseTensor):
mask = tf.equal(value.values, ignore_value)
value = tf.SparseTensor(
tf.boolean_mask(value.indices, mask),
tf.boolean_mask(value.values, mask), value.dense_shape)
elif isinstance(value, tf.Tensor):
indices = tf.where(tf.not_equal(value, ignore_value), name='indices')
value = tf.SparseTensor(
indices=indices,
values=tf.gather_nd(value, indices),
dense_shape=tf.shape(value, out_type=tf.int64))
dtype = self._true_type_dict.get(name, None)
if dtype:
value = tf.cast(value, dtype)
return value

def _parse_dataframe_value(self, value):
if len(value.nested_row_splits) == 0:
return value.values
value.values.set_shape([None])
sparse_value = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
value.nested_row_splits, value.values)
return tf.SparseTensor(sparse_value.sparse_indices,
sparse_value.sparse_values,
sparse_value.sparse_dense_shape)

def _parse_dataframe(self, df):
inputs = {}
for k, v in df.items():
if k in self._effective_fields:
if isinstance(v, dataframe.DataFrame.Value):
v = self._parse_dataframe_value(v)
elif k in self._label_fields:
if isinstance(v, dataframe.DataFrame.Value):
v = v.values
elif k in self._reserve_fields:
if isinstance(v, dataframe.DataFrame.Value):
v = v.values
else:
continue
inputs[k] = v
return inputs

def _build(self, mode, params):
input_files = []
for sub_path in self._input_path.strip().split(','):
input_files.extend(gfile.Glob(sub_path))
file_num = len(input_files)
logging.info('[task_index=%d] total_file_num=%d task_num=%d' %
(self._task_index, file_num, self._task_num))

task_index = self._task_index
task_num = self._task_num
if self._data_config.chief_redundant:
task_index = max(self._task_index - 1, 0)
task_num = max(self._task_num - 1, 1)

if self._data_config.pai_worker_queue and \
mode == tf.estimator.ModeKeys.TRAIN:
work_queue = WorkQueue(
input_files,
num_epochs=self.num_epochs,
shuffle=self._data_config.shuffle)
my_files = work_queue.input_dataset()
else:
my_files = []
for file_id in range(file_num):
if (file_id % task_num) == task_index:
my_files.append(input_files[file_id])

parquet_fields = parquet_pybind.parquet_fields(input_files[0])
parquet_input_fields = []
for f in parquet_fields:
if f.name in self._input_fields:
parquet_input_fields.append(f)

all_fields = set(self._effective_fields)
if mode != tf.estimator.ModeKeys.PREDICT:
all_fields |= set(self._label_fields)
if self._reserve_fields:
all_fields |= set(self._reserve_fields)

selected_fields = []
for f in parquet_input_fields:
if f.name in all_fields:
selected_fields.append(f)

num_parallel_reads = min(self._data_config.num_parallel_calls,
len(input_files) // task_num)
dataset = parquet_dataset_ops.ParquetDataset(
my_files,
batch_size=self._batch_size,
fields=selected_fields,
drop_remainder=self._data_config.drop_remainder,
num_parallel_reads=num_parallel_reads)
# partition_count=task_num,
# partition_index=task_index)

if mode == tf.estimator.ModeKeys.TRAIN:
if self._data_config.shuffle:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
seed=2020,
reshuffle_each_iteration=True)
dataset = dataset.repeat(self.num_epochs)
else:
dataset = dataset.repeat(1)

dataset = dataset.map(
self._parse_dataframe,
num_parallel_calls=self._data_config.num_parallel_calls)

# preprocess is necessary to transform data
# so that they could be feed into FeatureColumns
dataset = dataset.map(
map_func=self._preprocess,
num_parallel_calls=self._data_config.num_parallel_calls)

dataset = dataset.prefetch(buffer_size=self._prefetch_size)

if mode != tf.estimator.ModeKeys.PREDICT:
dataset = dataset.map(lambda x:
(self._get_features(x), self._get_labels(x)))
else:
dataset = dataset.map(lambda x: (self._get_features(x)))
return dataset

def _preprocess(self, field_dict):
for k, v in field_dict.items():
field_dict[k] = self._ignore_and_cast(k, v)
return super(ParquetInputV3, self)._preprocess(field_dict)
5 changes: 4 additions & 1 deletion easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ message DatasetConfig {
optional string user_define_fn_path = 7;
// output field type of user-defined function.
optional FieldType user_define_fn_res_type = 8;
// ignore value
optional string ignore_val = 9;
}

// set auto_expand_input_fields to true to
Expand Down Expand Up @@ -223,6 +225,7 @@ message DatasetConfig {
HiveInput = 16;
HiveRTPInput = 17;
HiveParquetInput = 18;
ParquetInputV3 = 21;
CriteoInput = 1001;
}
required InputType input_type = 10;
Expand Down Expand Up @@ -297,5 +300,5 @@ message DatasetConfig {
}
optional uint32 eval_batch_size = 1001 [default = 4096];


optional bool drop_remainder = 1002 [default = false];
}
2 changes: 1 addition & 1 deletion easy_rec/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
__version__ = '0.7.6'
__version__ = '0.7.7'
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
known_third_party = absl,common_io,docutils,eas_prediction,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
known_third_party = absl,common_io,distutils,docutils,eas_prediction,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
Expand Down

0 comments on commit 75a97ec

Please sign in to comment.