xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp (revision a58e774fba42e13aa00667d644e96b783fc914b4)
1b2812113SSuraj Sudhir //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2b2812113SSuraj Sudhir //
3b2812113SSuraj Sudhir // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b2812113SSuraj Sudhir // See https://llvm.org/LICENSE.txt for license information.
5b2812113SSuraj Sudhir // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b2812113SSuraj Sudhir //
7b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
8b2812113SSuraj Sudhir //
9b2812113SSuraj Sudhir // Insert reshape to binary op's input if needed to match rank
10b2812113SSuraj Sudhir //
11b2812113SSuraj Sudhir //===----------------------------------------------------------------------===//
12b2812113SSuraj Sudhir 
1367d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
148dea784bSRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h"
1567d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17e0537d1aSTai Ly #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19b2812113SSuraj Sudhir #include "mlir/Pass/Pass.h"
20b2812113SSuraj Sudhir #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21b2812113SSuraj Sudhir 
2267d0d7acSMichele Scuttari namespace mlir {
2367d0d7acSMichele Scuttari namespace tosa {
2467d0d7acSMichele Scuttari #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE
2567d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
2667d0d7acSMichele Scuttari } // namespace tosa
2767d0d7acSMichele Scuttari } // namespace mlir
2867d0d7acSMichele Scuttari 
29b2812113SSuraj Sudhir using namespace mlir;
30b2812113SSuraj Sudhir using namespace mlir::tosa;
31b2812113SSuraj Sudhir 
32e0537d1aSTai Ly namespace {
33b2812113SSuraj Sudhir 
342d870a2fSMarius Brehler /// Common code to create the reshape op where necessary to make the rank of the
35936819bfSTatWai Chong /// operations equal. input1 and input2 will be updated when the rank has
36936819bfSTatWai Chong /// changed. The caller is expected to use these to rewrite the original
37936819bfSTatWai Chong /// operator with the RESHAPE now in the graph.
38e0537d1aSTai Ly /// return failure when (1) no reshape needed, or (2) output_type is specified
39e0537d1aSTai Ly /// and it has different rank
40e0537d1aSTai Ly LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
41e0537d1aSTai Ly                                    RankedTensorType outputType, Value &input1,
42e0537d1aSTai Ly                                    Value &input2) {
435550c821STres Popp   auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
445550c821STres Popp   auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
45b2812113SSuraj Sudhir 
46936819bfSTatWai Chong   if (!input1Ty || !input2Ty) {
47936819bfSTatWai Chong     return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
48936819bfSTatWai Chong   }
49286e7bddSRob Suderman 
50286e7bddSRob Suderman   int64_t input1Rank = input1Ty.getRank();
51286e7bddSRob Suderman   int64_t input2Rank = input2Ty.getRank();
52b2812113SSuraj Sudhir 
539f310becSDavid Blaikie   if (input1Rank == input2Rank)
54936819bfSTatWai Chong     return rewriter.notifyMatchFailure(loc,
55936819bfSTatWai Chong                                        "cannot rewrite as its already correct");
569f310becSDavid Blaikie 
573745e708STai Ly   Value input1Copy = input1;
583745e708STai Ly   Value input2Copy = input2;
593745e708STai Ly   if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) {
60e0537d1aSTai Ly     return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
61b2812113SSuraj Sudhir   }
62b2812113SSuraj Sudhir 
63286e7bddSRob Suderman   // Verify the rank agrees with the output type if the output type is ranked.
64286e7bddSRob Suderman   if (outputType) {
65e0537d1aSTai Ly     if (outputType.getRank() !=
663745e708STai Ly             llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() ||
67e0537d1aSTai Ly         outputType.getRank() !=
683745e708STai Ly             llvm::cast<RankedTensorType>(input2Copy.getType()).getRank())
69936819bfSTatWai Chong       return rewriter.notifyMatchFailure(
70936819bfSTatWai Chong           loc, "the reshaped type doesn't agrees with the ranked output type");
71286e7bddSRob Suderman   }
72286e7bddSRob Suderman 
733745e708STai Ly   input1 = input1Copy;
743745e708STai Ly   input2 = input2Copy;
75b2812113SSuraj Sudhir 
76286e7bddSRob Suderman   return success();
77b2812113SSuraj Sudhir }
78b2812113SSuraj Sudhir 
79936819bfSTatWai Chong template <typename OpTy>
80936819bfSTatWai Chong struct ConvertTosaOp : public OpRewritePattern<OpTy> {
81b2812113SSuraj Sudhir   using OpRewritePattern<OpTy>::OpRewritePattern;
82b2812113SSuraj Sudhir 
83b2812113SSuraj Sudhir   LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
849f310becSDavid Blaikie                                 PatternRewriter &rewriter) const override {
85b2812113SSuraj Sudhir 
8613448db0SJacques Pienaar     Value input1 = tosaBinaryOp.getInput1();
8713448db0SJacques Pienaar     Value input2 = tosaBinaryOp.getInput2();
88b2812113SSuraj Sudhir     Value output = tosaBinaryOp.getResult();
89286e7bddSRob Suderman 
905550c821STres Popp     auto outputType = dyn_cast<RankedTensorType>(output.getType());
91a35f54c3SRob Suderman     if (!outputType)
92a35f54c3SRob Suderman       return failure();
93b2812113SSuraj Sudhir 
94b2812113SSuraj Sudhir     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
95936819bfSTatWai Chong                              input1, input2)
96286e7bddSRob Suderman             .failed())
97b2812113SSuraj Sudhir       return failure();
98b2812113SSuraj Sudhir 
99936819bfSTatWai Chong     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
100b2812113SSuraj Sudhir 
101b2812113SSuraj Sudhir     return success();
102b2812113SSuraj Sudhir   }
103b2812113SSuraj Sudhir };
104b2812113SSuraj Sudhir 
105b2812113SSuraj Sudhir // The MulOp has an extra parameter 'shift' not present in other elementwise
106b2812113SSuraj Sudhir // binary ops, that necessitates special handling of its builder.
107b2812113SSuraj Sudhir template <>
108b2812113SSuraj Sudhir struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
109b2812113SSuraj Sudhir   using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
110b2812113SSuraj Sudhir 
111b2812113SSuraj Sudhir   LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
1129f310becSDavid Blaikie                                 PatternRewriter &rewriter) const override {
113b2812113SSuraj Sudhir 
11413448db0SJacques Pienaar     Value input1 = tosaBinaryOp.getInput1();
11513448db0SJacques Pienaar     Value input2 = tosaBinaryOp.getInput2();
116*a58e774fSJack Frankland     Value shift = tosaBinaryOp.getShift();
117b2812113SSuraj Sudhir     Value output = tosaBinaryOp.getResult();
1185550c821STres Popp     auto outputType = dyn_cast<RankedTensorType>(output.getType());
119a35f54c3SRob Suderman     if (!outputType)
120a35f54c3SRob Suderman       return failure();
121b2812113SSuraj Sudhir 
122b2812113SSuraj Sudhir     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
123936819bfSTatWai Chong                              input1, input2)
124286e7bddSRob Suderman             .failed())
125b2812113SSuraj Sudhir       return failure();
126b2812113SSuraj Sudhir 
127936819bfSTatWai Chong     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
128936819bfSTatWai Chong                                              input2, shift);
129b2812113SSuraj Sudhir 
130b2812113SSuraj Sudhir     return success();
131b2812113SSuraj Sudhir   }
132b2812113SSuraj Sudhir };
133b2812113SSuraj Sudhir 
134b2812113SSuraj Sudhir // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
135b2812113SSuraj Sudhir // other elementwise binary ops, that necessitates special handling of its
136b2812113SSuraj Sudhir // builder.
137b2812113SSuraj Sudhir template <>
138b2812113SSuraj Sudhir struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
139b2812113SSuraj Sudhir     : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
140b2812113SSuraj Sudhir   using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
141b2812113SSuraj Sudhir 
142b2812113SSuraj Sudhir   LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
1439f310becSDavid Blaikie                                 PatternRewriter &rewriter) const override {
144b2812113SSuraj Sudhir 
14513448db0SJacques Pienaar     Value input1 = tosaBinaryOp.getInput1();
14613448db0SJacques Pienaar     Value input2 = tosaBinaryOp.getInput2();
14713448db0SJacques Pienaar     int32_t round = tosaBinaryOp.getRound();
148b2812113SSuraj Sudhir     Value output = tosaBinaryOp.getResult();
1495550c821STres Popp     auto outputType = dyn_cast<RankedTensorType>(output.getType());
150a35f54c3SRob Suderman     if (!outputType)
151a35f54c3SRob Suderman       return failure();
152b2812113SSuraj Sudhir 
153b2812113SSuraj Sudhir     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
154936819bfSTatWai Chong                              input1, input2)
155286e7bddSRob Suderman             .failed())
156b2812113SSuraj Sudhir       return failure();
157b2812113SSuraj Sudhir 
158b2812113SSuraj Sudhir     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
159936819bfSTatWai Chong         tosaBinaryOp, outputType, input1, input2, round);
160936819bfSTatWai Chong 
161936819bfSTatWai Chong     return success();
162936819bfSTatWai Chong   }
163936819bfSTatWai Chong };
164936819bfSTatWai Chong 
165936819bfSTatWai Chong template <>
166936819bfSTatWai Chong struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
167936819bfSTatWai Chong   using OpRewritePattern<tosa::SelectOp>::OpRewritePattern;
168936819bfSTatWai Chong 
169936819bfSTatWai Chong   LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
170936819bfSTatWai Chong                                 PatternRewriter &rewriter) const override {
171936819bfSTatWai Chong 
172936819bfSTatWai Chong     Value input1 = tosaOp.getPred();
173936819bfSTatWai Chong     Value input2 = tosaOp.getOnTrue();
174936819bfSTatWai Chong     Value input3 = tosaOp.getOnFalse();
175936819bfSTatWai Chong     Value output = tosaOp.getResult();
176936819bfSTatWai Chong 
1775550c821STres Popp     auto outputType = dyn_cast<RankedTensorType>(output.getType());
178936819bfSTatWai Chong     if (!outputType)
179936819bfSTatWai Chong       return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
180936819bfSTatWai Chong 
181936819bfSTatWai Chong     // Apply broadcasting to each pair of inputs separately, and chain them as
182936819bfSTatWai Chong     // compound as below so that the broadcasting happens all at once.
183936819bfSTatWai Chong     bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
184936819bfSTatWai Chong                                           input1, input2)
185936819bfSTatWai Chong                          .succeeded();
186936819bfSTatWai Chong 
187936819bfSTatWai Chong     bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
188936819bfSTatWai Chong                                           input1, input3)
189936819bfSTatWai Chong                          .succeeded();
190936819bfSTatWai Chong 
191936819bfSTatWai Chong     bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
192936819bfSTatWai Chong                                           input2, input3)
193936819bfSTatWai Chong                          .succeeded();
194936819bfSTatWai Chong 
195936819bfSTatWai Chong     if (!reshaped1 && !reshaped2 && !reshaped3)
196936819bfSTatWai Chong       return rewriter.notifyMatchFailure(
197936819bfSTatWai Chong           tosaOp,
198936819bfSTatWai Chong           "cannot rewrite as the rank of all operands is already aligned");
199936819bfSTatWai Chong 
2005550c821STres Popp     int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
2015550c821STres Popp     int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
2025550c821STres Popp     int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
203e0537d1aSTai Ly     int32_t outputRank = outputType.getRank();
204936819bfSTatWai Chong 
205e0537d1aSTai Ly     if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
206e0537d1aSTai Ly         (result1Rank != outputRank))
207936819bfSTatWai Chong       return rewriter.notifyMatchFailure(
208936819bfSTatWai Chong           tosaOp, "not all ranks are aligned with each other");
209936819bfSTatWai Chong 
210936819bfSTatWai Chong     rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
211936819bfSTatWai Chong                                                 input2, input3);
212b2812113SSuraj Sudhir 
213b2812113SSuraj Sudhir     return success();
214b2812113SSuraj Sudhir   }
215b2812113SSuraj Sudhir };
216be0a7e9fSMehdi Amini } // namespace
217b2812113SSuraj Sudhir 
218b2812113SSuraj Sudhir namespace {
219b2812113SSuraj Sudhir /// Pass that enables broadcast by making all input arrays have the same
220b2812113SSuraj Sudhir /// number of dimensions. Insert RESHAPE operations to lower rank operand
221039b969bSMichele Scuttari struct TosaMakeBroadcastable
22267d0d7acSMichele Scuttari     : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
223b2812113SSuraj Sudhir public:
22441574554SRiver Riddle   void runOnOperation() override {
22541574554SRiver Riddle     auto func = getOperation();
226dc4e913bSChris Lattner     RewritePatternSet patterns(func.getContext());
227b2812113SSuraj Sudhir     MLIRContext *ctx = func.getContext();
228b2812113SSuraj Sudhir     // Add the generated patterns to the list.
22938ff7e11Snatashaknk     patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
23038ff7e11Snatashaknk     patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
23138ff7e11Snatashaknk     patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
232dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
233dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
234dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
23582383d5fSTai Ly     patterns.add<ConvertTosaOp<tosa::IntDivOp>>(ctx);
236dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
237dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
238dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
239dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
240dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
241dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
242dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
243dc4e913bSChris Lattner     patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
24438ff7e11Snatashaknk     patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
24538ff7e11Snatashaknk     patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
24638ff7e11Snatashaknk     patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
247936819bfSTatWai Chong     patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
24838ff7e11Snatashaknk     patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
24909dfc571SJacques Pienaar     (void)applyPatternsGreedily(func, std::move(patterns));
250b2812113SSuraj Sudhir   }
251b2812113SSuraj Sudhir };
252be0a7e9fSMehdi Amini } // namespace
253039b969bSMichele Scuttari 
254039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
255039b969bSMichele Scuttari   return std::make_unique<TosaMakeBroadcastable>();
256039b969bSMichele Scuttari }
257