From f1ebc42c7016cef5839722ab4274b1e449ff9c4d Mon Sep 17 00:00:00 2001 From: liqize Date: Fri, 12 Jun 2026 17:05:29 +0800 Subject: [PATCH] cpu: rv64: share RVV eltwise emitters --- src/cpu/rv64/jit_rvv_eltwise_emitter.cpp | 530 +++++++++++++++++++++++ src/cpu/rv64/jit_rvv_eltwise_emitter.hpp | 119 +++++ src/cpu/rv64/jit_rvv_eltwise_kernel.cpp | 239 ++++++---- src/cpu/rv64/jit_rvv_eltwise_kernel.hpp | 15 +- src/cpu/rv64/jit_rvv_softmax_kernel.cpp | 71 +-- src/cpu/rv64/rvv_eltwise.hpp | 17 +- 6 files changed, 828 insertions(+), 163 deletions(-) create mode 100644 src/cpu/rv64/jit_rvv_eltwise_emitter.cpp create mode 100644 src/cpu/rv64/jit_rvv_eltwise_emitter.hpp diff --git a/src/cpu/rv64/jit_rvv_eltwise_emitter.cpp b/src/cpu/rv64/jit_rvv_eltwise_emitter.cpp new file mode 100644 index 00000000000..642fabbca5c --- /dev/null +++ b/src/cpu/rv64/jit_rvv_eltwise_emitter.cpp @@ -0,0 +1,530 @@ +/******************************************************************************* +* Copyright 2026 SpacemiT Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include + +#include "cpu/rv64/jit_rvv_eltwise_emitter.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { + +using namespace Xbyak_riscv; + +namespace { + +constexpr struct { + float LowerRange; + float UpperRange; + uint32_t RoundingBias; + uint32_t Log2Reciprocal; + uint32_t Log2High; + uint32_t Log2Low; + uint32_t poly_0; + uint32_t poly_1; + uint32_t poly_2; + uint32_t poly_3; + uint32_t poly_4; + uint32_t poly_56; + int32_t MinimumExponent; + int32_t MaximumExponent; +} exp_constants = { + -103.9720840454f, + 88.7762626647950f, + 0x4b400000, + 0x3fb8aa3b, + 0xbf317200, + 0xb5bfbe8e, + 0x3ab4a000, + 0x3c092f6e, + 0x3d2aadad, + 0x3e2aaa28, + 0x3efffffb, + 0x3f800000, + int32_t(0xc1000000), + int32_t(0x3f800000), +}; + +constexpr struct { + float LowerRange; + float UpperRange; + float alpha_13; + float alpha_11; + float alpha_9; + float alpha_7; + float alpha_5; + float alpha_3; + float alpha_1; + float beta_6; + float beta_4; + float beta_2; + float beta_0; +} tanh_constants = { + -9.0f, + 9.0f, + -2.76076847742355e-16f, + 2.00018790482477e-13f, + -8.60467152213735e-11f, + 5.12229709037114e-08f, + 1.48572235717979e-05f, + 6.37261928875436e-04f, + 4.89352455891786e-03f, + 1.19825839466702e-06f, + 1.18534705686654e-04f, + 2.26843463243900e-03f, + 4.89352518554385e-03f, +}; + +constexpr struct { + float alpha_13; + float alpha_11; + float alpha_9; + float alpha_7; + float alpha_5; + float alpha_3; + float alpha_1; + float beta_6; + float beta_4; + float beta_2; + float beta_0; + float GELU_COEF_A; + float GELU_QUICK_COEF; + float SQRT_2_OVER_PI; +} gelu_tanh_constants = { + -2.76076847742355e-16f, + 2.00018790482477e-13f, + -8.60467152213735e-11f, + 5.12229709037114e-08f, + 1.48572235717979e-05f, + 6.37261928875436e-04f, + 4.89352455891786e-03f, + 1.19825839466702e-06f, + 1.18534705686654e-04f, + 2.26843463243900e-03f, + 4.89352518554385e-03f, + 0.044715f, + -1.702f, + 0.79788456080286535587989211986876f, +}; + +constexpr struct { + float ErfUpperAbsRange; + float ErfSplitBoundary; + float ErfSMALL_P0; + float ErfSMALL_P1; + float ErfSMALL_P2; + float ErfSMALL_P3; + float ErfSMALL_P4; + float ErfSMALL_P5_Minus_One; + float ErfReserved0; + float ErfBIG_P0; + float ErfBIG_P1; + float ErfBIG_P2; + float ErfBIG_P3; + float ErfBIG_P4; + float ErfBIG_P5; + float ErfBIG_P6_Minus_One; + float ErfNegZero; + float ErfOne; + float Exp_UpperRange; + float Exp_LowerRange; + float Exp_Log2Reciprocal; + float Exp_log2_hi; + float Exp_log2_lo; + float Exp_P0; + float Exp_P1; + float Exp_P2; + float Exp_P3; + float Exp_P4; + float Exp_P5; + float Exp_P6; + float Exp_C; + int32_t Exp_X7F; +} erf_constants = { + 3.925f, + 0.921875f, + -5.99104969e-4f, + 4.99339588e-3f, + -2.67667342e-2f, + 1.12818025e-1f, + -3.76124859e-1f, + 1.28379151e-1f, + 0.0f, + 1.72948930e-5f, + -3.83208680e-4f, + 3.88393435e-3f, + -2.42545605e-2f, + 1.06777847e-1f, + 6.34846687e-1f, + 1.28717512e-1f, + -0.0f, + 1.0f, + // Independent parameters to calculate Exp for Erff() + 88.3762626647950f, + -88.3762626647949f, + 1.44269504088896341f, + -6.93145752e-1f, + -1.42860677e-6f, + 1.38319808e-3f, + 8.37550033e-3f, + 4.16689515e-2f, + 1.66664466e-1f, + 4.99999851e-1f, + 1.00000000e+0f, + 1.00000000e+0f, + 1.25829120e+7f, + 127, +}; + +constexpr uint32_t exp_poly[] = {exp_constants.poly_0, exp_constants.poly_1, + exp_constants.poly_2, exp_constants.poly_3, exp_constants.poly_4, + exp_constants.poly_56}; +constexpr float tanh_num_poly[] = {tanh_constants.alpha_13, + tanh_constants.alpha_11, tanh_constants.alpha_9, tanh_constants.alpha_7, + tanh_constants.alpha_5, tanh_constants.alpha_3, tanh_constants.alpha_1}; +constexpr float tanh_den_poly[] = {tanh_constants.beta_6, tanh_constants.beta_4, + tanh_constants.beta_2, tanh_constants.beta_0}; +constexpr float erf_small_poly[] + = {erf_constants.ErfSMALL_P0, erf_constants.ErfSMALL_P1, + erf_constants.ErfSMALL_P2, erf_constants.ErfSMALL_P3, + erf_constants.ErfSMALL_P4, erf_constants.ErfSMALL_P5_Minus_One}; +constexpr float erf_big_poly[] = {erf_constants.ErfBIG_P0, + erf_constants.ErfBIG_P1, erf_constants.ErfBIG_P2, + erf_constants.ErfBIG_P3, erf_constants.ErfBIG_P4, + erf_constants.ErfBIG_P5, erf_constants.ErfBIG_P6_Minus_One}; +constexpr float erf_exp_poly[] = {erf_constants.Exp_P0, erf_constants.Exp_P1, + erf_constants.Exp_P2, erf_constants.Exp_P3, erf_constants.Exp_P4, + erf_constants.Exp_P5, erf_constants.Exp_P6}; + +template +constexpr int array_size(const T (&)[N]) { + return N; +} + +uint32_t float_to_bits(float value) { + uint32_t bits = 0; + std::memcpy(&bits, &value, sizeof(bits)); + return bits; +} + +} // namespace + +jit_rvv_eltwise_fwd_emitter_t::jit_rvv_eltwise_fwd_emitter_t( + jit_generator_t *jit) + : jit_(jit) {} + +void jit_rvv_eltwise_fwd_emitter_t::load_f32_bits( + const FReg &dst, const Reg &tmp, uint32_t bits) { + jit_->li(tmp, static_cast(bits)); + jit_->fmv_w_x(dst, tmp); +} + +void jit_rvv_eltwise_fwd_emitter_t::load_f32_const( + const FReg &dst, const Reg &tmp, float value) { + load_f32_bits(dst, tmp, float_to_bits(value)); +} + +void jit_rvv_eltwise_fwd_emitter_t::broadcast_f32_bits( + const eltwise_aux_regs_t &r, const VReg &dst, uint32_t bits) { + load_f32_bits(r.ftmp0, r.tmp_reg0, bits); + jit_->vfmv_v_f(dst, r.ftmp0); +} + +void jit_rvv_eltwise_fwd_emitter_t::broadcast_f32_const( + const eltwise_aux_regs_t &r, const VReg &dst, float value) { + broadcast_f32_bits(r, dst, float_to_bits(value)); +} + +void jit_rvv_eltwise_fwd_emitter_t::clamp_f32(const eltwise_aux_regs_t &r, + const VReg &dst, float lower, float upper) { + load_f32_const(r.ftmp0, r.tmp_reg0, lower); + jit_->vfmax_vf(dst, dst, r.ftmp0); + load_f32_const(r.ftmp0, r.tmp_reg0, upper); + jit_->vfmin_vf(dst, dst, r.ftmp0); +} + +void jit_rvv_eltwise_fwd_emitter_t::mul_f32_const(const eltwise_aux_regs_t &r, + const VReg &dst, const VReg &src, float value) { + load_f32_const(r.ftmp0, r.tmp_reg0, value); + jit_->vfmul_vf(dst, src, r.ftmp0); +} + +void jit_rvv_eltwise_fwd_emitter_t::horner_f32(const eltwise_aux_regs_t &r, + const VReg &dst, const VReg &x, const float *coeffs, int coeff_count) { + assert(coeff_count > 0); + broadcast_f32_const(r, dst, coeffs[0]); + for (int i = 1; i < coeff_count; ++i) { + jit_->vfmul_vv(dst, dst, x); + load_f32_const(r.ftmp0, r.tmp_reg0, coeffs[i]); + jit_->vfadd_vf(dst, dst, r.ftmp0); + } +} + +void jit_rvv_eltwise_fwd_emitter_t::horner_f32_bits(const eltwise_aux_regs_t &r, + const VReg &dst, const VReg &x, const uint32_t *coeffs, + int coeff_count) { + assert(coeff_count > 0); + broadcast_f32_bits(r, dst, coeffs[0]); + for (int i = 1; i < coeff_count; ++i) { + jit_->vfmul_vv(dst, dst, x); + load_f32_bits(r.ftmp0, r.tmp_reg0, coeffs[i]); + jit_->vfadd_vf(dst, dst, r.ftmp0); + } +} + +void jit_rvv_eltwise_fwd_emitter_t::abs( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + UNUSED(r); + jit_->vfabs_v(dst, src); +} + +void jit_rvv_eltwise_fwd_emitter_t::clip( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfmax_vf(dst, src, r.alpha); + jit_->vfmin_vf(dst, dst, r.beta); +} + +void jit_rvv_eltwise_fwd_emitter_t::hardsigmoid( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfmul_vf(dst, src, r.alpha); + jit_->vfadd_vf(dst, dst, r.beta); + jit_->vfmax_vf(dst, dst, r.zero); + jit_->vfmin_vf(dst, dst, r.one); +} + +void jit_rvv_eltwise_fwd_emitter_t::hardswish( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfmul_vf(r.tmp0, src, r.alpha); + jit_->vfadd_vf(r.tmp0, r.tmp0, r.beta); + jit_->vfmax_vf(r.tmp0, r.tmp0, r.zero); + jit_->vfmin_vf(r.tmp0, r.tmp0, r.one); + jit_->vfmul_vv(dst, src, r.tmp0); +} + +void jit_rvv_eltwise_fwd_emitter_t::linear( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfmul_vf(dst, src, r.alpha); + jit_->vfadd_vf(dst, dst, r.beta); +} + +void jit_rvv_eltwise_fwd_emitter_t::relu( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + const VReg v_mask(0); + jit_->vmfgt_vf(v_mask, src, r.zero); + jit_->vfmul_vf(r.tmp0, src, r.alpha); + jit_->vmerge_vvm(dst, r.tmp0, src); +} + +void jit_rvv_eltwise_fwd_emitter_t::sqrt( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + UNUSED(r); + jit_->vfsqrt_v(dst, src); +} + +void jit_rvv_eltwise_fwd_emitter_t::square( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + UNUSED(r); + jit_->vfmul_vv(dst, src, src); +} + +void jit_rvv_eltwise_fwd_emitter_t::leakyrelu( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfmax_vf(dst, src, r.zero); + jit_->vfmin_vf(r.tmp0, src, r.zero); + jit_->vfmul_vf(r.tmp0, r.tmp0, r.alpha); + jit_->vfadd_vv(dst, dst, r.tmp0); +} + +void jit_rvv_eltwise_fwd_emitter_t::round( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfcvt_x_f_v(r.tmp0, src); + jit_->vfcvt_f_x_v(dst, r.tmp0); +} + +void jit_rvv_eltwise_fwd_emitter_t::exp( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vmv_v_v(dst, src); + clamp_f32(r, dst, exp_constants.LowerRange, exp_constants.UpperRange); + + broadcast_f32_bits(r, r.tmp0, exp_constants.RoundingBias); + load_f32_bits(r.ftmp0, r.tmp_reg0, exp_constants.Log2Reciprocal); + jit_->vfmacc_vf(r.tmp0, r.ftmp0, dst); + + load_f32_bits(r.ftmp0, r.tmp_reg0, exp_constants.RoundingBias); + jit_->vfsub_vf(r.tmp1, r.tmp0, r.ftmp0); + load_f32_bits(r.ftmp0, r.tmp_reg0, exp_constants.Log2High); + jit_->vfmacc_vf(dst, r.ftmp0, r.tmp1); + load_f32_bits(r.ftmp0, r.tmp_reg0, exp_constants.Log2Low); + jit_->vfmacc_vf(dst, r.ftmp0, r.tmp1); + + horner_f32_bits(r, r.tmp2, dst, exp_poly, array_size(exp_poly)); + + jit_->vsll_vi(r.tmp0, r.tmp0, 23); + jit_->li(r.tmp_reg0, exp_constants.MaximumExponent); + jit_->li(r.tmp_reg1, exp_constants.MinimumExponent); + jit_->vmin_vx(r.tmp1, r.tmp0, r.tmp_reg0); + jit_->vmax_vx(r.tmp1, r.tmp1, r.tmp_reg1); + jit_->vsub_vv(r.tmp0, r.tmp0, r.tmp1); + jit_->vadd_vx(r.tmp0, r.tmp0, r.tmp_reg0); + jit_->vadd_vx(r.tmp1, r.tmp1, r.tmp_reg0); + + jit_->vfmul_vv(dst, dst, r.tmp0); + jit_->vfmadd_vv(r.tmp2, dst, r.tmp0); + jit_->vfmul_vv(dst, r.tmp2, r.tmp1); +} + +void jit_rvv_eltwise_fwd_emitter_t::erf_exp_complement( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vmv_v_v(r.tmp1, src); + load_f32_const(r.ftmp0, r.tmp_reg0, erf_constants.Exp_LowerRange); + jit_->vfmax_vf(r.tmp1, r.tmp1, r.ftmp0); + + broadcast_f32_const(r, r.tmp0, erf_constants.Exp_Log2Reciprocal); + jit_->vfmul_vv(r.tmp0, r.tmp0, r.tmp1); + load_f32_const(r.ftmp0, r.tmp_reg0, erf_constants.Exp_C); + jit_->vfadd_vf(r.tmp0, r.tmp0, r.ftmp0); + jit_->vfsub_vf(r.tmp0, r.tmp0, r.ftmp0); + + load_f32_const(r.ftmp0, r.tmp_reg0, erf_constants.Exp_log2_hi); + jit_->vfmacc_vf(r.tmp1, r.ftmp0, r.tmp0); + load_f32_const(r.ftmp0, r.tmp_reg0, erf_constants.Exp_log2_lo); + jit_->vfmacc_vf(r.tmp1, r.ftmp0, r.tmp0); + + horner_f32(r, dst, r.tmp1, erf_exp_poly, array_size(erf_exp_poly)); + + jit_->vfcvt_x_f_v(r.tmp0, r.tmp0); + jit_->li(r.tmp_reg0, erf_constants.Exp_X7F); + jit_->vadd_vx(r.tmp0, r.tmp0, r.tmp_reg0); + jit_->vsll_vi(r.tmp0, r.tmp0, 23); + jit_->vfmul_vv(dst, dst, r.tmp0); + jit_->vfrsub_vf(dst, dst, r.one); +} + +void jit_rvv_eltwise_fwd_emitter_t::tanh( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vmv_v_v(dst, src); + clamp_f32(r, dst, tanh_constants.LowerRange, tanh_constants.UpperRange); + + jit_->vfmul_vv(r.tmp0, dst, dst); + horner_f32(r, r.tmp1, r.tmp0, tanh_num_poly, array_size(tanh_num_poly)); + jit_->vfmul_vv(r.tmp1, r.tmp1, dst); + horner_f32(r, r.tmp2, r.tmp0, tanh_den_poly, array_size(tanh_den_poly)); + + jit_->vfdiv_vv(dst, r.tmp1, r.tmp2); +} + +void jit_rvv_eltwise_fwd_emitter_t::erf( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + const VReg v_mask(0); + + broadcast_f32_const(r, r.tmp3, -0.0f); + jit_->vand_vv(r.tmp2, src, r.tmp3); + jit_->vfabs_v(r.tmp0, src); + jit_->vfmul_vv(r.tmp1, r.tmp0, r.tmp0); + + horner_f32(r, dst, r.tmp1, erf_small_poly, array_size(erf_small_poly)); + jit_->vfmul_vv(dst, dst, r.tmp0); + jit_->vfadd_vv(dst, dst, r.tmp0); + + load_f32_const(r.ftmp0, r.tmp_reg0, erf_constants.ErfSplitBoundary); + jit_->vmfgt_vf(v_mask, r.tmp0, r.ftmp0); + load_f32_const(r.ftmp0, r.tmp_reg0, erf_constants.ErfUpperAbsRange); + jit_->vfmin_vf(r.tmp0, r.tmp0, r.ftmp0); + + horner_f32(r, r.tmp1, r.tmp0, erf_big_poly, array_size(erf_big_poly)); + jit_->vfmul_vv(r.tmp1, r.tmp1, r.tmp0); + jit_->vfadd_vv(r.tmp1, r.tmp1, r.tmp0); + + jit_->vfneg_v(r.tmp1, r.tmp1); + erf_exp_complement(r, r.tmp3, r.tmp1); + + jit_->vmerge_vvm(r.tmp3, dst, r.tmp3); + jit_->vor_vv(dst, r.tmp3, r.tmp2); +} + +void jit_rvv_eltwise_fwd_emitter_t::reciprocal( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfrec7_v(dst, src); + for (int i = 0; i < 2; ++i) { + broadcast_f32_const(r, r.tmp2, 2.0f); + jit_->vfnmsac_vv(r.tmp2, src, dst); + jit_->vfmul_vv(dst, dst, r.tmp2); + } +} + +void jit_rvv_eltwise_fwd_emitter_t::sigmoid( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + const VReg v_mask(0); + + jit_->vfabs_v(r.tmp3, src); + jit_->vmfeq_vv(v_mask, src, r.tmp3); + jit_->vfneg_v(r.tmp3, r.tmp3); + exp(r, r.tmp3, r.tmp3); + + jit_->vfadd_vf(r.tmp0, r.tmp3, r.one); + reciprocal(r, r.tmp1, r.tmp0); + jit_->vfmul_vv(r.tmp3, r.tmp3, r.tmp1); + jit_->vmerge_vvm(dst, r.tmp3, r.tmp1); +} +void jit_rvv_eltwise_fwd_emitter_t::swish( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfmul_vf(r.tmp0, src, r.alpha); + sigmoid(r, r.tmp0, r.tmp0); + jit_->vfmul_vv(dst, src, r.tmp0); +} + +void jit_rvv_eltwise_fwd_emitter_t::elu( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + const VReg v_mask(0); + exp(r, r.tmp3, src); + jit_->vfsub_vf(r.tmp3, r.tmp3, r.one); + jit_->vfmul_vf(r.tmp3, r.tmp3, r.alpha); + jit_->vmflt_vf(v_mask, src, r.zero); + jit_->vmerge_vvm(dst, src, r.tmp3); +} + +void jit_rvv_eltwise_fwd_emitter_t::gelu_tanh( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + jit_->vfmul_vv(r.tmp0, src, src); + mul_f32_const(r, r.tmp0, r.tmp0, gelu_tanh_constants.GELU_COEF_A); + jit_->vfadd_vf(r.tmp0, r.tmp0, r.one); + jit_->vfmul_vv(r.tmp0, r.tmp0, src); + mul_f32_const(r, r.tmp0, r.tmp0, gelu_tanh_constants.SQRT_2_OVER_PI); + + tanh(r, r.tmp3, r.tmp0); + jit_->vfadd_vf(r.tmp3, r.tmp3, r.one); + jit_->vfmul_vv(dst, r.tmp3, src); + mul_f32_const(r, dst, dst, 0.5f); + clamp_f32(r, dst, -256.0f, 256.0f); +} + +void jit_rvv_eltwise_fwd_emitter_t::gelu_erf( + const eltwise_aux_regs_t &r, const VReg &dst, const VReg &src) { + mul_f32_const(r, r.tmp0, src, 0.70710678118654752440f); + erf(r, dst, r.tmp0); + + jit_->vfadd_vf(dst, dst, r.one); + mul_f32_const(r, r.tmp0, src, 0.5f); + jit_->vfmul_vv(dst, r.tmp0, dst); +} + +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/rv64/jit_rvv_eltwise_emitter.hpp b/src/cpu/rv64/jit_rvv_eltwise_emitter.hpp new file mode 100644 index 00000000000..ca88ec0c8f3 --- /dev/null +++ b/src/cpu/rv64/jit_rvv_eltwise_emitter.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2026 SpacemiT Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RV64_JIT_RVV_ELTWISE_EMITTER_HPP +#define CPU_RV64_JIT_RVV_ELTWISE_EMITTER_HPP + +#include + +#include "common/c_types_map.hpp" +#include "cpu/rv64/jit_generator.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { + +struct eltwise_aux_regs_t { + Xbyak_riscv::VReg tmp0; + Xbyak_riscv::VReg tmp1; + Xbyak_riscv::VReg tmp2; + Xbyak_riscv::VReg tmp3; + Xbyak_riscv::FReg alpha; + Xbyak_riscv::FReg beta; + Xbyak_riscv::FReg zero; + Xbyak_riscv::FReg one; + Xbyak_riscv::FReg ftmp0; + Xbyak_riscv::FReg ftmp1; + Xbyak_riscv::Reg tmp_reg0; + Xbyak_riscv::Reg tmp_reg1; +}; + +struct jit_rvv_eltwise_fwd_emitter_t { + explicit jit_rvv_eltwise_fwd_emitter_t(jit_generator_t *jit); + + void load_f32_bits(const Xbyak_riscv::FReg &dst, + const Xbyak_riscv::Reg &tmp, uint32_t bits); + void load_f32_const(const Xbyak_riscv::FReg &dst, + const Xbyak_riscv::Reg &tmp, float value); + void broadcast_f32_bits(const eltwise_aux_regs_t &r, + const Xbyak_riscv::VReg &dst, uint32_t bits); + void broadcast_f32_const(const eltwise_aux_regs_t &r, + const Xbyak_riscv::VReg &dst, float value); + void clamp_f32(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + float lower, float upper); + void mul_f32_const(const eltwise_aux_regs_t &r, + const Xbyak_riscv::VReg &dst, const Xbyak_riscv::VReg &src, + float value); + void horner_f32(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &x, const float *coeffs, int coeff_count); + void horner_f32_bits(const eltwise_aux_regs_t &r, + const Xbyak_riscv::VReg &dst, const Xbyak_riscv::VReg &x, + const uint32_t *coeffs, int coeff_count); + + void abs(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void clip(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void hardsigmoid(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void hardswish(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void linear(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void relu(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void sqrt(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void square(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void leakyrelu(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void round(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + + void exp(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void erf_exp_complement(const eltwise_aux_regs_t &r, + const Xbyak_riscv::VReg &dst, const Xbyak_riscv::VReg &src); + void tanh(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void erf(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void reciprocal(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void sigmoid(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void swish(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + + void elu(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void gelu_tanh(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + void gelu_erf(const eltwise_aux_regs_t &r, const Xbyak_riscv::VReg &dst, + const Xbyak_riscv::VReg &src); + +private: + jit_generator_t *jit_; +}; + +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_RV64_JIT_RVV_ELTWISE_EMITTER_HPP diff --git a/src/cpu/rv64/jit_rvv_eltwise_kernel.cpp b/src/cpu/rv64/jit_rvv_eltwise_kernel.cpp index 2f594c0e515..475b5e51aea 100644 --- a/src/cpu/rv64/jit_rvv_eltwise_kernel.cpp +++ b/src/cpu/rv64/jit_rvv_eltwise_kernel.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2026 Institute of Software, Chinese Academy of Sciences * Copyright 2026 openKylin community +* Copyright 2026 SpacemiT Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +19,7 @@ #include #include "cpu/rv64/jit_rvv_eltwise_kernel.hpp" +#include "cpu/rv64/jit_rvv_eltwise_emitter.hpp" namespace dnnl { namespace impl { @@ -65,7 +67,15 @@ bool jit_rvv_eltwise_f32_supported(alg_kind_t alg) { case alg_kind::eltwise_linear: case alg_kind::eltwise_relu: case alg_kind::eltwise_sqrt: - case alg_kind::eltwise_square: return true; + case alg_kind::eltwise_square: + case alg_kind::eltwise_round: + case alg_kind::eltwise_tanh: + case alg_kind::eltwise_logistic: + case alg_kind::eltwise_swish: + case alg_kind::eltwise_elu: + case alg_kind::eltwise_gelu_tanh: + case alg_kind::eltwise_gelu_erf: + case alg_kind::eltwise_exp: return true; default: return false; } } @@ -99,49 +109,61 @@ void jit_rvv_eltwise_apply_f32(alg_kind_t alg, const float *src, float *dst, case alg_kind::eltwise_square: dispatch_jit_eltwise_f32(&p); break; + case alg_kind::eltwise_tanh: + dispatch_jit_eltwise_f32(&p); + break; + case alg_kind::eltwise_logistic: + dispatch_jit_eltwise_f32(&p); + break; + case alg_kind::eltwise_round: + dispatch_jit_eltwise_f32(&p); + break; + case alg_kind::eltwise_swish: + dispatch_jit_eltwise_f32(&p); + break; + case alg_kind::eltwise_elu: + dispatch_jit_eltwise_f32(&p); + break; + case alg_kind::eltwise_gelu_tanh: + dispatch_jit_eltwise_f32(&p); + break; + case alg_kind::eltwise_gelu_erf: + dispatch_jit_eltwise_f32(&p); + break; + case alg_kind::eltwise_exp: + dispatch_jit_eltwise_f32(&p); + break; default: assert(!"unsupported f32 eltwise JIT alg"); } } -void jit_rvv_eltwise_fwd_kernel_t::compute_vector(const VReg &v_dst, - const VReg &v_src, const VReg &v_tmp, const FReg &f_alpha, - const FReg &f_beta, const FReg &f_zero, const FReg &f_one) { +void jit_rvv_eltwise_fwd_kernel_t::compute_vector( + const eltwise_aux_regs_t &r, const VReg &v_dst, const VReg &v_src) { #if defined(XBYAK_RISCV_V) && XBYAK_RISCV_V == 1 - const VReg v_mask(0); + jit_rvv_eltwise_fwd_emitter_t elt(this); switch (alg_) { - case alg_kind::eltwise_abs: vfabs_v(v_dst, v_src); break; - case alg_kind::eltwise_clip: - vfmax_vf(v_dst, v_src, f_alpha); - vfmin_vf(v_dst, v_dst, f_beta); - break; + case alg_kind::eltwise_abs: elt.abs(r, v_dst, v_src); break; + case alg_kind::eltwise_clip: elt.clip(r, v_dst, v_src); break; case alg_kind::eltwise_hardsigmoid: - vfmul_vf(v_dst, v_src, f_alpha); - vfadd_vf(v_dst, v_dst, f_beta); - vfmax_vf(v_dst, v_dst, f_zero); - vfmin_vf(v_dst, v_dst, f_one); - break; - case alg_kind::eltwise_hardswish: - vfmul_vf(v_tmp, v_src, f_alpha); - vfadd_vf(v_tmp, v_tmp, f_beta); - vfmax_vf(v_tmp, v_tmp, f_zero); - vfmin_vf(v_tmp, v_tmp, f_one); - vfmul_vv(v_dst, v_src, v_tmp); - break; - case alg_kind::eltwise_linear: - vfmul_vf(v_dst, v_src, f_alpha); - vfadd_vf(v_dst, v_dst, f_beta); - break; - case alg_kind::eltwise_relu: - vmfgt_vf(v_mask, v_src, f_zero); - vfmul_vf(v_tmp, v_src, f_alpha); - vmerge_vvm(v_dst, v_tmp, v_src); - break; - case alg_kind::eltwise_sqrt: vfsqrt_v(v_dst, v_src); break; - case alg_kind::eltwise_square: vfmul_vv(v_dst, v_src, v_src); break; - default: assert(!"unsupported f32 eltwise JIT alg"); + elt.hardsigmoid(r, v_dst, v_src); + break; + case alg_kind::eltwise_hardswish: elt.hardswish(r, v_dst, v_src); break; + case alg_kind::eltwise_linear: elt.linear(r, v_dst, v_src); break; + case alg_kind::eltwise_relu: elt.relu(r, v_dst, v_src); break; + case alg_kind::eltwise_sqrt: elt.sqrt(r, v_dst, v_src); break; + case alg_kind::eltwise_square: elt.square(r, v_dst, v_src); break; + case alg_kind::eltwise_round: elt.round(r, v_dst, v_src); break; + case alg_kind::eltwise_tanh: elt.tanh(r, v_dst, v_src); break; + case alg_kind::eltwise_logistic: elt.sigmoid(r, v_dst, v_src); break; + case alg_kind::eltwise_swish: elt.swish(r, v_dst, v_src); break; + case alg_kind::eltwise_elu: elt.elu(r, v_dst, v_src); break; + case alg_kind::eltwise_gelu_tanh: elt.gelu_tanh(r, v_dst, v_src); break; + case alg_kind::eltwise_gelu_erf: elt.gelu_erf(r, v_dst, v_src); break; + case alg_kind::eltwise_exp: elt.exp(r, v_dst, v_src); break; + default: assert(!"unsupported eltwise fwd JIT alg"); } #else - UNUSED(v_dst, v_src, v_tmp, f_alpha, f_beta, f_zero, f_one); + UNUSED(r, v_dst, v_src); #endif } @@ -153,16 +175,23 @@ void jit_rvv_eltwise_fwd_kernel_t::generate() { const Reg reg_len = a3; const Reg reg_vl = t0; const Reg reg_bytes = t1; - const Reg reg_tmp = t2; const FReg f_alpha = fa0; const FReg f_beta = fa1; const FReg f_zero = fa2; const FReg f_one = fa3; + const FReg f_tmp0 = fa4; + const FReg f_tmp1 = fa5; const VReg v_src(4); - const VReg v_tmp(8); + const VReg v_tmp0(8); const VReg v_dst(12); + const VReg v_tmp1(16); + const VReg v_tmp2(20); + const VReg v_tmp3(24); + + const eltwise_aux_regs_t regs {v_tmp0, v_tmp1, v_tmp2, v_tmp3, f_alpha, + f_beta, f_zero, f_one, f_tmp0, f_tmp1, t2, t3}; ld(reg_src, reg_param, 0); ld(reg_dst, reg_param, 8); @@ -170,15 +199,15 @@ void jit_rvv_eltwise_fwd_kernel_t::generate() { flw(f_alpha, reg_param, 24); flw(f_beta, reg_param, 28); fmv_w_x(f_zero, x0); - li(reg_tmp, 0x3f800000); - fmv_w_x(f_one, reg_tmp); + li(regs.tmp_reg0, 0x3f800000); + fmv_w_x(f_one, regs.tmp_reg0); Label loop, done; L(loop); beqz(reg_len, done); vsetvli(reg_vl, reg_len, SEW::e32, LMUL::m4); vle32_v(v_src, reg_src); - compute_vector(v_dst, v_src, v_tmp, f_alpha, f_beta, f_zero, f_one); + compute_vector(regs, v_dst, v_src); vse32_v(v_dst, reg_dst); slli(reg_bytes, reg_vl, 2); add(reg_src, reg_src, reg_bytes); @@ -212,7 +241,15 @@ bool jit_rvv_eltwise_fwd_f16_supported(alg_kind_t alg) { case alg_kind::eltwise_linear: case alg_kind::eltwise_relu: case alg_kind::eltwise_sqrt: - case alg_kind::eltwise_square: return true; + case alg_kind::eltwise_square: + case alg_kind::eltwise_round: + case alg_kind::eltwise_tanh: + case alg_kind::eltwise_logistic: + case alg_kind::eltwise_swish: + case alg_kind::eltwise_elu: + case alg_kind::eltwise_gelu_tanh: + case alg_kind::eltwise_gelu_erf: + case alg_kind::eltwise_exp: return true; default: return false; } } @@ -246,51 +283,61 @@ void jit_rvv_eltwise_apply_fwd_f16(alg_kind_t alg, const void *src, void *dst, case alg_kind::eltwise_square: dispatch_jit_eltwise_fwd_f16(&p); break; + case alg_kind::eltwise_tanh: + dispatch_jit_eltwise_fwd_f16(&p); + break; + case alg_kind::eltwise_logistic: + dispatch_jit_eltwise_fwd_f16(&p); + break; + case alg_kind::eltwise_round: + dispatch_jit_eltwise_fwd_f16(&p); + break; + case alg_kind::eltwise_swish: + dispatch_jit_eltwise_fwd_f16(&p); + break; + case alg_kind::eltwise_elu: + dispatch_jit_eltwise_fwd_f16(&p); + break; + case alg_kind::eltwise_gelu_tanh: + dispatch_jit_eltwise_fwd_f16(&p); + break; + case alg_kind::eltwise_gelu_erf: + dispatch_jit_eltwise_fwd_f16(&p); + break; + case alg_kind::eltwise_exp: + dispatch_jit_eltwise_fwd_f16(&p); + break; default: assert(!"unsupported f16 eltwise fwd JIT alg"); } } -void jit_rvv_eltwise_fwd_kernel_f16_t::compute_vector(const VReg &v_dst, - const VReg &v_src, const VReg &v_tmp, const FReg &f_alpha, - const FReg &f_beta, const FReg &f_zero, const FReg &f_one) { +void jit_rvv_eltwise_fwd_kernel_f16_t::compute_vector( + const eltwise_aux_regs_t &r, const VReg &v_dst, const VReg &v_src) { #if defined(XBYAK_RISCV_V) && XBYAK_RISCV_V == 1 - // Compute runs at SEW=e32, LMUL=m4 — same operands and shape as the - // upstream f32 eltwise kernel, since the widening makes v_src an f32 group. - const VReg v_mask(0); + jit_rvv_eltwise_fwd_emitter_t elt(this); switch (alg_) { - case alg_kind::eltwise_abs: vfabs_v(v_dst, v_src); break; - case alg_kind::eltwise_clip: - vfmax_vf(v_dst, v_src, f_alpha); - vfmin_vf(v_dst, v_dst, f_beta); - break; + case alg_kind::eltwise_abs: elt.abs(r, v_dst, v_src); break; + case alg_kind::eltwise_clip: elt.clip(r, v_dst, v_src); break; case alg_kind::eltwise_hardsigmoid: - vfmul_vf(v_dst, v_src, f_alpha); - vfadd_vf(v_dst, v_dst, f_beta); - vfmax_vf(v_dst, v_dst, f_zero); - vfmin_vf(v_dst, v_dst, f_one); - break; - case alg_kind::eltwise_hardswish: - vfmul_vf(v_tmp, v_src, f_alpha); - vfadd_vf(v_tmp, v_tmp, f_beta); - vfmax_vf(v_tmp, v_tmp, f_zero); - vfmin_vf(v_tmp, v_tmp, f_one); - vfmul_vv(v_dst, v_src, v_tmp); - break; - case alg_kind::eltwise_linear: - vfmul_vf(v_dst, v_src, f_alpha); - vfadd_vf(v_dst, v_dst, f_beta); - break; - case alg_kind::eltwise_relu: - vmfgt_vf(v_mask, v_src, f_zero); - vfmul_vf(v_tmp, v_src, f_alpha); - vmerge_vvm(v_dst, v_tmp, v_src); - break; - case alg_kind::eltwise_sqrt: vfsqrt_v(v_dst, v_src); break; - case alg_kind::eltwise_square: vfmul_vv(v_dst, v_src, v_src); break; - default: assert(!"unsupported f16 eltwise fwd JIT alg"); + elt.hardsigmoid(r, v_dst, v_src); + break; + case alg_kind::eltwise_hardswish: elt.hardswish(r, v_dst, v_src); break; + case alg_kind::eltwise_linear: elt.linear(r, v_dst, v_src); break; + case alg_kind::eltwise_relu: elt.relu(r, v_dst, v_src); break; + case alg_kind::eltwise_sqrt: elt.sqrt(r, v_dst, v_src); break; + case alg_kind::eltwise_square: elt.square(r, v_dst, v_src); break; + case alg_kind::eltwise_round: elt.round(r, v_dst, v_src); break; + case alg_kind::eltwise_tanh: elt.tanh(r, v_dst, v_src); break; + case alg_kind::eltwise_logistic: elt.sigmoid(r, v_dst, v_src); break; + case alg_kind::eltwise_swish: elt.swish(r, v_dst, v_src); break; + case alg_kind::eltwise_elu: elt.elu(r, v_dst, v_src); break; + case alg_kind::eltwise_gelu_tanh: elt.gelu_tanh(r, v_dst, v_src); break; + case alg_kind::eltwise_gelu_erf: elt.gelu_erf(r, v_dst, v_src); break; + case alg_kind::eltwise_exp: elt.exp(r, v_dst, v_src); break; + default: assert(!"unsupported eltwise fwd JIT alg"); } #else - UNUSED(v_dst, v_src, v_tmp, f_alpha, f_beta, f_zero, f_one); + UNUSED(r, v_dst, v_src); #endif } @@ -302,24 +349,34 @@ void jit_rvv_eltwise_fwd_kernel_f16_t::generate() { const Reg reg_len = a3; const Reg reg_vl = t0; const Reg reg_bytes = t1; - const Reg reg_tmp = t2; const FReg f_alpha = fa0; const FReg f_beta = fa1; const FReg f_zero = fa2; const FReg f_one = fa3; + const FReg f_tmp0 = fa4; + const FReg f_tmp1 = fa5; // Reg layout (widen-narrow at SEW=e32 LMUL=m4 compute): - // v_in_f16 (m2, regs 2-3) ← f16 load - // v_src (m4, regs 4-7) ← widened f32 input - // v_tmp (m4, regs 8-11) ← scratch - // v_dst (m4, regs 12-15) ← compute output - // v_out_f16 (m2, regs 16-17) ← narrow store buffer + // v_in_f16 (m2, regs 2-3) <- f16 load + // v_src (m4, regs 4-7) <- widened f32 input + // v_tmp0 (m4, regs 8-11) <- scratch + // v_dst (m4, regs 12-15) <- compute output + // v_tmp1 (m4, regs 16-19) <- scratch + // v_tmp2 (m4, regs 20-23) <- scratch + // v_tmp3 (m4, regs 24-27) <- scratch + // v_out_f16 (m2, regs 28-29) <- narrow store buffer const VReg v_in_f16(2); const VReg v_src(4); - const VReg v_tmp(8); + const VReg v_tmp0(8); const VReg v_dst(12); - const VReg v_out_f16(16); + const VReg v_tmp1(16); + const VReg v_tmp2(20); + const VReg v_tmp3(24); + const VReg v_out_f16(28); + + const eltwise_aux_regs_t regs {v_tmp0, v_tmp1, v_tmp2, v_tmp3, f_alpha, + f_beta, f_zero, f_one, f_tmp0, f_tmp1, t2, t3}; ld(reg_src, reg_param, 0); ld(reg_dst, reg_param, 8); @@ -327,26 +384,24 @@ void jit_rvv_eltwise_fwd_kernel_f16_t::generate() { flw(f_alpha, reg_param, 24); flw(f_beta, reg_param, 28); fmv_w_x(f_zero, x0); - li(reg_tmp, 0x3f800000); - fmv_w_x(f_one, reg_tmp); + li(regs.tmp_reg0, 0x3f800000); + fmv_w_x(f_one, regs.tmp_reg0); Label loop, done; L(loop); beqz(reg_len, done); // Three vsetvli phases. vfwcvt reads SEW as the source narrow width; - // vfncvt reads SEW as the *destination* narrow width — opposite - // convention. The compute runs at the wide f32 SEW. VLMAX matches at - // e16/m2 and e32/m4 for VLEN >= 64, so reg_vl is preserved. + // vfncvt reads SEW as the destination narrow width. The compute runs at + // e32/m4, and VLMAX matches e16/m2 for VLEN >= 64. vsetvli(reg_vl, reg_len, SEW::e16, LMUL::m2); vle16_v(v_in_f16, reg_src); vfwcvt_f_f_v(v_src, v_in_f16); vsetvli(reg_vl, reg_vl, SEW::e32, LMUL::m4); - compute_vector(v_dst, v_src, v_tmp, f_alpha, f_beta, f_zero, f_one); + compute_vector(regs, v_dst, v_src); vsetvli(reg_vl, reg_vl, SEW::e16, LMUL::m2); vfncvt_f_f_w(v_out_f16, v_dst); vse16_v(v_out_f16, reg_dst); - // 2 bytes per f16 element. slli(reg_bytes, reg_vl, 1); add(reg_src, reg_src, reg_bytes); add(reg_dst, reg_dst, reg_bytes); diff --git a/src/cpu/rv64/jit_rvv_eltwise_kernel.hpp b/src/cpu/rv64/jit_rvv_eltwise_kernel.hpp index 7da5867df0a..ca125cdf712 100644 --- a/src/cpu/rv64/jit_rvv_eltwise_kernel.hpp +++ b/src/cpu/rv64/jit_rvv_eltwise_kernel.hpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2026 Institute of Software, Chinese Academy of Sciences * Copyright 2026 openKylin community +* Copyright 2026 SpacemiT Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +27,8 @@ namespace impl { namespace cpu { namespace rv64 { +struct eltwise_aux_regs_t; + struct jit_rvv_eltwise_fwd_kernel_t : public jit_generator_t { struct call_params_t { const float *src; @@ -47,10 +50,8 @@ struct jit_rvv_eltwise_fwd_kernel_t : public jit_generator_t { void generate() override; private: - void compute_vector(const Xbyak_riscv::VReg &v_dst, - const Xbyak_riscv::VReg &v_src, const Xbyak_riscv::VReg &v_tmp, - const Xbyak_riscv::FReg &f_alpha, const Xbyak_riscv::FReg &f_beta, - const Xbyak_riscv::FReg &f_zero, const Xbyak_riscv::FReg &f_one); + void compute_vector(const eltwise_aux_regs_t &r, + const Xbyak_riscv::VReg &v_dst, const Xbyak_riscv::VReg &v_src); alg_kind_t alg_; }; @@ -84,10 +85,8 @@ struct jit_rvv_eltwise_fwd_kernel_f16_t : public jit_generator_t { void generate() override; private: - void compute_vector(const Xbyak_riscv::VReg &v_dst, - const Xbyak_riscv::VReg &v_src, const Xbyak_riscv::VReg &v_tmp, - const Xbyak_riscv::FReg &f_alpha, const Xbyak_riscv::FReg &f_beta, - const Xbyak_riscv::FReg &f_zero, const Xbyak_riscv::FReg &f_one); + void compute_vector(const eltwise_aux_regs_t &r, + const Xbyak_riscv::VReg &v_dst, const Xbyak_riscv::VReg &v_src); alg_kind_t alg_; }; diff --git a/src/cpu/rv64/jit_rvv_softmax_kernel.cpp b/src/cpu/rv64/jit_rvv_softmax_kernel.cpp index bab8fae2571..b62161c7746 100644 --- a/src/cpu/rv64/jit_rvv_softmax_kernel.cpp +++ b/src/cpu/rv64/jit_rvv_softmax_kernel.cpp @@ -17,6 +17,7 @@ #include +#include "cpu/rv64/jit_rvv_eltwise_emitter.hpp" #include "cpu/rv64/jit_rvv_softmax_kernel.hpp" namespace dnnl { @@ -284,20 +285,11 @@ void jit_rvv_softmax_f16_exp_sub_sum_kernel_t::generate() { const Reg reg_sub_tmp = t2; const Reg reg_vl = t0; const Reg reg_bytes = t1; - const Reg reg_maxexp = t4; - const Reg reg_minexp = t5; - const Reg reg_imm = t6; const FReg f_zero = ft0; - const FReg f_lower = ft1; - const FReg f_upper = ft2; - const FReg f_round = ft3; - const FReg f_log2_recip = ft4; - const FReg f_log2_high = ft5; - const FReg f_log2_low = ft6; + const FReg f_tmp0 = ft1; + const FReg f_tmp1 = ft2; const FReg f_sub = fa0; - const FReg f_poly = ft7; - const FReg f_poly_coeff = ft8; const FReg f_sum = ft10; const VReg v_in16(0); @@ -308,13 +300,9 @@ void jit_rvv_softmax_f16_exp_sub_sum_kernel_t::generate() { const VReg v_acc(24); const VReg v_red(28); - auto load_f32_bits = [&](const FReg &freg, uint32_t bits) { - li(reg_imm, static_cast(bits)); - fmv_w_x(freg, reg_imm); - }; - auto load_f32 = [&](const FReg &freg, float value) { - load_f32_bits(freg, utils::bit_cast(value)); - }; + jit_rvv_eltwise_fwd_emitter_t elt(this); + const eltwise_aux_regs_t regs {v_bias, v_tmpv, v_poly, v_red, f_sub, f_zero, + f_zero, f_zero, f_tmp0, f_tmp1, t4, t5}; ld(reg_src, reg_param, F16_EXP_SUB_SUM_OFF(src)); ld(reg_tmp, reg_param, F16_EXP_SUB_SUM_OFF(tmp)); @@ -322,17 +310,7 @@ void jit_rvv_softmax_f16_exp_sub_sum_kernel_t::generate() { ld(reg_sum, reg_param, F16_EXP_SUB_SUM_OFF(sum)); lw(reg_sub_tmp, reg_param, F16_EXP_SUB_SUM_OFF(sub)); fmv_w_x(f_sub, reg_sub_tmp); - fmv_w_x(f_zero, x0); - load_f32(f_lower, -103.9720840454f); - load_f32(f_upper, 88.7762626647950f); - load_f32_bits(f_round, 0x4b400000u); - load_f32_bits(f_log2_recip, 0x3fb8aa3bu); - load_f32_bits(f_log2_high, 0xbf317200u); - load_f32_bits(f_log2_low, 0xb5bfbe8eu); - load_f32_bits(f_poly, 0x3ab4a000u); - li(reg_minexp, static_cast(0xC1000000u)); - li(reg_maxexp, static_cast(0x3F800000u)); vsetvli(reg_vl, x0, SEW::e32, LMUL::m4); vfmv_v_f(v_acc, f_zero); @@ -349,40 +327,9 @@ void jit_rvv_softmax_f16_exp_sub_sum_kernel_t::generate() { vsetvli(reg_vl, reg_len, SEW::e32, LMUL::m4); vfsub_vf(v_x, v_x, f_sub); - vfmv_v_f(v_bias, f_round); - vfmax_vf(v_x, v_x, f_lower); - vfmin_vf(v_x, v_x, f_upper); - vfmacc_vf(v_bias, f_log2_recip, v_x); - vfsub_vf(v_tmpv, v_bias, f_round); - vfmacc_vf(v_x, f_log2_high, v_tmpv); - vfmacc_vf(v_x, f_log2_low, v_tmpv); - vfmv_v_f(v_poly, f_poly); - load_f32_bits(f_poly_coeff, 0x3c092f6eu); - vfmul_vv(v_poly, v_poly, v_x); - vfadd_vf(v_poly, v_poly, f_poly_coeff); - load_f32_bits(f_poly_coeff, 0x3d2aadadu); - vfmul_vv(v_poly, v_poly, v_x); - vfadd_vf(v_poly, v_poly, f_poly_coeff); - load_f32_bits(f_poly_coeff, 0x3e2aaa28u); - vfmul_vv(v_poly, v_poly, v_x); - vfadd_vf(v_poly, v_poly, f_poly_coeff); - load_f32_bits(f_poly_coeff, 0x3efffffbu); - vfmul_vv(v_poly, v_poly, v_x); - vfadd_vf(v_poly, v_poly, f_poly_coeff); - load_f32_bits(f_poly_coeff, 0x3f800000u); - vfmul_vv(v_poly, v_poly, v_x); - vfadd_vf(v_poly, v_poly, f_poly_coeff); - vsll_vi(v_bias, v_bias, 23); - vmin_vx(v_tmpv, v_bias, reg_maxexp); - vmax_vx(v_tmpv, v_tmpv, reg_minexp); - vsub_vv(v_bias, v_bias, v_tmpv); - vadd_vx(v_bias, v_bias, reg_maxexp); - vadd_vx(v_tmpv, v_tmpv, reg_maxexp); - vfmul_vv(v_x, v_x, v_bias); - vfmadd_vv(v_poly, v_x, v_bias); - vfmul_vv(v_poly, v_poly, v_tmpv); - vfadd_vv(v_acc, v_acc, v_poly); - vse32_v(v_poly, reg_tmp); + elt.exp(regs, v_x, v_x); + vfadd_vv(v_acc, v_acc, v_x); + vse32_v(v_x, reg_tmp); slli(reg_bytes, reg_vl, 2); add(reg_tmp, reg_tmp, reg_bytes); sub(reg_len, reg_len, reg_vl); diff --git a/src/cpu/rv64/rvv_eltwise.hpp b/src/cpu/rv64/rvv_eltwise.hpp index 3fe288b087e..9163535cf42 100644 --- a/src/cpu/rv64/rvv_eltwise.hpp +++ b/src/cpu/rv64/rvv_eltwise.hpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright 2025 ZTE Corporation +* Copyright 2026 SpacemiT Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +74,21 @@ struct rvv_eltwise_fwd_t : public primitive_t { bool use_dense_; bool check_alg_kind() const { - return utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu, + const auto alg = desc()->alg_kind; + using namespace dnnl::impl::data_type; + if (utils::one_of(dst_md()->data_type, f32, f16)) { + return utils::one_of(alg, alg_kind::eltwise_relu, + alg_kind::eltwise_square, alg_kind::eltwise_abs, + alg_kind::eltwise_sqrt, alg_kind::eltwise_linear, + alg_kind::eltwise_clip, alg_kind::eltwise_hardsigmoid, + alg_kind::eltwise_hardswish, alg_kind::eltwise_tanh, + alg_kind::eltwise_logistic, alg_kind::eltwise_round, + alg_kind::eltwise_swish, alg_kind::eltwise_elu, + alg_kind::eltwise_gelu_tanh, alg_kind::eltwise_gelu_erf, + alg_kind::eltwise_exp); + } + + return utils::one_of(alg, alg_kind::eltwise_relu, alg_kind::eltwise_square, alg_kind::eltwise_abs, alg_kind::eltwise_sqrt, alg_kind::eltwise_linear, alg_kind::eltwise_clip, alg_kind::eltwise_hardsigmoid,