xref: /llvm-project/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
13fef2d26SRiver Riddle //===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file contains test passes for lowering the gpu.all_reduce op.
103fef2d26SRiver Riddle //
113fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
123fef2d26SRiver Riddle 
13abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
15d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
16dd16cd73SFabian Mora #include "mlir/Dialect/Index/IR/IndexDialect.h"
173fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
182af186f9SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h"
19c0345b46SJakub Kuderski #include "mlir/IR/PatternMatch.h"
203fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
213fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
223fef2d26SRiver Riddle 
233fef2d26SRiver Riddle using namespace mlir;
243fef2d26SRiver Riddle 
253fef2d26SRiver Riddle namespace {
263fef2d26SRiver Riddle struct TestGpuRewritePass
273fef2d26SRiver Riddle     : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> {
285e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGpuRewritePass)
295e50dd04SRiver Riddle 
303fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
31dd16cd73SFabian Mora     registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
32a54f4eaeSMogball                     memref::MemRefDialect>();
333fef2d26SRiver Riddle   }
34b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-gpu-rewrite"; }
35b5e22e6dSMehdi Amini   StringRef getDescription() const final {
36b5e22e6dSMehdi Amini     return "Applies all rewrite patterns within the GPU dialect.";
37b5e22e6dSMehdi Amini   }
383fef2d26SRiver Riddle   void runOnOperation() override {
393fef2d26SRiver Riddle     RewritePatternSet patterns(&getContext());
403fef2d26SRiver Riddle     populateGpuRewritePatterns(patterns);
41*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
423fef2d26SRiver Riddle   }
433fef2d26SRiver Riddle };
442af186f9SJakub Kuderski 
452af186f9SJakub Kuderski struct TestGpuSubgroupReduceLoweringPass
462af186f9SJakub Kuderski     : public PassWrapper<TestGpuSubgroupReduceLoweringPass,
472af186f9SJakub Kuderski                          OperationPass<ModuleOp>> {
482af186f9SJakub Kuderski   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
492af186f9SJakub Kuderski       TestGpuSubgroupReduceLoweringPass)
502af186f9SJakub Kuderski 
51c0345b46SJakub Kuderski   TestGpuSubgroupReduceLoweringPass() = default;
52c0345b46SJakub Kuderski   TestGpuSubgroupReduceLoweringPass(
53c0345b46SJakub Kuderski       const TestGpuSubgroupReduceLoweringPass &pass)
54c0345b46SJakub Kuderski       : PassWrapper(pass) {}
55c0345b46SJakub Kuderski 
562af186f9SJakub Kuderski   void getDependentDialects(DialectRegistry &registry) const override {
572af186f9SJakub Kuderski     registry.insert<arith::ArithDialect, vector::VectorDialect>();
582af186f9SJakub Kuderski   }
59c0345b46SJakub Kuderski 
602af186f9SJakub Kuderski   StringRef getArgument() const final {
612af186f9SJakub Kuderski     return "test-gpu-subgroup-reduce-lowering";
622af186f9SJakub Kuderski   }
63c0345b46SJakub Kuderski 
642af186f9SJakub Kuderski   StringRef getDescription() const final {
652af186f9SJakub Kuderski     return "Applies gpu.subgroup_reduce lowering patterns.";
662af186f9SJakub Kuderski   }
67c0345b46SJakub Kuderski 
68c0345b46SJakub Kuderski   Option<bool> expandToShuffles{
69c0345b46SJakub Kuderski       *this, "expand-to-shuffles",
70c0345b46SJakub Kuderski       llvm::cl::desc("Expand subgroup_reduce ops to shuffle ops."),
71c0345b46SJakub Kuderski       llvm::cl::init(false)};
72c0345b46SJakub Kuderski 
732af186f9SJakub Kuderski   void runOnOperation() override {
742af186f9SJakub Kuderski     RewritePatternSet patterns(&getContext());
75c0345b46SJakub Kuderski 
76c0345b46SJakub Kuderski     // Since both pattern sets match on the same ops, set higher benefit to
77c0345b46SJakub Kuderski     // perform fewer failing matches.
78fd26f844SAndrea Faulds     populateGpuBreakDownSubgroupReducePatterns(patterns,
79c0345b46SJakub Kuderski                                                /*maxShuffleBitwidth=*/32,
80c0345b46SJakub Kuderski                                                PatternBenefit(2));
81a800ffacSAndrea Faulds     if (expandToShuffles) {
82fd26f844SAndrea Faulds       populateGpuLowerSubgroupReduceToShufflePatterns(
83c0345b46SJakub Kuderski           patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
84a800ffacSAndrea Faulds       populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
85a800ffacSAndrea Faulds           patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
86a800ffacSAndrea Faulds     }
87c0345b46SJakub Kuderski 
88*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
892af186f9SJakub Kuderski   }
902af186f9SJakub Kuderski };
913fef2d26SRiver Riddle } // namespace
923fef2d26SRiver Riddle 
933fef2d26SRiver Riddle namespace mlir {
942af186f9SJakub Kuderski void registerTestGpuLoweringPasses() {
95b5e22e6dSMehdi Amini   PassRegistration<TestGpuRewritePass>();
962af186f9SJakub Kuderski   PassRegistration<TestGpuSubgroupReduceLoweringPass>();
973fef2d26SRiver Riddle }
983fef2d26SRiver Riddle } // namespace mlir
99