Skip to content

Commit

Permalink
add custom op demo
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Oct 20, 2023
1 parent 43e17c6 commit efa5683
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
26 changes: 17 additions & 9 deletions easy_rec/python/layers/keras/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,25 @@
"""Convenience blocks for using custom ops."""
import logging
import os

import tensorflow as tf
from tensorflow.python.framework import ops

import easy_rec

# LIB_PATH = tf.sysconfig.get_link_flags()[0][2:]
# LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH')
# if LIB_PATH not in LD_LIBRARY_PATH:
# os.environ['LD_LIBRARY_PATH'] = ':'.join([LIB_PATH, LD_LIBRARY_PATH])
# logging.info('set LD_LIBRARY_PATH=%s' % os.getenv('LD_LIBRARY_PATH'))
curr_dir, _ = os.path.split(__file__)
parent_dir = os.path.dirname(curr_dir)
ops_idr = os.path.dirname(parent_dir)
ops_dir = os.path.join(ops_idr, 'python', 'ops')
if 'PAI' in tf.__version__:
ops_dir = os.path.join(ops_dir, '1.12_pai')
elif tf.__version__.startswith('1.12'):
ops_dir = os.path.join(ops_dir, '1.12')
elif tf.__version__.startswith('1.15'):
if 'IS_ON_PAI' in os.environ:
ops_dir = os.path.join(ops_dir, 'DeepRec')
else:
ops_dir = os.path.join(ops_dir, '1.15')
else:
ops_dir = None

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand All @@ -23,8 +31,8 @@ class EditDistance(tf.keras.layers.Layer):

def __init__(self, params, name='edit_distance', reuse=None, **kwargs):
super(EditDistance, self).__init__(name, **kwargs)

custom_op_path = os.path.join(easy_rec.ops_dir, 'libedit_distance.so')
logging.info("ops_dir is %s" % ops_dir)
custom_op_path = os.path.join(ops_dir, 'libedit_distance.so')
try:
custom_ops = tf.load_op_library(custom_op_path)
logging.info('load edit_distance op from %s succeed' % custom_op_path)
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def test_highway(self):
self.assertTrue(self._success)

@unittest.skipIf(
LooseVersion(tf.__version__) < LooseVersion('2.0.0'),
LooseVersion(tf.__version__) >= LooseVersion('2.0.0'),
'EditDistanceOp only work before tf version == 2.0')
def test_custom_op(self):
self._success = test_utils.test_single_train_eval(
Expand Down
2 changes: 1 addition & 1 deletion easy_rec/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
__version__ = '0.7.5'
__version__ = '0.7.6'

0 comments on commit efa5683

Please sign in to comment.