1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Insert reshape to binary op's input if needed to match rank 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 17 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 namespace mlir { 23 namespace tosa { 24 #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE 25 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 26 } // namespace tosa 27 } // namespace mlir 28 29 using namespace mlir; 30 using namespace mlir::tosa; 31 32 namespace { 33 34 /// Common code to create the reshape op where necessary to make the rank of the 35 /// operations equal. input1 and input2 will be updated when the rank has 36 /// changed. The caller is expected to use these to rewrite the original 37 /// operator with the RESHAPE now in the graph. 38 /// return failure when (1) no reshape needed, or (2) output_type is specified 39 /// and it has different rank 40 LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, 41 RankedTensorType outputType, Value &input1, 42 Value &input2) { 43 auto input1Ty = dyn_cast<RankedTensorType>(input1.getType()); 44 auto input2Ty = dyn_cast<RankedTensorType>(input2.getType()); 45 46 if (!input1Ty || !input2Ty) { 47 return rewriter.notifyMatchFailure(loc, "input not a ranked tensor"); 48 } 49 50 int64_t input1Rank = input1Ty.getRank(); 51 int64_t input2Rank = input2Ty.getRank(); 52 53 if (input1Rank == input2Rank) 54 return rewriter.notifyMatchFailure(loc, 55 "cannot rewrite as its already correct"); 56 57 Value input1Copy = input1; 58 Value input2Copy = input2; 59 if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) { 60 return rewriter.notifyMatchFailure(loc, "failed to reshape inputs"); 61 } 62 63 // Verify the rank agrees with the output type if the output type is ranked. 64 if (outputType) { 65 if (outputType.getRank() != 66 llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() || 67 outputType.getRank() != 68 llvm::cast<RankedTensorType>(input2Copy.getType()).getRank()) 69 return rewriter.notifyMatchFailure( 70 loc, "the reshaped type doesn't agrees with the ranked output type"); 71 } 72 73 input1 = input1Copy; 74 input2 = input2Copy; 75 76 return success(); 77 } 78 79 template <typename OpTy> 80 struct ConvertTosaOp : public OpRewritePattern<OpTy> { 81 using OpRewritePattern<OpTy>::OpRewritePattern; 82 83 LogicalResult matchAndRewrite(OpTy tosaBinaryOp, 84 PatternRewriter &rewriter) const override { 85 86 Value input1 = tosaBinaryOp.getInput1(); 87 Value input2 = tosaBinaryOp.getInput2(); 88 Value output = tosaBinaryOp.getResult(); 89 90 auto outputType = dyn_cast<RankedTensorType>(output.getType()); 91 if (!outputType) 92 return failure(); 93 94 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 95 input1, input2) 96 .failed()) 97 return failure(); 98 99 rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2); 100 101 return success(); 102 } 103 }; 104 105 // The MulOp has an extra parameter 'shift' not present in other elementwise 106 // binary ops, that necessitates special handling of its builder. 107 template <> 108 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> { 109 using OpRewritePattern<tosa::MulOp>::OpRewritePattern; 110 111 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp, 112 PatternRewriter &rewriter) const override { 113 114 Value input1 = tosaBinaryOp.getInput1(); 115 Value input2 = tosaBinaryOp.getInput2(); 116 Value shift = tosaBinaryOp.getShift(); 117 Value output = tosaBinaryOp.getResult(); 118 auto outputType = dyn_cast<RankedTensorType>(output.getType()); 119 if (!outputType) 120 return failure(); 121 122 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 123 input1, input2) 124 .failed()) 125 return failure(); 126 127 rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1, 128 input2, shift); 129 130 return success(); 131 } 132 }; 133 134 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in 135 // other elementwise binary ops, that necessitates special handling of its 136 // builder. 137 template <> 138 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp> 139 : public OpRewritePattern<tosa::ArithmeticRightShiftOp> { 140 using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern; 141 142 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp, 143 PatternRewriter &rewriter) const override { 144 145 Value input1 = tosaBinaryOp.getInput1(); 146 Value input2 = tosaBinaryOp.getInput2(); 147 int32_t round = tosaBinaryOp.getRound(); 148 Value output = tosaBinaryOp.getResult(); 149 auto outputType = dyn_cast<RankedTensorType>(output.getType()); 150 if (!outputType) 151 return failure(); 152 153 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 154 input1, input2) 155 .failed()) 156 return failure(); 157 158 rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>( 159 tosaBinaryOp, outputType, input1, input2, round); 160 161 return success(); 162 } 163 }; 164 165 template <> 166 struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> { 167 using OpRewritePattern<tosa::SelectOp>::OpRewritePattern; 168 169 LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, 170 PatternRewriter &rewriter) const override { 171 172 Value input1 = tosaOp.getPred(); 173 Value input2 = tosaOp.getOnTrue(); 174 Value input3 = tosaOp.getOnFalse(); 175 Value output = tosaOp.getResult(); 176 177 auto outputType = dyn_cast<RankedTensorType>(output.getType()); 178 if (!outputType) 179 return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); 180 181 // Apply broadcasting to each pair of inputs separately, and chain them as 182 // compound as below so that the broadcasting happens all at once. 183 bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, 184 input1, input2) 185 .succeeded(); 186 187 bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, 188 input1, input3) 189 .succeeded(); 190 191 bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, 192 input2, input3) 193 .succeeded(); 194 195 if (!reshaped1 && !reshaped2 && !reshaped3) 196 return rewriter.notifyMatchFailure( 197 tosaOp, 198 "cannot rewrite as the rank of all operands is already aligned"); 199 200 int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank(); 201 int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank(); 202 int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank(); 203 int32_t outputRank = outputType.getRank(); 204 205 if ((result1Rank != result2Rank) || (result2Rank != result3Rank) || 206 (result1Rank != outputRank)) 207 return rewriter.notifyMatchFailure( 208 tosaOp, "not all ranks are aligned with each other"); 209 210 rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1, 211 input2, input3); 212 213 return success(); 214 } 215 }; 216 } // namespace 217 218 namespace { 219 /// Pass that enables broadcast by making all input arrays have the same 220 /// number of dimensions. Insert RESHAPE operations to lower rank operand 221 struct TosaMakeBroadcastable 222 : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> { 223 public: 224 void runOnOperation() override { 225 auto func = getOperation(); 226 RewritePatternSet patterns(func.getContext()); 227 MLIRContext *ctx = func.getContext(); 228 // Add the generated patterns to the list. 229 patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx); 230 patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx); 231 patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx); 232 patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx); 233 patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx); 234 patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx); 235 patterns.add<ConvertTosaOp<tosa::IntDivOp>>(ctx); 236 patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx); 237 patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx); 238 patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx); 239 patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx); 240 patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx); 241 patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx); 242 patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx); 243 patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx); 244 patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx); 245 patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx); 246 patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx); 247 patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx); 248 patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx); 249 (void)applyPatternsGreedily(func, std::move(patterns)); 250 } 251 }; 252 } // namespace 253 254 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() { 255 return std::make_unique<TosaMakeBroadcastable>(); 256 } 257