13fef2d26SRiver Riddle //===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle // 93fef2d26SRiver Riddle // This file contains test passes for lowering the gpu.all_reduce op. 103fef2d26SRiver Riddle // 113fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 123fef2d26SRiver Riddle 13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 15d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h" 16dd16cd73SFabian Mora #include "mlir/Dialect/Index/IR/IndexDialect.h" 173fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 182af186f9SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h" 19c0345b46SJakub Kuderski #include "mlir/IR/PatternMatch.h" 203fef2d26SRiver Riddle #include "mlir/Pass/Pass.h" 213fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 223fef2d26SRiver Riddle 233fef2d26SRiver Riddle using namespace mlir; 243fef2d26SRiver Riddle 253fef2d26SRiver Riddle namespace { 263fef2d26SRiver Riddle struct TestGpuRewritePass 273fef2d26SRiver Riddle : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> { 285e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGpuRewritePass) 295e50dd04SRiver Riddle 303fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override { 31dd16cd73SFabian Mora registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect, 32a54f4eaeSMogball memref::MemRefDialect>(); 333fef2d26SRiver Riddle } 34b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-gpu-rewrite"; } 35b5e22e6dSMehdi Amini StringRef getDescription() const final { 36b5e22e6dSMehdi Amini return "Applies all rewrite patterns within the GPU dialect."; 37b5e22e6dSMehdi Amini } 383fef2d26SRiver Riddle void runOnOperation() override { 393fef2d26SRiver Riddle RewritePatternSet patterns(&getContext()); 403fef2d26SRiver Riddle populateGpuRewritePatterns(patterns); 41*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 423fef2d26SRiver Riddle } 433fef2d26SRiver Riddle }; 442af186f9SJakub Kuderski 452af186f9SJakub Kuderski struct TestGpuSubgroupReduceLoweringPass 462af186f9SJakub Kuderski : public PassWrapper<TestGpuSubgroupReduceLoweringPass, 472af186f9SJakub Kuderski OperationPass<ModuleOp>> { 482af186f9SJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 492af186f9SJakub Kuderski TestGpuSubgroupReduceLoweringPass) 502af186f9SJakub Kuderski 51c0345b46SJakub Kuderski TestGpuSubgroupReduceLoweringPass() = default; 52c0345b46SJakub Kuderski TestGpuSubgroupReduceLoweringPass( 53c0345b46SJakub Kuderski const TestGpuSubgroupReduceLoweringPass &pass) 54c0345b46SJakub Kuderski : PassWrapper(pass) {} 55c0345b46SJakub Kuderski 562af186f9SJakub Kuderski void getDependentDialects(DialectRegistry ®istry) const override { 572af186f9SJakub Kuderski registry.insert<arith::ArithDialect, vector::VectorDialect>(); 582af186f9SJakub Kuderski } 59c0345b46SJakub Kuderski 602af186f9SJakub Kuderski StringRef getArgument() const final { 612af186f9SJakub Kuderski return "test-gpu-subgroup-reduce-lowering"; 622af186f9SJakub Kuderski } 63c0345b46SJakub Kuderski 642af186f9SJakub Kuderski StringRef getDescription() const final { 652af186f9SJakub Kuderski return "Applies gpu.subgroup_reduce lowering patterns."; 662af186f9SJakub Kuderski } 67c0345b46SJakub Kuderski 68c0345b46SJakub Kuderski Option<bool> expandToShuffles{ 69c0345b46SJakub Kuderski *this, "expand-to-shuffles", 70c0345b46SJakub Kuderski llvm::cl::desc("Expand subgroup_reduce ops to shuffle ops."), 71c0345b46SJakub Kuderski llvm::cl::init(false)}; 72c0345b46SJakub Kuderski 732af186f9SJakub Kuderski void runOnOperation() override { 742af186f9SJakub Kuderski RewritePatternSet patterns(&getContext()); 75c0345b46SJakub Kuderski 76c0345b46SJakub Kuderski // Since both pattern sets match on the same ops, set higher benefit to 77c0345b46SJakub Kuderski // perform fewer failing matches. 78fd26f844SAndrea Faulds populateGpuBreakDownSubgroupReducePatterns(patterns, 79c0345b46SJakub Kuderski /*maxShuffleBitwidth=*/32, 80c0345b46SJakub Kuderski PatternBenefit(2)); 81a800ffacSAndrea Faulds if (expandToShuffles) { 82fd26f844SAndrea Faulds populateGpuLowerSubgroupReduceToShufflePatterns( 83c0345b46SJakub Kuderski patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32); 84a800ffacSAndrea Faulds populateGpuLowerClusteredSubgroupReduceToShufflePatterns( 85a800ffacSAndrea Faulds patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32); 86a800ffacSAndrea Faulds } 87c0345b46SJakub Kuderski 88*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 892af186f9SJakub Kuderski } 902af186f9SJakub Kuderski }; 913fef2d26SRiver Riddle } // namespace 923fef2d26SRiver Riddle 933fef2d26SRiver Riddle namespace mlir { 942af186f9SJakub Kuderski void registerTestGpuLoweringPasses() { 95b5e22e6dSMehdi Amini PassRegistration<TestGpuRewritePass>(); 962af186f9SJakub Kuderski PassRegistration<TestGpuSubgroupReduceLoweringPass>(); 973fef2d26SRiver Riddle } 983fef2d26SRiver Riddle } // namespace mlir 99