xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp (revision a58e774fba42e13aa00667d644e96b783fc914b4)
1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 namespace mlir {
23 namespace tosa {
24 #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE
25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
26 } // namespace tosa
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::tosa;
31 
32 namespace {
33 
34 /// Common code to create the reshape op where necessary to make the rank of the
35 /// operations equal. input1 and input2 will be updated when the rank has
36 /// changed. The caller is expected to use these to rewrite the original
37 /// operator with the RESHAPE now in the graph.
38 /// return failure when (1) no reshape needed, or (2) output_type is specified
39 /// and it has different rank
40 LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
41                                    RankedTensorType outputType, Value &input1,
42                                    Value &input2) {
43   auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
44   auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
45 
46   if (!input1Ty || !input2Ty) {
47     return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
48   }
49 
50   int64_t input1Rank = input1Ty.getRank();
51   int64_t input2Rank = input2Ty.getRank();
52 
53   if (input1Rank == input2Rank)
54     return rewriter.notifyMatchFailure(loc,
55                                        "cannot rewrite as its already correct");
56 
57   Value input1Copy = input1;
58   Value input2Copy = input2;
59   if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) {
60     return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
61   }
62 
63   // Verify the rank agrees with the output type if the output type is ranked.
64   if (outputType) {
65     if (outputType.getRank() !=
66             llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() ||
67         outputType.getRank() !=
68             llvm::cast<RankedTensorType>(input2Copy.getType()).getRank())
69       return rewriter.notifyMatchFailure(
70           loc, "the reshaped type doesn't agrees with the ranked output type");
71   }
72 
73   input1 = input1Copy;
74   input2 = input2Copy;
75 
76   return success();
77 }
78 
79 template <typename OpTy>
80 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
81   using OpRewritePattern<OpTy>::OpRewritePattern;
82 
83   LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
84                                 PatternRewriter &rewriter) const override {
85 
86     Value input1 = tosaBinaryOp.getInput1();
87     Value input2 = tosaBinaryOp.getInput2();
88     Value output = tosaBinaryOp.getResult();
89 
90     auto outputType = dyn_cast<RankedTensorType>(output.getType());
91     if (!outputType)
92       return failure();
93 
94     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
95                              input1, input2)
96             .failed())
97       return failure();
98 
99     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
100 
101     return success();
102   }
103 };
104 
105 // The MulOp has an extra parameter 'shift' not present in other elementwise
106 // binary ops, that necessitates special handling of its builder.
107 template <>
108 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
109   using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
110 
111   LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
112                                 PatternRewriter &rewriter) const override {
113 
114     Value input1 = tosaBinaryOp.getInput1();
115     Value input2 = tosaBinaryOp.getInput2();
116     Value shift = tosaBinaryOp.getShift();
117     Value output = tosaBinaryOp.getResult();
118     auto outputType = dyn_cast<RankedTensorType>(output.getType());
119     if (!outputType)
120       return failure();
121 
122     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
123                              input1, input2)
124             .failed())
125       return failure();
126 
127     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
128                                              input2, shift);
129 
130     return success();
131   }
132 };
133 
134 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
135 // other elementwise binary ops, that necessitates special handling of its
136 // builder.
137 template <>
138 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
139     : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
140   using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
141 
142   LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
143                                 PatternRewriter &rewriter) const override {
144 
145     Value input1 = tosaBinaryOp.getInput1();
146     Value input2 = tosaBinaryOp.getInput2();
147     int32_t round = tosaBinaryOp.getRound();
148     Value output = tosaBinaryOp.getResult();
149     auto outputType = dyn_cast<RankedTensorType>(output.getType());
150     if (!outputType)
151       return failure();
152 
153     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
154                              input1, input2)
155             .failed())
156       return failure();
157 
158     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
159         tosaBinaryOp, outputType, input1, input2, round);
160 
161     return success();
162   }
163 };
164 
165 template <>
166 struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
167   using OpRewritePattern<tosa::SelectOp>::OpRewritePattern;
168 
169   LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
170                                 PatternRewriter &rewriter) const override {
171 
172     Value input1 = tosaOp.getPred();
173     Value input2 = tosaOp.getOnTrue();
174     Value input3 = tosaOp.getOnFalse();
175     Value output = tosaOp.getResult();
176 
177     auto outputType = dyn_cast<RankedTensorType>(output.getType());
178     if (!outputType)
179       return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
180 
181     // Apply broadcasting to each pair of inputs separately, and chain them as
182     // compound as below so that the broadcasting happens all at once.
183     bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
184                                           input1, input2)
185                          .succeeded();
186 
187     bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
188                                           input1, input3)
189                          .succeeded();
190 
191     bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
192                                           input2, input3)
193                          .succeeded();
194 
195     if (!reshaped1 && !reshaped2 && !reshaped3)
196       return rewriter.notifyMatchFailure(
197           tosaOp,
198           "cannot rewrite as the rank of all operands is already aligned");
199 
200     int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
201     int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
202     int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
203     int32_t outputRank = outputType.getRank();
204 
205     if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
206         (result1Rank != outputRank))
207       return rewriter.notifyMatchFailure(
208           tosaOp, "not all ranks are aligned with each other");
209 
210     rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
211                                                 input2, input3);
212 
213     return success();
214   }
215 };
216 } // namespace
217 
218 namespace {
219 /// Pass that enables broadcast by making all input arrays have the same
220 /// number of dimensions. Insert RESHAPE operations to lower rank operand
221 struct TosaMakeBroadcastable
222     : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
223 public:
224   void runOnOperation() override {
225     auto func = getOperation();
226     RewritePatternSet patterns(func.getContext());
227     MLIRContext *ctx = func.getContext();
228     // Add the generated patterns to the list.
229     patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
230     patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
231     patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
232     patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
233     patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
234     patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
235     patterns.add<ConvertTosaOp<tosa::IntDivOp>>(ctx);
236     patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
237     patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
238     patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
239     patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
240     patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
241     patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
242     patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
243     patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
244     patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
245     patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
246     patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
247     patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
248     patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
249     (void)applyPatternsGreedily(func, std::move(patterns));
250   }
251 };
252 } // namespace
253 
254 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
255   return std::make_unique<TosaMakeBroadcastable>();
256 }
257