xref: /llvm-project/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/LLVMIR/VCIXDialect.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/Vector/IR/VectorOps.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 
19 namespace mlir {
20 namespace {
21 
22 /// Return number of extracts required to make input VectorType \vt legal and
23 /// also return thatlegal vector type.
24 /// For fixed vectors nothing special is needed. Scalable vectors are legalizes
25 /// according to LLVM's encoding:
26 /// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html
27 static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
28   VectorType vt = cast<VectorType>(type);
29   // To simplify test pass, avoid multi-dimensional vectors.
30   if (!vt || vt.getRank() != 1)
31     return {0, nullptr};
32 
33   if (!vt.isScalable())
34     return {1, vt};
35 
36   Type eltTy = vt.getElementType();
37   unsigned sew = 0;
38   if (eltTy.isF32())
39     sew = 32;
40   else if (eltTy.isF64())
41     sew = 64;
42   else if (auto intTy = dyn_cast<IntegerType>(eltTy))
43     sew = intTy.getWidth();
44   else
45     return {0, nullptr};
46 
47   unsigned eltCount = vt.getShape()[0];
48   const unsigned lmul = eltCount * sew / 64;
49 
50   unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1;
51   return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})};
52 }
53 
54 /// Replace math.cos(v) operation with vcix.v.iv(v).
55 struct MathCosToVCIX final : OpRewritePattern<math::CosOp> {
56   using OpRewritePattern::OpRewritePattern;
57 
58   LogicalResult matchAndRewrite(math::CosOp op,
59                                 PatternRewriter &rewriter) const override {
60     const Type opType = op.getOperand().getType();
61     auto [n, legalType] = legalizeVectorType(opType);
62     if (!legalType)
63       return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
64     Location loc = op.getLoc();
65     Value vec = op.getOperand();
66     Attribute immAttr = rewriter.getI32IntegerAttr(0);
67     Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
68     Value rvl = nullptr;
69     if (legalType.isScalable())
70       // Use arbitrary runtime vector length when vector type is scalable.
71       // Proper conversion pass should take it from the IR.
72       rvl = rewriter.create<arith::ConstantOp>(loc,
73                                                rewriter.getI64IntegerAttr(9));
74     Value res;
75     if (n == 1) {
76       res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec,
77                                                immAttr, rvl);
78     } else {
79       const unsigned eltCount = legalType.getShape()[0];
80       Type eltTy = legalType.getElementType();
81       Value zero = rewriter.create<arith::ConstantOp>(
82           loc, eltTy, rewriter.getZeroAttr(eltTy));
83       res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
84       for (unsigned i = 0; i < n; ++i) {
85         Value extracted = rewriter.create<vector::ScalableExtractOp>(
86             loc, legalType, vec, i * eltCount);
87         Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr,
88                                                      extracted, immAttr, rvl);
89         res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
90                                                         i * eltCount);
91       }
92     }
93     rewriter.replaceOp(op, res);
94     return success();
95   }
96 };
97 
98 // Replace math.sin(v) operation with vcix.v.sv(v, v).
99 struct MathSinToVCIX final : OpRewritePattern<math::SinOp> {
100   using OpRewritePattern::OpRewritePattern;
101 
102   LogicalResult matchAndRewrite(math::SinOp op,
103                                 PatternRewriter &rewriter) const override {
104     const Type opType = op.getOperand().getType();
105     auto [n, legalType] = legalizeVectorType(opType);
106     if (!legalType)
107       return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
108     Location loc = op.getLoc();
109     Value vec = op.getOperand();
110     Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
111     Value rvl = nullptr;
112     if (legalType.isScalable())
113       // Use arbitrary runtime vector length when vector type is scalable.
114       // Proper conversion pass should take it from the IR.
115       rvl = rewriter.create<arith::ConstantOp>(loc,
116                                                rewriter.getI64IntegerAttr(9));
117     Value res;
118     if (n == 1) {
119       res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
120                                             vec, rvl);
121     } else {
122       const unsigned eltCount = legalType.getShape()[0];
123       Type eltTy = legalType.getElementType();
124       Value zero = rewriter.create<arith::ConstantOp>(
125           loc, eltTy, rewriter.getZeroAttr(eltTy));
126       res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
127       for (unsigned i = 0; i < n; ++i) {
128         Value extracted = rewriter.create<vector::ScalableExtractOp>(
129             loc, legalType, vec, i * eltCount);
130         Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
131                                                   extracted, extracted, rvl);
132         res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
133                                                         i * eltCount);
134       }
135     }
136     rewriter.replaceOp(op, res);
137     return success();
138   }
139 };
140 
141 // Replace math.tan(v) operation with vcix.v.sv(v, 0.0f).
142 struct MathTanToVCIX final : OpRewritePattern<math::TanOp> {
143   using OpRewritePattern::OpRewritePattern;
144 
145   LogicalResult matchAndRewrite(math::TanOp op,
146                                 PatternRewriter &rewriter) const override {
147     const Type opType = op.getOperand().getType();
148     auto [n, legalType] = legalizeVectorType(opType);
149     Type eltTy = legalType.getElementType();
150     if (!legalType)
151       return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
152     Location loc = op.getLoc();
153     Value vec = op.getOperand();
154     Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
155     Value zero = rewriter.create<arith::ConstantOp>(
156         loc, eltTy, rewriter.getZeroAttr(eltTy));
157     Value rvl = nullptr;
158     if (legalType.isScalable())
159       // Use arbitrary runtime vector length when vector type is scalable.
160       // Proper conversion pass should take it from the IR.
161       rvl = rewriter.create<arith::ConstantOp>(loc,
162                                                rewriter.getI64IntegerAttr(9));
163     Value res;
164     if (n == 1) {
165       res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
166                                             zero, rvl);
167     } else {
168       const unsigned eltCount = legalType.getShape()[0];
169       res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
170       for (unsigned i = 0; i < n; ++i) {
171         Value extracted = rewriter.create<vector::ScalableExtractOp>(
172             loc, legalType, vec, i * eltCount);
173         Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
174                                                   extracted, zero, rvl);
175         res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
176                                                         i * eltCount);
177       }
178     }
179     rewriter.replaceOp(op, res);
180     return success();
181   }
182 };
183 
184 // Replace math.log(v) operation with vcix.v.sv(v, 0).
185 struct MathLogToVCIX final : OpRewritePattern<math::LogOp> {
186   using OpRewritePattern::OpRewritePattern;
187 
188   LogicalResult matchAndRewrite(math::LogOp op,
189                                 PatternRewriter &rewriter) const override {
190     const Type opType = op.getOperand().getType();
191     auto [n, legalType] = legalizeVectorType(opType);
192     if (!legalType)
193       return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
194     Location loc = op.getLoc();
195     Value vec = op.getOperand();
196     Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
197     Value rvl = nullptr;
198     Value zeroInt = rewriter.create<arith::ConstantOp>(
199         loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
200     if (legalType.isScalable())
201       // Use arbitrary runtime vector length when vector type is scalable.
202       // Proper conversion pass should take it from the IR.
203       rvl = rewriter.create<arith::ConstantOp>(loc,
204                                                rewriter.getI64IntegerAttr(9));
205     Value res;
206     if (n == 1) {
207       res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
208                                             zeroInt, rvl);
209     } else {
210       const unsigned eltCount = legalType.getShape()[0];
211       Type eltTy = legalType.getElementType();
212       Value zero = rewriter.create<arith::ConstantOp>(
213           loc, eltTy, rewriter.getZeroAttr(eltTy));
214       res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
215       for (unsigned i = 0; i < n; ++i) {
216         Value extracted = rewriter.create<vector::ScalableExtractOp>(
217             loc, legalType, vec, i * eltCount);
218         Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
219                                                   extracted, zeroInt, rvl);
220         res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
221                                                         i * eltCount);
222       }
223     }
224     rewriter.replaceOp(op, res);
225     return success();
226   }
227 };
228 
229 struct TestMathToVCIX
230     : PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> {
231   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX)
232 
233   StringRef getArgument() const final { return "test-math-to-vcix"; }
234 
235   StringRef getDescription() const final {
236     return "Test lowering patterns that converts some vector operations to "
237            "VCIX. Since DLA can implement VCIX instructions in completely "
238            "different way, conversions of that test pass only lives here.";
239   }
240 
241   void getDependentDialects(DialectRegistry &registry) const override {
242     registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect,
243                     vcix::VCIXDialect, vector::VectorDialect>();
244   }
245 
246   void runOnOperation() override {
247     MLIRContext *ctx = &getContext();
248     RewritePatternSet patterns(ctx);
249     patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>(
250         ctx);
251     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
252   }
253 };
254 
255 } // namespace
256 
257 namespace test {
258 void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); }
259 } // namespace test
260 } // namespace mlir
261