xref: /llvm-project/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp (revision 7ad63c0e44ef277591497a176991e7723165611e)
1 //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Utils/IndexingUtils.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "math-to-funcs"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37 
38 namespace {
39 // Pattern to convert vector operations to scalar operations.
40 template <typename Op>
41 struct VecOpToScalarOp : public OpRewritePattern<Op> {
42 public:
43   using OpRewritePattern<Op>::OpRewritePattern;
44 
45   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46 };
47 
48 // Callback type for getting pre-generated FuncOp implementing
49 // an operation of the given type.
50 using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
51 
52 // Pattern to convert scalar IPowIOp into a call of outlined
53 // software implementation.
54 class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
55 public:
56   IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
57       : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
58 
59   /// Convert IPowI into a call to a local function implementing
60   /// the power operation. The local function computes a scalar result,
61   /// so vector forms of IPowI are linearized.
62   LogicalResult matchAndRewrite(math::IPowIOp op,
63                                 PatternRewriter &rewriter) const final;
64 
65 private:
66   GetFuncCallbackTy getFuncOpCallback;
67 };
68 
69 // Pattern to convert scalar FPowIOp into a call of outlined
70 // software implementation.
71 class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
72 public:
73   FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
74       : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
75 
76   /// Convert FPowI into a call to a local function implementing
77   /// the power operation. The local function computes a scalar result,
78   /// so vector forms of FPowI are linearized.
79   LogicalResult matchAndRewrite(math::FPowIOp op,
80                                 PatternRewriter &rewriter) const final;
81 
82 private:
83   GetFuncCallbackTy getFuncOpCallback;
84 };
85 
86 // Pattern to convert scalar ctlz into a call of outlined software
87 // implementation.
88 class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
89 public:
90   CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
91       : OpRewritePattern<math::CountLeadingZerosOp>(context),
92         getFuncOpCallback(cb) {}
93 
94   /// Convert ctlz into a call to a local function implementing
95   /// the count leading zeros operation.
96   LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
97                                 PatternRewriter &rewriter) const final;
98 
99 private:
100   GetFuncCallbackTy getFuncOpCallback;
101 };
102 } // namespace
103 
104 template <typename Op>
105 LogicalResult
106 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
107   Type opType = op.getType();
108   Location loc = op.getLoc();
109   auto vecType = dyn_cast<VectorType>(opType);
110 
111   if (!vecType)
112     return rewriter.notifyMatchFailure(op, "not a vector operation");
113   if (!vecType.hasRank())
114     return rewriter.notifyMatchFailure(op, "unknown vector rank");
115   ArrayRef<int64_t> shape = vecType.getShape();
116   int64_t numElements = vecType.getNumElements();
117 
118   Type resultElementType = vecType.getElementType();
119   Attribute initValueAttr;
120   if (isa<FloatType>(resultElementType))
121     initValueAttr = FloatAttr::get(resultElementType, 0.0);
122   else
123     initValueAttr = IntegerAttr::get(resultElementType, 0);
124   Value result = rewriter.create<arith::ConstantOp>(
125       loc, DenseElementsAttr::get(vecType, initValueAttr));
126   SmallVector<int64_t> strides = computeStrides(shape);
127   for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
128     SmallVector<int64_t> positions = delinearize(linearIndex, strides);
129     SmallVector<Value> operands;
130     for (Value input : op->getOperands())
131       operands.push_back(
132           rewriter.create<vector::ExtractOp>(loc, input, positions));
133     Value scalarOp =
134         rewriter.create<Op>(loc, vecType.getElementType(), operands);
135     result =
136         rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
137   }
138   rewriter.replaceOp(op, result);
139   return success();
140 }
141 
142 static FunctionType getElementalFuncTypeForOp(Operation *op) {
143   SmallVector<Type, 1> resultTys(op->getNumResults());
144   SmallVector<Type, 2> inputTys(op->getNumOperands());
145   std::transform(op->result_type_begin(), op->result_type_end(),
146                  resultTys.begin(),
147                  [](Type ty) { return getElementTypeOrSelf(ty); });
148   std::transform(op->operand_type_begin(), op->operand_type_end(),
149                  inputTys.begin(),
150                  [](Type ty) { return getElementTypeOrSelf(ty); });
151   return FunctionType::get(op->getContext(), inputTys, resultTys);
152 }
153 
154 /// Create linkonce_odr function to implement the power function with
155 /// the given \p elementType type inside \p module. The \p elementType
156 /// must be IntegerType, an the created function has
157 /// 'IntegerType (*)(IntegerType, IntegerType)' function type.
158 ///
159 /// template <typename T>
160 /// T __mlir_math_ipowi_*(T b, T p) {
161 ///   if (p == T(0))
162 ///     return T(1);
163 ///   if (p < T(0)) {
164 ///     if (b == T(0))
165 ///       return T(1) / T(0); // trigger div-by-zero
166 ///     if (b == T(1))
167 ///       return T(1);
168 ///     if (b == T(-1)) {
169 ///       if (p & T(1))
170 ///         return T(-1);
171 ///       return T(1);
172 ///     }
173 ///     return T(0);
174 ///   }
175 ///   T result = T(1);
176 ///   while (true) {
177 ///     if (p & T(1))
178 ///       result *= b;
179 ///     p >>= T(1);
180 ///     if (p == T(0))
181 ///       return result;
182 ///     b *= b;
183 ///   }
184 /// }
185 static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
186   assert(isa<IntegerType>(elementType) &&
187          "non-integer element type for IPowIOp");
188 
189   ImplicitLocOpBuilder builder =
190       ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
191 
192   std::string funcName("__mlir_math_ipowi");
193   llvm::raw_string_ostream nameOS(funcName);
194   nameOS << '_' << elementType;
195 
196   FunctionType funcType = FunctionType::get(
197       builder.getContext(), {elementType, elementType}, elementType);
198   auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
199   LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
200   Attribute linkage =
201       LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
202   funcOp->setAttr("llvm.linkage", linkage);
203   funcOp.setPrivate();
204 
205   Block *entryBlock = funcOp.addEntryBlock();
206   Region *funcBody = entryBlock->getParent();
207 
208   Value bArg = funcOp.getArgument(0);
209   Value pArg = funcOp.getArgument(1);
210   builder.setInsertionPointToEnd(entryBlock);
211   Value zeroValue = builder.create<arith::ConstantOp>(
212       elementType, builder.getIntegerAttr(elementType, 0));
213   Value oneValue = builder.create<arith::ConstantOp>(
214       elementType, builder.getIntegerAttr(elementType, 1));
215   Value minusOneValue = builder.create<arith::ConstantOp>(
216       elementType,
217       builder.getIntegerAttr(elementType,
218                              APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
219                                    /*isSigned=*/true)));
220 
221   // if (p == T(0))
222   //   return T(1);
223   auto pIsZero =
224       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
225   Block *thenBlock = builder.createBlock(funcBody);
226   builder.create<func::ReturnOp>(oneValue);
227   Block *fallthroughBlock = builder.createBlock(funcBody);
228   // Set up conditional branch for (p == T(0)).
229   builder.setInsertionPointToEnd(pIsZero->getBlock());
230   builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
231 
232   // if (p < T(0)) {
233   builder.setInsertionPointToEnd(fallthroughBlock);
234   auto pIsNeg =
235       builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
236   //   if (b == T(0))
237   builder.createBlock(funcBody);
238   auto bIsZero =
239       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
240   //     return T(1) / T(0);
241   thenBlock = builder.createBlock(funcBody);
242   builder.create<func::ReturnOp>(
243       builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
244   fallthroughBlock = builder.createBlock(funcBody);
245   // Set up conditional branch for (b == T(0)).
246   builder.setInsertionPointToEnd(bIsZero->getBlock());
247   builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
248 
249   //   if (b == T(1))
250   builder.setInsertionPointToEnd(fallthroughBlock);
251   auto bIsOne =
252       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
253   //    return T(1);
254   thenBlock = builder.createBlock(funcBody);
255   builder.create<func::ReturnOp>(oneValue);
256   fallthroughBlock = builder.createBlock(funcBody);
257   // Set up conditional branch for (b == T(1)).
258   builder.setInsertionPointToEnd(bIsOne->getBlock());
259   builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
260 
261   //   if (b == T(-1)) {
262   builder.setInsertionPointToEnd(fallthroughBlock);
263   auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
264                                                    bArg, minusOneValue);
265   //     if (p & T(1))
266   builder.createBlock(funcBody);
267   auto pIsOdd = builder.create<arith::CmpIOp>(
268       arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
269       zeroValue);
270   //       return T(-1);
271   thenBlock = builder.createBlock(funcBody);
272   builder.create<func::ReturnOp>(minusOneValue);
273   fallthroughBlock = builder.createBlock(funcBody);
274   // Set up conditional branch for (p & T(1)).
275   builder.setInsertionPointToEnd(pIsOdd->getBlock());
276   builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
277 
278   //     return T(1);
279   //   } // b == T(-1)
280   builder.setInsertionPointToEnd(fallthroughBlock);
281   builder.create<func::ReturnOp>(oneValue);
282   fallthroughBlock = builder.createBlock(funcBody);
283   // Set up conditional branch for (b == T(-1)).
284   builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
285   builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
286                                    fallthroughBlock);
287 
288   //   return T(0);
289   // } // (p < T(0))
290   builder.setInsertionPointToEnd(fallthroughBlock);
291   builder.create<func::ReturnOp>(zeroValue);
292   Block *loopHeader = builder.createBlock(
293       funcBody, funcBody->end(), {elementType, elementType, elementType},
294       {builder.getLoc(), builder.getLoc(), builder.getLoc()});
295   // Set up conditional branch for (p < T(0)).
296   builder.setInsertionPointToEnd(pIsNeg->getBlock());
297   // Set initial values of 'result', 'b' and 'p' for the loop.
298   builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
299                                    ValueRange{oneValue, bArg, pArg});
300 
301   // T result = T(1);
302   // while (true) {
303   //   if (p & T(1))
304   //     result *= b;
305   //   p >>= T(1);
306   //   if (p == T(0))
307   //     return result;
308   //   b *= b;
309   // }
310   Value resultTmp = loopHeader->getArgument(0);
311   Value baseTmp = loopHeader->getArgument(1);
312   Value powerTmp = loopHeader->getArgument(2);
313   builder.setInsertionPointToEnd(loopHeader);
314 
315   //   if (p & T(1))
316   auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
317       arith::CmpIPredicate::ne,
318       builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
319   thenBlock = builder.createBlock(funcBody);
320   //     result *= b;
321   Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
322   fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
323                                          builder.getLoc());
324   builder.setInsertionPointToEnd(thenBlock);
325   builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
326   // Set up conditional branch for (p & T(1)).
327   builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
328   builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
329                                    resultTmp);
330   // Merged 'result'.
331   newResultTmp = fallthroughBlock->getArgument(0);
332 
333   //   p >>= T(1);
334   builder.setInsertionPointToEnd(fallthroughBlock);
335   Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
336 
337   //   if (p == T(0))
338   auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
339                                                       newPowerTmp, zeroValue);
340   //     return result;
341   thenBlock = builder.createBlock(funcBody);
342   builder.create<func::ReturnOp>(newResultTmp);
343   fallthroughBlock = builder.createBlock(funcBody);
344   // Set up conditional branch for (p == T(0)).
345   builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
346   builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
347 
348   //   b *= b;
349   // }
350   builder.setInsertionPointToEnd(fallthroughBlock);
351   Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
352   // Pass new values for 'result', 'b' and 'p' to the loop header.
353   builder.create<cf::BranchOp>(
354       ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
355   return funcOp;
356 }
357 
358 /// Convert IPowI into a call to a local function implementing
359 /// the power operation. The local function computes a scalar result,
360 /// so vector forms of IPowI are linearized.
361 LogicalResult
362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
363                                  PatternRewriter &rewriter) const {
364   auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
365 
366   if (!baseType)
367     return rewriter.notifyMatchFailure(op, "non-integer base operand");
368 
369   // The outlined software implementation must have been already
370   // generated.
371   func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372   if (!elementFunc)
373     return rewriter.notifyMatchFailure(op, "missing software implementation");
374 
375   rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
376   return success();
377 }
378 
379 /// Create linkonce_odr function to implement the power function with
380 /// the given \p funcType type inside \p module. The \p funcType must be
381 /// 'FloatType (*)(FloatType, IntegerType)' function type.
382 ///
383 /// template <typename T>
384 /// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
385 ///   if (p == Tp{0})
386 ///     return Tb{1};
387 ///   bool isNegativePower{p < Tp{0}}
388 ///   bool isMin{p == std::numeric_limits<Tp>::min()};
389 ///   if (isMin) {
390 ///     p = std::numeric_limits<Tp>::max();
391 ///   } else if (isNegativePower) {
392 ///     p = -p;
393 ///   }
394 ///   Tb result = Tb{1};
395 ///   Tb origBase = Tb{b};
396 ///   while (true) {
397 ///     if (p & Tp{1})
398 ///       result *= b;
399 ///     p >>= Tp{1};
400 ///     if (p == Tp{0})
401 ///       break;
402 ///     b *= b;
403 ///   }
404 ///   if (isMin) {
405 ///     result *= origBase;
406 ///   }
407 ///   if (isNegativePower) {
408 ///     result = Tb{1} / result;
409 ///   }
410 ///   return result;
411 /// }
412 static func::FuncOp createElementFPowIFunc(ModuleOp *module,
413                                            FunctionType funcType) {
414   auto baseType = cast<FloatType>(funcType.getInput(0));
415   auto powType = cast<IntegerType>(funcType.getInput(1));
416   ImplicitLocOpBuilder builder =
417       ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
418 
419   std::string funcName("__mlir_math_fpowi");
420   llvm::raw_string_ostream nameOS(funcName);
421   nameOS << '_' << baseType;
422   nameOS << '_' << powType;
423   auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
424   LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
425   Attribute linkage =
426       LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
427   funcOp->setAttr("llvm.linkage", linkage);
428   funcOp.setPrivate();
429 
430   Block *entryBlock = funcOp.addEntryBlock();
431   Region *funcBody = entryBlock->getParent();
432 
433   Value bArg = funcOp.getArgument(0);
434   Value pArg = funcOp.getArgument(1);
435   builder.setInsertionPointToEnd(entryBlock);
436   Value oneBValue = builder.create<arith::ConstantOp>(
437       baseType, builder.getFloatAttr(baseType, 1.0));
438   Value zeroPValue = builder.create<arith::ConstantOp>(
439       powType, builder.getIntegerAttr(powType, 0));
440   Value onePValue = builder.create<arith::ConstantOp>(
441       powType, builder.getIntegerAttr(powType, 1));
442   Value minPValue = builder.create<arith::ConstantOp>(
443       powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
444                                                    powType.getWidth())));
445   Value maxPValue = builder.create<arith::ConstantOp>(
446       powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
447                                                    powType.getWidth())));
448 
449   // if (p == Tp{0})
450   //   return Tb{1};
451   auto pIsZero =
452       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
453   Block *thenBlock = builder.createBlock(funcBody);
454   builder.create<func::ReturnOp>(oneBValue);
455   Block *fallthroughBlock = builder.createBlock(funcBody);
456   // Set up conditional branch for (p == Tp{0}).
457   builder.setInsertionPointToEnd(pIsZero->getBlock());
458   builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
459 
460   builder.setInsertionPointToEnd(fallthroughBlock);
461   // bool isNegativePower{p < Tp{0}}
462   auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
463                                               zeroPValue);
464   // bool isMin{p == std::numeric_limits<Tp>::min()};
465   auto pIsMin =
466       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
467 
468   // if (isMin) {
469   //   p = std::numeric_limits<Tp>::max();
470   // } else if (isNegativePower) {
471   //   p = -p;
472   // }
473   Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
474   auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
475   pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
476 
477   // Tb result = Tb{1};
478   // Tb origBase = Tb{b};
479   // while (true) {
480   //   if (p & Tp{1})
481   //     result *= b;
482   //   p >>= Tp{1};
483   //   if (p == Tp{0})
484   //     break;
485   //   b *= b;
486   // }
487   Block *loopHeader = builder.createBlock(
488       funcBody, funcBody->end(), {baseType, baseType, powType},
489       {builder.getLoc(), builder.getLoc(), builder.getLoc()});
490   // Set initial values of 'result', 'b' and 'p' for the loop.
491   builder.setInsertionPointToEnd(pInit->getBlock());
492   builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
493 
494   // Create loop body.
495   Value resultTmp = loopHeader->getArgument(0);
496   Value baseTmp = loopHeader->getArgument(1);
497   Value powerTmp = loopHeader->getArgument(2);
498   builder.setInsertionPointToEnd(loopHeader);
499 
500   //   if (p & Tp{1})
501   auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
502       arith::CmpIPredicate::ne,
503       builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
504   thenBlock = builder.createBlock(funcBody);
505   //     result *= b;
506   Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
507   fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
508                                          builder.getLoc());
509   builder.setInsertionPointToEnd(thenBlock);
510   builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
511   // Set up conditional branch for (p & Tp{1}).
512   builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
513   builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
514                                    resultTmp);
515   // Merged 'result'.
516   newResultTmp = fallthroughBlock->getArgument(0);
517 
518   //   p >>= Tp{1};
519   builder.setInsertionPointToEnd(fallthroughBlock);
520   Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
521 
522   //   if (p == Tp{0})
523   auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
524                                                       newPowerTmp, zeroPValue);
525   //     break;
526   //
527   // The conditional branch is finalized below with a jump to
528   // the loop exit block.
529   fallthroughBlock = builder.createBlock(funcBody);
530 
531   //   b *= b;
532   // }
533   builder.setInsertionPointToEnd(fallthroughBlock);
534   Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
535   // Pass new values for 'result', 'b' and 'p' to the loop header.
536   builder.create<cf::BranchOp>(
537       ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
538 
539   // Set up conditional branch for early loop exit:
540   //   if (p == Tp{0})
541   //     break;
542   Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
543                                         builder.getLoc());
544   builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
545   builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
546                                    fallthroughBlock, ValueRange{});
547 
548   // if (isMin) {
549   //   result *= origBase;
550   // }
551   newResultTmp = loopExit->getArgument(0);
552   thenBlock = builder.createBlock(funcBody);
553   fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
554                                          builder.getLoc());
555   builder.setInsertionPointToEnd(loopExit);
556   builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
557                                    newResultTmp);
558   builder.setInsertionPointToEnd(thenBlock);
559   newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
560   builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
561 
562   /// if (isNegativePower) {
563   ///   result = Tb{1} / result;
564   /// }
565   newResultTmp = fallthroughBlock->getArgument(0);
566   thenBlock = builder.createBlock(funcBody);
567   Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
568                                            builder.getLoc());
569   builder.setInsertionPointToEnd(fallthroughBlock);
570   builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
571                                    newResultTmp);
572   builder.setInsertionPointToEnd(thenBlock);
573   newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
574   builder.create<cf::BranchOp>(newResultTmp, returnBlock);
575 
576   // return result;
577   builder.setInsertionPointToEnd(returnBlock);
578   builder.create<func::ReturnOp>(returnBlock->getArgument(0));
579 
580   return funcOp;
581 }
582 
583 /// Convert FPowI into a call to a local function implementing
584 /// the power operation. The local function computes a scalar result,
585 /// so vector forms of FPowI are linearized.
586 LogicalResult
587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
588                                  PatternRewriter &rewriter) const {
589   if (dyn_cast<VectorType>(op.getType()))
590     return rewriter.notifyMatchFailure(op, "non-scalar operation");
591 
592   FunctionType funcType = getElementalFuncTypeForOp(op);
593 
594   // The outlined software implementation must have been already
595   // generated.
596   func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
597   if (!elementFunc)
598     return rewriter.notifyMatchFailure(op, "missing software implementation");
599 
600   rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
601   return success();
602 }
603 
604 /// Create function to implement the ctlz function the given \p elementType type
605 /// inside \p module. The \p elementType must be IntegerType, an the created
606 /// function has 'IntegerType (*)(IntegerType)' function type.
607 ///
608 /// template <typename T>
609 /// T __mlir_math_ctlz_*(T x) {
610 ///     bits = sizeof(x) * 8;
611 ///     if (x == 0)
612 ///       return bits;
613 ///
614 ///     uint32_t n = 0;
615 ///     for (int i = 1; i < bits; ++i) {
616 ///         if (x < 0) continue;
617 ///         n++;
618 ///         x <<= 1;
619 ///     }
620 ///     return n;
621 /// }
622 ///
623 /// Converts to (for i32):
624 ///
625 /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
626 ///   %c_32 = arith.constant 32 : index
627 ///   %c_0 = arith.constant 0 : i32
628 ///   %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
629 ///   %out = scf.if %arg_eq_zero {
630 ///     scf.yield %c_32 : i32
631 ///   } else {
632 ///     %c_1index = arith.constant 1 : index
633 ///     %c_1i32 = arith.constant 1 : i32
634 ///     %n = arith.constant 0 : i32
635 ///     %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
636 ///         iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
637 ///       %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
638 ///       %yield_val = scf.if %cond {
639 ///         scf.yield %arg_iter, %n_iter : i32, i32
640 ///       } else {
641 ///         %arg_next = arith.shli %arg_iter, %c_1i32 : i32
642 ///         %n_next = arith.addi %n_iter, %c_1i32 : i32
643 ///         scf.yield %arg_next, %n_next : i32, i32
644 ///       }
645 ///       scf.yield %yield_val: i32, i32
646 ///     }
647 ///     scf.yield %n_out : i32
648 ///   }
649 ///   return %out: i32
650 /// }
651 static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
652   if (!isa<IntegerType>(elementType)) {
653     LLVM_DEBUG({
654       DBGS() << "non-integer element type for CtlzFunc; type was: ";
655       elementType.print(llvm::dbgs());
656     });
657     llvm_unreachable("non-integer element type");
658   }
659   int64_t bitWidth = elementType.getIntOrFloatBitWidth();
660 
661   Location loc = module->getLoc();
662   ImplicitLocOpBuilder builder =
663       ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody());
664 
665   std::string funcName("__mlir_math_ctlz");
666   llvm::raw_string_ostream nameOS(funcName);
667   nameOS << '_' << elementType;
668   FunctionType funcType =
669       FunctionType::get(builder.getContext(), {elementType}, elementType);
670   auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
671 
672   // LinkonceODR ensures that there is only one implementation of this function
673   // across all math.ctlz functions that are lowered in this way.
674   LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
675   Attribute linkage =
676       LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
677   funcOp->setAttr("llvm.linkage", linkage);
678   funcOp.setPrivate();
679 
680   // set the insertion point to the start of the function
681   Block *funcBody = funcOp.addEntryBlock();
682   builder.setInsertionPointToStart(funcBody);
683 
684   Value arg = funcOp.getArgument(0);
685   Type indexType = builder.getIndexType();
686   Value bitWidthValue = builder.create<arith::ConstantOp>(
687       elementType, builder.getIntegerAttr(elementType, bitWidth));
688   Value zeroValue = builder.create<arith::ConstantOp>(
689       elementType, builder.getIntegerAttr(elementType, 0));
690 
691   Value inputEqZero =
692       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
693 
694   // if input == 0, return bit width, else enter loop.
695   scf::IfOp ifOp = builder.create<scf::IfOp>(
696       elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
697   ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
698 
699   auto elseBuilder =
700       ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
701 
702   Value oneIndex = elseBuilder.create<arith::ConstantOp>(
703       indexType, elseBuilder.getIndexAttr(1));
704   Value oneValue = elseBuilder.create<arith::ConstantOp>(
705       elementType, elseBuilder.getIntegerAttr(elementType, 1));
706   Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
707       indexType, elseBuilder.getIndexAttr(bitWidth));
708   Value nValue = elseBuilder.create<arith::ConstantOp>(
709       elementType, elseBuilder.getIntegerAttr(elementType, 0));
710 
711   auto loop = elseBuilder.create<scf::ForOp>(
712       oneIndex, bitWidthIndex, oneIndex,
713       // Initial values for two loop induction variables, the arg which is being
714       // shifted left in each iteration, and the n value which tracks the count
715       // of leading zeros.
716       ValueRange{arg, nValue},
717       // Callback to build the body of the for loop
718       //   if (arg < 0) {
719       //     continue;
720       //   } else {
721       //     n++;
722       //     arg <<= 1;
723       //   }
724       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
725         Value argIter = args[0];
726         Value nIter = args[1];
727 
728         Value argIsNonNegative = b.create<arith::CmpIOp>(
729             loc, arith::CmpIPredicate::slt, argIter, zeroValue);
730         scf::IfOp ifOp = b.create<scf::IfOp>(
731             loc, argIsNonNegative,
732             [&](OpBuilder &b, Location loc) {
733               // If arg is negative, continue (effectively, break)
734               b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
735             },
736             [&](OpBuilder &b, Location loc) {
737               // Otherwise, increment n and shift arg left.
738               Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
739               Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
740               b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
741             });
742         b.create<scf::YieldOp>(loc, ifOp.getResults());
743       });
744   elseBuilder.create<scf::YieldOp>(loop.getResult(1));
745 
746   builder.create<func::ReturnOp>(ifOp.getResult(0));
747   return funcOp;
748 }
749 
750 /// Convert ctlz into a call to a local function implementing the ctlz
751 /// operation.
752 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
753                                               PatternRewriter &rewriter) const {
754   if (dyn_cast<VectorType>(op.getType()))
755     return rewriter.notifyMatchFailure(op, "non-scalar operation");
756 
757   Type type = getElementTypeOrSelf(op.getResult().getType());
758   func::FuncOp elementFunc = getFuncOpCallback(op, type);
759   if (!elementFunc)
760     return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
761       diag << "Missing software implementation for op " << op->getName()
762            << " and type " << type;
763     });
764 
765   rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
766   return success();
767 }
768 
769 namespace {
770 struct ConvertMathToFuncsPass
771     : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
772   ConvertMathToFuncsPass() = default;
773   ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
774       : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
775 
776   void runOnOperation() override;
777 
778 private:
779   // Return true, if this FPowI operation must be converted
780   // because the width of its exponent's type is greater than
781   // or equal to minWidthOfFPowIExponent option value.
782   bool isFPowIConvertible(math::FPowIOp op);
783 
784   // Reture true, if operation is integer type.
785   bool isConvertible(Operation *op);
786 
787   // Generate outlined implementations for power operations
788   // and store them in funcImpls map.
789   void generateOpImplementations();
790 
791   // A map between pairs of (operation, type) deduced from operations that this
792   // pass will convert, and the corresponding outlined software implementations
793   // of these operations for the given type.
794   DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
795 };
796 } // namespace
797 
798 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
799   auto expTy =
800       dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
801   return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
802 }
803 
804 bool ConvertMathToFuncsPass::isConvertible(Operation *op) {
805   return isa<IntegerType>(getElementTypeOrSelf(op->getResult(0).getType()));
806 }
807 
808 void ConvertMathToFuncsPass::generateOpImplementations() {
809   ModuleOp module = getOperation();
810 
811   module.walk([&](Operation *op) {
812     TypeSwitch<Operation *>(op)
813         .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
814           if (!convertCtlz || !isConvertible(op))
815             return;
816           Type resultType = getElementTypeOrSelf(op.getResult().getType());
817 
818           // Generate the software implementation of this operation,
819           // if it has not been generated yet.
820           auto key = std::pair(op->getName(), resultType);
821           auto entry = funcImpls.try_emplace(key, func::FuncOp{});
822           if (entry.second)
823             entry.first->second = createCtlzFunc(&module, resultType);
824         })
825         .Case<math::IPowIOp>([&](math::IPowIOp op) {
826           if (!isConvertible(op))
827             return;
828 
829           Type resultType = getElementTypeOrSelf(op.getResult().getType());
830 
831           // Generate the software implementation of this operation,
832           // if it has not been generated yet.
833           auto key = std::pair(op->getName(), resultType);
834           auto entry = funcImpls.try_emplace(key, func::FuncOp{});
835           if (entry.second)
836             entry.first->second = createElementIPowIFunc(&module, resultType);
837         })
838         .Case<math::FPowIOp>([&](math::FPowIOp op) {
839           if (!isFPowIConvertible(op))
840             return;
841 
842           FunctionType funcType = getElementalFuncTypeForOp(op);
843 
844           // Generate the software implementation of this operation,
845           // if it has not been generated yet.
846           // FPowI implementations are mapped via the FunctionType
847           // created from the operation's result and operands.
848           auto key = std::pair(op->getName(), funcType);
849           auto entry = funcImpls.try_emplace(key, func::FuncOp{});
850           if (entry.second)
851             entry.first->second = createElementFPowIFunc(&module, funcType);
852         });
853   });
854 }
855 
856 void ConvertMathToFuncsPass::runOnOperation() {
857   ModuleOp module = getOperation();
858 
859   // Create outlined implementations for power operations.
860   generateOpImplementations();
861 
862   RewritePatternSet patterns(&getContext());
863   patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
864                VecOpToScalarOp<math::CountLeadingZerosOp>>(
865       patterns.getContext());
866 
867   // For the given Type Returns FuncOp stored in funcImpls map.
868   auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
869     auto it = funcImpls.find(std::pair(op->getName(), type));
870     if (it == funcImpls.end())
871       return {};
872 
873     return it->second;
874   };
875   patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
876                                                  getFuncOpByType);
877 
878   if (convertCtlz)
879     patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
880 
881   ConversionTarget target(getContext());
882   target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
883                          func::FuncDialect, scf::SCFDialect,
884                          vector::VectorDialect>();
885 
886   target.addDynamicallyLegalOp<math::IPowIOp>(
887       [this](math::IPowIOp op) { return !isConvertible(op); });
888   if (convertCtlz) {
889     target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
890         [this](math::CountLeadingZerosOp op) { return !isConvertible(op); });
891   }
892   target.addDynamicallyLegalOp<math::FPowIOp>(
893       [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
894   if (failed(applyPartialConversion(module, target, std::move(patterns))))
895     signalPassFailure();
896 }
897