xref: /llvm-project/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains test passes for lowering the gpu.all_reduce op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/GPU/Transforms/Passes.h"
16 #include "mlir/Dialect/Index/IR/IndexDialect.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 struct TestGpuRewritePass
27     : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> {
28   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGpuRewritePass)
29 
30   void getDependentDialects(DialectRegistry &registry) const override {
31     registry.insert<arith::ArithDialect, func::FuncDialect, index::IndexDialect,
32                     memref::MemRefDialect>();
33   }
34   StringRef getArgument() const final { return "test-gpu-rewrite"; }
35   StringRef getDescription() const final {
36     return "Applies all rewrite patterns within the GPU dialect.";
37   }
38   void runOnOperation() override {
39     RewritePatternSet patterns(&getContext());
40     populateGpuRewritePatterns(patterns);
41     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
42   }
43 };
44 
45 struct TestGpuSubgroupReduceLoweringPass
46     : public PassWrapper<TestGpuSubgroupReduceLoweringPass,
47                          OperationPass<ModuleOp>> {
48   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
49       TestGpuSubgroupReduceLoweringPass)
50 
51   TestGpuSubgroupReduceLoweringPass() = default;
52   TestGpuSubgroupReduceLoweringPass(
53       const TestGpuSubgroupReduceLoweringPass &pass)
54       : PassWrapper(pass) {}
55 
56   void getDependentDialects(DialectRegistry &registry) const override {
57     registry.insert<arith::ArithDialect, vector::VectorDialect>();
58   }
59 
60   StringRef getArgument() const final {
61     return "test-gpu-subgroup-reduce-lowering";
62   }
63 
64   StringRef getDescription() const final {
65     return "Applies gpu.subgroup_reduce lowering patterns.";
66   }
67 
68   Option<bool> expandToShuffles{
69       *this, "expand-to-shuffles",
70       llvm::cl::desc("Expand subgroup_reduce ops to shuffle ops."),
71       llvm::cl::init(false)};
72 
73   void runOnOperation() override {
74     RewritePatternSet patterns(&getContext());
75 
76     // Since both pattern sets match on the same ops, set higher benefit to
77     // perform fewer failing matches.
78     populateGpuBreakDownSubgroupReducePatterns(patterns,
79                                                /*maxShuffleBitwidth=*/32,
80                                                PatternBenefit(2));
81     if (expandToShuffles) {
82       populateGpuLowerSubgroupReduceToShufflePatterns(
83           patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
84       populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
85           patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
86     }
87 
88     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
89   }
90 };
91 } // namespace
92 
93 namespace mlir {
94 void registerTestGpuLoweringPasses() {
95   PassRegistration<TestGpuRewritePass>();
96   PassRegistration<TestGpuSubgroupReduceLoweringPass>();
97 }
98 } // namespace mlir
99