1b2812113SSuraj Sudhir //===- TosaMakeBroadcastable.cpp ------------------------------------------===// 2b2812113SSuraj Sudhir // 3b2812113SSuraj Sudhir // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b2812113SSuraj Sudhir // See https://llvm.org/LICENSE.txt for license information. 5b2812113SSuraj Sudhir // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b2812113SSuraj Sudhir // 7b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 8b2812113SSuraj Sudhir // 9b2812113SSuraj Sudhir // Insert reshape to binary op's input if needed to match rank 10b2812113SSuraj Sudhir // 11b2812113SSuraj Sudhir //===----------------------------------------------------------------------===// 12b2812113SSuraj Sudhir 1367d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h" 148dea784bSRob Suderman #include "mlir/Dialect/Tensor/IR/Tensor.h" 1567d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/Transforms/Passes.h" 17e0537d1aSTai Ly #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 18b2812113SSuraj Sudhir #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 19b2812113SSuraj Sudhir #include "mlir/Pass/Pass.h" 20b2812113SSuraj Sudhir #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21b2812113SSuraj Sudhir 2267d0d7acSMichele Scuttari namespace mlir { 2367d0d7acSMichele Scuttari namespace tosa { 2467d0d7acSMichele Scuttari #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE 2567d0d7acSMichele Scuttari #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 2667d0d7acSMichele Scuttari } // namespace tosa 2767d0d7acSMichele Scuttari } // namespace mlir 2867d0d7acSMichele Scuttari 29b2812113SSuraj Sudhir using namespace mlir; 30b2812113SSuraj Sudhir using namespace mlir::tosa; 31b2812113SSuraj Sudhir 32e0537d1aSTai Ly namespace { 33b2812113SSuraj Sudhir 342d870a2fSMarius Brehler /// Common code to create the reshape op where necessary to make the rank of the 35936819bfSTatWai Chong /// operations equal. input1 and input2 will be updated when the rank has 36936819bfSTatWai Chong /// changed. The caller is expected to use these to rewrite the original 37936819bfSTatWai Chong /// operator with the RESHAPE now in the graph. 38e0537d1aSTai Ly /// return failure when (1) no reshape needed, or (2) output_type is specified 39e0537d1aSTai Ly /// and it has different rank 40e0537d1aSTai Ly LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, 41e0537d1aSTai Ly RankedTensorType outputType, Value &input1, 42e0537d1aSTai Ly Value &input2) { 435550c821STres Popp auto input1Ty = dyn_cast<RankedTensorType>(input1.getType()); 445550c821STres Popp auto input2Ty = dyn_cast<RankedTensorType>(input2.getType()); 45b2812113SSuraj Sudhir 46936819bfSTatWai Chong if (!input1Ty || !input2Ty) { 47936819bfSTatWai Chong return rewriter.notifyMatchFailure(loc, "input not a ranked tensor"); 48936819bfSTatWai Chong } 49286e7bddSRob Suderman 50286e7bddSRob Suderman int64_t input1Rank = input1Ty.getRank(); 51286e7bddSRob Suderman int64_t input2Rank = input2Ty.getRank(); 52b2812113SSuraj Sudhir 539f310becSDavid Blaikie if (input1Rank == input2Rank) 54936819bfSTatWai Chong return rewriter.notifyMatchFailure(loc, 55936819bfSTatWai Chong "cannot rewrite as its already correct"); 569f310becSDavid Blaikie 573745e708STai Ly Value input1Copy = input1; 583745e708STai Ly Value input2Copy = input2; 593745e708STai Ly if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) { 60e0537d1aSTai Ly return rewriter.notifyMatchFailure(loc, "failed to reshape inputs"); 61b2812113SSuraj Sudhir } 62b2812113SSuraj Sudhir 63286e7bddSRob Suderman // Verify the rank agrees with the output type if the output type is ranked. 64286e7bddSRob Suderman if (outputType) { 65e0537d1aSTai Ly if (outputType.getRank() != 663745e708STai Ly llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() || 67e0537d1aSTai Ly outputType.getRank() != 683745e708STai Ly llvm::cast<RankedTensorType>(input2Copy.getType()).getRank()) 69936819bfSTatWai Chong return rewriter.notifyMatchFailure( 70936819bfSTatWai Chong loc, "the reshaped type doesn't agrees with the ranked output type"); 71286e7bddSRob Suderman } 72286e7bddSRob Suderman 733745e708STai Ly input1 = input1Copy; 743745e708STai Ly input2 = input2Copy; 75b2812113SSuraj Sudhir 76286e7bddSRob Suderman return success(); 77b2812113SSuraj Sudhir } 78b2812113SSuraj Sudhir 79936819bfSTatWai Chong template <typename OpTy> 80936819bfSTatWai Chong struct ConvertTosaOp : public OpRewritePattern<OpTy> { 81b2812113SSuraj Sudhir using OpRewritePattern<OpTy>::OpRewritePattern; 82b2812113SSuraj Sudhir 83b2812113SSuraj Sudhir LogicalResult matchAndRewrite(OpTy tosaBinaryOp, 849f310becSDavid Blaikie PatternRewriter &rewriter) const override { 85b2812113SSuraj Sudhir 8613448db0SJacques Pienaar Value input1 = tosaBinaryOp.getInput1(); 8713448db0SJacques Pienaar Value input2 = tosaBinaryOp.getInput2(); 88b2812113SSuraj Sudhir Value output = tosaBinaryOp.getResult(); 89286e7bddSRob Suderman 905550c821STres Popp auto outputType = dyn_cast<RankedTensorType>(output.getType()); 91a35f54c3SRob Suderman if (!outputType) 92a35f54c3SRob Suderman return failure(); 93b2812113SSuraj Sudhir 94b2812113SSuraj Sudhir if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 95936819bfSTatWai Chong input1, input2) 96286e7bddSRob Suderman .failed()) 97b2812113SSuraj Sudhir return failure(); 98b2812113SSuraj Sudhir 99936819bfSTatWai Chong rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2); 100b2812113SSuraj Sudhir 101b2812113SSuraj Sudhir return success(); 102b2812113SSuraj Sudhir } 103b2812113SSuraj Sudhir }; 104b2812113SSuraj Sudhir 105b2812113SSuraj Sudhir // The MulOp has an extra parameter 'shift' not present in other elementwise 106b2812113SSuraj Sudhir // binary ops, that necessitates special handling of its builder. 107b2812113SSuraj Sudhir template <> 108b2812113SSuraj Sudhir struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> { 109b2812113SSuraj Sudhir using OpRewritePattern<tosa::MulOp>::OpRewritePattern; 110b2812113SSuraj Sudhir 111b2812113SSuraj Sudhir LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp, 1129f310becSDavid Blaikie PatternRewriter &rewriter) const override { 113b2812113SSuraj Sudhir 11413448db0SJacques Pienaar Value input1 = tosaBinaryOp.getInput1(); 11513448db0SJacques Pienaar Value input2 = tosaBinaryOp.getInput2(); 116*a58e774fSJack Frankland Value shift = tosaBinaryOp.getShift(); 117b2812113SSuraj Sudhir Value output = tosaBinaryOp.getResult(); 1185550c821STres Popp auto outputType = dyn_cast<RankedTensorType>(output.getType()); 119a35f54c3SRob Suderman if (!outputType) 120a35f54c3SRob Suderman return failure(); 121b2812113SSuraj Sudhir 122b2812113SSuraj Sudhir if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 123936819bfSTatWai Chong input1, input2) 124286e7bddSRob Suderman .failed()) 125b2812113SSuraj Sudhir return failure(); 126b2812113SSuraj Sudhir 127936819bfSTatWai Chong rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1, 128936819bfSTatWai Chong input2, shift); 129b2812113SSuraj Sudhir 130b2812113SSuraj Sudhir return success(); 131b2812113SSuraj Sudhir } 132b2812113SSuraj Sudhir }; 133b2812113SSuraj Sudhir 134b2812113SSuraj Sudhir // The ArithmeticRightShiftOp has an extra parameter 'round' not present in 135b2812113SSuraj Sudhir // other elementwise binary ops, that necessitates special handling of its 136b2812113SSuraj Sudhir // builder. 137b2812113SSuraj Sudhir template <> 138b2812113SSuraj Sudhir struct ConvertTosaOp<tosa::ArithmeticRightShiftOp> 139b2812113SSuraj Sudhir : public OpRewritePattern<tosa::ArithmeticRightShiftOp> { 140b2812113SSuraj Sudhir using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern; 141b2812113SSuraj Sudhir 142b2812113SSuraj Sudhir LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp, 1439f310becSDavid Blaikie PatternRewriter &rewriter) const override { 144b2812113SSuraj Sudhir 14513448db0SJacques Pienaar Value input1 = tosaBinaryOp.getInput1(); 14613448db0SJacques Pienaar Value input2 = tosaBinaryOp.getInput2(); 14713448db0SJacques Pienaar int32_t round = tosaBinaryOp.getRound(); 148b2812113SSuraj Sudhir Value output = tosaBinaryOp.getResult(); 1495550c821STres Popp auto outputType = dyn_cast<RankedTensorType>(output.getType()); 150a35f54c3SRob Suderman if (!outputType) 151a35f54c3SRob Suderman return failure(); 152b2812113SSuraj Sudhir 153b2812113SSuraj Sudhir if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 154936819bfSTatWai Chong input1, input2) 155286e7bddSRob Suderman .failed()) 156b2812113SSuraj Sudhir return failure(); 157b2812113SSuraj Sudhir 158b2812113SSuraj Sudhir rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>( 159936819bfSTatWai Chong tosaBinaryOp, outputType, input1, input2, round); 160936819bfSTatWai Chong 161936819bfSTatWai Chong return success(); 162936819bfSTatWai Chong } 163936819bfSTatWai Chong }; 164936819bfSTatWai Chong 165936819bfSTatWai Chong template <> 166936819bfSTatWai Chong struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> { 167936819bfSTatWai Chong using OpRewritePattern<tosa::SelectOp>::OpRewritePattern; 168936819bfSTatWai Chong 169936819bfSTatWai Chong LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, 170936819bfSTatWai Chong PatternRewriter &rewriter) const override { 171936819bfSTatWai Chong 172936819bfSTatWai Chong Value input1 = tosaOp.getPred(); 173936819bfSTatWai Chong Value input2 = tosaOp.getOnTrue(); 174936819bfSTatWai Chong Value input3 = tosaOp.getOnFalse(); 175936819bfSTatWai Chong Value output = tosaOp.getResult(); 176936819bfSTatWai Chong 1775550c821STres Popp auto outputType = dyn_cast<RankedTensorType>(output.getType()); 178936819bfSTatWai Chong if (!outputType) 179936819bfSTatWai Chong return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); 180936819bfSTatWai Chong 181936819bfSTatWai Chong // Apply broadcasting to each pair of inputs separately, and chain them as 182936819bfSTatWai Chong // compound as below so that the broadcasting happens all at once. 183936819bfSTatWai Chong bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, 184936819bfSTatWai Chong input1, input2) 185936819bfSTatWai Chong .succeeded(); 186936819bfSTatWai Chong 187936819bfSTatWai Chong bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, 188936819bfSTatWai Chong input1, input3) 189936819bfSTatWai Chong .succeeded(); 190936819bfSTatWai Chong 191936819bfSTatWai Chong bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, 192936819bfSTatWai Chong input2, input3) 193936819bfSTatWai Chong .succeeded(); 194936819bfSTatWai Chong 195936819bfSTatWai Chong if (!reshaped1 && !reshaped2 && !reshaped3) 196936819bfSTatWai Chong return rewriter.notifyMatchFailure( 197936819bfSTatWai Chong tosaOp, 198936819bfSTatWai Chong "cannot rewrite as the rank of all operands is already aligned"); 199936819bfSTatWai Chong 2005550c821STres Popp int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank(); 2015550c821STres Popp int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank(); 2025550c821STres Popp int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank(); 203e0537d1aSTai Ly int32_t outputRank = outputType.getRank(); 204936819bfSTatWai Chong 205e0537d1aSTai Ly if ((result1Rank != result2Rank) || (result2Rank != result3Rank) || 206e0537d1aSTai Ly (result1Rank != outputRank)) 207936819bfSTatWai Chong return rewriter.notifyMatchFailure( 208936819bfSTatWai Chong tosaOp, "not all ranks are aligned with each other"); 209936819bfSTatWai Chong 210936819bfSTatWai Chong rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1, 211936819bfSTatWai Chong input2, input3); 212b2812113SSuraj Sudhir 213b2812113SSuraj Sudhir return success(); 214b2812113SSuraj Sudhir } 215b2812113SSuraj Sudhir }; 216be0a7e9fSMehdi Amini } // namespace 217b2812113SSuraj Sudhir 218b2812113SSuraj Sudhir namespace { 219b2812113SSuraj Sudhir /// Pass that enables broadcast by making all input arrays have the same 220b2812113SSuraj Sudhir /// number of dimensions. Insert RESHAPE operations to lower rank operand 221039b969bSMichele Scuttari struct TosaMakeBroadcastable 22267d0d7acSMichele Scuttari : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> { 223b2812113SSuraj Sudhir public: 22441574554SRiver Riddle void runOnOperation() override { 22541574554SRiver Riddle auto func = getOperation(); 226dc4e913bSChris Lattner RewritePatternSet patterns(func.getContext()); 227b2812113SSuraj Sudhir MLIRContext *ctx = func.getContext(); 228b2812113SSuraj Sudhir // Add the generated patterns to the list. 22938ff7e11Snatashaknk patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx); 23038ff7e11Snatashaknk patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx); 23138ff7e11Snatashaknk patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx); 232dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx); 233dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx); 234dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx); 23582383d5fSTai Ly patterns.add<ConvertTosaOp<tosa::IntDivOp>>(ctx); 236dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx); 237dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx); 238dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx); 239dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx); 240dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx); 241dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx); 242dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx); 243dc4e913bSChris Lattner patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx); 24438ff7e11Snatashaknk patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx); 24538ff7e11Snatashaknk patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx); 24638ff7e11Snatashaknk patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx); 247936819bfSTatWai Chong patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx); 24838ff7e11Snatashaknk patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx); 24909dfc571SJacques Pienaar (void)applyPatternsGreedily(func, std::move(patterns)); 250b2812113SSuraj Sudhir } 251b2812113SSuraj Sudhir }; 252be0a7e9fSMehdi Amini } // namespace 253039b969bSMichele Scuttari 254039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() { 255039b969bSMichele Scuttari return std::make_unique<TosaMakeBroadcastable>(); 256039b969bSMichele Scuttari } 257