xref: /llvm-project/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (revision 30badf96bbaa5ddfd8049442e573fd270a89ddc8)
1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/IR/ImplicitLocOpBuilder.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 namespace mlir {
18 namespace arith {
19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 } // namespace arith
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 /// Create an integer or index constant.
createConst(Location loc,Type type,int value,PatternRewriter & rewriter)27 static Value createConst(Location loc, Type type, int value,
28                          PatternRewriter &rewriter) {
29   auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31     return rewriter.create<arith::ConstantOp>(
32         loc, DenseElementsAttr::get(shapedTy, attr));
33   }
34 
35   return rewriter.create<arith::ConstantOp>(loc, attr);
36 }
37 
38 namespace {
39 
40 /// Expands CeilDivUIOp (n, m) into
41 ///  n == 0 ? 0 : ((n-1) / m) + 1
42 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
43   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::CeilDivUIOpConverter44   LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
45                                 PatternRewriter &rewriter) const final {
46     Location loc = op.getLoc();
47     Value a = op.getLhs();
48     Value b = op.getRhs();
49     Value zero = createConst(loc, a.getType(), 0, rewriter);
50     Value compare =
51         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
52     Value one = createConst(loc, a.getType(), 1, rewriter);
53     Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54     Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55     Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
57     return success();
58   }
59 };
60 
61 /// Expands CeilDivSIOp (n, m) into
62 ///   1) x = (m > 0) ? -1 : 1
63 ///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
64 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
65   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::CeilDivSIOpConverter66   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
67                                 PatternRewriter &rewriter) const final {
68     Location loc = op.getLoc();
69     Type type = op.getType();
70     Value a = op.getLhs();
71     Value b = op.getRhs();
72     Value plusOne = createConst(loc, type, 1, rewriter);
73     Value zero = createConst(loc, type, 0, rewriter);
74     Value minusOne = createConst(loc, type, -1, rewriter);
75     // Compute x = (b>0) ? -1 : 1.
76     Value compare =
77         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78     Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
79     // Compute positive res: 1 + ((x+a)/b).
80     Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81     Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82     Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
83     // Compute negative res: - ((-a)/b).
84     Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85     Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86     Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
87     // Result is (a*b>0) ? pos result : neg result.
88     // Note, we want to avoid using a*b because of possible overflow.
89     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
90     // not particuliarly care if a*b<0 is true or false when b is zero
91     // as this will result in an illegal divide. So `a*b<0` can be reformulated
92     // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
93     // We pick the first expression here.
94     Value aNeg =
95         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
96     Value aPos =
97         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
98     Value bNeg =
99         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
100     Value bPos =
101         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
104     Value compareRes =
105         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
106     // Perform substitution and return success.
107     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
108                                                  negRes);
109     return success();
110   }
111 };
112 
113 /// Expands FloorDivSIOp (x, y) into
114 /// z = x / y
115 /// if (z * y != x && (x < 0) != (y < 0)) {
116 ///   return  z - 1;
117 /// } else {
118 ///   return z;
119 /// }
120 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
121   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::FloorDivSIOpConverter122   LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
123                                 PatternRewriter &rewriter) const final {
124     Location loc = op.getLoc();
125     Type type = op.getType();
126     Value a = op.getLhs();
127     Value b = op.getRhs();
128 
129     Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130     Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131     Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132         loc, arith::CmpIPredicate::ne, a, product);
133     Value zero = createConst(loc, type, 0, rewriter);
134 
135     Value aNeg =
136         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
137     Value bNeg =
138         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
139 
140     Value signOpposite = rewriter.create<arith::CmpIOp>(
141         loc, arith::CmpIPredicate::ne, aNeg, bNeg);
142     Value cond =
143         rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
144 
145     Value minusOne = createConst(loc, type, -1, rewriter);
146     Value quotientMinusOne =
147         rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
148 
149     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
150                                                  quotient);
151     return success();
152   }
153 };
154 
155 template <typename OpTy, arith::CmpIPredicate pred>
156 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
157 public:
158   using OpRewritePattern<OpTy>::OpRewritePattern;
159 
matchAndRewrite__anon521e2a780111::MaxMinIOpConverter160   LogicalResult matchAndRewrite(OpTy op,
161                                 PatternRewriter &rewriter) const final {
162     Value lhs = op.getLhs();
163     Value rhs = op.getRhs();
164 
165     Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
166     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
167     return success();
168   }
169 };
170 
171 template <typename OpTy, arith::CmpFPredicate pred>
172 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
173 public:
174   using OpRewritePattern<OpTy>::OpRewritePattern;
175 
matchAndRewrite__anon521e2a780111::MaximumMinimumFOpConverter176   LogicalResult matchAndRewrite(OpTy op,
177                                 PatternRewriter &rewriter) const final {
178     Value lhs = op.getLhs();
179     Value rhs = op.getRhs();
180 
181     Location loc = op.getLoc();
182     // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
183     static_assert(pred == arith::CmpFPredicate::UGT ||
184                       pred == arith::CmpFPredicate::ULT,
185                   "pred must be either UGT or ULT");
186     Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
187     Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
188 
189     // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
190     Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
191                                                  rhs, rhs);
192     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
193     return success();
194   }
195 };
196 
197 template <typename OpTy, arith::CmpFPredicate pred>
198 struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
199 public:
200   using OpRewritePattern<OpTy>::OpRewritePattern;
201 
matchAndRewrite__anon521e2a780111::MaxNumMinNumFOpConverter202   LogicalResult matchAndRewrite(OpTy op,
203                                 PatternRewriter &rewriter) const final {
204     Value lhs = op.getLhs();
205     Value rhs = op.getRhs();
206 
207     Location loc = op.getLoc();
208     // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
209     static_assert(pred == arith::CmpFPredicate::UGT ||
210                       pred == arith::CmpFPredicate::ULT,
211                   "pred must be either UGT or ULT");
212     Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
213     Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
214 
215     // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
216     Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
217                                                  lhs, lhs);
218     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
219     return success();
220   }
221 };
222 
223 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
224   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::BFloat16ExtFOpConverter225   LogicalResult matchAndRewrite(arith::ExtFOp op,
226                                 PatternRewriter &rewriter) const final {
227     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
228     auto operand = op.getOperand();
229     Type operandTy = operand.getType();
230     Type resultTy = op.getType();
231     Type operandETy = getElementTypeOrSelf(operandTy);
232     Type resultETy = getElementTypeOrSelf(resultTy);
233 
234     if (!operandETy.isBF16() || !resultETy.isF32()) {
235       return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
236     }
237 
238     Type i16Ty = b.getI16Type();
239     Type i32Ty = b.getI32Type();
240     if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
241       i16Ty = shapedTy.clone(i16Ty);
242       i32Ty = shapedTy.clone(i32Ty);
243     }
244 
245     Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
246     Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
247 
248     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
249     Value shl = b.create<arith::ShLIOp>(exti, c16);
250     Value result = b.create<arith::BitcastOp>(resultTy, shl);
251 
252     rewriter.replaceOp(op, result);
253     return success();
254   }
255 };
256 
257 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
258   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anon521e2a780111::BFloat16TruncFOpConverter259   LogicalResult matchAndRewrite(arith::TruncFOp op,
260                                 PatternRewriter &rewriter) const final {
261     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
262     auto operand = op.getOperand();
263     Type operandTy = operand.getType();
264     Type resultTy = op.getType();
265     Type operandETy = getElementTypeOrSelf(operandTy);
266     Type resultETy = getElementTypeOrSelf(resultTy);
267 
268     if (!operandETy.isF32() || !resultETy.isBF16()) {
269       return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
270     }
271 
272     if (op.getRoundingmodeAttr()) {
273       return rewriter.notifyMatchFailure(
274           op, "only applicable to default rounding mode.");
275     }
276 
277     Type i16Ty = b.getI16Type();
278     Type i32Ty = b.getI32Type();
279     Type f32Ty = b.getF32Type();
280     if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
281       i16Ty = shapedTy.clone(i16Ty);
282       i32Ty = shapedTy.clone(i32Ty);
283       f32Ty = shapedTy.clone(f32Ty);
284     }
285 
286     // Algorithm borrowed from this excellent code:
287     // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
288     // There is a magic idea there, to let the addition of the rounding_bias to
289     // the mantissa simply overflow into the exponent bits. It's a bit of an
290     // aggressive, obfuscating optimization, but it is well-tested code, and it
291     // results in more concise and efficient IR.
292     // The case of NaN is handled separately (see isNaN and the final select).
293     // The case of infinities is NOT handled separately, which deserves an
294     // explanation. As the encoding of infinities has zero mantissa, the
295     // rounding-bias addition never carries into the exponent so that just gets
296     // truncated away, and as bfloat16 and float32 have the same number of
297     // exponent bits, that simple truncation is the desired outcome for
298     // infinities.
299     Value isNan =
300         b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
301     // Constant used to make the rounding bias.
302     Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
303     // Constant used to generate a quiet NaN.
304     Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
305     // Small constants used to address bits.
306     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
307     Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
308     // Reinterpret the input f32 value as bits.
309     Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
310     // Read bit 16 as a value in {0,1}.
311     Value bit16 =
312         b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
313     // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
314     // on bit 16, implementing the tie-breaking "to nearest even".
315     Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
316     // Add the rounding bias. Generally we want this to be added to the
317     // mantissa, but nothing prevents this to from carrying into the exponent
318     // bits, which would feel like a bug, but this is the magic trick here:
319     // when that happens, the mantissa gets reset to zero and the exponent
320     // gets incremented by the carry... which is actually exactly what we
321     // want.
322     Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
323     // Now that the rounding-bias has been added, truncating the low bits
324     // yields the correctly rounded result.
325     Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
326     Value normalCaseResult_i16 =
327         b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
328     // Select either the above-computed result, or a quiet NaN constant
329     // if the input was NaN.
330     Value select =
331         b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
332     Value result = b.create<arith::BitcastOp>(resultTy, select);
333     rewriter.replaceOp(op, result);
334     return success();
335   }
336 };
337 
338 struct ArithExpandOpsPass
339     : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
340   using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
341 
runOnOperation__anon521e2a780111::ArithExpandOpsPass342   void runOnOperation() override {
343     RewritePatternSet patterns(&getContext());
344     ConversionTarget target(getContext());
345 
346     arith::populateArithExpandOpsPatterns(patterns);
347 
348     target.addLegalDialect<arith::ArithDialect>();
349     // clang-format off
350     target.addIllegalOp<
351       arith::CeilDivSIOp,
352       arith::CeilDivUIOp,
353       arith::FloorDivSIOp,
354       arith::MaxSIOp,
355       arith::MaxUIOp,
356       arith::MinSIOp,
357       arith::MinUIOp,
358       arith::MaximumFOp,
359       arith::MinimumFOp,
360       arith::MaxNumFOp,
361       arith::MinNumFOp
362     >();
363 
364     if (includeBf16) {
365       arith::populateExpandBFloat16Patterns(patterns);
366       target.addDynamicallyLegalOp<arith::ExtFOp>(
367         [](arith::ExtFOp op) {
368           Type inETy = getElementTypeOrSelf(op.getOperand().getType());
369           Type outETy = getElementTypeOrSelf(op.getType());
370           return !(inETy.isBF16() && outETy.isF32());
371         });
372 
373       target.addDynamicallyLegalOp<arith::TruncFOp>(
374         [](arith::TruncFOp op)  {
375           Type inETy = getElementTypeOrSelf(op.getOperand().getType());
376           Type outETy = getElementTypeOrSelf(op.getType());
377           return !(inETy.isF32() && outETy.isBF16());
378         });
379     }
380 
381     // clang-format on
382     if (failed(applyPartialConversion(getOperation(), target,
383                                       std::move(patterns))))
384       signalPassFailure();
385   }
386 };
387 
388 } // namespace
389 
populateCeilFloorDivExpandOpsPatterns(RewritePatternSet & patterns)390 void mlir::arith::populateCeilFloorDivExpandOpsPatterns(
391     RewritePatternSet &patterns) {
392   patterns
393       .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
394           patterns.getContext());
395 }
396 
populateExpandBFloat16Patterns(RewritePatternSet & patterns)397 void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
398   patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
399       patterns.getContext());
400 }
401 
populateArithExpandOpsPatterns(RewritePatternSet & patterns)402 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
403   populateCeilFloorDivExpandOpsPatterns(patterns);
404   // clang-format off
405   patterns.add<
406     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
407     MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
408     MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
409     MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
410     MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
411     MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
412     MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
413     MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
414    >(patterns.getContext());
415   // clang-format on
416 }
417