xref: /llvm-project/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp (revision 7ad63c0e44ef277591497a176991e7723165611e)
1223c54c4SSlava Zakharin //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===//
2223c54c4SSlava Zakharin //
3223c54c4SSlava Zakharin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4223c54c4SSlava Zakharin // See https://llvm.org/LICENSE.txt for license information.
5223c54c4SSlava Zakharin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6223c54c4SSlava Zakharin //
7223c54c4SSlava Zakharin //===----------------------------------------------------------------------===//
8223c54c4SSlava Zakharin 
9223c54c4SSlava Zakharin #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
1067d0d7acSMichele Scuttari 
11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
12223c54c4SSlava Zakharin #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
13223c54c4SSlava Zakharin #include "mlir/Dialect/Func/IR/FuncOps.h"
14223c54c4SSlava Zakharin #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15223c54c4SSlava Zakharin #include "mlir/Dialect/Math/IR/Math.h"
16bfbccfa1SJeremy Kun #include "mlir/Dialect/SCF/IR/SCF.h"
17223c54c4SSlava Zakharin #include "mlir/Dialect/Utils/IndexingUtils.h"
18223c54c4SSlava Zakharin #include "mlir/Dialect/Vector/IR/VectorOps.h"
19223c54c4SSlava Zakharin #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
20223c54c4SSlava Zakharin #include "mlir/IR/ImplicitLocOpBuilder.h"
21223c54c4SSlava Zakharin #include "mlir/IR/TypeUtilities.h"
2267d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
23223c54c4SSlava Zakharin #include "mlir/Transforms/DialectConversion.h"
24223c54c4SSlava Zakharin #include "llvm/ADT/DenseMap.h"
25223c54c4SSlava Zakharin #include "llvm/ADT/TypeSwitch.h"
26bfbccfa1SJeremy Kun #include "llvm/Support/Debug.h"
27223c54c4SSlava Zakharin 
2867d0d7acSMichele Scuttari namespace mlir {
2967d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTMATHTOFUNCS
3067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3167d0d7acSMichele Scuttari } // namespace mlir
3267d0d7acSMichele Scuttari 
33223c54c4SSlava Zakharin using namespace mlir;
34223c54c4SSlava Zakharin 
35bfbccfa1SJeremy Kun #define DEBUG_TYPE "math-to-funcs"
36bfbccfa1SJeremy Kun #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37bfbccfa1SJeremy Kun 
38223c54c4SSlava Zakharin namespace {
39223c54c4SSlava Zakharin // Pattern to convert vector operations to scalar operations.
40223c54c4SSlava Zakharin template <typename Op>
41223c54c4SSlava Zakharin struct VecOpToScalarOp : public OpRewritePattern<Op> {
42223c54c4SSlava Zakharin public:
43223c54c4SSlava Zakharin   using OpRewritePattern<Op>::OpRewritePattern;
44223c54c4SSlava Zakharin 
45223c54c4SSlava Zakharin   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
46223c54c4SSlava Zakharin };
47223c54c4SSlava Zakharin 
48223c54c4SSlava Zakharin // Callback type for getting pre-generated FuncOp implementing
49bfbccfa1SJeremy Kun // an operation of the given type.
50bfbccfa1SJeremy Kun using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>;
51223c54c4SSlava Zakharin 
52223c54c4SSlava Zakharin // Pattern to convert scalar IPowIOp into a call of outlined
53223c54c4SSlava Zakharin // software implementation.
5422702cc7SSlava Zakharin class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> {
55223c54c4SSlava Zakharin public:
56bfbccfa1SJeremy Kun   IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
57223c54c4SSlava Zakharin       : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {}
58223c54c4SSlava Zakharin 
59223c54c4SSlava Zakharin   /// Convert IPowI into a call to a local function implementing
60223c54c4SSlava Zakharin   /// the power operation. The local function computes a scalar result,
61223c54c4SSlava Zakharin   /// so vector forms of IPowI are linearized.
62223c54c4SSlava Zakharin   LogicalResult matchAndRewrite(math::IPowIOp op,
63223c54c4SSlava Zakharin                                 PatternRewriter &rewriter) const final;
6422702cc7SSlava Zakharin 
6522702cc7SSlava Zakharin private:
66bfbccfa1SJeremy Kun   GetFuncCallbackTy getFuncOpCallback;
6722702cc7SSlava Zakharin };
6822702cc7SSlava Zakharin 
6922702cc7SSlava Zakharin // Pattern to convert scalar FPowIOp into a call of outlined
7022702cc7SSlava Zakharin // software implementation.
7122702cc7SSlava Zakharin class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> {
7222702cc7SSlava Zakharin public:
73bfbccfa1SJeremy Kun   FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
7422702cc7SSlava Zakharin       : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {}
7522702cc7SSlava Zakharin 
7622702cc7SSlava Zakharin   /// Convert FPowI into a call to a local function implementing
7722702cc7SSlava Zakharin   /// the power operation. The local function computes a scalar result,
7822702cc7SSlava Zakharin   /// so vector forms of FPowI are linearized.
7922702cc7SSlava Zakharin   LogicalResult matchAndRewrite(math::FPowIOp op,
8022702cc7SSlava Zakharin                                 PatternRewriter &rewriter) const final;
8122702cc7SSlava Zakharin 
8222702cc7SSlava Zakharin private:
83bfbccfa1SJeremy Kun   GetFuncCallbackTy getFuncOpCallback;
84bfbccfa1SJeremy Kun };
85bfbccfa1SJeremy Kun 
86bfbccfa1SJeremy Kun // Pattern to convert scalar ctlz into a call of outlined software
87bfbccfa1SJeremy Kun // implementation.
88bfbccfa1SJeremy Kun class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> {
89bfbccfa1SJeremy Kun public:
90bfbccfa1SJeremy Kun   CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb)
91bfbccfa1SJeremy Kun       : OpRewritePattern<math::CountLeadingZerosOp>(context),
92bfbccfa1SJeremy Kun         getFuncOpCallback(cb) {}
93bfbccfa1SJeremy Kun 
94bfbccfa1SJeremy Kun   /// Convert ctlz into a call to a local function implementing
95bfbccfa1SJeremy Kun   /// the count leading zeros operation.
96bfbccfa1SJeremy Kun   LogicalResult matchAndRewrite(math::CountLeadingZerosOp op,
97bfbccfa1SJeremy Kun                                 PatternRewriter &rewriter) const final;
98bfbccfa1SJeremy Kun 
99bfbccfa1SJeremy Kun private:
100bfbccfa1SJeremy Kun   GetFuncCallbackTy getFuncOpCallback;
101223c54c4SSlava Zakharin };
102223c54c4SSlava Zakharin } // namespace
103223c54c4SSlava Zakharin 
104223c54c4SSlava Zakharin template <typename Op>
105223c54c4SSlava Zakharin LogicalResult
106223c54c4SSlava Zakharin VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
107223c54c4SSlava Zakharin   Type opType = op.getType();
108223c54c4SSlava Zakharin   Location loc = op.getLoc();
1095550c821STres Popp   auto vecType = dyn_cast<VectorType>(opType);
110223c54c4SSlava Zakharin 
111223c54c4SSlava Zakharin   if (!vecType)
112223c54c4SSlava Zakharin     return rewriter.notifyMatchFailure(op, "not a vector operation");
113223c54c4SSlava Zakharin   if (!vecType.hasRank())
114223c54c4SSlava Zakharin     return rewriter.notifyMatchFailure(op, "unknown vector rank");
115223c54c4SSlava Zakharin   ArrayRef<int64_t> shape = vecType.getShape();
116223c54c4SSlava Zakharin   int64_t numElements = vecType.getNumElements();
117223c54c4SSlava Zakharin 
11822702cc7SSlava Zakharin   Type resultElementType = vecType.getElementType();
11922702cc7SSlava Zakharin   Attribute initValueAttr;
1205550c821STres Popp   if (isa<FloatType>(resultElementType))
12122702cc7SSlava Zakharin     initValueAttr = FloatAttr::get(resultElementType, 0.0);
12222702cc7SSlava Zakharin   else
12322702cc7SSlava Zakharin     initValueAttr = IntegerAttr::get(resultElementType, 0);
124223c54c4SSlava Zakharin   Value result = rewriter.create<arith::ConstantOp>(
12522702cc7SSlava Zakharin       loc, DenseElementsAttr::get(vecType, initValueAttr));
1267a69a9d7SNicolas Vasilache   SmallVector<int64_t> strides = computeStrides(shape);
127223c54c4SSlava Zakharin   for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
128203fad47SNicolas Vasilache     SmallVector<int64_t> positions = delinearize(linearIndex, strides);
129223c54c4SSlava Zakharin     SmallVector<Value> operands;
130223c54c4SSlava Zakharin     for (Value input : op->getOperands())
131223c54c4SSlava Zakharin       operands.push_back(
132223c54c4SSlava Zakharin           rewriter.create<vector::ExtractOp>(loc, input, positions));
133223c54c4SSlava Zakharin     Value scalarOp =
134223c54c4SSlava Zakharin         rewriter.create<Op>(loc, vecType.getElementType(), operands);
135223c54c4SSlava Zakharin     result =
136223c54c4SSlava Zakharin         rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
137223c54c4SSlava Zakharin   }
138223c54c4SSlava Zakharin   rewriter.replaceOp(op, result);
139223c54c4SSlava Zakharin   return success();
140223c54c4SSlava Zakharin }
141223c54c4SSlava Zakharin 
14222702cc7SSlava Zakharin static FunctionType getElementalFuncTypeForOp(Operation *op) {
14322702cc7SSlava Zakharin   SmallVector<Type, 1> resultTys(op->getNumResults());
14422702cc7SSlava Zakharin   SmallVector<Type, 2> inputTys(op->getNumOperands());
14522702cc7SSlava Zakharin   std::transform(op->result_type_begin(), op->result_type_end(),
14622702cc7SSlava Zakharin                  resultTys.begin(),
14722702cc7SSlava Zakharin                  [](Type ty) { return getElementTypeOrSelf(ty); });
14822702cc7SSlava Zakharin   std::transform(op->operand_type_begin(), op->operand_type_end(),
14922702cc7SSlava Zakharin                  inputTys.begin(),
15022702cc7SSlava Zakharin                  [](Type ty) { return getElementTypeOrSelf(ty); });
15122702cc7SSlava Zakharin   return FunctionType::get(op->getContext(), inputTys, resultTys);
15222702cc7SSlava Zakharin }
15322702cc7SSlava Zakharin 
154223c54c4SSlava Zakharin /// Create linkonce_odr function to implement the power function with
15522702cc7SSlava Zakharin /// the given \p elementType type inside \p module. The \p elementType
15622702cc7SSlava Zakharin /// must be IntegerType, an the created function has
157223c54c4SSlava Zakharin /// 'IntegerType (*)(IntegerType, IntegerType)' function type.
158223c54c4SSlava Zakharin ///
159223c54c4SSlava Zakharin /// template <typename T>
160223c54c4SSlava Zakharin /// T __mlir_math_ipowi_*(T b, T p) {
161223c54c4SSlava Zakharin ///   if (p == T(0))
162223c54c4SSlava Zakharin ///     return T(1);
163223c54c4SSlava Zakharin ///   if (p < T(0)) {
164223c54c4SSlava Zakharin ///     if (b == T(0))
165223c54c4SSlava Zakharin ///       return T(1) / T(0); // trigger div-by-zero
166223c54c4SSlava Zakharin ///     if (b == T(1))
167223c54c4SSlava Zakharin ///       return T(1);
168223c54c4SSlava Zakharin ///     if (b == T(-1)) {
169223c54c4SSlava Zakharin ///       if (p & T(1))
170223c54c4SSlava Zakharin ///         return T(-1);
171223c54c4SSlava Zakharin ///       return T(1);
172223c54c4SSlava Zakharin ///     }
173223c54c4SSlava Zakharin ///     return T(0);
174223c54c4SSlava Zakharin ///   }
175223c54c4SSlava Zakharin ///   T result = T(1);
176223c54c4SSlava Zakharin ///   while (true) {
177223c54c4SSlava Zakharin ///     if (p & T(1))
178223c54c4SSlava Zakharin ///       result *= b;
179223c54c4SSlava Zakharin ///     p >>= T(1);
180223c54c4SSlava Zakharin ///     if (p == T(0))
181223c54c4SSlava Zakharin ///       return result;
182223c54c4SSlava Zakharin ///     b *= b;
183223c54c4SSlava Zakharin ///   }
184223c54c4SSlava Zakharin /// }
185223c54c4SSlava Zakharin static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
1865550c821STres Popp   assert(isa<IntegerType>(elementType) &&
187223c54c4SSlava Zakharin          "non-integer element type for IPowIOp");
188223c54c4SSlava Zakharin 
189223c54c4SSlava Zakharin   ImplicitLocOpBuilder builder =
190223c54c4SSlava Zakharin       ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
191223c54c4SSlava Zakharin 
192223c54c4SSlava Zakharin   std::string funcName("__mlir_math_ipowi");
193223c54c4SSlava Zakharin   llvm::raw_string_ostream nameOS(funcName);
194223c54c4SSlava Zakharin   nameOS << '_' << elementType;
195223c54c4SSlava Zakharin 
196223c54c4SSlava Zakharin   FunctionType funcType = FunctionType::get(
197223c54c4SSlava Zakharin       builder.getContext(), {elementType, elementType}, elementType);
198223c54c4SSlava Zakharin   auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
199223c54c4SSlava Zakharin   LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
200223c54c4SSlava Zakharin   Attribute linkage =
201223c54c4SSlava Zakharin       LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
202223c54c4SSlava Zakharin   funcOp->setAttr("llvm.linkage", linkage);
203223c54c4SSlava Zakharin   funcOp.setPrivate();
204223c54c4SSlava Zakharin 
205223c54c4SSlava Zakharin   Block *entryBlock = funcOp.addEntryBlock();
206223c54c4SSlava Zakharin   Region *funcBody = entryBlock->getParent();
207223c54c4SSlava Zakharin 
208223c54c4SSlava Zakharin   Value bArg = funcOp.getArgument(0);
209223c54c4SSlava Zakharin   Value pArg = funcOp.getArgument(1);
210223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(entryBlock);
211223c54c4SSlava Zakharin   Value zeroValue = builder.create<arith::ConstantOp>(
212223c54c4SSlava Zakharin       elementType, builder.getIntegerAttr(elementType, 0));
213223c54c4SSlava Zakharin   Value oneValue = builder.create<arith::ConstantOp>(
214223c54c4SSlava Zakharin       elementType, builder.getIntegerAttr(elementType, 1));
215223c54c4SSlava Zakharin   Value minusOneValue = builder.create<arith::ConstantOp>(
216223c54c4SSlava Zakharin       elementType,
217223c54c4SSlava Zakharin       builder.getIntegerAttr(elementType,
218223c54c4SSlava Zakharin                              APInt(elementType.getIntOrFloatBitWidth(), -1ULL,
219223c54c4SSlava Zakharin                                    /*isSigned=*/true)));
220223c54c4SSlava Zakharin 
221223c54c4SSlava Zakharin   // if (p == T(0))
222223c54c4SSlava Zakharin   //   return T(1);
223223c54c4SSlava Zakharin   auto pIsZero =
224223c54c4SSlava Zakharin       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue);
225223c54c4SSlava Zakharin   Block *thenBlock = builder.createBlock(funcBody);
226223c54c4SSlava Zakharin   builder.create<func::ReturnOp>(oneValue);
227223c54c4SSlava Zakharin   Block *fallthroughBlock = builder.createBlock(funcBody);
228223c54c4SSlava Zakharin   // Set up conditional branch for (p == T(0)).
229223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(pIsZero->getBlock());
230223c54c4SSlava Zakharin   builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
231223c54c4SSlava Zakharin 
232223c54c4SSlava Zakharin   // if (p < T(0)) {
233223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
234223c54c4SSlava Zakharin   auto pIsNeg =
235223c54c4SSlava Zakharin       builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue);
236223c54c4SSlava Zakharin   //   if (b == T(0))
237223c54c4SSlava Zakharin   builder.createBlock(funcBody);
238223c54c4SSlava Zakharin   auto bIsZero =
239223c54c4SSlava Zakharin       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue);
240223c54c4SSlava Zakharin   //     return T(1) / T(0);
241223c54c4SSlava Zakharin   thenBlock = builder.createBlock(funcBody);
242223c54c4SSlava Zakharin   builder.create<func::ReturnOp>(
243223c54c4SSlava Zakharin       builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult());
244223c54c4SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody);
245223c54c4SSlava Zakharin   // Set up conditional branch for (b == T(0)).
246223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(bIsZero->getBlock());
247223c54c4SSlava Zakharin   builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock);
248223c54c4SSlava Zakharin 
249223c54c4SSlava Zakharin   //   if (b == T(1))
250223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
251223c54c4SSlava Zakharin   auto bIsOne =
252223c54c4SSlava Zakharin       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue);
253223c54c4SSlava Zakharin   //    return T(1);
254223c54c4SSlava Zakharin   thenBlock = builder.createBlock(funcBody);
255223c54c4SSlava Zakharin   builder.create<func::ReturnOp>(oneValue);
256223c54c4SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody);
257223c54c4SSlava Zakharin   // Set up conditional branch for (b == T(1)).
258223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(bIsOne->getBlock());
259223c54c4SSlava Zakharin   builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock);
260223c54c4SSlava Zakharin 
261223c54c4SSlava Zakharin   //   if (b == T(-1)) {
262223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
263223c54c4SSlava Zakharin   auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
264223c54c4SSlava Zakharin                                                    bArg, minusOneValue);
265223c54c4SSlava Zakharin   //     if (p & T(1))
266223c54c4SSlava Zakharin   builder.createBlock(funcBody);
267223c54c4SSlava Zakharin   auto pIsOdd = builder.create<arith::CmpIOp>(
268223c54c4SSlava Zakharin       arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue),
269223c54c4SSlava Zakharin       zeroValue);
270223c54c4SSlava Zakharin   //       return T(-1);
271223c54c4SSlava Zakharin   thenBlock = builder.createBlock(funcBody);
272223c54c4SSlava Zakharin   builder.create<func::ReturnOp>(minusOneValue);
273223c54c4SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody);
274223c54c4SSlava Zakharin   // Set up conditional branch for (p & T(1)).
275223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(pIsOdd->getBlock());
276223c54c4SSlava Zakharin   builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock);
277223c54c4SSlava Zakharin 
278223c54c4SSlava Zakharin   //     return T(1);
279223c54c4SSlava Zakharin   //   } // b == T(-1)
280223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
281223c54c4SSlava Zakharin   builder.create<func::ReturnOp>(oneValue);
282223c54c4SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody);
283223c54c4SSlava Zakharin   // Set up conditional branch for (b == T(-1)).
284223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(bIsMinusOne->getBlock());
285223c54c4SSlava Zakharin   builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(),
286223c54c4SSlava Zakharin                                    fallthroughBlock);
287223c54c4SSlava Zakharin 
288223c54c4SSlava Zakharin   //   return T(0);
289223c54c4SSlava Zakharin   // } // (p < T(0))
290223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
291223c54c4SSlava Zakharin   builder.create<func::ReturnOp>(zeroValue);
292223c54c4SSlava Zakharin   Block *loopHeader = builder.createBlock(
293223c54c4SSlava Zakharin       funcBody, funcBody->end(), {elementType, elementType, elementType},
294223c54c4SSlava Zakharin       {builder.getLoc(), builder.getLoc(), builder.getLoc()});
295223c54c4SSlava Zakharin   // Set up conditional branch for (p < T(0)).
296223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(pIsNeg->getBlock());
297223c54c4SSlava Zakharin   // Set initial values of 'result', 'b' and 'p' for the loop.
298223c54c4SSlava Zakharin   builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader,
299223c54c4SSlava Zakharin                                    ValueRange{oneValue, bArg, pArg});
300223c54c4SSlava Zakharin 
301223c54c4SSlava Zakharin   // T result = T(1);
302223c54c4SSlava Zakharin   // while (true) {
303223c54c4SSlava Zakharin   //   if (p & T(1))
304223c54c4SSlava Zakharin   //     result *= b;
305223c54c4SSlava Zakharin   //   p >>= T(1);
306223c54c4SSlava Zakharin   //   if (p == T(0))
307223c54c4SSlava Zakharin   //     return result;
308223c54c4SSlava Zakharin   //   b *= b;
309223c54c4SSlava Zakharin   // }
310223c54c4SSlava Zakharin   Value resultTmp = loopHeader->getArgument(0);
311223c54c4SSlava Zakharin   Value baseTmp = loopHeader->getArgument(1);
312223c54c4SSlava Zakharin   Value powerTmp = loopHeader->getArgument(2);
313223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(loopHeader);
314223c54c4SSlava Zakharin 
315223c54c4SSlava Zakharin   //   if (p & T(1))
316223c54c4SSlava Zakharin   auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
317223c54c4SSlava Zakharin       arith::CmpIPredicate::ne,
318223c54c4SSlava Zakharin       builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue);
319223c54c4SSlava Zakharin   thenBlock = builder.createBlock(funcBody);
320223c54c4SSlava Zakharin   //     result *= b;
321223c54c4SSlava Zakharin   Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp);
322223c54c4SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType,
323223c54c4SSlava Zakharin                                          builder.getLoc());
324223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(thenBlock);
325223c54c4SSlava Zakharin   builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
326223c54c4SSlava Zakharin   // Set up conditional branch for (p & T(1)).
327223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
328223c54c4SSlava Zakharin   builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
329223c54c4SSlava Zakharin                                    resultTmp);
330223c54c4SSlava Zakharin   // Merged 'result'.
331223c54c4SSlava Zakharin   newResultTmp = fallthroughBlock->getArgument(0);
332223c54c4SSlava Zakharin 
333223c54c4SSlava Zakharin   //   p >>= T(1);
334223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
335223c54c4SSlava Zakharin   Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue);
336223c54c4SSlava Zakharin 
337223c54c4SSlava Zakharin   //   if (p == T(0))
338223c54c4SSlava Zakharin   auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
339223c54c4SSlava Zakharin                                                       newPowerTmp, zeroValue);
340223c54c4SSlava Zakharin   //     return result;
341223c54c4SSlava Zakharin   thenBlock = builder.createBlock(funcBody);
342223c54c4SSlava Zakharin   builder.create<func::ReturnOp>(newResultTmp);
343223c54c4SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody);
344223c54c4SSlava Zakharin   // Set up conditional branch for (p == T(0)).
345223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
346223c54c4SSlava Zakharin   builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock);
347223c54c4SSlava Zakharin 
348223c54c4SSlava Zakharin   //   b *= b;
349223c54c4SSlava Zakharin   // }
350223c54c4SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
351223c54c4SSlava Zakharin   Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp);
352223c54c4SSlava Zakharin   // Pass new values for 'result', 'b' and 'p' to the loop header.
353223c54c4SSlava Zakharin   builder.create<cf::BranchOp>(
354223c54c4SSlava Zakharin       ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
355223c54c4SSlava Zakharin   return funcOp;
356223c54c4SSlava Zakharin }
357223c54c4SSlava Zakharin 
358223c54c4SSlava Zakharin /// Convert IPowI into a call to a local function implementing
359223c54c4SSlava Zakharin /// the power operation. The local function computes a scalar result,
360223c54c4SSlava Zakharin /// so vector forms of IPowI are linearized.
361223c54c4SSlava Zakharin LogicalResult
362223c54c4SSlava Zakharin IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
363223c54c4SSlava Zakharin                                  PatternRewriter &rewriter) const {
3645550c821STres Popp   auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
365223c54c4SSlava Zakharin 
366223c54c4SSlava Zakharin   if (!baseType)
367223c54c4SSlava Zakharin     return rewriter.notifyMatchFailure(op, "non-integer base operand");
368223c54c4SSlava Zakharin 
369223c54c4SSlava Zakharin   // The outlined software implementation must have been already
370223c54c4SSlava Zakharin   // generated.
371bfbccfa1SJeremy Kun   func::FuncOp elementFunc = getFuncOpCallback(op, baseType);
372223c54c4SSlava Zakharin   if (!elementFunc)
373223c54c4SSlava Zakharin     return rewriter.notifyMatchFailure(op, "missing software implementation");
374223c54c4SSlava Zakharin 
375223c54c4SSlava Zakharin   rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
376223c54c4SSlava Zakharin   return success();
377223c54c4SSlava Zakharin }
378223c54c4SSlava Zakharin 
37922702cc7SSlava Zakharin /// Create linkonce_odr function to implement the power function with
38022702cc7SSlava Zakharin /// the given \p funcType type inside \p module. The \p funcType must be
38122702cc7SSlava Zakharin /// 'FloatType (*)(FloatType, IntegerType)' function type.
38222702cc7SSlava Zakharin ///
38322702cc7SSlava Zakharin /// template <typename T>
38422702cc7SSlava Zakharin /// Tb __mlir_math_fpowi_*(Tb b, Tp p) {
38522702cc7SSlava Zakharin ///   if (p == Tp{0})
38622702cc7SSlava Zakharin ///     return Tb{1};
38722702cc7SSlava Zakharin ///   bool isNegativePower{p < Tp{0}}
38822702cc7SSlava Zakharin ///   bool isMin{p == std::numeric_limits<Tp>::min()};
38922702cc7SSlava Zakharin ///   if (isMin) {
39022702cc7SSlava Zakharin ///     p = std::numeric_limits<Tp>::max();
39122702cc7SSlava Zakharin ///   } else if (isNegativePower) {
39222702cc7SSlava Zakharin ///     p = -p;
39322702cc7SSlava Zakharin ///   }
39422702cc7SSlava Zakharin ///   Tb result = Tb{1};
39522702cc7SSlava Zakharin ///   Tb origBase = Tb{b};
39622702cc7SSlava Zakharin ///   while (true) {
39722702cc7SSlava Zakharin ///     if (p & Tp{1})
39822702cc7SSlava Zakharin ///       result *= b;
39922702cc7SSlava Zakharin ///     p >>= Tp{1};
40022702cc7SSlava Zakharin ///     if (p == Tp{0})
40122702cc7SSlava Zakharin ///       break;
40222702cc7SSlava Zakharin ///     b *= b;
40322702cc7SSlava Zakharin ///   }
40422702cc7SSlava Zakharin ///   if (isMin) {
40522702cc7SSlava Zakharin ///     result *= origBase;
40622702cc7SSlava Zakharin ///   }
40722702cc7SSlava Zakharin ///   if (isNegativePower) {
40822702cc7SSlava Zakharin ///     result = Tb{1} / result;
40922702cc7SSlava Zakharin ///   }
41022702cc7SSlava Zakharin ///   return result;
41122702cc7SSlava Zakharin /// }
41222702cc7SSlava Zakharin static func::FuncOp createElementFPowIFunc(ModuleOp *module,
41322702cc7SSlava Zakharin                                            FunctionType funcType) {
4145550c821STres Popp   auto baseType = cast<FloatType>(funcType.getInput(0));
4155550c821STres Popp   auto powType = cast<IntegerType>(funcType.getInput(1));
41622702cc7SSlava Zakharin   ImplicitLocOpBuilder builder =
41722702cc7SSlava Zakharin       ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
41822702cc7SSlava Zakharin 
41922702cc7SSlava Zakharin   std::string funcName("__mlir_math_fpowi");
42022702cc7SSlava Zakharin   llvm::raw_string_ostream nameOS(funcName);
42122702cc7SSlava Zakharin   nameOS << '_' << baseType;
42222702cc7SSlava Zakharin   nameOS << '_' << powType;
42322702cc7SSlava Zakharin   auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
42422702cc7SSlava Zakharin   LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
42522702cc7SSlava Zakharin   Attribute linkage =
42622702cc7SSlava Zakharin       LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
42722702cc7SSlava Zakharin   funcOp->setAttr("llvm.linkage", linkage);
42822702cc7SSlava Zakharin   funcOp.setPrivate();
42922702cc7SSlava Zakharin 
43022702cc7SSlava Zakharin   Block *entryBlock = funcOp.addEntryBlock();
43122702cc7SSlava Zakharin   Region *funcBody = entryBlock->getParent();
43222702cc7SSlava Zakharin 
43322702cc7SSlava Zakharin   Value bArg = funcOp.getArgument(0);
43422702cc7SSlava Zakharin   Value pArg = funcOp.getArgument(1);
43522702cc7SSlava Zakharin   builder.setInsertionPointToEnd(entryBlock);
43622702cc7SSlava Zakharin   Value oneBValue = builder.create<arith::ConstantOp>(
43722702cc7SSlava Zakharin       baseType, builder.getFloatAttr(baseType, 1.0));
43822702cc7SSlava Zakharin   Value zeroPValue = builder.create<arith::ConstantOp>(
43922702cc7SSlava Zakharin       powType, builder.getIntegerAttr(powType, 0));
44022702cc7SSlava Zakharin   Value onePValue = builder.create<arith::ConstantOp>(
44122702cc7SSlava Zakharin       powType, builder.getIntegerAttr(powType, 1));
44222702cc7SSlava Zakharin   Value minPValue = builder.create<arith::ConstantOp>(
44322702cc7SSlava Zakharin       powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue(
44422702cc7SSlava Zakharin                                                    powType.getWidth())));
44522702cc7SSlava Zakharin   Value maxPValue = builder.create<arith::ConstantOp>(
44622702cc7SSlava Zakharin       powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue(
44722702cc7SSlava Zakharin                                                    powType.getWidth())));
44822702cc7SSlava Zakharin 
44922702cc7SSlava Zakharin   // if (p == Tp{0})
45022702cc7SSlava Zakharin   //   return Tb{1};
45122702cc7SSlava Zakharin   auto pIsZero =
45222702cc7SSlava Zakharin       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue);
45322702cc7SSlava Zakharin   Block *thenBlock = builder.createBlock(funcBody);
45422702cc7SSlava Zakharin   builder.create<func::ReturnOp>(oneBValue);
45522702cc7SSlava Zakharin   Block *fallthroughBlock = builder.createBlock(funcBody);
45622702cc7SSlava Zakharin   // Set up conditional branch for (p == Tp{0}).
45722702cc7SSlava Zakharin   builder.setInsertionPointToEnd(pIsZero->getBlock());
45822702cc7SSlava Zakharin   builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock);
45922702cc7SSlava Zakharin 
46022702cc7SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
46122702cc7SSlava Zakharin   // bool isNegativePower{p < Tp{0}}
46222702cc7SSlava Zakharin   auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg,
46322702cc7SSlava Zakharin                                               zeroPValue);
46422702cc7SSlava Zakharin   // bool isMin{p == std::numeric_limits<Tp>::min()};
46522702cc7SSlava Zakharin   auto pIsMin =
46622702cc7SSlava Zakharin       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue);
46722702cc7SSlava Zakharin 
46822702cc7SSlava Zakharin   // if (isMin) {
46922702cc7SSlava Zakharin   //   p = std::numeric_limits<Tp>::max();
47022702cc7SSlava Zakharin   // } else if (isNegativePower) {
47122702cc7SSlava Zakharin   //   p = -p;
47222702cc7SSlava Zakharin   // }
47322702cc7SSlava Zakharin   Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg);
47422702cc7SSlava Zakharin   auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg);
47522702cc7SSlava Zakharin   pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit);
47622702cc7SSlava Zakharin 
47722702cc7SSlava Zakharin   // Tb result = Tb{1};
47822702cc7SSlava Zakharin   // Tb origBase = Tb{b};
47922702cc7SSlava Zakharin   // while (true) {
48022702cc7SSlava Zakharin   //   if (p & Tp{1})
48122702cc7SSlava Zakharin   //     result *= b;
48222702cc7SSlava Zakharin   //   p >>= Tp{1};
48322702cc7SSlava Zakharin   //   if (p == Tp{0})
48422702cc7SSlava Zakharin   //     break;
48522702cc7SSlava Zakharin   //   b *= b;
48622702cc7SSlava Zakharin   // }
48722702cc7SSlava Zakharin   Block *loopHeader = builder.createBlock(
48822702cc7SSlava Zakharin       funcBody, funcBody->end(), {baseType, baseType, powType},
48922702cc7SSlava Zakharin       {builder.getLoc(), builder.getLoc(), builder.getLoc()});
49022702cc7SSlava Zakharin   // Set initial values of 'result', 'b' and 'p' for the loop.
49122702cc7SSlava Zakharin   builder.setInsertionPointToEnd(pInit->getBlock());
49222702cc7SSlava Zakharin   builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit});
49322702cc7SSlava Zakharin 
49422702cc7SSlava Zakharin   // Create loop body.
49522702cc7SSlava Zakharin   Value resultTmp = loopHeader->getArgument(0);
49622702cc7SSlava Zakharin   Value baseTmp = loopHeader->getArgument(1);
49722702cc7SSlava Zakharin   Value powerTmp = loopHeader->getArgument(2);
49822702cc7SSlava Zakharin   builder.setInsertionPointToEnd(loopHeader);
49922702cc7SSlava Zakharin 
50022702cc7SSlava Zakharin   //   if (p & Tp{1})
50122702cc7SSlava Zakharin   auto powerTmpIsOdd = builder.create<arith::CmpIOp>(
50222702cc7SSlava Zakharin       arith::CmpIPredicate::ne,
50322702cc7SSlava Zakharin       builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue);
50422702cc7SSlava Zakharin   thenBlock = builder.createBlock(funcBody);
50522702cc7SSlava Zakharin   //     result *= b;
50622702cc7SSlava Zakharin   Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp);
50722702cc7SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
50822702cc7SSlava Zakharin                                          builder.getLoc());
50922702cc7SSlava Zakharin   builder.setInsertionPointToEnd(thenBlock);
51022702cc7SSlava Zakharin   builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
51122702cc7SSlava Zakharin   // Set up conditional branch for (p & Tp{1}).
51222702cc7SSlava Zakharin   builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock());
51322702cc7SSlava Zakharin   builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock,
51422702cc7SSlava Zakharin                                    resultTmp);
51522702cc7SSlava Zakharin   // Merged 'result'.
51622702cc7SSlava Zakharin   newResultTmp = fallthroughBlock->getArgument(0);
51722702cc7SSlava Zakharin 
51822702cc7SSlava Zakharin   //   p >>= Tp{1};
51922702cc7SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
52022702cc7SSlava Zakharin   Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue);
52122702cc7SSlava Zakharin 
52222702cc7SSlava Zakharin   //   if (p == Tp{0})
52322702cc7SSlava Zakharin   auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
52422702cc7SSlava Zakharin                                                       newPowerTmp, zeroPValue);
52522702cc7SSlava Zakharin   //     break;
52622702cc7SSlava Zakharin   //
52722702cc7SSlava Zakharin   // The conditional branch is finalized below with a jump to
52822702cc7SSlava Zakharin   // the loop exit block.
52922702cc7SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody);
53022702cc7SSlava Zakharin 
53122702cc7SSlava Zakharin   //   b *= b;
53222702cc7SSlava Zakharin   // }
53322702cc7SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
53422702cc7SSlava Zakharin   Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp);
53522702cc7SSlava Zakharin   // Pass new values for 'result', 'b' and 'p' to the loop header.
53622702cc7SSlava Zakharin   builder.create<cf::BranchOp>(
53722702cc7SSlava Zakharin       ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader);
53822702cc7SSlava Zakharin 
53922702cc7SSlava Zakharin   // Set up conditional branch for early loop exit:
54022702cc7SSlava Zakharin   //   if (p == Tp{0})
54122702cc7SSlava Zakharin   //     break;
54222702cc7SSlava Zakharin   Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType,
54322702cc7SSlava Zakharin                                         builder.getLoc());
54422702cc7SSlava Zakharin   builder.setInsertionPointToEnd(newPowerIsZero->getBlock());
54522702cc7SSlava Zakharin   builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp,
54622702cc7SSlava Zakharin                                    fallthroughBlock, ValueRange{});
54722702cc7SSlava Zakharin 
54822702cc7SSlava Zakharin   // if (isMin) {
54922702cc7SSlava Zakharin   //   result *= origBase;
55022702cc7SSlava Zakharin   // }
55122702cc7SSlava Zakharin   newResultTmp = loopExit->getArgument(0);
55222702cc7SSlava Zakharin   thenBlock = builder.createBlock(funcBody);
55322702cc7SSlava Zakharin   fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
55422702cc7SSlava Zakharin                                          builder.getLoc());
55522702cc7SSlava Zakharin   builder.setInsertionPointToEnd(loopExit);
55622702cc7SSlava Zakharin   builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock,
55722702cc7SSlava Zakharin                                    newResultTmp);
55822702cc7SSlava Zakharin   builder.setInsertionPointToEnd(thenBlock);
55922702cc7SSlava Zakharin   newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg);
56022702cc7SSlava Zakharin   builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock);
56122702cc7SSlava Zakharin 
56222702cc7SSlava Zakharin   /// if (isNegativePower) {
56322702cc7SSlava Zakharin   ///   result = Tb{1} / result;
56422702cc7SSlava Zakharin   /// }
56522702cc7SSlava Zakharin   newResultTmp = fallthroughBlock->getArgument(0);
56622702cc7SSlava Zakharin   thenBlock = builder.createBlock(funcBody);
56722702cc7SSlava Zakharin   Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType,
56822702cc7SSlava Zakharin                                            builder.getLoc());
56922702cc7SSlava Zakharin   builder.setInsertionPointToEnd(fallthroughBlock);
57022702cc7SSlava Zakharin   builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock,
57122702cc7SSlava Zakharin                                    newResultTmp);
57222702cc7SSlava Zakharin   builder.setInsertionPointToEnd(thenBlock);
57322702cc7SSlava Zakharin   newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp);
57422702cc7SSlava Zakharin   builder.create<cf::BranchOp>(newResultTmp, returnBlock);
57522702cc7SSlava Zakharin 
57622702cc7SSlava Zakharin   // return result;
57722702cc7SSlava Zakharin   builder.setInsertionPointToEnd(returnBlock);
57822702cc7SSlava Zakharin   builder.create<func::ReturnOp>(returnBlock->getArgument(0));
57922702cc7SSlava Zakharin 
58022702cc7SSlava Zakharin   return funcOp;
58122702cc7SSlava Zakharin }
58222702cc7SSlava Zakharin 
58322702cc7SSlava Zakharin /// Convert FPowI into a call to a local function implementing
58422702cc7SSlava Zakharin /// the power operation. The local function computes a scalar result,
58522702cc7SSlava Zakharin /// so vector forms of FPowI are linearized.
58622702cc7SSlava Zakharin LogicalResult
58722702cc7SSlava Zakharin FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
58822702cc7SSlava Zakharin                                  PatternRewriter &rewriter) const {
5895550c821STres Popp   if (dyn_cast<VectorType>(op.getType()))
59022702cc7SSlava Zakharin     return rewriter.notifyMatchFailure(op, "non-scalar operation");
59122702cc7SSlava Zakharin 
59222702cc7SSlava Zakharin   FunctionType funcType = getElementalFuncTypeForOp(op);
59322702cc7SSlava Zakharin 
59422702cc7SSlava Zakharin   // The outlined software implementation must have been already
59522702cc7SSlava Zakharin   // generated.
596bfbccfa1SJeremy Kun   func::FuncOp elementFunc = getFuncOpCallback(op, funcType);
59722702cc7SSlava Zakharin   if (!elementFunc)
59822702cc7SSlava Zakharin     return rewriter.notifyMatchFailure(op, "missing software implementation");
59922702cc7SSlava Zakharin 
60022702cc7SSlava Zakharin   rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands());
60122702cc7SSlava Zakharin   return success();
60222702cc7SSlava Zakharin }
60322702cc7SSlava Zakharin 
604bfbccfa1SJeremy Kun /// Create function to implement the ctlz function the given \p elementType type
605bfbccfa1SJeremy Kun /// inside \p module. The \p elementType must be IntegerType, an the created
606bfbccfa1SJeremy Kun /// function has 'IntegerType (*)(IntegerType)' function type.
607bfbccfa1SJeremy Kun ///
608bfbccfa1SJeremy Kun /// template <typename T>
609bfbccfa1SJeremy Kun /// T __mlir_math_ctlz_*(T x) {
610bfbccfa1SJeremy Kun ///     bits = sizeof(x) * 8;
611bfbccfa1SJeremy Kun ///     if (x == 0)
612bfbccfa1SJeremy Kun ///       return bits;
613bfbccfa1SJeremy Kun ///
614bfbccfa1SJeremy Kun ///     uint32_t n = 0;
615bfbccfa1SJeremy Kun ///     for (int i = 1; i < bits; ++i) {
616bfbccfa1SJeremy Kun ///         if (x < 0) continue;
617bfbccfa1SJeremy Kun ///         n++;
618bfbccfa1SJeremy Kun ///         x <<= 1;
619bfbccfa1SJeremy Kun ///     }
620bfbccfa1SJeremy Kun ///     return n;
621bfbccfa1SJeremy Kun /// }
622bfbccfa1SJeremy Kun ///
623bfbccfa1SJeremy Kun /// Converts to (for i32):
624bfbccfa1SJeremy Kun ///
625bfbccfa1SJeremy Kun /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 {
626bfbccfa1SJeremy Kun ///   %c_32 = arith.constant 32 : index
627bfbccfa1SJeremy Kun ///   %c_0 = arith.constant 0 : i32
628bfbccfa1SJeremy Kun ///   %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1
629bfbccfa1SJeremy Kun ///   %out = scf.if %arg_eq_zero {
630bfbccfa1SJeremy Kun ///     scf.yield %c_32 : i32
631bfbccfa1SJeremy Kun ///   } else {
632bfbccfa1SJeremy Kun ///     %c_1index = arith.constant 1 : index
633bfbccfa1SJeremy Kun ///     %c_1i32 = arith.constant 1 : i32
634bfbccfa1SJeremy Kun ///     %n = arith.constant 0 : i32
635bfbccfa1SJeremy Kun ///     %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index
636bfbccfa1SJeremy Kun ///         iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) {
637bfbccfa1SJeremy Kun ///       %cond = arith.cmpi slt, %arg_iter, %c_0 : i32
638bfbccfa1SJeremy Kun ///       %yield_val = scf.if %cond {
639bfbccfa1SJeremy Kun ///         scf.yield %arg_iter, %n_iter : i32, i32
640bfbccfa1SJeremy Kun ///       } else {
641bfbccfa1SJeremy Kun ///         %arg_next = arith.shli %arg_iter, %c_1i32 : i32
642bfbccfa1SJeremy Kun ///         %n_next = arith.addi %n_iter, %c_1i32 : i32
643bfbccfa1SJeremy Kun ///         scf.yield %arg_next, %n_next : i32, i32
644bfbccfa1SJeremy Kun ///       }
645bfbccfa1SJeremy Kun ///       scf.yield %yield_val: i32, i32
646bfbccfa1SJeremy Kun ///     }
647bfbccfa1SJeremy Kun ///     scf.yield %n_out : i32
648bfbccfa1SJeremy Kun ///   }
649bfbccfa1SJeremy Kun ///   return %out: i32
650bfbccfa1SJeremy Kun /// }
651bfbccfa1SJeremy Kun static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
6525550c821STres Popp   if (!isa<IntegerType>(elementType)) {
653bfbccfa1SJeremy Kun     LLVM_DEBUG({
654bfbccfa1SJeremy Kun       DBGS() << "non-integer element type for CtlzFunc; type was: ";
655bfbccfa1SJeremy Kun       elementType.print(llvm::dbgs());
656bfbccfa1SJeremy Kun     });
657bfbccfa1SJeremy Kun     llvm_unreachable("non-integer element type");
658bfbccfa1SJeremy Kun   }
659bfbccfa1SJeremy Kun   int64_t bitWidth = elementType.getIntOrFloatBitWidth();
660bfbccfa1SJeremy Kun 
661bfbccfa1SJeremy Kun   Location loc = module->getLoc();
662bfbccfa1SJeremy Kun   ImplicitLocOpBuilder builder =
663bfbccfa1SJeremy Kun       ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody());
664bfbccfa1SJeremy Kun 
665bfbccfa1SJeremy Kun   std::string funcName("__mlir_math_ctlz");
666bfbccfa1SJeremy Kun   llvm::raw_string_ostream nameOS(funcName);
667bfbccfa1SJeremy Kun   nameOS << '_' << elementType;
668bfbccfa1SJeremy Kun   FunctionType funcType =
669bfbccfa1SJeremy Kun       FunctionType::get(builder.getContext(), {elementType}, elementType);
670bfbccfa1SJeremy Kun   auto funcOp = builder.create<func::FuncOp>(funcName, funcType);
671bfbccfa1SJeremy Kun 
672bfbccfa1SJeremy Kun   // LinkonceODR ensures that there is only one implementation of this function
673bfbccfa1SJeremy Kun   // across all math.ctlz functions that are lowered in this way.
674bfbccfa1SJeremy Kun   LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR;
675bfbccfa1SJeremy Kun   Attribute linkage =
676bfbccfa1SJeremy Kun       LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
677bfbccfa1SJeremy Kun   funcOp->setAttr("llvm.linkage", linkage);
678bfbccfa1SJeremy Kun   funcOp.setPrivate();
679bfbccfa1SJeremy Kun 
680bfbccfa1SJeremy Kun   // set the insertion point to the start of the function
681bfbccfa1SJeremy Kun   Block *funcBody = funcOp.addEntryBlock();
682bfbccfa1SJeremy Kun   builder.setInsertionPointToStart(funcBody);
683bfbccfa1SJeremy Kun 
684bfbccfa1SJeremy Kun   Value arg = funcOp.getArgument(0);
685bfbccfa1SJeremy Kun   Type indexType = builder.getIndexType();
686bfbccfa1SJeremy Kun   Value bitWidthValue = builder.create<arith::ConstantOp>(
687bfbccfa1SJeremy Kun       elementType, builder.getIntegerAttr(elementType, bitWidth));
688bfbccfa1SJeremy Kun   Value zeroValue = builder.create<arith::ConstantOp>(
689bfbccfa1SJeremy Kun       elementType, builder.getIntegerAttr(elementType, 0));
690bfbccfa1SJeremy Kun 
691bfbccfa1SJeremy Kun   Value inputEqZero =
692bfbccfa1SJeremy Kun       builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue);
693bfbccfa1SJeremy Kun 
694bfbccfa1SJeremy Kun   // if input == 0, return bit width, else enter loop.
695bfbccfa1SJeremy Kun   scf::IfOp ifOp = builder.create<scf::IfOp>(
696bfbccfa1SJeremy Kun       elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true);
697bfbccfa1SJeremy Kun   ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
698bfbccfa1SJeremy Kun 
699bfbccfa1SJeremy Kun   auto elseBuilder =
700bfbccfa1SJeremy Kun       ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
701bfbccfa1SJeremy Kun 
702bfbccfa1SJeremy Kun   Value oneIndex = elseBuilder.create<arith::ConstantOp>(
703bfbccfa1SJeremy Kun       indexType, elseBuilder.getIndexAttr(1));
704bfbccfa1SJeremy Kun   Value oneValue = elseBuilder.create<arith::ConstantOp>(
705bfbccfa1SJeremy Kun       elementType, elseBuilder.getIntegerAttr(elementType, 1));
706bfbccfa1SJeremy Kun   Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>(
707bfbccfa1SJeremy Kun       indexType, elseBuilder.getIndexAttr(bitWidth));
708bfbccfa1SJeremy Kun   Value nValue = elseBuilder.create<arith::ConstantOp>(
709bfbccfa1SJeremy Kun       elementType, elseBuilder.getIntegerAttr(elementType, 0));
710bfbccfa1SJeremy Kun 
711bfbccfa1SJeremy Kun   auto loop = elseBuilder.create<scf::ForOp>(
712bfbccfa1SJeremy Kun       oneIndex, bitWidthIndex, oneIndex,
713bfbccfa1SJeremy Kun       // Initial values for two loop induction variables, the arg which is being
714bfbccfa1SJeremy Kun       // shifted left in each iteration, and the n value which tracks the count
715bfbccfa1SJeremy Kun       // of leading zeros.
716bfbccfa1SJeremy Kun       ValueRange{arg, nValue},
717bfbccfa1SJeremy Kun       // Callback to build the body of the for loop
718bfbccfa1SJeremy Kun       //   if (arg < 0) {
719bfbccfa1SJeremy Kun       //     continue;
720bfbccfa1SJeremy Kun       //   } else {
721bfbccfa1SJeremy Kun       //     n++;
722bfbccfa1SJeremy Kun       //     arg <<= 1;
723bfbccfa1SJeremy Kun       //   }
724bfbccfa1SJeremy Kun       [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
725bfbccfa1SJeremy Kun         Value argIter = args[0];
726bfbccfa1SJeremy Kun         Value nIter = args[1];
727bfbccfa1SJeremy Kun 
728bfbccfa1SJeremy Kun         Value argIsNonNegative = b.create<arith::CmpIOp>(
729bfbccfa1SJeremy Kun             loc, arith::CmpIPredicate::slt, argIter, zeroValue);
730bfbccfa1SJeremy Kun         scf::IfOp ifOp = b.create<scf::IfOp>(
731bfbccfa1SJeremy Kun             loc, argIsNonNegative,
732bfbccfa1SJeremy Kun             [&](OpBuilder &b, Location loc) {
733bfbccfa1SJeremy Kun               // If arg is negative, continue (effectively, break)
734bfbccfa1SJeremy Kun               b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter});
735bfbccfa1SJeremy Kun             },
736bfbccfa1SJeremy Kun             [&](OpBuilder &b, Location loc) {
737bfbccfa1SJeremy Kun               // Otherwise, increment n and shift arg left.
738bfbccfa1SJeremy Kun               Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue);
739bfbccfa1SJeremy Kun               Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue);
740bfbccfa1SJeremy Kun               b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext});
741bfbccfa1SJeremy Kun             });
742bfbccfa1SJeremy Kun         b.create<scf::YieldOp>(loc, ifOp.getResults());
743bfbccfa1SJeremy Kun       });
744bfbccfa1SJeremy Kun   elseBuilder.create<scf::YieldOp>(loop.getResult(1));
745bfbccfa1SJeremy Kun 
746bfbccfa1SJeremy Kun   builder.create<func::ReturnOp>(ifOp.getResult(0));
747bfbccfa1SJeremy Kun   return funcOp;
748bfbccfa1SJeremy Kun }
749bfbccfa1SJeremy Kun 
750bfbccfa1SJeremy Kun /// Convert ctlz into a call to a local function implementing the ctlz
751bfbccfa1SJeremy Kun /// operation.
752bfbccfa1SJeremy Kun LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
753bfbccfa1SJeremy Kun                                               PatternRewriter &rewriter) const {
7545550c821STres Popp   if (dyn_cast<VectorType>(op.getType()))
755bfbccfa1SJeremy Kun     return rewriter.notifyMatchFailure(op, "non-scalar operation");
756bfbccfa1SJeremy Kun 
757bfbccfa1SJeremy Kun   Type type = getElementTypeOrSelf(op.getResult().getType());
758bfbccfa1SJeremy Kun   func::FuncOp elementFunc = getFuncOpCallback(op, type);
759bfbccfa1SJeremy Kun   if (!elementFunc)
760bfbccfa1SJeremy Kun     return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
761bfbccfa1SJeremy Kun       diag << "Missing software implementation for op " << op->getName()
762bfbccfa1SJeremy Kun            << " and type " << type;
763bfbccfa1SJeremy Kun     });
764bfbccfa1SJeremy Kun 
765bfbccfa1SJeremy Kun   rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand());
766bfbccfa1SJeremy Kun   return success();
767bfbccfa1SJeremy Kun }
768bfbccfa1SJeremy Kun 
769223c54c4SSlava Zakharin namespace {
770223c54c4SSlava Zakharin struct ConvertMathToFuncsPass
77167d0d7acSMichele Scuttari     : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> {
772039b969bSMichele Scuttari   ConvertMathToFuncsPass() = default;
77322702cc7SSlava Zakharin   ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options)
77422702cc7SSlava Zakharin       : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {}
775223c54c4SSlava Zakharin 
776223c54c4SSlava Zakharin   void runOnOperation() override;
777223c54c4SSlava Zakharin 
778223c54c4SSlava Zakharin private:
77922702cc7SSlava Zakharin   // Return true, if this FPowI operation must be converted
78022702cc7SSlava Zakharin   // because the width of its exponent's type is greater than
78122702cc7SSlava Zakharin   // or equal to minWidthOfFPowIExponent option value.
78222702cc7SSlava Zakharin   bool isFPowIConvertible(math::FPowIOp op);
78322702cc7SSlava Zakharin 
784*7ad63c0eSLongsheng Mou   // Reture true, if operation is integer type.
785*7ad63c0eSLongsheng Mou   bool isConvertible(Operation *op);
786*7ad63c0eSLongsheng Mou 
787223c54c4SSlava Zakharin   // Generate outlined implementations for power operations
788bfbccfa1SJeremy Kun   // and store them in funcImpls map.
789bfbccfa1SJeremy Kun   void generateOpImplementations();
790223c54c4SSlava Zakharin 
791bfbccfa1SJeremy Kun   // A map between pairs of (operation, type) deduced from operations that this
792bfbccfa1SJeremy Kun   // pass will convert, and the corresponding outlined software implementations
793bfbccfa1SJeremy Kun   // of these operations for the given type.
794bfbccfa1SJeremy Kun   DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls;
795223c54c4SSlava Zakharin };
796223c54c4SSlava Zakharin } // namespace
797223c54c4SSlava Zakharin 
79822702cc7SSlava Zakharin bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
79922702cc7SSlava Zakharin   auto expTy =
8005550c821STres Popp       dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
80122702cc7SSlava Zakharin   return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
80222702cc7SSlava Zakharin }
80322702cc7SSlava Zakharin 
804*7ad63c0eSLongsheng Mou bool ConvertMathToFuncsPass::isConvertible(Operation *op) {
805*7ad63c0eSLongsheng Mou   return isa<IntegerType>(getElementTypeOrSelf(op->getResult(0).getType()));
806*7ad63c0eSLongsheng Mou }
807*7ad63c0eSLongsheng Mou 
808bfbccfa1SJeremy Kun void ConvertMathToFuncsPass::generateOpImplementations() {
809223c54c4SSlava Zakharin   ModuleOp module = getOperation();
810223c54c4SSlava Zakharin 
811223c54c4SSlava Zakharin   module.walk([&](Operation *op) {
81222702cc7SSlava Zakharin     TypeSwitch<Operation *>(op)
813bfbccfa1SJeremy Kun         .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) {
814*7ad63c0eSLongsheng Mou           if (!convertCtlz || !isConvertible(op))
815ce47090dSSlava Zakharin             return;
816bfbccfa1SJeremy Kun           Type resultType = getElementTypeOrSelf(op.getResult().getType());
817bfbccfa1SJeremy Kun 
818bfbccfa1SJeremy Kun           // Generate the software implementation of this operation,
819bfbccfa1SJeremy Kun           // if it has not been generated yet.
820bfbccfa1SJeremy Kun           auto key = std::pair(op->getName(), resultType);
821bfbccfa1SJeremy Kun           auto entry = funcImpls.try_emplace(key, func::FuncOp{});
822bfbccfa1SJeremy Kun           if (entry.second)
823bfbccfa1SJeremy Kun             entry.first->second = createCtlzFunc(&module, resultType);
824bfbccfa1SJeremy Kun         })
82522702cc7SSlava Zakharin         .Case<math::IPowIOp>([&](math::IPowIOp op) {
826*7ad63c0eSLongsheng Mou           if (!isConvertible(op))
827*7ad63c0eSLongsheng Mou             return;
828*7ad63c0eSLongsheng Mou 
829223c54c4SSlava Zakharin           Type resultType = getElementTypeOrSelf(op.getResult().getType());
830223c54c4SSlava Zakharin 
831223c54c4SSlava Zakharin           // Generate the software implementation of this operation,
832223c54c4SSlava Zakharin           // if it has not been generated yet.
833bfbccfa1SJeremy Kun           auto key = std::pair(op->getName(), resultType);
834bfbccfa1SJeremy Kun           auto entry = funcImpls.try_emplace(key, func::FuncOp{});
835223c54c4SSlava Zakharin           if (entry.second)
836223c54c4SSlava Zakharin             entry.first->second = createElementIPowIFunc(&module, resultType);
83722702cc7SSlava Zakharin         })
83822702cc7SSlava Zakharin         .Case<math::FPowIOp>([&](math::FPowIOp op) {
83922702cc7SSlava Zakharin           if (!isFPowIConvertible(op))
84022702cc7SSlava Zakharin             return;
84122702cc7SSlava Zakharin 
84222702cc7SSlava Zakharin           FunctionType funcType = getElementalFuncTypeForOp(op);
84322702cc7SSlava Zakharin 
84422702cc7SSlava Zakharin           // Generate the software implementation of this operation,
84522702cc7SSlava Zakharin           // if it has not been generated yet.
84622702cc7SSlava Zakharin           // FPowI implementations are mapped via the FunctionType
84722702cc7SSlava Zakharin           // created from the operation's result and operands.
848bfbccfa1SJeremy Kun           auto key = std::pair(op->getName(), funcType);
849bfbccfa1SJeremy Kun           auto entry = funcImpls.try_emplace(key, func::FuncOp{});
85022702cc7SSlava Zakharin           if (entry.second)
85122702cc7SSlava Zakharin             entry.first->second = createElementFPowIFunc(&module, funcType);
852223c54c4SSlava Zakharin         });
853223c54c4SSlava Zakharin   });
854223c54c4SSlava Zakharin }
855223c54c4SSlava Zakharin 
856223c54c4SSlava Zakharin void ConvertMathToFuncsPass::runOnOperation() {
857223c54c4SSlava Zakharin   ModuleOp module = getOperation();
858223c54c4SSlava Zakharin 
859223c54c4SSlava Zakharin   // Create outlined implementations for power operations.
860bfbccfa1SJeremy Kun   generateOpImplementations();
861223c54c4SSlava Zakharin 
862223c54c4SSlava Zakharin   RewritePatternSet patterns(&getContext());
863bfbccfa1SJeremy Kun   patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>,
864bfbccfa1SJeremy Kun                VecOpToScalarOp<math::CountLeadingZerosOp>>(
86522702cc7SSlava Zakharin       patterns.getContext());
866223c54c4SSlava Zakharin 
867bfbccfa1SJeremy Kun   // For the given Type Returns FuncOp stored in funcImpls map.
868bfbccfa1SJeremy Kun   auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp {
869bfbccfa1SJeremy Kun     auto it = funcImpls.find(std::pair(op->getName(), type));
870bfbccfa1SJeremy Kun     if (it == funcImpls.end())
871223c54c4SSlava Zakharin       return {};
872223c54c4SSlava Zakharin 
873223c54c4SSlava Zakharin     return it->second;
874223c54c4SSlava Zakharin   };
87522702cc7SSlava Zakharin   patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(),
876bfbccfa1SJeremy Kun                                                  getFuncOpByType);
877bfbccfa1SJeremy Kun 
878bfbccfa1SJeremy Kun   if (convertCtlz)
879bfbccfa1SJeremy Kun     patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType);
880223c54c4SSlava Zakharin 
881223c54c4SSlava Zakharin   ConversionTarget target(getContext());
882abc362a1SJakub Kuderski   target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
883bfbccfa1SJeremy Kun                          func::FuncDialect, scf::SCFDialect,
884bfbccfa1SJeremy Kun                          vector::VectorDialect>();
885bfbccfa1SJeremy Kun 
886*7ad63c0eSLongsheng Mou   target.addDynamicallyLegalOp<math::IPowIOp>(
887*7ad63c0eSLongsheng Mou       [this](math::IPowIOp op) { return !isConvertible(op); });
888*7ad63c0eSLongsheng Mou   if (convertCtlz) {
889*7ad63c0eSLongsheng Mou     target.addDynamicallyLegalOp<math::CountLeadingZerosOp>(
890*7ad63c0eSLongsheng Mou         [this](math::CountLeadingZerosOp op) { return !isConvertible(op); });
891*7ad63c0eSLongsheng Mou   }
89222702cc7SSlava Zakharin   target.addDynamicallyLegalOp<math::FPowIOp>(
89322702cc7SSlava Zakharin       [this](math::FPowIOp op) { return !isFPowIConvertible(op); });
894223c54c4SSlava Zakharin   if (failed(applyPartialConversion(module, target, std::move(patterns))))
895223c54c4SSlava Zakharin     signalPassFailure();
896223c54c4SSlava Zakharin }
897