diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc index b93eccd8e17242..16cd6011fa9e1c 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -1078,5 +1078,45 @@ TEST_F(InstructionFusionTest, NoSkipScatterComputationsIfNoFusionEmitters) { EXPECT_TRUE(changed); } +static constexpr absl::string_view kReduceModuleString = R"( + HloModule module + + %reduce_max (param0: f32[], param1: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + %maximum.1 = f32[] maximum(f32[] lhs, f32[] rhs) + %convert.8 = bf16[] convert(f32[] maximum.1) + ROOT %convert.9 = f32[] convert(bf16[] convert.8) + } + + ENTRY %main (arg0: f32[13,5,10,62]) + -> f32[13,5,10] { + %arg0 = f32[13,5,10,62]{3,2,1,0} parameter(0) + %init = f32[] constant(0) + ROOT %reduce = f32[13,5,10]{2,1,0} reduce(%arg0, %init), + dimensions={3}, to_apply=reduce_max + } +)"; + +TEST_F(InstructionFusionTest, SkipReduceComputationsIfFusionEmitters) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kReduceModuleString)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InstructionFusionTest, NoSkipReduceComputationsIfNoFusionEmitters) { + auto mod_config = GetModuleConfigForTest(); + auto debug_options = GetDebugOptionsForTest(); + (*debug_options.mutable_xla_backend_extra_options()) + [options::kDisableNewFusionEmitters] = "true"; + mod_config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( + kReduceModuleString, mod_config)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + CpuInstructionFusion().Run(module.get())); + EXPECT_TRUE(changed); +} } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/cpu_options.cc b/third_party/xla/xla/service/cpu/cpu_options.cc index 6de38a21ef8adb..d4de69fffaeff2 100644 --- a/third_party/xla/xla/service/cpu/cpu_options.cc +++ b/third_party/xla/xla/service/cpu/cpu_options.cc @@ -145,4 +145,15 @@ std::optional> LlvmIrGemmTileSize( tile_size_n_in_vector_width); } +bool UseExperimentalLoopFusion(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kDisableNewFusionEmitters) == 0; +} + +bool FlattenAfterFusion(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kFlattenAfterFusion) > 0; +} } // namespace xla::cpu::options diff --git a/third_party/xla/xla/service/cpu/cpu_options.h b/third_party/xla/xla/service/cpu/cpu_options.h index 17f92d5251be66..15b9407ed439a4 100644 --- a/third_party/xla/xla/service/cpu/cpu_options.h +++ b/third_party/xla/xla/service/cpu/cpu_options.h @@ -21,12 +21,17 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/service/hlo_module_config.h" // Helper functions for querying options that are specific to the CPU backend. namespace xla::cpu::options { +inline constexpr absl::string_view kDisableNewFusionEmitters = + "xla_cpu_disable_new_fusion_emitters"; +inline constexpr absl::string_view kFlattenAfterFusion = + "xla_cpu_flatten_after_fusion"; bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool SlpVectorizerDisabled(const HloModuleConfig& config); @@ -38,6 +43,8 @@ std::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); absl::StatusOr SmallWhileLoopByteThreshold( const HloModuleConfig& config); +bool UseExperimentalLoopFusion(const HloModuleConfig& config); +bool FlattenAfterFusion(const HloModuleConfig& config); } // namespace xla::cpu::options