xref: /llvm-project/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (revision e692af85966903614d470a7742ed89d124baf1a6)
1 //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Arith dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToArith/TosaToArith.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 using namespace tosa;
22 
23 namespace {
24 
25 class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
26 public:
27   using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
28 
29   LogicalResult matchAndRewrite(tosa::ConstOp op,
30                                 PatternRewriter &rewriter) const final {
31     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
32     return success();
33   }
34 };
35 
36 Type matchContainerType(Type element, Type container) {
37   if (auto shapedTy = dyn_cast<ShapedType>(container))
38     return shapedTy.clone(element);
39 
40   return element;
41 }
42 
43 TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
44   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45     Type eTy = shapedTy.getElementType();
46     APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
47     return DenseIntElementsAttr::get(shapedTy, valueInt);
48   }
49 
50   return rewriter.getIntegerAttr(type, value);
51 }
52 
53 Value getConstantValue(Location loc, Type type, int64_t value,
54                        PatternRewriter &rewriter) {
55   return rewriter.create<arith::ConstantOp>(
56       loc, getConstantAttr(type, value, rewriter));
57 }
58 
59 // This converts the TOSA ApplyScale operator to a set of arithmetic ops,
60 // using 64-bit operations to perform the necessary multiply, bias, and shift.
61 class ApplyScaleGenericOpConverter
62     : public OpRewritePattern<tosa::ApplyScaleOp> {
63 public:
64   using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
65 
66   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
67                                 PatternRewriter &rewriter) const final {
68     Location loc = op.getLoc();
69     Value value = op.getValue();
70     Value multiplier32 = op.getMultiplier();
71 
72     Type resultTy = op.getType();
73     Type valueTy = value.getType();
74     Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
75     Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
76 
77     Value zero = getConstantValue(loc, valueTy, 0, rewriter);
78     Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
79     Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
80 
81     Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
82 
83     // Compute the multiplication in 64-bits then select the high / low parts.
84     Value value64 = value;
85     if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
86       value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
87     Value multiplier64 =
88         rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
89     Value multiply64 =
90         rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
91 
92     // Apply normal rounding.
93     Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
94     Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
95     round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
96     multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
97 
98     // Apply double rounding if necessary.
99     if (op.getDoubleRound()) {
100       int64_t roundInt = 1 << 30;
101       Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
102       Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
103       Value positive = rewriter.create<arith::CmpIOp>(
104           loc, arith::CmpIPredicate::sge, value, zero);
105       Value dir =
106           rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
107       Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
108       Value valid = rewriter.create<arith::CmpIOp>(
109           loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
110       multiply64 =
111           rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
112     }
113 
114     Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
115     Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
116 
117     rewriter.replaceOp(op, result32);
118     return success();
119   }
120 };
121 
122 class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
123 public:
124   using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
125 
126   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
127                                 PatternRewriter &rewriter) const final {
128     Location loc = op.getLoc();
129 
130     Type resultTy = op.getType();
131     Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
132 
133     Value value = op.getValue();
134     if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
135       return failure();
136     }
137 
138     Value value32 = op.getValue();
139     Value multiplier32 = op.getMultiplier();
140     Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
141 
142     // Constants used during the scaling operation.
143     Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
144     Value one32 = getConstantValue(loc, i32Ty, 1, rewriter);
145     Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
146     Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
147     Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
148 
149     // Compute the multiplication in 64-bits then select the high / low parts.
150     // Grab out the high/low of the computation
151     auto value64 =
152         rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
153     Value low32 = value64.getLow();
154     Value high32 = value64.getHigh();
155 
156     // Determine the direction and amount to shift the high bits.
157     Value shiftOver32 = rewriter.create<arith::CmpIOp>(
158         loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
159     Value roundHighBits = rewriter.create<arith::CmpIOp>(
160         loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
161 
162     Value shiftHighL =
163         rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
164     Value shiftHighR =
165         rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
166 
167     shiftHighL =
168         rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
169     shiftHighR =
170         rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
171 
172     // Conditionally perform our double round.
173     if (op.getDoubleRound()) {
174       Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
175       Value valuePositive = rewriter.create<arith::CmpIOp>(
176           loc, arith::CmpIPredicate::sge, value32, zero32);
177 
178       Value roundDir =
179           rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
180       roundDir =
181           rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
182 
183       Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
184       Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
185       Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
186 
187       Value shiftRound =
188           rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
189 
190       low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
191       high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
192     }
193 
194     // Conditionally apply rounding in the low bits.
195     {
196       Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
197       Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
198       roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
199                                                   roundBit);
200 
201       Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
202       Value wasRounded = rewriter.create<arith::CmpIOp>(
203           loc, arith::CmpIPredicate::ugt, low32, newLow32);
204       low32 = newLow32;
205 
206       Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
207       high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
208     }
209 
210     // Conditionally apply rounding in the high bits.
211     {
212       Value shiftSubOne =
213           rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
214       Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
215       roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
216                                                   zero32);
217       high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
218     }
219 
220     // Combine the correct high/low bits into the final rescale result.
221     high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
222     high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
223     low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
224     low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
225 
226     // Apply the rounding behavior and shift to the final alignment.
227     Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
228 
229     // Truncate if necessary.
230     if (!getElementTypeOrSelf(resultTy).isInteger(32)) {
231       result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
232     }
233 
234     rewriter.replaceOp(op, result);
235     return success();
236   }
237 };
238 
239 } // namespace
240 
241 void mlir::tosa::populateTosaToArithConversionPatterns(
242     RewritePatternSet *patterns) {
243   patterns->add<ConstOpConverter>(patterns->getContext());
244 }
245 
246 void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
247     RewritePatternSet *patterns, bool include32Bit) {
248   patterns->add<ApplyScaleGenericOpConverter>(patterns->getContext(), 100);
249   if (include32Bit) {
250     patterns->add<ApplyScale32BitOpConverter>(patterns->getContext(), 200);
251   }
252 }
253