1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // These rewriters lower from the Tosa to the Arith dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TosaToArith/TosaToArith.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/IR/TypeUtilities.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 using namespace mlir; 21 using namespace tosa; 22 23 namespace { 24 25 class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> { 26 public: 27 using OpRewritePattern<tosa::ConstOp>::OpRewritePattern; 28 29 LogicalResult matchAndRewrite(tosa::ConstOp op, 30 PatternRewriter &rewriter) const final { 31 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue()); 32 return success(); 33 } 34 }; 35 36 Type matchContainerType(Type element, Type container) { 37 if (auto shapedTy = dyn_cast<ShapedType>(container)) 38 return shapedTy.clone(element); 39 40 return element; 41 } 42 43 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { 44 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 45 Type eTy = shapedTy.getElementType(); 46 APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true); 47 return DenseIntElementsAttr::get(shapedTy, valueInt); 48 } 49 50 return rewriter.getIntegerAttr(type, value); 51 } 52 53 Value getConstantValue(Location loc, Type type, int64_t value, 54 PatternRewriter &rewriter) { 55 return rewriter.create<arith::ConstantOp>( 56 loc, getConstantAttr(type, value, rewriter)); 57 } 58 59 // This converts the TOSA ApplyScale operator to a set of arithmetic ops, 60 // using 64-bit operations to perform the necessary multiply, bias, and shift. 61 class ApplyScaleGenericOpConverter 62 : public OpRewritePattern<tosa::ApplyScaleOp> { 63 public: 64 using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern; 65 66 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, 67 PatternRewriter &rewriter) const final { 68 Location loc = op.getLoc(); 69 Value value = op.getValue(); 70 Value multiplier32 = op.getMultiplier(); 71 72 Type resultTy = op.getType(); 73 Type valueTy = value.getType(); 74 Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); 75 Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy); 76 77 Value zero = getConstantValue(loc, valueTy, 0, rewriter); 78 Value one64 = getConstantValue(loc, i64Ty, 1, rewriter); 79 Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter); 80 81 Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); 82 83 // Compute the multiplication in 64-bits then select the high / low parts. 84 Value value64 = value; 85 if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) 86 value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value); 87 Value multiplier64 = 88 rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32); 89 Value multiply64 = 90 rewriter.create<arith::MulIOp>(loc, value64, multiplier64); 91 92 // Apply normal rounding. 93 Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32); 94 Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64); 95 round = rewriter.create<arith::ShRUIOp>(loc, round, one64); 96 multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round); 97 98 // Apply double rounding if necessary. 99 if (op.getDoubleRound()) { 100 int64_t roundInt = 1 << 30; 101 Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter); 102 Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter); 103 Value positive = rewriter.create<arith::CmpIOp>( 104 loc, arith::CmpIPredicate::sge, value, zero); 105 Value dir = 106 rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown); 107 Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64); 108 Value valid = rewriter.create<arith::CmpIOp>( 109 loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); 110 multiply64 = 111 rewriter.create<arith::SelectOp>(loc, valid, val, multiply64); 112 } 113 114 Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64); 115 Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64); 116 117 rewriter.replaceOp(op, result32); 118 return success(); 119 } 120 }; 121 122 class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> { 123 public: 124 using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern; 125 126 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, 127 PatternRewriter &rewriter) const final { 128 Location loc = op.getLoc(); 129 130 Type resultTy = op.getType(); 131 Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); 132 133 Value value = op.getValue(); 134 if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) { 135 return failure(); 136 } 137 138 Value value32 = op.getValue(); 139 Value multiplier32 = op.getMultiplier(); 140 Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); 141 142 // Constants used during the scaling operation. 143 Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter); 144 Value one32 = getConstantValue(loc, i32Ty, 1, rewriter); 145 Value two32 = getConstantValue(loc, i32Ty, 2, rewriter); 146 Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter); 147 Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter); 148 149 // Compute the multiplication in 64-bits then select the high / low parts. 150 // Grab out the high/low of the computation 151 auto value64 = 152 rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32); 153 Value low32 = value64.getLow(); 154 Value high32 = value64.getHigh(); 155 156 // Determine the direction and amount to shift the high bits. 157 Value shiftOver32 = rewriter.create<arith::CmpIOp>( 158 loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); 159 Value roundHighBits = rewriter.create<arith::CmpIOp>( 160 loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); 161 162 Value shiftHighL = 163 rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32); 164 Value shiftHighR = 165 rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32); 166 167 shiftHighL = 168 rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL); 169 shiftHighR = 170 rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32); 171 172 // Conditionally perform our double round. 173 if (op.getDoubleRound()) { 174 Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter); 175 Value valuePositive = rewriter.create<arith::CmpIOp>( 176 loc, arith::CmpIPredicate::sge, value32, zero32); 177 178 Value roundDir = 179 rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32); 180 roundDir = 181 rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32); 182 183 Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32); 184 Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir); 185 Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32); 186 187 Value shiftRound = 188 rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32); 189 190 low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound); 191 high32 = rewriter.create<arith::AddIOp>(loc, high32, carry); 192 } 193 194 // Conditionally apply rounding in the low bits. 195 { 196 Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32); 197 Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); 198 roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32, 199 roundBit); 200 201 Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit); 202 Value wasRounded = rewriter.create<arith::CmpIOp>( 203 loc, arith::CmpIPredicate::ugt, low32, newLow32); 204 low32 = newLow32; 205 206 Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded); 207 high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32); 208 } 209 210 // Conditionally apply rounding in the high bits. 211 { 212 Value shiftSubOne = 213 rewriter.create<arith::SubIOp>(loc, shiftHighR, one32); 214 Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); 215 roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit, 216 zero32); 217 high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit); 218 } 219 220 // Combine the correct high/low bits into the final rescale result. 221 high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL); 222 high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR); 223 low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32); 224 low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32); 225 226 // Apply the rounding behavior and shift to the final alignment. 227 Value result = rewriter.create<arith::AddIOp>(loc, low32, high32); 228 229 // Truncate if necessary. 230 if (!getElementTypeOrSelf(resultTy).isInteger(32)) { 231 result = rewriter.create<arith::TruncIOp>(loc, resultTy, result); 232 } 233 234 rewriter.replaceOp(op, result); 235 return success(); 236 } 237 }; 238 239 } // namespace 240 241 void mlir::tosa::populateTosaToArithConversionPatterns( 242 RewritePatternSet *patterns) { 243 patterns->add<ConstOpConverter>(patterns->getContext()); 244 } 245 246 void mlir::tosa::populateTosaRescaleToArithConversionPatterns( 247 RewritePatternSet *patterns, bool include32Bit) { 248 patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100); 249 if (include32Bit) { 250 patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200); 251 } 252 } 253