1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arith/Transforms/Passes.h"
10
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "mlir/Transforms/DialectConversion.h"
16
17 namespace mlir {
18 namespace arith {
19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 } // namespace arith
22 } // namespace mlir
23
24 using namespace mlir;
25
26 /// Create an integer or index constant.
createConst(Location loc,Type type,int value,PatternRewriter & rewriter)27 static Value createConst(Location loc, Type type, int value,
28 PatternRewriter &rewriter) {
29 auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return rewriter.create<arith::ConstantOp>(
32 loc, DenseElementsAttr::get(shapedTy, attr));
33 }
34
35 return rewriter.create<arith::ConstantOp>(loc, attr);
36 }
37
38 namespace {
39
40 /// Expands CeilDivUIOp (n, m) into
41 /// n == 0 ? 0 : ((n-1) / m) + 1
42 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
43 using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::CeilDivUIOpConverter44 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
45 PatternRewriter &rewriter) const final {
46 Location loc = op.getLoc();
47 Value a = op.getLhs();
48 Value b = op.getRhs();
49 Value zero = createConst(loc, a.getType(), 0, rewriter);
50 Value compare =
51 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
52 Value one = createConst(loc, a.getType(), 1, rewriter);
53 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
57 return success();
58 }
59 };
60
61 /// Expands CeilDivSIOp (n, m) into
62 /// 1) x = (m > 0) ? -1 : 1
63 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
64 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
65 using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::CeilDivSIOpConverter66 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
67 PatternRewriter &rewriter) const final {
68 Location loc = op.getLoc();
69 Type type = op.getType();
70 Value a = op.getLhs();
71 Value b = op.getRhs();
72 Value plusOne = createConst(loc, type, 1, rewriter);
73 Value zero = createConst(loc, type, 0, rewriter);
74 Value minusOne = createConst(loc, type, -1, rewriter);
75 // Compute x = (b>0) ? -1 : 1.
76 Value compare =
77 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78 Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
79 // Compute positive res: 1 + ((x+a)/b).
80 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
83 // Compute negative res: - ((-a)/b).
84 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
87 // Result is (a*b>0) ? pos result : neg result.
88 // Note, we want to avoid using a*b because of possible overflow.
89 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
90 // not particuliarly care if a*b<0 is true or false when b is zero
91 // as this will result in an illegal divide. So `a*b<0` can be reformulated
92 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
93 // We pick the first expression here.
94 Value aNeg =
95 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
96 Value aPos =
97 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
98 Value bNeg =
99 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
100 Value bPos =
101 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
104 Value compareRes =
105 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
106 // Perform substitution and return success.
107 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
108 negRes);
109 return success();
110 }
111 };
112
113 /// Expands FloorDivSIOp (x, y) into
114 /// z = x / y
115 /// if (z * y != x && (x < 0) != (y < 0)) {
116 /// return z - 1;
117 /// } else {
118 /// return z;
119 /// }
120 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
121 using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::FloorDivSIOpConverter122 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
123 PatternRewriter &rewriter) const final {
124 Location loc = op.getLoc();
125 Type type = op.getType();
126 Value a = op.getLhs();
127 Value b = op.getRhs();
128
129 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132 loc, arith::CmpIPredicate::ne, a, product);
133 Value zero = createConst(loc, type, 0, rewriter);
134
135 Value aNeg =
136 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
137 Value bNeg =
138 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
139
140 Value signOpposite = rewriter.create<arith::CmpIOp>(
141 loc, arith::CmpIPredicate::ne, aNeg, bNeg);
142 Value cond =
143 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
144
145 Value minusOne = createConst(loc, type, -1, rewriter);
146 Value quotientMinusOne =
147 rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
148
149 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
150 quotient);
151 return success();
152 }
153 };
154
155 template <typename OpTy, arith::CmpIPredicate pred>
156 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
157 public:
158 using OpRewritePattern<OpTy>::OpRewritePattern;
159
matchAndRewrite__anon521e2a780111::MaxMinIOpConverter160 LogicalResult matchAndRewrite(OpTy op,
161 PatternRewriter &rewriter) const final {
162 Value lhs = op.getLhs();
163 Value rhs = op.getRhs();
164
165 Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
166 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
167 return success();
168 }
169 };
170
171 template <typename OpTy, arith::CmpFPredicate pred>
172 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
173 public:
174 using OpRewritePattern<OpTy>::OpRewritePattern;
175
matchAndRewrite__anon521e2a780111::MaximumMinimumFOpConverter176 LogicalResult matchAndRewrite(OpTy op,
177 PatternRewriter &rewriter) const final {
178 Value lhs = op.getLhs();
179 Value rhs = op.getRhs();
180
181 Location loc = op.getLoc();
182 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
183 static_assert(pred == arith::CmpFPredicate::UGT ||
184 pred == arith::CmpFPredicate::ULT,
185 "pred must be either UGT or ULT");
186 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
187 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
188
189 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
190 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
191 rhs, rhs);
192 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
193 return success();
194 }
195 };
196
197 template <typename OpTy, arith::CmpFPredicate pred>
198 struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
199 public:
200 using OpRewritePattern<OpTy>::OpRewritePattern;
201
matchAndRewrite__anon521e2a780111::MaxNumMinNumFOpConverter202 LogicalResult matchAndRewrite(OpTy op,
203 PatternRewriter &rewriter) const final {
204 Value lhs = op.getLhs();
205 Value rhs = op.getRhs();
206
207 Location loc = op.getLoc();
208 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
209 static_assert(pred == arith::CmpFPredicate::UGT ||
210 pred == arith::CmpFPredicate::ULT,
211 "pred must be either UGT or ULT");
212 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
213 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
214
215 // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
216 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
217 lhs, lhs);
218 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
219 return success();
220 }
221 };
222
223 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
224 using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::BFloat16ExtFOpConverter225 LogicalResult matchAndRewrite(arith::ExtFOp op,
226 PatternRewriter &rewriter) const final {
227 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
228 auto operand = op.getOperand();
229 Type operandTy = operand.getType();
230 Type resultTy = op.getType();
231 Type operandETy = getElementTypeOrSelf(operandTy);
232 Type resultETy = getElementTypeOrSelf(resultTy);
233
234 if (!operandETy.isBF16() || !resultETy.isF32()) {
235 return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
236 }
237
238 Type i16Ty = b.getI16Type();
239 Type i32Ty = b.getI32Type();
240 if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
241 i16Ty = shapedTy.clone(i16Ty);
242 i32Ty = shapedTy.clone(i32Ty);
243 }
244
245 Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
246 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
247
248 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
249 Value shl = b.create<arith::ShLIOp>(exti, c16);
250 Value result = b.create<arith::BitcastOp>(resultTy, shl);
251
252 rewriter.replaceOp(op, result);
253 return success();
254 }
255 };
256
257 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
258 using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::BFloat16TruncFOpConverter259 LogicalResult matchAndRewrite(arith::TruncFOp op,
260 PatternRewriter &rewriter) const final {
261 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
262 auto operand = op.getOperand();
263 Type operandTy = operand.getType();
264 Type resultTy = op.getType();
265 Type operandETy = getElementTypeOrSelf(operandTy);
266 Type resultETy = getElementTypeOrSelf(resultTy);
267
268 if (!operandETy.isF32() || !resultETy.isBF16()) {
269 return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
270 }
271
272 if (op.getRoundingmodeAttr()) {
273 return rewriter.notifyMatchFailure(
274 op, "only applicable to default rounding mode.");
275 }
276
277 Type i16Ty = b.getI16Type();
278 Type i32Ty = b.getI32Type();
279 Type f32Ty = b.getF32Type();
280 if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
281 i16Ty = shapedTy.clone(i16Ty);
282 i32Ty = shapedTy.clone(i32Ty);
283 f32Ty = shapedTy.clone(f32Ty);
284 }
285
286 // Algorithm borrowed from this excellent code:
287 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
288 // There is a magic idea there, to let the addition of the rounding_bias to
289 // the mantissa simply overflow into the exponent bits. It's a bit of an
290 // aggressive, obfuscating optimization, but it is well-tested code, and it
291 // results in more concise and efficient IR.
292 // The case of NaN is handled separately (see isNaN and the final select).
293 // The case of infinities is NOT handled separately, which deserves an
294 // explanation. As the encoding of infinities has zero mantissa, the
295 // rounding-bias addition never carries into the exponent so that just gets
296 // truncated away, and as bfloat16 and float32 have the same number of
297 // exponent bits, that simple truncation is the desired outcome for
298 // infinities.
299 Value isNan =
300 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
301 // Constant used to make the rounding bias.
302 Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
303 // Constant used to generate a quiet NaN.
304 Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
305 // Small constants used to address bits.
306 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
307 Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
308 // Reinterpret the input f32 value as bits.
309 Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
310 // Read bit 16 as a value in {0,1}.
311 Value bit16 =
312 b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
313 // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
314 // on bit 16, implementing the tie-breaking "to nearest even".
315 Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
316 // Add the rounding bias. Generally we want this to be added to the
317 // mantissa, but nothing prevents this to from carrying into the exponent
318 // bits, which would feel like a bug, but this is the magic trick here:
319 // when that happens, the mantissa gets reset to zero and the exponent
320 // gets incremented by the carry... which is actually exactly what we
321 // want.
322 Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
323 // Now that the rounding-bias has been added, truncating the low bits
324 // yields the correctly rounded result.
325 Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
326 Value normalCaseResult_i16 =
327 b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
328 // Select either the above-computed result, or a quiet NaN constant
329 // if the input was NaN.
330 Value select =
331 b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
332 Value result = b.create<arith::BitcastOp>(resultTy, select);
333 rewriter.replaceOp(op, result);
334 return success();
335 }
336 };
337
338 struct ArithExpandOpsPass
339 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
340 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
341
runOnOperation__anon521e2a780111::ArithExpandOpsPass342 void runOnOperation() override {
343 RewritePatternSet patterns(&getContext());
344 ConversionTarget target(getContext());
345
346 arith::populateArithExpandOpsPatterns(patterns);
347
348 target.addLegalDialect<arith::ArithDialect>();
349 // clang-format off
350 target.addIllegalOp<
351 arith::CeilDivSIOp,
352 arith::CeilDivUIOp,
353 arith::FloorDivSIOp,
354 arith::MaxSIOp,
355 arith::MaxUIOp,
356 arith::MinSIOp,
357 arith::MinUIOp,
358 arith::MaximumFOp,
359 arith::MinimumFOp,
360 arith::MaxNumFOp,
361 arith::MinNumFOp
362 >();
363
364 if (includeBf16) {
365 arith::populateExpandBFloat16Patterns(patterns);
366 target.addDynamicallyLegalOp<arith::ExtFOp>(
367 [](arith::ExtFOp op) {
368 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
369 Type outETy = getElementTypeOrSelf(op.getType());
370 return !(inETy.isBF16() && outETy.isF32());
371 });
372
373 target.addDynamicallyLegalOp<arith::TruncFOp>(
374 [](arith::TruncFOp op) {
375 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
376 Type outETy = getElementTypeOrSelf(op.getType());
377 return !(inETy.isF32() && outETy.isBF16());
378 });
379 }
380
381 // clang-format on
382 if (failed(applyPartialConversion(getOperation(), target,
383 std::move(patterns))))
384 signalPassFailure();
385 }
386 };
387
388 } // namespace
389
populateCeilFloorDivExpandOpsPatterns(RewritePatternSet & patterns)390 void mlir::arith::populateCeilFloorDivExpandOpsPatterns(
391 RewritePatternSet &patterns) {
392 patterns
393 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
394 patterns.getContext());
395 }
396
populateExpandBFloat16Patterns(RewritePatternSet & patterns)397 void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
398 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
399 patterns.getContext());
400 }
401
populateArithExpandOpsPatterns(RewritePatternSet & patterns)402 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
403 populateCeilFloorDivExpandOpsPatterns(patterns);
404 // clang-format off
405 patterns.add<
406 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
407 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
408 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
409 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
410 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
411 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
412 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
413 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
414 >(patterns.getContext());
415 // clang-format on
416 }
417