Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op remainder for all platform #4912

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/layer/arm/binaryop_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, div_ps(y, x))
MAKE_FUNCTION(binary_op_rpow, (float)powf(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2f(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2f(y, x), atan2_ps(y, x))
MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y))
// *INDENT-ON*
// clang-format on

Expand All @@ -308,6 +309,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
2 changes: 2 additions & 0 deletions src/layer/arm/binaryop_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vdiv_f16(y, x), vdivq_f16(y, x))
MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)powf(y, x), vcvt_f16_f32(pow_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x))))))
MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2f(x, y), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y))))))
MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2f(y, x), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x))))))
MAKE_FUNCTION(binary_op_remainder_fp16s, (__fp16)remainderf(x, y), vcvt_f16_f32(remainder_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(remainder_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(remainder_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y))))))
// *INDENT-ON*
// clang-format on

Expand All @@ -352,6 +353,7 @@ static void binary_op_vector_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_fp16s<binary_op_rpow_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_fp16s<binary_op_atan2_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_fp16s<binary_op_ratan2_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_fp16s<binary_op_remainder_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
11 changes: 11 additions & 0 deletions src/layer/arm/neon_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,5 +395,16 @@ static inline float32x4_t atan2_ps(float32x4_t a, float32x4_t b)
return vld1q_f32(tmpx);
}

static inline float32x4_t remainder_ps(float32x4_t x, float32x4_t y)
{
float tmpx[4];
float tmpy[4];
vst1q_f32(tmpx, x);
vst1q_f32(tmpy, y);
for (int i = 0; i < 4; i++)
tmpx[i] = remainderf(tmpx[i], tmpy[i]);
return vld1q_f32(tmpx);
}

#include "neon_mathfun_tanh.h"
#endif // NEON_MATHFUN_H
13 changes: 13 additions & 0 deletions src/layer/binaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ struct binary_op_ratan2
}
};

struct binary_op_remainder
{
float operator()(const float& x, const float& y) const
{
const float div_result = x / y;
const float floor_result = floorf(div_result);
const float mul_result = floor_result * y;
return x - mul_result;
}
};

static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type, const Option& opt)
{
if (op_type == BinaryOp::Operation_ADD) return binary_op_broadcast<binary_op_add>(a, b, c, opt);
Expand All @@ -251,6 +262,7 @@ static void binary_op_broadcast(const Mat& a, const Mat& b, Mat& c, int op_type,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast<binary_op_atan2>(b, a, c, opt);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_broadcast<binary_op_remainder>(a, b, c, opt);

// should never reach here
}
Expand All @@ -269,6 +281,7 @@ static void binary_op_scalar_inplace(Mat& bottom_top_blob, float b, int op_type,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_inplace<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar_inplace<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar_inplace<binary_op_ratan2>(bottom_top_blob, b, opt);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_scalar_inplace<binary_op_remainder>(bottom_top_blob, b, opt);

// should never reach here
}
Expand Down
3 changes: 2 additions & 1 deletion src/layer/binaryop.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class BinaryOp : public Layer
Operation_RDIV = 8,
Operation_RPOW = 9,
Operation_ATAN2 = 10,
Operation_RATAN2 = 11
Operation_RATAN2 = 11,
Operation_REMAINDER = 12
};

public:
Expand Down
2 changes: 2 additions & 0 deletions src/layer/loongarch/binaryop_loongarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __lsx_vfdiv_s(y, x))
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y))
// *INDENT-ON*
// clang-format on

Expand All @@ -335,6 +336,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
13 changes: 13 additions & 0 deletions src/layer/loongarch/lsx_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,17 @@ static inline __m128 atan2_ps(__m128 a, __m128 b)
return (__m128)__lsx_vld(tmpx, 0);
}

static inline __m128 remainder_ps(__m128 x, __m128 y)
{
float tmpx[4];
float tmpy[4];
__lsx_vst(x, tmpx, 0);
__lsx_vst(y, tmpy, 0);
tmpx[0] = remainderf(tmpx[0], tmpy[0]);
tmpx[1] = remainderf(tmpx[1], tmpy[1]);
tmpx[2] = remainderf(tmpx[2], tmpy[2]);
tmpx[3] = remainderf(tmpx[3], tmpy[3]);
return (__m128)__lsx_vld(tmpx, 0);
}

#endif // LSX_MATHFUN_H
2 changes: 2 additions & 0 deletions src/layer/mips/binaryop_mips.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x))
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
MAKE_FUNCTION(binary_op_remainder, remainderf(x, y), remainder_ps(x, y))
// *INDENT-ON*
// clang-format on

Expand All @@ -335,6 +336,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
13 changes: 13 additions & 0 deletions src/layer/mips/msa_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,17 @@ static inline v4f32 atan2_ps(v4f32 a, v4f32 b)
return (v4f32)__msa_ld_w(tmpx, 0);
}

static inline v4f32 remainder_ps(v4f32 x, v4f32 y)
{
float tmpx[4];
float tmpy[4];
__msa_st_w((v4i32)x, tmpx, 0);
__msa_st_w((v4i32)y, tmpy, 0);
tmpx[0] = remainderf(tmpx[0], tmpy[0]);
tmpx[1] = remainderf(tmpx[1], tmpy[1]);
tmpx[2] = remainderf(tmpx[2], tmpy[2]);
tmpx[3] = remainderf(tmpx[3], tmpy[3]);
return (v4f32)__msa_ld_w(tmpx, 0);
}

#endif // MSA_MATHFUN_H
4 changes: 4 additions & 0 deletions src/layer/riscv/binaryop_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, vfdiv_vv_f32m8(y, x, vl), vfrdiv_vf_f32m8(x
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f32m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f32m8(x, vl), vl))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y, vl), atan2_ps(x, vfmv_v_f_f32m8(y, vl), vl), atan2_ps(vfmv_v_f_f32m8(x, vl), y, vl))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x, vl), atan2_ps(vfmv_v_f_f32m8(y, vl), x, vl), atan2_ps(y, vfmv_v_f_f32m8(x, vl), vl))
MAKE_FUNCTION(binary_op_remainder, (float)remainderf(x, y), remainder_ps(x, y, vl), remainder_ps(x, vfmv_v_f_f32m8(y, vl), vl), remainder_ps(vfmv_v_f_f32m8(x, vl), y, vl))
// *INDENT-ON*
// clang-format on

Expand All @@ -316,6 +317,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down Expand Up @@ -887,6 +889,7 @@ MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vfdiv_vv_f16m8(y, x, vl), vfrdiv_vf_f
MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)pow((float)y, (float)x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f16m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f16m8(x, vl), vl))
MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2((float)x, (float)y), atan2_ps(x, y, vl), atan2_ps(x, vfmv_v_f_f16m8(y, vl), vl), atan2_ps(vfmv_v_f_f16m8(x, vl), y, vl))
MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2((float)y, (float)x), atan2_ps(y, x, vl), atan2_ps(vfmv_v_f_f16m8(y, vl), x, vl), atan2_ps(y, vfmv_v_f_f16m8(x, vl), vl))
MAKE_FUNCTION(binary_op_remainder_fp16s, (__fp16)remainderf((float)x, (float)y), remainder_ps(x, y, vl), remainder_ps(x, vfmv_v_f_f16m8(y, vl), vl), remainder_ps(vfmv_v_f_f16m8(x, vl), y, vl))
// *INDENT-ON*
// clang-format on

Expand All @@ -910,6 +913,7 @@ static void binary_op_vector_fp16s(const __fp16* ptr, const __fp16* ptr1, __fp16
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector_fp16s<binary_op_rpow_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector_fp16s<binary_op_atan2_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector_fp16s<binary_op_ratan2_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector_fp16s<binary_op_remainder_fp16s>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
19 changes: 19 additions & 0 deletions src/layer/riscv/rvv_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,23 @@ _RVV_FLOAT32_ATAN2_OP(2, 16)
_RVV_FLOAT32_ATAN2_OP(4, 8)
_RVV_FLOAT32_ATAN2_OP(8, 4)

#define _RVV_FLOAT32_REMAINDER_OP(LMUL, MLEN) \
static inline vfloat32m##LMUL##_t remainder_ps(vfloat32m##LMUL##_t x, vfloat32m##LMUL##_t y, size_t vl) \
{ \
std::vector<float> tmpx(vl); \
std::vector<float> tmpy(vl); \
vse32_v_f32m##LMUL(tmpx.data(), x, vl); \
vse32_v_f32m##LMUL(tmpy.data(), y, vl); \
for (size_t i = 0; i < vl; i++) \
{ \
tmpx[i] = remainderf(tmpx[i], tmpy[i]); \
} \
return vle32_v_f32m##LMUL(tmpx.data(), vl); \
}

_RVV_FLOAT32_REMAINDER_OP(1, 32)
_RVV_FLOAT32_REMAINDER_OP(2, 16)
_RVV_FLOAT32_REMAINDER_OP(4, 8)
_RVV_FLOAT32_REMAINDER_OP(8, 4)

#endif // RVV_MATHFUN_H
19 changes: 19 additions & 0 deletions src/layer/riscv/rvv_mathfun_fp16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,23 @@ _RVV_FLOAT16_ATAN2_OP(2, 16)
_RVV_FLOAT16_ATAN2_OP(4, 8)
_RVV_FLOAT16_ATAN2_OP(8, 4)

#define _RVV_FLOAT16_REMAINDER_OP(LMUL, MLEN) \
static inline vfloat16m##LMUL##_t remainder_ps(vfloat16m##LMUL##_t x, vfloat16m##LMUL##_t y, size_t vl) \
{ \
std::vector<__fp16> tmpx(vl); \
std::vector<__fp16> tmpy(vl); \
vse16_v_f16m##LMUL(tmpx.data(), x, vl); \
vse16_v_f16m##LMUL(tmpy.data(), y, vl); \
for (size_t i = 0; i < vl; i++) \
{ \
tmpx[i] = (__fp16)remainderf((float)tmpx[i], (float)tmpy[i]); \
} \
return vle16_v_f16m##LMUL(tmpx.data(), vl); \
}

_RVV_FLOAT16_REMAINDER_OP(1, 32)
_RVV_FLOAT16_REMAINDER_OP(2, 16)
_RVV_FLOAT16_REMAINDER_OP(4, 8)
_RVV_FLOAT16_REMAINDER_OP(8, 4)

#endif // RVV_MATHFUN_FP16S_H
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop.comp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop_broadcast.comp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop_broadcast_pack1to4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
5 changes: 5 additions & 0 deletions src/layer/vulkan/shader/binaryop_broadcast_pack1to8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ void main()
res[1] = atan(v2[1], v1[1]);
#endif
}
if (op_type == 12)
{
res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0];
res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1];
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop_broadcast_pack4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
5 changes: 5 additions & 0 deletions src/layer/vulkan/shader/binaryop_broadcast_pack8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ void main()
res[1] = atan(v2[1], v1[1]);
#endif
}
if (op_type == 12)
{
res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0];
res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1];
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
1 change: 1 addition & 0 deletions src/layer/vulkan/shader/binaryop_pack4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ void main()
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
if (op_type == 12) res = v1 - floorf(v1 / v2) * v2;

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
5 changes: 5 additions & 0 deletions src/layer/vulkan/shader/binaryop_pack8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ void main()
res[1] = atan(v2[1], v1[1]);
#endif
}
if (op_type == 12)
{
res[0] = v1[0] - floorf(v1[0] / v2[0]) * v2[0];
res[1] = v1[1] - floorf(v1[1] / v2[1]) * v2[1];
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
8 changes: 8 additions & 0 deletions src/layer/x86/avx512_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -856,4 +856,12 @@ static NCNN_FORCEINLINE __m512 abs512_ps(__m512 x)
return _mm512_andnot_ps(magic_negative_zero, x);
}

static NCNN_FORCEINLINE __m512 remainder512_ps(__m512 x, __m512 y)
{
const __m512 div_result = _mm512_div_ps(x, y);
const __m512 floor_result = _mm512_floor_ps(div_result);
const __m512 mul_result = _mm512_mul_ps(y, floor_result);
return _mm512_sub_ps(x, mul_result);
}

#endif // AVX512_MATHFUN_H
8 changes: 8 additions & 0 deletions src/layer/x86/avx_mathfun.h
Original file line number Diff line number Diff line change
Expand Up @@ -1087,4 +1087,12 @@ static NCNN_FORCEINLINE __m256 abs256_ps(__m256 x)
return _mm256_andnot_ps(magic_negative_zero, x);
}

static NCNN_FORCEINLINE __m256 remainder256_ps(__m256 x, __m256 y)
{
const __m256 div_result = _mm256_div_ps(x, y);
const __m256 floor_result = _mm256_floor_ps(div_result);
const __m256 mul_result = _mm256_mul_ps(y, floor_result);
return _mm256_sub_ps(x, mul_result);
}

#endif // AVX_MATHFUN_H
30 changes: 30 additions & 0 deletions src/layer/x86/binaryop_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,35 @@ struct binary_op_ratan2
#endif // __SSE2__
};

struct binary_op_remainder
{
float func(const float& x, const float& y) const
{
const float div_result = x / y;
const float floor_result = floorf(div_result);
const float mul_result = floor_result * y;
return x - mul_result;
}
#if __SSE2__
__m128 func_pack4(const __m128& x, const __m128& y) const
{
return remainder_ps(x, y);
}
#if __AVX__
__m256 func_pack8(const __m256& x, const __m256& y) const
{
return remainder256_ps(x, y);
}
#if __AVX512F__
__m512 func_pack16(const __m512& x, const __m512& y) const
{
return remainder512_ps(x, y);
}
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__
};

} // namespace BinaryOp_x86_functor

static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr, int aw, int bw, int ap, int bp, int op_type)
Expand All @@ -807,6 +836,7 @@ static void binary_op_vector(const float* ptr, const float* ptr1, float* outptr,
if (op_type == BinaryOp::Operation_RPOW) return binary_op_vector<binary_op_rpow>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_vector<binary_op_atan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_vector<binary_op_ratan2>(ptr, ptr1, outptr, aw, bw, ap, bp);
if (op_type == BinaryOp::Operation_REMAINDER) return binary_op_vector<binary_op_remainder>(ptr, ptr1, outptr, aw, bw, ap, bp);

// should never reach here
}
Expand Down
Loading