diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index def5460..3e63f69 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -67,19 +67,20 @@ def wrapped(proto): _jax_ops = { + # go/keep-sorted start "Abs": _get_jax_op(jnp.abs, {"T"}), + "Acosh": _get_jax_op(jnp.arccosh, {"T"}), "Add": _get_jax_op(anp.add, {"T"}), "AddN": _get_jax_op( lambda *args: anp.sum_(anp.stack(args, axis=0), axis=0, keepdims=False), {"T", "N"}), "AddV2": _get_jax_op(anp.add, {"T"}), + "Angle": _get_jax_op(jnp.angle, {"T", "Tout"}), "ArgMax": _get_jax_op(jnp.argmax, {"T", "Tidx", "output_type"}), "ArgMin": _get_jax_op(jnp.argmin, {"T", "Tidx", "output_type"}), - "Acosh": _get_jax_op(jnp.arccosh, {"T"}), - "Angle": _get_jax_op(jnp.angle, {"T", "Tout"}), "Asinh": _get_jax_op(jnp.arcsinh, {"T"}), - "Atanh": _get_jax_op(jnp.arctanh, {"T"}), "Atan2": _get_jax_op(jnp.arctan2, {"T"}), + "Atanh": _get_jax_op(jnp.arctanh, {"T"}), "BesselI0e": _get_jax_op(jax.lax.bessel_i0e, {"T"}), "BesselI1e": _get_jax_op(jax.lax.bessel_i1e, {"T"}), "BitwiseAnd": _get_jax_op(jnp.bitwise_and, {"T"}), @@ -109,11 +110,10 @@ def wrapped(proto): "FFT3D": _get_jax_op( functools.partial(jnp.fft.fftn, axes=(-3, -2, -1,)), {"Tcomplex"}), "Floor": _get_jax_op(jnp.floor, {"T"}), - "FloorMod": _get_jax_op(anp.mod, {"T"}), "FloorDiv": _get_jax_op(anp.floor_divide, {"T"}), + "FloorMod": _get_jax_op(anp.mod, {"T"}), "Greater": _get_jax_op(anp.greater, {"T"}), "GreaterEqual": _get_jax_op(anp.greater_equal, {"T"}), - "Identity": _get_jax_op(lambda x: x, {"T"}), "IFFT": _get_jax_op( functools.partial(jnp.fft.ifftn, axes=(-1,)), {"Tcomplex"}), "IFFT2D": _get_jax_op( @@ -128,12 +128,13 @@ def wrapped(proto): "IRFFT3D": _get_jax_op( functools.partial( jnp.fft.irfftn, axes=(-3, -2, -1,)), {"Tcomplex", "Treal"}), + "Identity": _get_jax_op(lambda x: x, {"T"}), "Igamma": _get_jax_op(jax.lax.igamma, {"T"}), "Igammac": _get_jax_op(jax.lax.igammac, {"T"}), "Imag": _get_jax_op(jax.lax.imag, {"T", "Tout"}), - "IsFinite": _get_jax_op(jnp.isfinite, {"T"}), "Invert": _get_jax_op(jnp.bitwise_not, {"T"}), "InvertPermutation": _get_jax_op(anp.invert_permutation, {"T"}), + "IsFinite": _get_jax_op(jnp.isfinite, {"T"}), "L2Loss": _get_jax_op(lambda x: 0.5 * jnp.sum(jnp.square(x)), {"T"}), "LeftShift": _get_jax_op(jnp.left_shift, {"T"}), "Less": _get_jax_op(anp.less, {"T", "incompatible_shape_error"}), @@ -144,8 +145,8 @@ def wrapped(proto): "LogicalAnd": _get_jax_op(jnp.logical_and, {"T"}), "LogicalNot": _get_jax_op(jnp.logical_not, {"T"}), "LogicalOr": _get_jax_op(jnp.logical_or, {"T"}), - "Minimum": _get_jax_op(anp.minimum, {"T"}), "Maximum": _get_jax_op(anp.maximum, {"T"}), + "Minimum": _get_jax_op(anp.minimum, {"T"}), "Mul": _get_jax_op(anp.multiply, {"T"}), "Neg": _get_jax_op(anp.negative, {"T"}), "NoOp": _get_jax_op(lambda: _EMPTY_RETURN_VALUE, set({})), @@ -153,14 +154,6 @@ def wrapped(proto): "OnesLike": _get_jax_op(jnp.ones_like, {"T"}), "PopulationCount": _get_jax_op(jax.lax.population_count, {"T"}), "Pow": _get_jax_op(anp.power, {"T"}), - "Rank": _get_jax_op(lambda x: np.array(jnp.ndim(x)), {"T"}), - "Real": _get_jax_op(jax.lax.real, {"T", "Tout"}), - "ReadVariableOp": _get_jax_op(lambda x: x, {"dtype"}), - "RealDiv": _get_jax_op(anp.true_divide, {"T"}), - "Reciprocal": _get_jax_op(anp.reciprocal, {"T"}), - "Relu": _get_jax_op(jax.nn.relu, {"T"}), - "Relu6": _get_jax_op(jax.nn.relu6, {"T"}), - "ReverseV2": _get_jax_op(anp.flip, {"T", "Tidx"}), "RFFT": _get_jax_op( functools.partial(jnp.fft.rfftn, axes=(-1,)), {"Tcomplex", "Treal"}), "RFFT2D": _get_jax_op( @@ -169,6 +162,14 @@ def wrapped(proto): "RFFT3D": _get_jax_op( functools.partial( jnp.fft.rfftn, axes=(-3, -2, -1,)), {"Tcomplex", "Treal"}), + "Rank": _get_jax_op(lambda x: np.array(jnp.ndim(x)), {"T"}), + "ReadVariableOp": _get_jax_op(lambda x: x, {"dtype"}), + "Real": _get_jax_op(jax.lax.real, {"T", "Tout"}), + "RealDiv": _get_jax_op(anp.true_divide, {"T"}), + "Reciprocal": _get_jax_op(anp.reciprocal, {"T"}), + "Relu": _get_jax_op(jax.nn.relu, {"T"}), + "Relu6": _get_jax_op(jax.nn.relu6, {"T"}), + "ReverseV2": _get_jax_op(anp.flip, {"T", "Tidx"}), "RightShift": _get_jax_op(jnp.right_shift, {"T"}), "Round": _get_jax_op(jnp.round, {"T"}), "Rsqrt": _get_jax_op(jax.lax.rsqrt, {"T"}), @@ -203,6 +204,7 @@ def wrapped(proto): {"T", "Tindices", "Tnumsegments"}), "Where": _get_jax_op(jnp.argwhere, {"T"}), "ZerosLike": _get_jax_op(jnp.zeros_like, {"T"}), + # go/keep-sorted end # The assignment logic is handled in _OpNode and convert(). "AssignAddVariableOp": _get_jax_op(jnp.add, {"dtype"}), "AssignSubVariableOp": _get_jax_op(jnp.subtract, {"dtype"}),