From 679ae9bc94cac28211c37c7f447d628d4fbead67 Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Wed, 21 Apr 2021 06:02:18 +0000 Subject: [PATCH 1/9] Add Fused Div Add Softmax OP; --- .../core/kernels/cuda_gpu/cuda_langunit.cpp | 183 ++++++++++++++++++ .../kernels/fused_div_add_softmax.cpp | 87 +++++++++ .../graph/bertfusion_optimizer/CMakeLists.txt | 1 + .../softmax_related_fusion_optimizer.cpp | 89 +++++++++ .../softmax_related_fusion_optimizer.hpp | 32 +++ .../engine/pass/graph/bertfusion_pass.cpp | 15 ++ 6 files changed, 407 insertions(+) create mode 100644 src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp create mode 100644 src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.cpp create mode 100644 src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.hpp diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp index 3cbedd785..1a35f9fd0 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp @@ -3159,6 +3159,189 @@ inline void DispatchSoftmax(cudaStream_t stream, const int64_t rows, const int64 DispatchSoftmaxBlockUncachedImpl(stream, rows, cols, x, y); } } + +template +__global__ void FusedSoftmaxWarpImpl(const int64_t rows, const int64_t cols, const T* x, const T* x1, const T* x2, T* y) { + static_assert(cols_per_thread % pack_size == 0, ""); + constexpr int num_packs = cols_per_thread / pack_size; + assert(cols <= cols_per_thread * kWarpSize); + using ComputeType = typename GetComputeType::type; + ComputeType buf[cols_per_thread]; + ComputeType buf_bias[cols_per_thread]; + const int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y; + const int num_global_warp = gridDim.x * blockDim.y; + const int lane_id = threadIdx.x; + for (int64_t row = global_warp_id; row < rows; row += num_global_warp) { + const int64_t row_offset = row * cols; + const T* row_x = x + row_offset; + const T scale = x1[0]; + const T* row_x_bias = x2 + row_offset; + + T* row_y = y + row_offset; + ComputeType thread_max = -Inf(); +#pragma unroll + for (int pack_id = 0; pack_id < num_packs; ++pack_id) { + const int col = (pack_id * kWarpSize + lane_id) * pack_size; + if (!padding || col < cols) { + MultiFetch()(buf + pack_id * pack_size, row_x + col); + MultiFetch()(buf_bias + pack_id * pack_size, row_x_bias + col); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + buf[pack_id * pack_size + i] = add(fdividef(buf[pack_id * pack_size + i], scale), buf_bias[pack_id * pack_size + i]); + thread_max = max(thread_max, buf[pack_id * pack_size + i]); + //thread_max = max(thread_max, buf[pack_id * pack_size + i]); + } + } else { +#pragma unroll + for (int i = 0; i < pack_size; ++i) { buf[pack_id * pack_size + i] = -Inf(); } +#pragma unroll + for (int i = 0; i < pack_size; ++i) { buf_bias[pack_id * pack_size + i] = -Inf(); } + } + } + const ComputeType warp_max = WarpAllReduce(thread_max); + ComputeType thread_sum = 0; +#pragma unroll + for (int i = 0; i < cols_per_thread; ++i) { + buf[i] = exp(buf[i] - warp_max); + thread_sum += buf[i]; + } + const ComputeType warp_sum = WarpAllReduce(thread_sum); +#pragma unroll + for (int i = 0; i < cols_per_thread; ++i) { buf[i] = buf[i] / warp_sum; } +#pragma unroll + for (int i = 0; i < num_packs; ++i) { + const int col = (i * kWarpSize + lane_id) * pack_size; + if (!padding || col < cols) { + MultiStore()(row_y + col, buf + i * pack_size); + } + } + } +} + +template +inline void LaunchFusedSoftmaxWarpImpl(cudaStream_t stream, const int64_t rows, const int64_t cols, + const T* x, const T* x1, const T* x2, T* y) { + constexpr int block_size = 128; + constexpr int waves = 32; + static_assert(block_size % kWarpSize == 0, ""); + constexpr int rows_per_block = block_size / kWarpSize; + dim3 block_dim(kWarpSize, rows_per_block); + const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; + const int grid_dim_x = GetNumBlocks(block_size, num_blocks, waves); + FusedSoftmaxWarpImpl + <<>>(rows, cols, x, x1, x2, y); +} + +template +inline void DispatchFusedSoftmaxWarpImplPadding(cudaStream_t stream, const int64_t rows, + const int64_t cols, const T* x, const T* x1, const T* x2, T* y) { + if (cols == cols_per_thread * kWarpSize) { + LaunchFusedSoftmaxWarpImpl(stream, rows, cols, x, x1, x2, y); + } else { + LaunchFusedSoftmaxWarpImpl(stream, rows, cols, x, x1, x2, y); + } +} + +template +typename std::enable_if::type DispatchFusedSoftmaxWarpImplCols(cudaStream_t stream, + const int64_t rows, + const int64_t cols, + const T* x, const T* x1, const T* x2, T* y) { + if (cols <= 0) { return; } +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + DispatchFusedSoftmaxWarpImplPadding(stream, rows, cols, x, x1, x2, y); \ + } + DEFINE_ONE_ELIF(1) + DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(3) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(5) + DEFINE_ONE_ELIF(6) + DEFINE_ONE_ELIF(7) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(9) + DEFINE_ONE_ELIF(10) + DEFINE_ONE_ELIF(11) + DEFINE_ONE_ELIF(12) + DEFINE_ONE_ELIF(13) + DEFINE_ONE_ELIF(14) + DEFINE_ONE_ELIF(15) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(17) + DEFINE_ONE_ELIF(18) + DEFINE_ONE_ELIF(19) + DEFINE_ONE_ELIF(20) + DEFINE_ONE_ELIF(21) + DEFINE_ONE_ELIF(22) + DEFINE_ONE_ELIF(23) + DEFINE_ONE_ELIF(24) + DEFINE_ONE_ELIF(25) + DEFINE_ONE_ELIF(26) + DEFINE_ONE_ELIF(27) + DEFINE_ONE_ELIF(28) + DEFINE_ONE_ELIF(29) + DEFINE_ONE_ELIF(30) + DEFINE_ONE_ELIF(31) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF + else { + return; + } +} + +template +typename std::enable_if::type DispatchFusedSoftmaxWarpImplCols(cudaStream_t stream, + const int64_t rows, + const int64_t cols, + const T* x, const T* x1, const T* x2, T* y) { + if (cols <= 0) { return; } +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + DispatchFusedSoftmaxWarpImplPadding(stream, rows, cols, x, x1, x2, y); \ + } + DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(6) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(10) + DEFINE_ONE_ELIF(12) + DEFINE_ONE_ELIF(14) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(18) + DEFINE_ONE_ELIF(20) + DEFINE_ONE_ELIF(22) + DEFINE_ONE_ELIF(24) + DEFINE_ONE_ELIF(26) + DEFINE_ONE_ELIF(28) + DEFINE_ONE_ELIF(30) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF + else { + return; + } +} + +template +inline void DispatchFusedSoftmaxWarpImplPackSize(cudaStream_t stream, const int64_t rows, + const int64_t cols, const T* x, const T* x1, const T* x2, T* y) { + DispatchFusedSoftmaxWarpImplCols(stream, rows, cols, x, x1, x2, y); +} + +template +inline void DispatchFusedSoftmaxWarpImpl(cudaStream_t stream, const int64_t rows, const int64_t cols, + const T* x, const T* x1, const T* x2, T* y) { + DispatchFusedSoftmaxWarpImplPackSize(stream, rows, cols, x, x1, x2, y); +} + + +template +inline void DispatchFusedSoftmaxWarp(cudaStream_t stream, const int64_t rows, const int64_t cols, const T* x, const T* x1, const T* x2, T* y) { + if (cols <= 1024) { + DispatchFusedSoftmaxWarpImpl(stream, rows, cols, x, x1, x2, y); + } +} + )" , diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp new file mode 100644 index 000000000..64f31e3ef --- /dev/null +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// This is the 2rd-generation of kernel definition, recommend to extend new ops with this style +// Changes needed for creating an new kernel with 2rd generation style. +// + +#include "../cuda_emitter.hpp" +#include "../cuda_langunit.hpp" +#include "nnfusion/core/operators/generic_op/generic_op.hpp" + +namespace nnfusion +{ + namespace kernels + { + namespace cuda + { + class FusedDivAddSoftmax : public CudaLibEmitter + { + shared_ptr generic_op; + nnfusion::Shape input_shape, output_shape; + int N, D; + element::Type dtype; + + public: + FusedDivAddSoftmax(shared_ptr ctx) + : CudaLibEmitter(ctx) + , generic_op( + static_pointer_cast(ctx->gnode->get_op_ptr())) + { + input_shape = nnfusion::Shape(ctx->inputs[0]->get_shape()); + output_shape = nnfusion::Shape(ctx->outputs[0]->get_shape()); + dtype = m_context->inputs[0]->get_element_type(); + size_t axis = output_shape.size() - 1; + + N = 1; + D = 1; + for (size_t i = 0; i < input_shape.size(); i++) + { + if (i < axis) + { + N *= input_shape[i]; + } + else + { + D *= input_shape[i]; + } + } + } + + LanguageUnit_p emit_function_body() override + { + LanguageUnit_p _lu(new LanguageUnit(get_function_name())); + auto& lu = *_lu; + auto code = nnfusion::op::create_code_from_template( + R"( + DispatchSoftmax<@dtype@>(stream, @N@, @D@, input0, output0); + )", + {{"dtype", (dtype == element::f16) ? "half" : "float"}, {"D", D}, {"N", N}}); + + lu << code << "\n"; + return _lu; + } + + LanguageUnit_p emit_dependency() override + { + GENERIC_OP_LOGGING(); + + LanguageUnit_p _lu(new LanguageUnit(get_function_name() + "_dep")); + _lu->require(declaration::oneflow_softmax); + declaration::oneflow_softmax->require(header::math_constants); + declaration::oneflow_softmax->require(header::cub); + + return _lu; + } + }; + } // namespace cuda + } // namespace kernels +} // namespace nnfusion + +using namespace nnfusion; +using namespace nnfusion::kernels; + +REGISTER_KERNEL_EMITTER( + "FusedDivAddSoftmax", // op_name + Device(CUDA_GPU).TypeConstraint(element::f32).Tag("cuda_kernel").Priority(2), // attrs + cuda::FusedDivAddSoftmax) // constructor diff --git a/src/nnfusion/engine/pass/graph/bertfusion_optimizer/CMakeLists.txt b/src/nnfusion/engine/pass/graph/bertfusion_optimizer/CMakeLists.txt index a71dbf241..80be986ae 100644 --- a/src/nnfusion/engine/pass/graph/bertfusion_optimizer/CMakeLists.txt +++ b/src/nnfusion/engine/pass/graph/bertfusion_optimizer/CMakeLists.txt @@ -10,6 +10,7 @@ set(SRC embedlayernorm_fusion_optimizer.cpp skiplayernorm_fusion_optimizer.cpp matmuladd_fusion_optimizer.cpp + softmax_related_fusion_optimizer.cpp ) add_library(nnfusion_engine_pass_graph_bertfusion STATIC ${SRC}) diff --git a/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.cpp b/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.cpp new file mode 100644 index 000000000..4a295cc76 --- /dev/null +++ b/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.cpp @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "softmax_related_fusion_optimizer.hpp" +#include "nnfusion/frontend/util/evaluator.hpp" + +using namespace nnfusion; +using namespace nnfusion::graph; +using namespace nnfusion::pass::graph; + +bool SoftmaxRelatedFusionOptimizer::CheckStartingNode(std::shared_ptr node) +{ + if (node->get_op_type() == "Softmax") + return true; + return false; +} + + +bool SoftmaxRelatedFusionOptimizer::FindSubGraph(std::shared_ptr starting_node, + std::shared_ptr bertfusion_group) +{ + NNFUSION_CHECK_NOT_NULLPTR(bertfusion_group); + //reshape, broadcast, divide, add, softmax + auto softmax = starting_node; + auto add = softmax->get_in_edge(0)->get_src(); + auto div = add->get_in_edge(0)->get_src(); + auto reshape = div->get_in_edge(0)->get_src(); + auto broadcast = div->get_in_edge(1)->get_src(); + + bertfusion_group->fuse_group["inputs"] = { + broadcast, reshape, div, add, softmax + }; + bertfusion_group->nodes_to_remove.insert({ + broadcast, reshape, div, add, softmax + }); + + bertfusion_group->helper_nodes.push_back(starting_node); + bertfusion_group->edge_nodes.push_back(softmax); + return true; +} + +bool SoftmaxRelatedFusionOptimizer::FuseSubGraph(std::shared_ptr bertfusion_group) +{ + NNFUSION_CHECK_NOT_NULLPTR(bertfusion_group); + auto& fuse_group = bertfusion_group->fuse_group; + NNFUSION_CHECK(fuse_group.find("inputs") != fuse_group.end()); + auto& inputs = fuse_group["inputs"]; + NNFUSION_CHECK(inputs.size() == 5); + + // fused = softmax(reshape(input0)/input1 + broadcast(input2)); + auto broadcast = inputs[0]; + auto reshape = inputs[1]; + auto div = inputs[2]; + auto add = inputs[3]; + auto softmax = inputs[4]; + + auto input = reshape->get_in_edge(0)->get_src(); + auto input_scale = broadcast ->get_in_edge(0)->get_src(); + auto bias = add->get_in_edge(1)->get_src(); + + nnfusion::op::OpConfig::any myConfig; + auto fused_softmax = std::make_shared( + "Fused_Div_Add_Softmax_" + input->get_name() + "_ " + input_scale->get_name(), + "FusedDivAddSoftmax", myConfig); + + auto fused_softmax_node = m_graph->add_node_and_edge(fused_softmax, + { + GNodeIndex{input, 0}, + GNodeIndex{input_scale, 0}, + GNodeIndex{bias, 0}, + } + ); + fused_softmax_node->set_output_type_and_shape( + 0, fused_softmax_node->get_input_element_type(0), softmax->get_input_shape(0)); + + // replace edge + for (auto edge_node : bertfusion_group->edge_nodes) + { + auto out_edges = edge_node->get_out_edges(); + for (auto out_edge : out_edges) + { + auto dst = out_edge->get_dst(); + int y = out_edge->get_dst_input(); + m_graph->remove_edge(out_edge); + m_graph->add_edge(fused_softmax_node, 0, dst, y); + } + } + return RemoveNodes(bertfusion_group->nodes_to_remove, fused_softmax_node); +} diff --git a/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.hpp b/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.hpp new file mode 100644 index 000000000..26c511b17 --- /dev/null +++ b/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include "bertfusion_optimizer.hpp" + +namespace nnfusion +{ + namespace pass + { + namespace graph + { + class SoftmaxRelatedFusionOptimizer : public BertFusionOptimizer + { + public: + SoftmaxRelatedFusionOptimizer(std::shared_ptr graph) + : BertFusionOptimizer(graph) + { + + } + + private: + bool CheckStartingNode(std::shared_ptr node) override; + bool FindSubGraph(std::shared_ptr starting_node, + std::shared_ptr bertfusion_group) override; + bool FuseSubGraph(std::shared_ptr bertfusion_group) override; + }; + + } // namespace graph + } // namespace pass +} // namespace nnfusion diff --git a/src/nnfusion/engine/pass/graph/bertfusion_pass.cpp b/src/nnfusion/engine/pass/graph/bertfusion_pass.cpp index 2cd83bc42..cb1f78951 100644 --- a/src/nnfusion/engine/pass/graph/bertfusion_pass.cpp +++ b/src/nnfusion/engine/pass/graph/bertfusion_pass.cpp @@ -9,10 +9,12 @@ #include "bertfusion_optimizer/layernorm_fusion_optimizer.hpp" #include "bertfusion_optimizer/matmuladd_fusion_optimizer.hpp" #include "bertfusion_optimizer/skiplayernorm_fusion_optimizer.hpp" +#include "bertfusion_optimizer/softmax_related_fusion_optimizer.hpp" using namespace nnfusion; using namespace nnfusion::pass::graph; +DEFINE_bool(fsoftmax_related_fusion, false, ""); DEFINE_bool(fattention_fusion, false, ""); DEFINE_bool(flayernorm_fusion, false, ""); DEFINE_bool(fembedlayernorm_fusion, false, ""); @@ -100,5 +102,18 @@ bool BertFusionPass::run_on_graph(std::shared_ptr& graph NNFUSION_LOG(INFO) << "MatMulAddFusion Optimization Done."; } } + + if (FLAGS_fsoftmax_related_fusion || FLAGS_fenable_all_bert_fusion) + { + auto optimizer = std::make_shared(graph); + if (!optimizer->Optimize()) + { + NNFUSION_LOG(NNFUSION_WARNING) << "Softmax Related Optimization failed."; + } + else + { + NNFUSION_LOG(INFO) << "Softmax Related Optimization Done."; + } + } return true; } From 54246396d2ccc3584cfaec501e896032af88ddf6 Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Wed, 21 Apr 2021 07:37:51 +0000 Subject: [PATCH 2/9] Add op define; --- .../generic_op/generic_op_define/FusedDivAddSoftmax.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 src/nnfusion/core/operators/generic_op/generic_op_define/FusedDivAddSoftmax.cpp diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/FusedDivAddSoftmax.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/FusedDivAddSoftmax.cpp new file mode 100644 index 000000000..c41543934 --- /dev/null +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/FusedDivAddSoftmax.cpp @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "nnfusion/core/operators/generic_op/generic_op.hpp" + +REGISTER_OP(FusedDivAddSoftmax).infershape([](std::shared_ptr gnode) -> void { + gnode->set_output_type_and_shape( + 0, gnode->get_input_element_type(0), gnode->get_input_shape(0)); +}); From 4fe83f71031bbdc0a3a5596ce7cdb3eab5aed13b Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Wed, 21 Apr 2021 08:32:15 +0000 Subject: [PATCH 3/9] Add cu; --- .../core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp index 64f31e3ef..d31f817e6 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp @@ -64,10 +64,9 @@ namespace nnfusion LanguageUnit_p emit_dependency() override { - GENERIC_OP_LOGGING(); - LanguageUnit_p _lu(new LanguageUnit(get_function_name() + "_dep")); _lu->require(declaration::oneflow_softmax); + _lu->require(header::cub); declaration::oneflow_softmax->require(header::math_constants); declaration::oneflow_softmax->require(header::cub); From 7b2d5711baaccf074409c16153e8fa594c9b7544 Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Wed, 21 Apr 2021 08:56:22 +0000 Subject: [PATCH 4/9] Add cu; --- .../core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp index d31f817e6..87eb90149 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp @@ -54,7 +54,7 @@ namespace nnfusion auto& lu = *_lu; auto code = nnfusion::op::create_code_from_template( R"( - DispatchSoftmax<@dtype@>(stream, @N@, @D@, input0, output0); + DispatchSoftmax<@dtype@>(0, @N@, @D@, input0, output0); )", {{"dtype", (dtype == element::f16) ? "half" : "float"}, {"D", D}, {"N", N}}); From e765aca038cd919c3f635082c72fcfab31757f84 Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Wed, 21 Apr 2021 09:14:40 +0000 Subject: [PATCH 5/9] Add cu; --- src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp index 562c95727..89cde1d1c 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/softmax.cpp @@ -156,6 +156,7 @@ LanguageUnit_p cuda::Softmax::emit_dependency() if (softmax_ori == 0) { _lu->require(declaration::oneflow_softmax); + _lu->require(header::cub); declaration::oneflow_softmax->require(header::math_constants); declaration::oneflow_softmax->require(header::cub); } From b97ee878d899074eba69de4a58690a2c8eb9924e Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Wed, 21 Apr 2021 10:37:14 +0000 Subject: [PATCH 6/9] Add cu; --- .../core/kernels/cuda_gpu/cuda_langunit.cpp | 182 ++++++++++++++++++ 1 file changed, 182 insertions(+) diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp index 1a35f9fd0..702b2324e 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_langunit.cpp @@ -2673,6 +2673,188 @@ inline void DispatchSoftmax(cudaStream_t stream, const int64_t rows, const int64 DispatchSoftmaxBlockUncachedImpl(stream, rows, cols, x, y); } } + +template +__global__ void FusedSoftmaxWarpImpl(const int64_t rows, const int64_t cols, const T* x, const T* x1, const T* x2, T* y) { + static_assert(cols_per_thread % pack_size == 0, ""); + constexpr int num_packs = cols_per_thread / pack_size; + assert(cols <= cols_per_thread * kWarpSize); + using ComputeType = typename GetComputeType::type; + ComputeType buf[cols_per_thread]; + ComputeType buf_bias[cols_per_thread]; + const int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y; + const int num_global_warp = gridDim.x * blockDim.y; + const int lane_id = threadIdx.x; + for (int64_t row = global_warp_id; row < rows; row += num_global_warp) { + const int64_t row_offset = row * cols; + const T* row_x = x + row_offset; + const T scale = x1[0]; + const T* row_x_bias = x2 + row_offset; + + T* row_y = y + row_offset; + ComputeType thread_max = -Inf(); +#pragma unroll + for (int pack_id = 0; pack_id < num_packs; ++pack_id) { + const int col = (pack_id * kWarpSize + lane_id) * pack_size; + if (!padding || col < cols) { + MultiFetch()(buf + pack_id * pack_size, row_x + col); + MultiFetch()(buf_bias + pack_id * pack_size, row_x_bias + col); +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + buf[pack_id * pack_size + i] = add(fdividef(buf[pack_id * pack_size + i], scale), buf_bias[pack_id * pack_size + i]); + thread_max = max(thread_max, buf[pack_id * pack_size + i]); + //thread_max = max(thread_max, buf[pack_id * pack_size + i]); + } + } else { +#pragma unroll + for (int i = 0; i < pack_size; ++i) { buf[pack_id * pack_size + i] = -Inf(); } +#pragma unroll + for (int i = 0; i < pack_size; ++i) { buf_bias[pack_id * pack_size + i] = -Inf(); } + } + } + const ComputeType warp_max = WarpAllReduce(thread_max); + ComputeType thread_sum = 0; +#pragma unroll + for (int i = 0; i < cols_per_thread; ++i) { + buf[i] = exp(buf[i] - warp_max); + thread_sum += buf[i]; + } + const ComputeType warp_sum = WarpAllReduce(thread_sum); +#pragma unroll + for (int i = 0; i < cols_per_thread; ++i) { buf[i] = buf[i] / warp_sum; } +#pragma unroll + for (int i = 0; i < num_packs; ++i) { + const int col = (i * kWarpSize + lane_id) * pack_size; + if (!padding || col < cols) { + MultiStore()(row_y + col, buf + i * pack_size); + } + } + } +} + +template +inline void LaunchFusedSoftmaxWarpImpl(cudaStream_t stream, const int64_t rows, const int64_t cols, + const T* x, const T* x1, const T* x2, T* y) { + constexpr int block_size = 128; + constexpr int waves = 32; + static_assert(block_size % kWarpSize == 0, ""); + constexpr int rows_per_block = block_size / kWarpSize; + dim3 block_dim(kWarpSize, rows_per_block); + const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; + const int grid_dim_x = GetNumBlocks(block_size, num_blocks, waves); + FusedSoftmaxWarpImpl + <<>>(rows, cols, x, x1, x2, y); +} + +template +inline void DispatchFusedSoftmaxWarpImplPadding(cudaStream_t stream, const int64_t rows, + const int64_t cols, const T* x, const T* x1, const T* x2, T* y) { + if (cols == cols_per_thread * kWarpSize) { + LaunchFusedSoftmaxWarpImpl(stream, rows, cols, x, x1, x2, y); + } else { + LaunchFusedSoftmaxWarpImpl(stream, rows, cols, x, x1, x2, y); + } +} + +template +typename std::enable_if::type DispatchFusedSoftmaxWarpImplCols(cudaStream_t stream, + const int64_t rows, + const int64_t cols, + const T* x, const T* x1, const T* x2, T* y) { + if (cols <= 0) { return; } +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + DispatchFusedSoftmaxWarpImplPadding(stream, rows, cols, x, x1, x2, y); \ + } + DEFINE_ONE_ELIF(1) + DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(3) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(5) + DEFINE_ONE_ELIF(6) + DEFINE_ONE_ELIF(7) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(9) + DEFINE_ONE_ELIF(10) + DEFINE_ONE_ELIF(11) + DEFINE_ONE_ELIF(12) + DEFINE_ONE_ELIF(13) + DEFINE_ONE_ELIF(14) + DEFINE_ONE_ELIF(15) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(17) + DEFINE_ONE_ELIF(18) + DEFINE_ONE_ELIF(19) + DEFINE_ONE_ELIF(20) + DEFINE_ONE_ELIF(21) + DEFINE_ONE_ELIF(22) + DEFINE_ONE_ELIF(23) + DEFINE_ONE_ELIF(24) + DEFINE_ONE_ELIF(25) + DEFINE_ONE_ELIF(26) + DEFINE_ONE_ELIF(27) + DEFINE_ONE_ELIF(28) + DEFINE_ONE_ELIF(29) + DEFINE_ONE_ELIF(30) + DEFINE_ONE_ELIF(31) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF + else { + return; + } +} + +template +typename std::enable_if::type DispatchFusedSoftmaxWarpImplCols(cudaStream_t stream, + const int64_t rows, + const int64_t cols, + const T* x, const T* x1, const T* x2, T* y) { + if (cols <= 0) { return; } +#define DEFINE_ONE_ELIF(col) \ + else if (cols <= (col)*kWarpSize) { \ + DispatchFusedSoftmaxWarpImplPadding(stream, rows, cols, x, x1, x2, y); \ + } + DEFINE_ONE_ELIF(2) + DEFINE_ONE_ELIF(4) + DEFINE_ONE_ELIF(6) + DEFINE_ONE_ELIF(8) + DEFINE_ONE_ELIF(10) + DEFINE_ONE_ELIF(12) + DEFINE_ONE_ELIF(14) + DEFINE_ONE_ELIF(16) + DEFINE_ONE_ELIF(18) + DEFINE_ONE_ELIF(20) + DEFINE_ONE_ELIF(22) + DEFINE_ONE_ELIF(24) + DEFINE_ONE_ELIF(26) + DEFINE_ONE_ELIF(28) + DEFINE_ONE_ELIF(30) + DEFINE_ONE_ELIF(32) +#undef DEFINE_ONE_ELIF + else { + return; + } +} + +template +inline void DispatchFusedSoftmaxWarpImplPackSize(cudaStream_t stream, const int64_t rows, + const int64_t cols, const T* x, const T* x1, const T* x2, T* y) { + DispatchFusedSoftmaxWarpImplCols(stream, rows, cols, x, x1, x2, y); +} + +template +inline void DispatchFusedSoftmaxWarpImpl(cudaStream_t stream, const int64_t rows, const int64_t cols, + const T* x, const T* x1, const T* x2, T* y) { + DispatchFusedSoftmaxWarpImplPackSize(stream, rows, cols, x, x1, x2, y); +} + + +template +inline void DispatchFusedSoftmaxWarp(cudaStream_t stream, const int64_t rows, const int64_t cols, const T* x, const T* x1, const T* x2, T* y) { + if (cols <= 1024) { + DispatchFusedSoftmaxWarpImpl(stream, rows, cols, x, x1, x2, y); + } +} )", R"( /* From c8c9c77dbff8d27913f72c2e6eba7d725e404530 Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Wed, 21 Apr 2021 10:51:33 +0000 Subject: [PATCH 7/9] Add cu; --- .../core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp b/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp index 87eb90149..7d4221b58 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp +++ b/src/nnfusion/core/kernels/cuda_gpu/kernels/fused_div_add_softmax.cpp @@ -54,7 +54,7 @@ namespace nnfusion auto& lu = *_lu; auto code = nnfusion::op::create_code_from_template( R"( - DispatchSoftmax<@dtype@>(0, @N@, @D@, input0, output0); + DispatchFusedSoftmaxWarp<@dtype@>(0, @N@, @D@, input0, input1, input2, output0); )", {{"dtype", (dtype == element::f16) ? "half" : "float"}, {"D", D}, {"N", N}}); From 4520aa30b8af6344a15d8300293d8bd334f97bf1 Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Thu, 22 Apr 2021 05:08:00 +0000 Subject: [PATCH 8/9] Enable softmax related fusion by default; --- src/nnfusion/engine/pass/graph/bertfusion_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnfusion/engine/pass/graph/bertfusion_pass.cpp b/src/nnfusion/engine/pass/graph/bertfusion_pass.cpp index cb1f78951..805206d89 100644 --- a/src/nnfusion/engine/pass/graph/bertfusion_pass.cpp +++ b/src/nnfusion/engine/pass/graph/bertfusion_pass.cpp @@ -14,7 +14,7 @@ using namespace nnfusion; using namespace nnfusion::pass::graph; -DEFINE_bool(fsoftmax_related_fusion, false, ""); +DEFINE_bool(fsoftmax_related_fusion, true, ""); DEFINE_bool(fattention_fusion, false, ""); DEFINE_bool(flayernorm_fusion, false, ""); DEFINE_bool(fembedlayernorm_fusion, false, ""); From 82a82f265e3f6bb7e1c3e7cd4ad10c7959e998cf Mon Sep 17 00:00:00 2001 From: Wenxiang Hu Date: Thu, 22 Apr 2021 05:12:57 +0000 Subject: [PATCH 9/9] check col for softmax; --- .../softmax_related_fusion_optimizer.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.cpp b/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.cpp index 4a295cc76..f190e3998 100644 --- a/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.cpp +++ b/src/nnfusion/engine/pass/graph/bertfusion_optimizer/softmax_related_fusion_optimizer.cpp @@ -11,7 +11,13 @@ using namespace nnfusion::pass::graph; bool SoftmaxRelatedFusionOptimizer::CheckStartingNode(std::shared_ptr node) { if (node->get_op_type() == "Softmax") - return true; + { + auto is = node->get_input_shape(0); + auto col = is.back(); + if(col<=1024) + return true; + return false; + } return false; }