diff --git a/src/nnfusion/core/kernels/cpu/reference/constant.cpp b/src/nnfusion/core/kernels/cpu/reference/constant.cpp index 7917d10ad..b922d2cee 100644 --- a/src/nnfusion/core/kernels/cpu/reference/constant.cpp +++ b/src/nnfusion/core/kernels/cpu/reference/constant.cpp @@ -71,4 +71,4 @@ using namespace nnfusion; using namespace nnfusion::kernels; REGISTER_KERNEL_EMITTER("Constant", //op_name Device(GENERIC_CPU).TypeConstraint(element::f32), //attrs - cpu::Constant) // constructor \ No newline at end of file + cpu::Constant) // constructor diff --git a/src/nnfusion/core/kernels/cpu/reference/variable.cpp b/src/nnfusion/core/kernels/cpu/reference/variable.cpp index 5e16388f6..a4eeeea2b 100644 --- a/src/nnfusion/core/kernels/cpu/reference/variable.cpp +++ b/src/nnfusion/core/kernels/cpu/reference/variable.cpp @@ -69,4 +69,4 @@ using namespace nnfusion; using namespace nnfusion::kernels; REGISTER_KERNEL_EMITTER("Variable", //op_name Device(GENERIC_CPU).TypeConstraint(element::f32), //attrs - cpu::Variable) // constructor \ No newline at end of file + cpu::Variable) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_helper.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_helper.cpp index c1809a5cb..54a73cfb1 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_helper.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_helper.cpp @@ -33,6 +33,11 @@ LanguageUnit_p cuda::get_math_kernel(const std::string& name, writer << ")\n"; writer << "{\n"; writer.indent++; + if (name == "convert" && data_types[num_inputs] == "half" && data_types[0] == "int64_t") + { + writer << "return (long long)" + math_kernel << ";\n"; + } + else { writer << "return " + math_kernel << ";\n"; } diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/apply_adam.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/apply_adam.cpp index e42e0eda9..44308b801 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/apply_adam.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/apply_adam.cpp @@ -108,4 +108,4 @@ using namespace nnfusion::kernels; REGISTER_KERNEL_EMITTER( "ApplyAdam", Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), - cuda::ApplyAdam) \ No newline at end of file + cuda::ApplyAdam) diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp index 173e95e93..c42d7780c 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/batch_matmul.cpp @@ -8,6 +8,7 @@ // [a] ./new_kernel_0.cpp // [b] ../../../ops/op_define/new_op_0.cpp +#include #include "../cuda_emitter.hpp" #include "../cuda_langunit.hpp" #include "nnfusion/core/operators/generic_op/generic_op.hpp" @@ -52,6 +53,15 @@ namespace nnfusion const nnfusion::Shape& input_shape_0 = m_context->inputs[0]->get_shape(); const nnfusion::Shape& input_shape_1 = m_context->inputs[1]->get_shape(); + element::Type dtype0 = m_context->inputs[0]->get_element_type(); + element::Type dtype1 = m_context->inputs[1]->get_element_type(); + element::Type dtype2 = m_context->outputs[0]->get_element_type(); + NNFUSION_CHECK(dtype0 == dtype1 && dtype1 == dtype2) + << "Unsupported element type combination of (" << dtype0.c_type_string() + << ", " << dtype1.c_type_string() << ") -> " << dtype2.c_type_string() + << "."; + element::Type& dtype = dtype0; + bool transA = generic_op->localOpConfig.getRoot()["adj_x"]["b"]; bool transB = generic_op->localOpConfig.getRoot()["adj_y"]["b"]; size_t A1 = 1LU; @@ -92,10 +102,11 @@ namespace nnfusion stride_b = A2 * A3, ldc = A4, stride_c = A2 * A4; } + std::string type = dtype.c_type_string(); float alpha = 1.0f, beta = 0.0f; auto code = nnfusion::op::create_code_from_template( R"( - static const float alpha = @alpha@F, beta = @beta@F; + static const @dtype@ alpha = @alpha@, beta = @beta@; // if (!@hCublas@) // CUBLAS_SAFE_CALL(@api_create@(&@hCublas@)); CUBLAS_SAFE_CALL(@api_exec@( @@ -106,7 +117,9 @@ namespace nnfusion { {"hCublas", "cublas_handle"}, {"api_create", "cublasCreate"}, - {"api_exec", "cublasSgemmStridedBatched"}, + {"api_exec", + dtype == element::f32 ? "cublasSgemmStridedBatched" + : "cublasHgemmStridedBatched"}, {"transA", transB ? "CUBLAS_OP_T" : "CUBLAS_OP_N"}, {"transB", transA ? "CUBLAS_OP_T" : "CUBLAS_OP_N"}, {"alpha", alpha}, @@ -121,6 +134,7 @@ namespace nnfusion {"stride_b", stride_b}, {"stride_c", stride_c}, {"batch", A1}, + {"dtype", type}, }); LanguageUnit_p _lu(new LanguageUnit(get_function_name())); diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/constant.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/constant.cpp index 73f04f5fd..d43bcca07 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/constant.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/constant.cpp @@ -121,4 +121,4 @@ using namespace nnfusion; using namespace nnfusion::kernels; REGISTER_KERNEL_EMITTER("Constant", //op_name Device(CUDA_GPU).TypeConstraint(element::f32).Priority(2), //attrs - cuda::Constant) // constructor \ No newline at end of file + cuda::Constant) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp index acd1939b6..9d84ff321 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/dot.cpp @@ -86,7 +86,7 @@ LanguageUnit_p cuda::Dot::emit_function_body() // matrix * vector else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1)) { - lu << "const float alpha = 1.0;\n const float beta = 0;\n"; + lu << "const float alpha = 1.0;\n const float beta = 0.;\n"; lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, "; if (trans_A) lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", "; @@ -107,7 +107,7 @@ LanguageUnit_p cuda::Dot::emit_function_body() int n = trans_A ? arg0_shape[1] : arg0_shape[0]; int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - lu << "const float alpha = 1.0;\nconst float beta = 0;\n"; + lu << "const float alpha = 1.0;\nconst float beta = 0.;\n"; lu << "CUBLAS_SAFE_CALL(cublasSgemm(cublas_handle," << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") @@ -186,7 +186,7 @@ LanguageUnit_p cuda::Dot::emit_function_body() } } - lu << "const float alpha = 1.0;\nconst float beta = 0;\n"; + lu << "const float alpha = 1.0;\nconst float beta = 0.;\n"; lu << "CUBLAS_SAFE_CALL(cublasSgemm(cublas_handle," << " CUBLAS_OP_N," @@ -206,162 +206,109 @@ LanguageUnit_p cuda::Dot::emit_function_body() } else if (dtype == element::f16) { - // case 1: Scalar * Tensor - // if (arg0_shape.empty() || arg1_shape.empty()) - // { - // auto& second = (arg0_shape.empty() ? arg1_shape : arg0_shape); - // size_t count = nnfusion::shape_size(second); - - // string firstarg = (arg0_shape.empty() ? "input1" : "input0"); - // string secondarg = (arg0_shape.empty() ? "input0" : "input1"); - - // lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; - - // lu << "CUDA_SAFE_CALL(cudaMemcpy(outupt0, " << firstarg << ", " << count << ", cudaMemcpyDeviceToDevice));\n"; // copy `firstarg` to `output0` - // lu << "CUBLAS_SAFE_CALL(nnfusionHalfScale(" << secondarg << ", output0, " << count << "));\n"; - // } - // // case 2: 1d Dot - // else if ((arg0_shape.size() == arg1_shape.size()) && (arg0_shape.size() == reduction_axes)) - // { - // for (int i = 0; i < arg0_shape.size(); i++) - // { - // if (arg0_shape[i] != arg1_shape[i]) - // { - // std::vector arg_vec{"arg0", "arg1"}; - // std::vector shape_vec{arg0_shape, arg1_shape}; - - // NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - // << nnfusion::join(shape_vec) << " respectively, at Node " - // << m_context->gnode->get_name() - // << ", do not match for dot op"; - // } - // } - - // size_t count = nnfusion::shape_size(arg0_shape); - // lu << "cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_DEVICE);\n"; - - // lu << "CUBLAS_SAFE_CALL(cublasSdot(cublas_handle, " << count - // << ", static_cast(input0), 1, static_cast(input1), 1, " - // "static_cast(output0)));\n"; - // } - // // matrix * vector - // else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 1) && (reduction_axes == 1)) - // { - // lu << "const float alpha = 1.0;\n const float beta = 0;\n"; - // lu << "CUBLAS_SAFE_CALL(cublasSgemv(cublas_handle, "; - // if (trans_A) - // lu << "CUBLAS_OP_N, " << arg0_shape[0] << ", " << arg0_shape[1] << ", "; - // else - // lu << "CUBLAS_OP_T, " << arg0_shape[1] << ", " << arg0_shape[0] << ", "; - // lu << " &alpha," - // << " static_cast(input0)," << arg0_shape[1] << ", " - // << " static_cast(input1)," - // << " 1," - // << " &beta," - // << " static_cast(output0)," - // << " 1));\n"; - // } - // else if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) && - // (trans_A || trans_B)) - // { - // int m = trans_B ? arg1_shape[0] : arg1_shape[1]; - // int n = trans_A ? arg0_shape[1] : arg0_shape[0]; - // int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - - // lu << "const half alpha = 1.0;\nconst half beta = 0;\n"; - - // lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," - // << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") - // << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << "," - // << " " << n << "," - // << " " << k << "," - // << " &alpha," - // << " static_cast(input1)," - // << " " << arg1_shape[1] << "," - // << " static_cast(input0)," - // << " " << arg0_shape[1] << "," - // << " &beta," - // << " static_cast(output0)," - // << " " << m << "));\n"; - // } else { - size_t axes_for_m_count = arg0_shape.size() - reduction_axes; - size_t axes_for_n_count = arg1_shape.size() - reduction_axes; - size_t axes_for_k_count = reduction_axes; - size_t m = 1; - size_t n = 1; - size_t k = 1; - - // check if input and output size correct - // check and calculate k for arg0 and arg1 - size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k - size_t arg1_k_idx = 0; // first axe in arg1 for k - - for (size_t i = 0; i < axes_for_k_count; i++) + if ((arg0_shape.size() == 2) && (arg1_shape.size() == 2) && (reduction_axes == 1) && + (trans_A || trans_B)) { - k *= arg0_shape[arg0_k_idx]; - if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++]) - { - std::vector arg_vec{"arg0", "arg1"}; - std::vector shape_vec{arg0_shape, arg1_shape}; + int m = trans_B ? arg1_shape[0] : arg1_shape[1]; + int n = trans_A ? arg0_shape[1] : arg0_shape[0]; + int k = trans_A ? arg0_shape[0] : arg0_shape[1]; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; - } + lu << "const half alpha = 1.0;\nconst half beta = 0.;\n"; + + lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," + << (trans_B ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") + << (trans_A ? " CUBLAS_OP_T," : " CUBLAS_OP_N,") << " " << m << "," + << " " << n << "," + << " " << k << "," + << " &alpha," + << " static_cast(input1)," + << " " << arg1_shape[1] << "," + << " static_cast(input0)," + << " " << arg0_shape[1] << "," + << " &beta," + << " static_cast(output0)," + << " " << m << "));\n"; } - // check and calculate m for arg0 and out - size_t arg0_m_idx = 0; // first axe in arg0 for m - size_t out_m_idx = 0; // first axe in out for m - for (size_t i = 0; i < axes_for_m_count; i++) + else { - m *= arg0_shape[arg0_m_idx]; - if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++]) + size_t axes_for_m_count = arg0_shape.size() - reduction_axes; + size_t axes_for_n_count = arg1_shape.size() - reduction_axes; + size_t axes_for_k_count = reduction_axes; + size_t m = 1; + size_t n = 1; + size_t k = 1; + + // check if input and output size correct + // check and calculate k for arg0 and arg1 + size_t arg0_k_idx = axes_for_m_count; // first axe in arg0 for k + size_t arg1_k_idx = 0; // first axe in arg1 for k + + for (size_t i = 0; i < axes_for_k_count; i++) { - std::vector arg_vec{"arg0", "output"}; - std::vector shape_vec{arg0_shape, out_shape}; + k *= arg0_shape[arg0_k_idx]; + if (arg0_shape[arg0_k_idx++] != arg1_shape[arg1_k_idx++]) + { + std::vector arg_vec{"arg0", "arg1"}; + std::vector shape_vec{arg0_shape, arg1_shape}; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } } - } - // check and calculate n for arg1 and out - size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n - size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n - for (size_t i = 0; i < axes_for_n_count; i++) - { - n *= arg1_shape[arg1_n_idx]; - if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++]) + // check and calculate m for arg0 and out + size_t arg0_m_idx = 0; // first axe in arg0 for m + size_t out_m_idx = 0; // first axe in out for m + for (size_t i = 0; i < axes_for_m_count; i++) { - std::vector arg_vec{"arg1", "output"}; - std::vector shape_vec{arg1_shape, out_shape}; + m *= arg0_shape[arg0_m_idx]; + if (arg0_shape[arg0_m_idx++] != out_shape[out_m_idx++]) + { + std::vector arg_vec{"arg0", "output"}; + std::vector shape_vec{arg0_shape, out_shape}; - NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " - << nnfusion::join(shape_vec) << " respectively, at Node " - << m_context->gnode->get_name() - << ", do not match for dot op"; + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } } - } + // check and calculate n for arg1 and out + size_t arg1_n_idx = axes_for_k_count; // first axe in arg1 for n + size_t out_n_idx = axes_for_m_count; // first axe in arg1 for n + for (size_t i = 0; i < axes_for_n_count; i++) + { + n *= arg1_shape[arg1_n_idx]; + if (arg1_shape[arg1_n_idx++] != out_shape[out_n_idx++]) + { + std::vector arg_vec{"arg1", "output"}; + std::vector shape_vec{arg1_shape, out_shape}; - lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n"; - - lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," - << " CUBLAS_OP_N," - << " CUBLAS_OP_N," - << " " << n << "," - << " " << m << "," - << " " << k << "," - << " &alpha," - << " static_cast(input1)," - << " " << n << "," - << " static_cast(input0)," - << " " << k << "," - << " &beta," - << " static_cast(output0)," - << " " << n << "));\n"; - // } + NNFUSION_CHECK_FAIL() << nnfusion::join(arg_vec) << " with " + << nnfusion::join(shape_vec) << " respectively, at Node " + << m_context->gnode->get_name() + << ", do not match for dot op"; + } + } + + lu << "const half alpha = 1.0f;\nconst half beta = 0.f;\n"; + + lu << "CUBLAS_SAFE_CALL(cublasHgemm(cublas_handle," + << " CUBLAS_OP_N," + << " CUBLAS_OP_N," + << " " << n << "," + << " " << m << "," + << " " << k << "," + << " &alpha," + << " static_cast(input1)," + << " " << n << "," + << " static_cast(input0)," + << " " << k << "," + << " &beta," + << " static_cast(output0)," + << " " << n << "));\n"; + } } else { diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/dynamic_stitch.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/dynamic_stitch.cpp index 4bd847949..00cd81136 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/dynamic_stitch.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/dynamic_stitch.cpp @@ -124,4 +124,4 @@ LanguageUnit_p cuda::DynamicStitch::emit_dependency() REGISTER_KERNEL_EMITTER( "DynamicStitch", // op_name Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), // attrs - cuda::DynamicStitch) // constructor \ No newline at end of file + cuda::DynamicStitch) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/pad.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/pad.cpp index faab94fe9..733c037e4 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/pad.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/pad.cpp @@ -149,4 +149,4 @@ KernelRegistrar kernel_registrar0( REGISTER_KERNEL_EMITTER( "Pad", // op_name Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), // attrs - cuda::Pad) // constructor \ No newline at end of file + cuda::Pad) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/range.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/range.cpp index 1c5a30279..f7a06a159 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/range.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/range.cpp @@ -65,4 +65,4 @@ LanguageUnit_p cuda::Range::emit_dependency() REGISTER_KERNEL_EMITTER( "Range", // op_name Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), // attrs - cuda::Range) // constructor \ No newline at end of file + cuda::Range) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp index 97353e5e8..c9bfb3c26 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/reduce.hpp @@ -189,7 +189,7 @@ int data_idx_offset = block_idx * width; float val = 0.0; for (int tidx = thread_idx; tidx < width; tidx += block_size) { int data_idx = tidx + data_idx_offset; - val += input0[data_idx]; + val += static_cast(input0[data_idx]); } val = reduceSum(val, thread_idx, block_size, shm); if (thread_idx == 0) output0[block_idx] = val; diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/result.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/result.cpp index 229580e6a..73b819fd7 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/result.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/result.cpp @@ -90,4 +90,4 @@ LanguageUnit_p cuda::Result::emit_dependency() REGISTER_KERNEL_EMITTER( "Result", // op_name Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_lib").Priority(2), // attrs - cuda::Result) // constructor \ No newline at end of file + cuda::Result) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/reverse.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/reverse.cpp index 6d5fc374d..e3be51ffc 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/reverse.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/reverse.cpp @@ -103,4 +103,4 @@ LanguageUnit_p cuda::Reverse::emit_dependency() REGISTER_KERNEL_EMITTER( "Reverse", // op_name Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), // attrs - cuda::Reverse) // constructor \ No newline at end of file + cuda::Reverse) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/reverse_sequence.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/reverse_sequence.cpp index 487951930..612c51730 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/reverse_sequence.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/reverse_sequence.cpp @@ -130,4 +130,4 @@ REGISTER_KERNEL_EMITTER( REGISTER_KERNEL_EMITTER("ReverseSequence", // op_name Device(ROCM_GPU).TypeConstraint(element::f32).Priority(2), // attrs - cuda::RocmReverseSequence) // constructor \ No newline at end of file + cuda::RocmReverseSequence) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp index 4f5bfa067..b83e9a832 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp @@ -203,7 +203,8 @@ LanguageUnit_p { LanguageUnit_p _lu(new LanguageUnit); auto& lu = *_lu; - string data_type = "CUDNN_DATA_FLOAT"; //cuda::get_cudnn_datatype(type); + element::Type type = m_context->inputs[0]->get_element_type(); + string data_type = cuda::get_cudnn_datatype(type); string tensor_format = "CUDNN_TENSOR_NCHW"; lu << "cudnnTensorDescriptor_t " << desc << ";\n"; lu << "CUDNN_SAFE_CALL(cudnnCreateTensorDescriptor(&" << desc << "));\n"; diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/strided_slice_grad.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/strided_slice_grad.cpp index 342edf949..ab27b0ec7 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/strided_slice_grad.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/strided_slice_grad.cpp @@ -115,4 +115,4 @@ LanguageUnit_p cuda::StridedSliceGrad::emit_dependency() REGISTER_KERNEL_EMITTER( "StridedSliceGrad", // op_name Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), // attrs - cuda::StridedSliceGrad) // constructor \ No newline at end of file + cuda::StridedSliceGrad) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/tile.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/tile.cpp index 33a869e71..6dc5220d9 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/tile.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/tile.cpp @@ -153,4 +153,4 @@ REGISTER_KERNEL_EMITTER( REGISTER_KERNEL_EMITTER("Tile", //op_name Device(ROCM_GPU).TypeConstraint(element::f32).Priority(2), //attrs - cuda::RocmTile) // constructor \ No newline at end of file + cuda::RocmTile) // constructor diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/variable.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/variable.cpp index 419124649..80c5cc707 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/variable.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/variable.cpp @@ -79,4 +79,4 @@ using namespace nnfusion; using namespace nnfusion::kernels; REGISTER_KERNEL_EMITTER("Variable", //op_name Device(CUDA_GPU).TypeConstraint(element::f32).Priority(2), //attrs - cuda::Variable) // constructor \ No newline at end of file + cuda::Variable) // constructor diff --git a/src/nnfusion/engine/pass/graph/dot_transpose_pass.hpp b/src/nnfusion/engine/pass/graph/dot_transpose_pass.hpp index 76d40424d..42c782ae4 100644 --- a/src/nnfusion/engine/pass/graph/dot_transpose_pass.hpp +++ b/src/nnfusion/engine/pass/graph/dot_transpose_pass.hpp @@ -25,6 +25,6 @@ namespace nnfusion public: bool run_on_graph(std::shared_ptr& graph) override; }; - } // namespace pass - } // namespace graph + } // namespace graph + } // namespace pass } // namespace nnfusion diff --git a/src/nnfusion/engine/pass/graph/kernel_tuning.cpp b/src/nnfusion/engine/pass/graph/kernel_tuning.cpp index 068f7c904..751c5dc40 100644 --- a/src/nnfusion/engine/pass/graph/kernel_tuning.cpp +++ b/src/nnfusion/engine/pass/graph/kernel_tuning.cpp @@ -91,6 +91,12 @@ void print_tuning_results(std::vector> tuned_kerne << std::setw(10) << s->status << " | " << std::setw(6) << s->progress_step << "/" << FLAGS_fkernel_tuning_steps << " " << " | " << std::setw(12) << s->best_perf << " ms |\n"; + + if (fabs(s->best_perf + 1.0) < 1e-5) + { + NNFUSION_LOG(INFO) << "Kernel named \"" << s->op_name << "\" has not yet been tuned.\n" + << s->ir; + } } NNFUSION_LOG(INFO) << ss.str(); } diff --git a/src/nnfusion/frontend/onnx_import/core/tensor.hpp b/src/nnfusion/frontend/onnx_import/core/tensor.hpp index 415abc080..201aa580d 100644 --- a/src/nnfusion/frontend/onnx_import/core/tensor.hpp +++ b/src/nnfusion/frontend/onnx_import/core/tensor.hpp @@ -22,6 +22,7 @@ #pragma once #include "../util/util.hpp" +#include "nnfusion/common/type/data_buffer.hpp" namespace nnfusion { @@ -55,50 +56,31 @@ namespace nnfusion return detail::get_data(*m_tensor_proto); } + DataBuffer buffer_get_data() const + { + return detail::buffer_get_data(*m_tensor_proto); + } + const std::string& get_name() const { NNFUSION_CHECK(m_tensor_proto->has_name()) << "tensor has no name specified."; return m_tensor_proto->name(); } - const element::Type& get_ng_type() const + element::Type get_ng_type() const { NNFUSION_CHECK(m_tensor_proto->has_data_type()) << "tensor has no data type specified."; - switch (m_tensor_proto->data_type()) - { - case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: - return element::boolean; - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: - return element::f32; - case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: - return element::f64; - case onnx::TensorProto_DataType::TensorProto_DataType_INT8: return element::i8; - case onnx::TensorProto_DataType::TensorProto_DataType_INT16: - return element::i16; - case onnx::TensorProto_DataType::TensorProto_DataType_INT32: - return element::i32; - case onnx::TensorProto_DataType::TensorProto_DataType_INT64: - return element::i64; - case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: return element::u8; - case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: - return element::u16; - case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: - return element::u32; - case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: - return element::u64; - case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: - NNFUSION_CHECK_FAIL() << "data type is not defined"; - break; - default: - NNFUSION_CHECK_FAIL() - << "unsupported data type: " - << onnx::TensorProto_DataType_Name( - onnx::TensorProto_DataType(m_tensor_proto->data_type())); - break; - } + element::Type element_type; + bool status; + status = ONNXDataTypeToNNFusionElementType( + static_cast(m_tensor_proto->data_type()), + &element_type); + NNFUSION_CHECK(status) << "Data type not supported: " + << m_tensor_proto->data_type(); + + return element_type; } operator onnx::TensorProto_DataType() const diff --git a/src/nnfusion/frontend/onnx_import/op/constant.hpp b/src/nnfusion/frontend/onnx_import/op/constant.hpp index f3dd0bfc9..87163617d 100644 --- a/src/nnfusion/frontend/onnx_import/op/constant.hpp +++ b/src/nnfusion/frontend/onnx_import/op/constant.hpp @@ -65,8 +65,10 @@ namespace nnfusion Node node(node_proto); auto tensor = node.get_attribute_value("value"); - const auto& func_param = ONNX_CONST_MAP().at(tensor.get_ng_type()); - auto op = func_param(tensor.get_ng_type(), tensor); + // const auto& func_param = ONNX_CONST_MAP().at(tensor.get_ng_type()); + // auto op = func_param(tensor.get_ng_type(), tensor); + auto op = std::make_shared( + tensor.get_ng_type(), tensor.get_shape(), tensor.buffer_get_data()); op->set_name(node_proto.output(0)); auto gnode = m_graph->add_node_and_edge(op, graph::GNodeVector({})); diff --git a/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp b/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp index 2c29954c2..dfa79661f 100644 --- a/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp +++ b/src/nnfusion/frontend/onnx_import/util/graph_convert.cpp @@ -143,7 +143,7 @@ namespace nnfusion onnx::ModelProto proto_without_init; proto_without_init.CopyFrom(model_proto); proto_without_init.mutable_graph()->mutable_initializer()->Clear(); - NNFUSION_LOG(INFO) << proto_without_init.DebugString(); + // NNFUSION_LOG(INFO) << proto_without_init.DebugString(); } std::string diff --git a/src/nnfusion/frontend/onnx_import/util/util.cpp b/src/nnfusion/frontend/onnx_import/util/util.cpp index d6f52653d..6cd9f4316 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.cpp +++ b/src/nnfusion/frontend/onnx_import/util/util.cpp @@ -28,7 +28,7 @@ namespace nnfusion { namespace onnx_import { - bool ONNXDataTypeToNNFusionElementType(const onnx::TensorProto_DataType onnx_dt, + bool ONNXDataTypeToNNFusionElementType(onnx::TensorProto_DataType onnx_dt, nnfusion::element::Type* nnfusion_et) { switch (onnx_dt) @@ -36,8 +36,10 @@ namespace nnfusion case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: *nnfusion_et = element::boolean; break; - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: + *nnfusion_et = element::f16; + break; + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: *nnfusion_et = element::f32; break; case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: @@ -86,35 +88,9 @@ namespace nnfusion const Shape shape, const Tensor& tensor) { - switch (onnx_et) - { - case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: - return make_constant_op(element::boolean, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: - return make_constant_op(element::f32, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: - return make_constant_op(element::f64, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_INT8: - return make_constant_op(element::i8, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_INT16: - return make_constant_op(element::i16, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_INT32: - return make_constant_op(element::i32, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_INT64: - return make_constant_op(element::i64, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: - return make_constant_op(element::u8, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: - return make_constant_op(element::u16, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: - return make_constant_op(element::u32, shape, tensor); - case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: - return make_constant_op(element::u64, shape, tensor); - default: - NNFUSION_CHECK_FAIL() << "unsupported value info element type: " - << onnx::TensorProto_DataType_Name(onnx_et); - } + element::Type element_type = tensor.get_ng_type(); + return std::make_shared( + element_type, shape, tensor.buffer_get_data()); } std::shared_ptr GetInputNode(const NodeMap& all_ng_nodes, @@ -280,6 +256,83 @@ namespace nnfusion name, std::vector(kernel_shape.size(), 1UL)); } + DataBuffer detail::buffer_get_data(const onnx::TensorProto& tensor) + { + size_t n_element = 1; + element::Type type; + bool status; + auto onnx_dt = static_cast(tensor.data_type()); + + status = ONNXDataTypeToNNFusionElementType(onnx_dt, &type); + + NNFUSION_CHECK(status) << "Unsupported ONNX data_type " << tensor.data_type() + << " is found"; + + DataBuffer buf(type); + + for (auto dim : tensor.dims()) + { + n_element *= dim; + } + buf.resize(n_element); + + if (tensor.has_raw_data()) + { + buf.load(tensor.raw_data().data(), n_element); + } + else + { +#define GET_VALUE(pb_type, mid_type) \ + do \ + { \ + const void* dat; \ + mid_type m; \ + NNFUSION_CHECK(n_element == tensor.pb_type##_data_size()) \ + << "Tensor shape is not the same with tensor data_size. (" << n_element \ + << " != " << tensor.pb_type##_data_size() << ")"; \ + for (size_t i = 0; i < n_element; ++i) \ + { \ + m = static_cast(tensor.pb_type##_data()[i]); \ + dat = reinterpret_cast(&m); \ + buf.setElement(i, dat); \ + } \ + } while (0) + + switch (onnx_dt) + { + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: + GET_VALUE(int32, element::half); + break; + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: + GET_VALUE(float, float); + break; + case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: + GET_VALUE(double, double); + break; + case onnx::TensorProto_DataType::TensorProto_DataType_INT32: + GET_VALUE(int32, int32_t); + break; + case onnx::TensorProto_DataType::TensorProto_DataType_INT64: + GET_VALUE(int64, int64_t); + break; + case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: + GET_VALUE(uint64, uint64_t); + break; + case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: + case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: + case onnx::TensorProto_DataType::TensorProto_DataType_INT16: + case onnx::TensorProto_DataType::TensorProto_DataType_INT8: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: + default: + NNFUSION_CHECK_FAIL() << "unsupported onnx element type: " + << onnx::TensorProto_DataType_Name(onnx_dt); + } +#undef GET_VALUE + } + return buf; + } + } // namespace onnx_import } // namespace frontend } // namespace nnfusion diff --git a/src/nnfusion/frontend/onnx_import/util/util.hpp b/src/nnfusion/frontend/onnx_import/util/util.hpp index 871ab4801..2f7262842 100644 --- a/src/nnfusion/frontend/onnx_import/util/util.hpp +++ b/src/nnfusion/frontend/onnx_import/util/util.hpp @@ -29,6 +29,7 @@ #include "../onnx_base.hpp" #include "nnfusion/common/common.hpp" +#include "nnfusion/common/type/data_buffer.hpp" namespace nnfusion { @@ -51,6 +52,8 @@ namespace nnfusion return {it, it + (raw_data.size() / sizeof(T))}; } + DataBuffer buffer_get_data(const onnx::TensorProto& tensor); + template inline std::vector get_data(const onnx::TensorProto& tensor) { @@ -186,7 +189,7 @@ namespace nnfusion class Tensor; class Node; - bool ONNXDataTypeToNNFusionElementType(const onnx::TensorProto_DataType onnx_dt, + bool ONNXDataTypeToNNFusionElementType(onnx::TensorProto_DataType onnx_dt, nnfusion::element::Type* nnfusion_et); template diff --git a/src/nnfusion/frontend/util/evaluator.hpp b/src/nnfusion/frontend/util/evaluator.hpp index 46b257922..23d9bc7b8 100644 --- a/src/nnfusion/frontend/util/evaluator.hpp +++ b/src/nnfusion/frontend/util/evaluator.hpp @@ -105,21 +105,21 @@ namespace nnfusion nnfusion::profiler::IProfilingRuntime::Pointer runtime = nullptr; std::vector> kernel_regs; - runtime = nnfusion::profiler::RocmDefaultRuntime::Runtime(); + runtime = nnfusion::profiler::CudaDefaultRuntime::Runtime(); if (runtime->check_env()) { kernel_regs = KernelRegistry::Global()->FindKernelRegistrations( - gnode->get_op_type(), ROCM_GPU, element::f32); - if (kernel_regs.size() == 0) - kernel_regs = KernelRegistry::Global()->FindKernelRegistrations( - gnode->get_op_type(), CUDA_GPU, element::f32); + gnode->get_op_type(), CUDA_GPU, element::f32); } else { - runtime = nnfusion::profiler::CudaDefaultRuntime::Runtime(); + runtime = nnfusion::profiler::RocmDefaultRuntime::Runtime(); NNFUSION_CHECK(runtime->check_env()); kernel_regs = KernelRegistry::Global()->FindKernelRegistrations( - gnode->get_op_type(), CUDA_GPU, element::f32); + gnode->get_op_type(), ROCM_GPU, element::f32); + if (kernel_regs.size() == 0) + kernel_regs = KernelRegistry::Global()->FindKernelRegistrations( + gnode->get_op_type(), CUDA_GPU, element::f32); } bool const_infer_success = false;