1223c54c4SSlava Zakharin //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===// 2223c54c4SSlava Zakharin // 3223c54c4SSlava Zakharin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4223c54c4SSlava Zakharin // See https://llvm.org/LICENSE.txt for license information. 5223c54c4SSlava Zakharin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6223c54c4SSlava Zakharin // 7223c54c4SSlava Zakharin //===----------------------------------------------------------------------===// 8223c54c4SSlava Zakharin 9223c54c4SSlava Zakharin #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" 1067d0d7acSMichele Scuttari 11abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 12223c54c4SSlava Zakharin #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 13223c54c4SSlava Zakharin #include "mlir/Dialect/Func/IR/FuncOps.h" 14223c54c4SSlava Zakharin #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15223c54c4SSlava Zakharin #include "mlir/Dialect/Math/IR/Math.h" 16bfbccfa1SJeremy Kun #include "mlir/Dialect/SCF/IR/SCF.h" 17223c54c4SSlava Zakharin #include "mlir/Dialect/Utils/IndexingUtils.h" 18223c54c4SSlava Zakharin #include "mlir/Dialect/Vector/IR/VectorOps.h" 19223c54c4SSlava Zakharin #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 20223c54c4SSlava Zakharin #include "mlir/IR/ImplicitLocOpBuilder.h" 21223c54c4SSlava Zakharin #include "mlir/IR/TypeUtilities.h" 2267d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 23223c54c4SSlava Zakharin #include "mlir/Transforms/DialectConversion.h" 24223c54c4SSlava Zakharin #include "llvm/ADT/DenseMap.h" 25223c54c4SSlava Zakharin #include "llvm/ADT/TypeSwitch.h" 26bfbccfa1SJeremy Kun #include "llvm/Support/Debug.h" 27223c54c4SSlava Zakharin 2867d0d7acSMichele Scuttari namespace mlir { 2967d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTMATHTOFUNCS 3067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 3167d0d7acSMichele Scuttari } // namespace mlir 3267d0d7acSMichele Scuttari 33223c54c4SSlava Zakharin using namespace mlir; 34223c54c4SSlava Zakharin 35bfbccfa1SJeremy Kun #define DEBUG_TYPE "math-to-funcs" 36bfbccfa1SJeremy Kun #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 37bfbccfa1SJeremy Kun 38223c54c4SSlava Zakharin namespace { 39223c54c4SSlava Zakharin // Pattern to convert vector operations to scalar operations. 40223c54c4SSlava Zakharin template <typename Op> 41223c54c4SSlava Zakharin struct VecOpToScalarOp : public OpRewritePattern<Op> { 42223c54c4SSlava Zakharin public: 43223c54c4SSlava Zakharin using OpRewritePattern<Op>::OpRewritePattern; 44223c54c4SSlava Zakharin 45223c54c4SSlava Zakharin LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 46223c54c4SSlava Zakharin }; 47223c54c4SSlava Zakharin 48223c54c4SSlava Zakharin // Callback type for getting pre-generated FuncOp implementing 49bfbccfa1SJeremy Kun // an operation of the given type. 50bfbccfa1SJeremy Kun using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>; 51223c54c4SSlava Zakharin 52223c54c4SSlava Zakharin // Pattern to convert scalar IPowIOp into a call of outlined 53223c54c4SSlava Zakharin // software implementation. 5422702cc7SSlava Zakharin class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> { 55223c54c4SSlava Zakharin public: 56bfbccfa1SJeremy Kun IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb) 57223c54c4SSlava Zakharin : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {} 58223c54c4SSlava Zakharin 59223c54c4SSlava Zakharin /// Convert IPowI into a call to a local function implementing 60223c54c4SSlava Zakharin /// the power operation. The local function computes a scalar result, 61223c54c4SSlava Zakharin /// so vector forms of IPowI are linearized. 62223c54c4SSlava Zakharin LogicalResult matchAndRewrite(math::IPowIOp op, 63223c54c4SSlava Zakharin PatternRewriter &rewriter) const final; 6422702cc7SSlava Zakharin 6522702cc7SSlava Zakharin private: 66bfbccfa1SJeremy Kun GetFuncCallbackTy getFuncOpCallback; 6722702cc7SSlava Zakharin }; 6822702cc7SSlava Zakharin 6922702cc7SSlava Zakharin // Pattern to convert scalar FPowIOp into a call of outlined 7022702cc7SSlava Zakharin // software implementation. 7122702cc7SSlava Zakharin class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> { 7222702cc7SSlava Zakharin public: 73bfbccfa1SJeremy Kun FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb) 7422702cc7SSlava Zakharin : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {} 7522702cc7SSlava Zakharin 7622702cc7SSlava Zakharin /// Convert FPowI into a call to a local function implementing 7722702cc7SSlava Zakharin /// the power operation. The local function computes a scalar result, 7822702cc7SSlava Zakharin /// so vector forms of FPowI are linearized. 7922702cc7SSlava Zakharin LogicalResult matchAndRewrite(math::FPowIOp op, 8022702cc7SSlava Zakharin PatternRewriter &rewriter) const final; 8122702cc7SSlava Zakharin 8222702cc7SSlava Zakharin private: 83bfbccfa1SJeremy Kun GetFuncCallbackTy getFuncOpCallback; 84bfbccfa1SJeremy Kun }; 85bfbccfa1SJeremy Kun 86bfbccfa1SJeremy Kun // Pattern to convert scalar ctlz into a call of outlined software 87bfbccfa1SJeremy Kun // implementation. 88bfbccfa1SJeremy Kun class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> { 89bfbccfa1SJeremy Kun public: 90bfbccfa1SJeremy Kun CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb) 91bfbccfa1SJeremy Kun : OpRewritePattern<math::CountLeadingZerosOp>(context), 92bfbccfa1SJeremy Kun getFuncOpCallback(cb) {} 93bfbccfa1SJeremy Kun 94bfbccfa1SJeremy Kun /// Convert ctlz into a call to a local function implementing 95bfbccfa1SJeremy Kun /// the count leading zeros operation. 96bfbccfa1SJeremy Kun LogicalResult matchAndRewrite(math::CountLeadingZerosOp op, 97bfbccfa1SJeremy Kun PatternRewriter &rewriter) const final; 98bfbccfa1SJeremy Kun 99bfbccfa1SJeremy Kun private: 100bfbccfa1SJeremy Kun GetFuncCallbackTy getFuncOpCallback; 101223c54c4SSlava Zakharin }; 102223c54c4SSlava Zakharin } // namespace 103223c54c4SSlava Zakharin 104223c54c4SSlava Zakharin template <typename Op> 105223c54c4SSlava Zakharin LogicalResult 106223c54c4SSlava Zakharin VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 107223c54c4SSlava Zakharin Type opType = op.getType(); 108223c54c4SSlava Zakharin Location loc = op.getLoc(); 1095550c821STres Popp auto vecType = dyn_cast<VectorType>(opType); 110223c54c4SSlava Zakharin 111223c54c4SSlava Zakharin if (!vecType) 112223c54c4SSlava Zakharin return rewriter.notifyMatchFailure(op, "not a vector operation"); 113223c54c4SSlava Zakharin if (!vecType.hasRank()) 114223c54c4SSlava Zakharin return rewriter.notifyMatchFailure(op, "unknown vector rank"); 115223c54c4SSlava Zakharin ArrayRef<int64_t> shape = vecType.getShape(); 116223c54c4SSlava Zakharin int64_t numElements = vecType.getNumElements(); 117223c54c4SSlava Zakharin 11822702cc7SSlava Zakharin Type resultElementType = vecType.getElementType(); 11922702cc7SSlava Zakharin Attribute initValueAttr; 1205550c821STres Popp if (isa<FloatType>(resultElementType)) 12122702cc7SSlava Zakharin initValueAttr = FloatAttr::get(resultElementType, 0.0); 12222702cc7SSlava Zakharin else 12322702cc7SSlava Zakharin initValueAttr = IntegerAttr::get(resultElementType, 0); 124223c54c4SSlava Zakharin Value result = rewriter.create<arith::ConstantOp>( 12522702cc7SSlava Zakharin loc, DenseElementsAttr::get(vecType, initValueAttr)); 1267a69a9d7SNicolas Vasilache SmallVector<int64_t> strides = computeStrides(shape); 127223c54c4SSlava Zakharin for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { 128203fad47SNicolas Vasilache SmallVector<int64_t> positions = delinearize(linearIndex, strides); 129223c54c4SSlava Zakharin SmallVector<Value> operands; 130223c54c4SSlava Zakharin for (Value input : op->getOperands()) 131223c54c4SSlava Zakharin operands.push_back( 132223c54c4SSlava Zakharin rewriter.create<vector::ExtractOp>(loc, input, positions)); 133223c54c4SSlava Zakharin Value scalarOp = 134223c54c4SSlava Zakharin rewriter.create<Op>(loc, vecType.getElementType(), operands); 135223c54c4SSlava Zakharin result = 136223c54c4SSlava Zakharin rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); 137223c54c4SSlava Zakharin } 138223c54c4SSlava Zakharin rewriter.replaceOp(op, result); 139223c54c4SSlava Zakharin return success(); 140223c54c4SSlava Zakharin } 141223c54c4SSlava Zakharin 14222702cc7SSlava Zakharin static FunctionType getElementalFuncTypeForOp(Operation *op) { 14322702cc7SSlava Zakharin SmallVector<Type, 1> resultTys(op->getNumResults()); 14422702cc7SSlava Zakharin SmallVector<Type, 2> inputTys(op->getNumOperands()); 14522702cc7SSlava Zakharin std::transform(op->result_type_begin(), op->result_type_end(), 14622702cc7SSlava Zakharin resultTys.begin(), 14722702cc7SSlava Zakharin [](Type ty) { return getElementTypeOrSelf(ty); }); 14822702cc7SSlava Zakharin std::transform(op->operand_type_begin(), op->operand_type_end(), 14922702cc7SSlava Zakharin inputTys.begin(), 15022702cc7SSlava Zakharin [](Type ty) { return getElementTypeOrSelf(ty); }); 15122702cc7SSlava Zakharin return FunctionType::get(op->getContext(), inputTys, resultTys); 15222702cc7SSlava Zakharin } 15322702cc7SSlava Zakharin 154223c54c4SSlava Zakharin /// Create linkonce_odr function to implement the power function with 15522702cc7SSlava Zakharin /// the given \p elementType type inside \p module. The \p elementType 15622702cc7SSlava Zakharin /// must be IntegerType, an the created function has 157223c54c4SSlava Zakharin /// 'IntegerType (*)(IntegerType, IntegerType)' function type. 158223c54c4SSlava Zakharin /// 159223c54c4SSlava Zakharin /// template <typename T> 160223c54c4SSlava Zakharin /// T __mlir_math_ipowi_*(T b, T p) { 161223c54c4SSlava Zakharin /// if (p == T(0)) 162223c54c4SSlava Zakharin /// return T(1); 163223c54c4SSlava Zakharin /// if (p < T(0)) { 164223c54c4SSlava Zakharin /// if (b == T(0)) 165223c54c4SSlava Zakharin /// return T(1) / T(0); // trigger div-by-zero 166223c54c4SSlava Zakharin /// if (b == T(1)) 167223c54c4SSlava Zakharin /// return T(1); 168223c54c4SSlava Zakharin /// if (b == T(-1)) { 169223c54c4SSlava Zakharin /// if (p & T(1)) 170223c54c4SSlava Zakharin /// return T(-1); 171223c54c4SSlava Zakharin /// return T(1); 172223c54c4SSlava Zakharin /// } 173223c54c4SSlava Zakharin /// return T(0); 174223c54c4SSlava Zakharin /// } 175223c54c4SSlava Zakharin /// T result = T(1); 176223c54c4SSlava Zakharin /// while (true) { 177223c54c4SSlava Zakharin /// if (p & T(1)) 178223c54c4SSlava Zakharin /// result *= b; 179223c54c4SSlava Zakharin /// p >>= T(1); 180223c54c4SSlava Zakharin /// if (p == T(0)) 181223c54c4SSlava Zakharin /// return result; 182223c54c4SSlava Zakharin /// b *= b; 183223c54c4SSlava Zakharin /// } 184223c54c4SSlava Zakharin /// } 185223c54c4SSlava Zakharin static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { 1865550c821STres Popp assert(isa<IntegerType>(elementType) && 187223c54c4SSlava Zakharin "non-integer element type for IPowIOp"); 188223c54c4SSlava Zakharin 189223c54c4SSlava Zakharin ImplicitLocOpBuilder builder = 190223c54c4SSlava Zakharin ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); 191223c54c4SSlava Zakharin 192223c54c4SSlava Zakharin std::string funcName("__mlir_math_ipowi"); 193223c54c4SSlava Zakharin llvm::raw_string_ostream nameOS(funcName); 194223c54c4SSlava Zakharin nameOS << '_' << elementType; 195223c54c4SSlava Zakharin 196223c54c4SSlava Zakharin FunctionType funcType = FunctionType::get( 197223c54c4SSlava Zakharin builder.getContext(), {elementType, elementType}, elementType); 198223c54c4SSlava Zakharin auto funcOp = builder.create<func::FuncOp>(funcName, funcType); 199223c54c4SSlava Zakharin LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; 200223c54c4SSlava Zakharin Attribute linkage = 201223c54c4SSlava Zakharin LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 202223c54c4SSlava Zakharin funcOp->setAttr("llvm.linkage", linkage); 203223c54c4SSlava Zakharin funcOp.setPrivate(); 204223c54c4SSlava Zakharin 205223c54c4SSlava Zakharin Block *entryBlock = funcOp.addEntryBlock(); 206223c54c4SSlava Zakharin Region *funcBody = entryBlock->getParent(); 207223c54c4SSlava Zakharin 208223c54c4SSlava Zakharin Value bArg = funcOp.getArgument(0); 209223c54c4SSlava Zakharin Value pArg = funcOp.getArgument(1); 210223c54c4SSlava Zakharin builder.setInsertionPointToEnd(entryBlock); 211223c54c4SSlava Zakharin Value zeroValue = builder.create<arith::ConstantOp>( 212223c54c4SSlava Zakharin elementType, builder.getIntegerAttr(elementType, 0)); 213223c54c4SSlava Zakharin Value oneValue = builder.create<arith::ConstantOp>( 214223c54c4SSlava Zakharin elementType, builder.getIntegerAttr(elementType, 1)); 215223c54c4SSlava Zakharin Value minusOneValue = builder.create<arith::ConstantOp>( 216223c54c4SSlava Zakharin elementType, 217223c54c4SSlava Zakharin builder.getIntegerAttr(elementType, 218223c54c4SSlava Zakharin APInt(elementType.getIntOrFloatBitWidth(), -1ULL, 219223c54c4SSlava Zakharin /*isSigned=*/true))); 220223c54c4SSlava Zakharin 221223c54c4SSlava Zakharin // if (p == T(0)) 222223c54c4SSlava Zakharin // return T(1); 223223c54c4SSlava Zakharin auto pIsZero = 224223c54c4SSlava Zakharin builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue); 225223c54c4SSlava Zakharin Block *thenBlock = builder.createBlock(funcBody); 226223c54c4SSlava Zakharin builder.create<func::ReturnOp>(oneValue); 227223c54c4SSlava Zakharin Block *fallthroughBlock = builder.createBlock(funcBody); 228223c54c4SSlava Zakharin // Set up conditional branch for (p == T(0)). 229223c54c4SSlava Zakharin builder.setInsertionPointToEnd(pIsZero->getBlock()); 230223c54c4SSlava Zakharin builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock); 231223c54c4SSlava Zakharin 232223c54c4SSlava Zakharin // if (p < T(0)) { 233223c54c4SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 234223c54c4SSlava Zakharin auto pIsNeg = 235223c54c4SSlava Zakharin builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue); 236223c54c4SSlava Zakharin // if (b == T(0)) 237223c54c4SSlava Zakharin builder.createBlock(funcBody); 238223c54c4SSlava Zakharin auto bIsZero = 239223c54c4SSlava Zakharin builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue); 240223c54c4SSlava Zakharin // return T(1) / T(0); 241223c54c4SSlava Zakharin thenBlock = builder.createBlock(funcBody); 242223c54c4SSlava Zakharin builder.create<func::ReturnOp>( 243223c54c4SSlava Zakharin builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult()); 244223c54c4SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody); 245223c54c4SSlava Zakharin // Set up conditional branch for (b == T(0)). 246223c54c4SSlava Zakharin builder.setInsertionPointToEnd(bIsZero->getBlock()); 247223c54c4SSlava Zakharin builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock); 248223c54c4SSlava Zakharin 249223c54c4SSlava Zakharin // if (b == T(1)) 250223c54c4SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 251223c54c4SSlava Zakharin auto bIsOne = 252223c54c4SSlava Zakharin builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue); 253223c54c4SSlava Zakharin // return T(1); 254223c54c4SSlava Zakharin thenBlock = builder.createBlock(funcBody); 255223c54c4SSlava Zakharin builder.create<func::ReturnOp>(oneValue); 256223c54c4SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody); 257223c54c4SSlava Zakharin // Set up conditional branch for (b == T(1)). 258223c54c4SSlava Zakharin builder.setInsertionPointToEnd(bIsOne->getBlock()); 259223c54c4SSlava Zakharin builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock); 260223c54c4SSlava Zakharin 261223c54c4SSlava Zakharin // if (b == T(-1)) { 262223c54c4SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 263223c54c4SSlava Zakharin auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 264223c54c4SSlava Zakharin bArg, minusOneValue); 265223c54c4SSlava Zakharin // if (p & T(1)) 266223c54c4SSlava Zakharin builder.createBlock(funcBody); 267223c54c4SSlava Zakharin auto pIsOdd = builder.create<arith::CmpIOp>( 268223c54c4SSlava Zakharin arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue), 269223c54c4SSlava Zakharin zeroValue); 270223c54c4SSlava Zakharin // return T(-1); 271223c54c4SSlava Zakharin thenBlock = builder.createBlock(funcBody); 272223c54c4SSlava Zakharin builder.create<func::ReturnOp>(minusOneValue); 273223c54c4SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody); 274223c54c4SSlava Zakharin // Set up conditional branch for (p & T(1)). 275223c54c4SSlava Zakharin builder.setInsertionPointToEnd(pIsOdd->getBlock()); 276223c54c4SSlava Zakharin builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock); 277223c54c4SSlava Zakharin 278223c54c4SSlava Zakharin // return T(1); 279223c54c4SSlava Zakharin // } // b == T(-1) 280223c54c4SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 281223c54c4SSlava Zakharin builder.create<func::ReturnOp>(oneValue); 282223c54c4SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody); 283223c54c4SSlava Zakharin // Set up conditional branch for (b == T(-1)). 284223c54c4SSlava Zakharin builder.setInsertionPointToEnd(bIsMinusOne->getBlock()); 285223c54c4SSlava Zakharin builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(), 286223c54c4SSlava Zakharin fallthroughBlock); 287223c54c4SSlava Zakharin 288223c54c4SSlava Zakharin // return T(0); 289223c54c4SSlava Zakharin // } // (p < T(0)) 290223c54c4SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 291223c54c4SSlava Zakharin builder.create<func::ReturnOp>(zeroValue); 292223c54c4SSlava Zakharin Block *loopHeader = builder.createBlock( 293223c54c4SSlava Zakharin funcBody, funcBody->end(), {elementType, elementType, elementType}, 294223c54c4SSlava Zakharin {builder.getLoc(), builder.getLoc(), builder.getLoc()}); 295223c54c4SSlava Zakharin // Set up conditional branch for (p < T(0)). 296223c54c4SSlava Zakharin builder.setInsertionPointToEnd(pIsNeg->getBlock()); 297223c54c4SSlava Zakharin // Set initial values of 'result', 'b' and 'p' for the loop. 298223c54c4SSlava Zakharin builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader, 299223c54c4SSlava Zakharin ValueRange{oneValue, bArg, pArg}); 300223c54c4SSlava Zakharin 301223c54c4SSlava Zakharin // T result = T(1); 302223c54c4SSlava Zakharin // while (true) { 303223c54c4SSlava Zakharin // if (p & T(1)) 304223c54c4SSlava Zakharin // result *= b; 305223c54c4SSlava Zakharin // p >>= T(1); 306223c54c4SSlava Zakharin // if (p == T(0)) 307223c54c4SSlava Zakharin // return result; 308223c54c4SSlava Zakharin // b *= b; 309223c54c4SSlava Zakharin // } 310223c54c4SSlava Zakharin Value resultTmp = loopHeader->getArgument(0); 311223c54c4SSlava Zakharin Value baseTmp = loopHeader->getArgument(1); 312223c54c4SSlava Zakharin Value powerTmp = loopHeader->getArgument(2); 313223c54c4SSlava Zakharin builder.setInsertionPointToEnd(loopHeader); 314223c54c4SSlava Zakharin 315223c54c4SSlava Zakharin // if (p & T(1)) 316223c54c4SSlava Zakharin auto powerTmpIsOdd = builder.create<arith::CmpIOp>( 317223c54c4SSlava Zakharin arith::CmpIPredicate::ne, 318223c54c4SSlava Zakharin builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue); 319223c54c4SSlava Zakharin thenBlock = builder.createBlock(funcBody); 320223c54c4SSlava Zakharin // result *= b; 321223c54c4SSlava Zakharin Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp); 322223c54c4SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType, 323223c54c4SSlava Zakharin builder.getLoc()); 324223c54c4SSlava Zakharin builder.setInsertionPointToEnd(thenBlock); 325223c54c4SSlava Zakharin builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); 326223c54c4SSlava Zakharin // Set up conditional branch for (p & T(1)). 327223c54c4SSlava Zakharin builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); 328223c54c4SSlava Zakharin builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock, 329223c54c4SSlava Zakharin resultTmp); 330223c54c4SSlava Zakharin // Merged 'result'. 331223c54c4SSlava Zakharin newResultTmp = fallthroughBlock->getArgument(0); 332223c54c4SSlava Zakharin 333223c54c4SSlava Zakharin // p >>= T(1); 334223c54c4SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 335223c54c4SSlava Zakharin Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue); 336223c54c4SSlava Zakharin 337223c54c4SSlava Zakharin // if (p == T(0)) 338223c54c4SSlava Zakharin auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 339223c54c4SSlava Zakharin newPowerTmp, zeroValue); 340223c54c4SSlava Zakharin // return result; 341223c54c4SSlava Zakharin thenBlock = builder.createBlock(funcBody); 342223c54c4SSlava Zakharin builder.create<func::ReturnOp>(newResultTmp); 343223c54c4SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody); 344223c54c4SSlava Zakharin // Set up conditional branch for (p == T(0)). 345223c54c4SSlava Zakharin builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); 346223c54c4SSlava Zakharin builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock); 347223c54c4SSlava Zakharin 348223c54c4SSlava Zakharin // b *= b; 349223c54c4SSlava Zakharin // } 350223c54c4SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 351223c54c4SSlava Zakharin Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp); 352223c54c4SSlava Zakharin // Pass new values for 'result', 'b' and 'p' to the loop header. 353223c54c4SSlava Zakharin builder.create<cf::BranchOp>( 354223c54c4SSlava Zakharin ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); 355223c54c4SSlava Zakharin return funcOp; 356223c54c4SSlava Zakharin } 357223c54c4SSlava Zakharin 358223c54c4SSlava Zakharin /// Convert IPowI into a call to a local function implementing 359223c54c4SSlava Zakharin /// the power operation. The local function computes a scalar result, 360223c54c4SSlava Zakharin /// so vector forms of IPowI are linearized. 361223c54c4SSlava Zakharin LogicalResult 362223c54c4SSlava Zakharin IPowIOpLowering::matchAndRewrite(math::IPowIOp op, 363223c54c4SSlava Zakharin PatternRewriter &rewriter) const { 3645550c821STres Popp auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType()); 365223c54c4SSlava Zakharin 366223c54c4SSlava Zakharin if (!baseType) 367223c54c4SSlava Zakharin return rewriter.notifyMatchFailure(op, "non-integer base operand"); 368223c54c4SSlava Zakharin 369223c54c4SSlava Zakharin // The outlined software implementation must have been already 370223c54c4SSlava Zakharin // generated. 371bfbccfa1SJeremy Kun func::FuncOp elementFunc = getFuncOpCallback(op, baseType); 372223c54c4SSlava Zakharin if (!elementFunc) 373223c54c4SSlava Zakharin return rewriter.notifyMatchFailure(op, "missing software implementation"); 374223c54c4SSlava Zakharin 375223c54c4SSlava Zakharin rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands()); 376223c54c4SSlava Zakharin return success(); 377223c54c4SSlava Zakharin } 378223c54c4SSlava Zakharin 37922702cc7SSlava Zakharin /// Create linkonce_odr function to implement the power function with 38022702cc7SSlava Zakharin /// the given \p funcType type inside \p module. The \p funcType must be 38122702cc7SSlava Zakharin /// 'FloatType (*)(FloatType, IntegerType)' function type. 38222702cc7SSlava Zakharin /// 38322702cc7SSlava Zakharin /// template <typename T> 38422702cc7SSlava Zakharin /// Tb __mlir_math_fpowi_*(Tb b, Tp p) { 38522702cc7SSlava Zakharin /// if (p == Tp{0}) 38622702cc7SSlava Zakharin /// return Tb{1}; 38722702cc7SSlava Zakharin /// bool isNegativePower{p < Tp{0}} 38822702cc7SSlava Zakharin /// bool isMin{p == std::numeric_limits<Tp>::min()}; 38922702cc7SSlava Zakharin /// if (isMin) { 39022702cc7SSlava Zakharin /// p = std::numeric_limits<Tp>::max(); 39122702cc7SSlava Zakharin /// } else if (isNegativePower) { 39222702cc7SSlava Zakharin /// p = -p; 39322702cc7SSlava Zakharin /// } 39422702cc7SSlava Zakharin /// Tb result = Tb{1}; 39522702cc7SSlava Zakharin /// Tb origBase = Tb{b}; 39622702cc7SSlava Zakharin /// while (true) { 39722702cc7SSlava Zakharin /// if (p & Tp{1}) 39822702cc7SSlava Zakharin /// result *= b; 39922702cc7SSlava Zakharin /// p >>= Tp{1}; 40022702cc7SSlava Zakharin /// if (p == Tp{0}) 40122702cc7SSlava Zakharin /// break; 40222702cc7SSlava Zakharin /// b *= b; 40322702cc7SSlava Zakharin /// } 40422702cc7SSlava Zakharin /// if (isMin) { 40522702cc7SSlava Zakharin /// result *= origBase; 40622702cc7SSlava Zakharin /// } 40722702cc7SSlava Zakharin /// if (isNegativePower) { 40822702cc7SSlava Zakharin /// result = Tb{1} / result; 40922702cc7SSlava Zakharin /// } 41022702cc7SSlava Zakharin /// return result; 41122702cc7SSlava Zakharin /// } 41222702cc7SSlava Zakharin static func::FuncOp createElementFPowIFunc(ModuleOp *module, 41322702cc7SSlava Zakharin FunctionType funcType) { 4145550c821STres Popp auto baseType = cast<FloatType>(funcType.getInput(0)); 4155550c821STres Popp auto powType = cast<IntegerType>(funcType.getInput(1)); 41622702cc7SSlava Zakharin ImplicitLocOpBuilder builder = 41722702cc7SSlava Zakharin ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); 41822702cc7SSlava Zakharin 41922702cc7SSlava Zakharin std::string funcName("__mlir_math_fpowi"); 42022702cc7SSlava Zakharin llvm::raw_string_ostream nameOS(funcName); 42122702cc7SSlava Zakharin nameOS << '_' << baseType; 42222702cc7SSlava Zakharin nameOS << '_' << powType; 42322702cc7SSlava Zakharin auto funcOp = builder.create<func::FuncOp>(funcName, funcType); 42422702cc7SSlava Zakharin LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; 42522702cc7SSlava Zakharin Attribute linkage = 42622702cc7SSlava Zakharin LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 42722702cc7SSlava Zakharin funcOp->setAttr("llvm.linkage", linkage); 42822702cc7SSlava Zakharin funcOp.setPrivate(); 42922702cc7SSlava Zakharin 43022702cc7SSlava Zakharin Block *entryBlock = funcOp.addEntryBlock(); 43122702cc7SSlava Zakharin Region *funcBody = entryBlock->getParent(); 43222702cc7SSlava Zakharin 43322702cc7SSlava Zakharin Value bArg = funcOp.getArgument(0); 43422702cc7SSlava Zakharin Value pArg = funcOp.getArgument(1); 43522702cc7SSlava Zakharin builder.setInsertionPointToEnd(entryBlock); 43622702cc7SSlava Zakharin Value oneBValue = builder.create<arith::ConstantOp>( 43722702cc7SSlava Zakharin baseType, builder.getFloatAttr(baseType, 1.0)); 43822702cc7SSlava Zakharin Value zeroPValue = builder.create<arith::ConstantOp>( 43922702cc7SSlava Zakharin powType, builder.getIntegerAttr(powType, 0)); 44022702cc7SSlava Zakharin Value onePValue = builder.create<arith::ConstantOp>( 44122702cc7SSlava Zakharin powType, builder.getIntegerAttr(powType, 1)); 44222702cc7SSlava Zakharin Value minPValue = builder.create<arith::ConstantOp>( 44322702cc7SSlava Zakharin powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue( 44422702cc7SSlava Zakharin powType.getWidth()))); 44522702cc7SSlava Zakharin Value maxPValue = builder.create<arith::ConstantOp>( 44622702cc7SSlava Zakharin powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue( 44722702cc7SSlava Zakharin powType.getWidth()))); 44822702cc7SSlava Zakharin 44922702cc7SSlava Zakharin // if (p == Tp{0}) 45022702cc7SSlava Zakharin // return Tb{1}; 45122702cc7SSlava Zakharin auto pIsZero = 45222702cc7SSlava Zakharin builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue); 45322702cc7SSlava Zakharin Block *thenBlock = builder.createBlock(funcBody); 45422702cc7SSlava Zakharin builder.create<func::ReturnOp>(oneBValue); 45522702cc7SSlava Zakharin Block *fallthroughBlock = builder.createBlock(funcBody); 45622702cc7SSlava Zakharin // Set up conditional branch for (p == Tp{0}). 45722702cc7SSlava Zakharin builder.setInsertionPointToEnd(pIsZero->getBlock()); 45822702cc7SSlava Zakharin builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock); 45922702cc7SSlava Zakharin 46022702cc7SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 46122702cc7SSlava Zakharin // bool isNegativePower{p < Tp{0}} 46222702cc7SSlava Zakharin auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, 46322702cc7SSlava Zakharin zeroPValue); 46422702cc7SSlava Zakharin // bool isMin{p == std::numeric_limits<Tp>::min()}; 46522702cc7SSlava Zakharin auto pIsMin = 46622702cc7SSlava Zakharin builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue); 46722702cc7SSlava Zakharin 46822702cc7SSlava Zakharin // if (isMin) { 46922702cc7SSlava Zakharin // p = std::numeric_limits<Tp>::max(); 47022702cc7SSlava Zakharin // } else if (isNegativePower) { 47122702cc7SSlava Zakharin // p = -p; 47222702cc7SSlava Zakharin // } 47322702cc7SSlava Zakharin Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg); 47422702cc7SSlava Zakharin auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg); 47522702cc7SSlava Zakharin pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit); 47622702cc7SSlava Zakharin 47722702cc7SSlava Zakharin // Tb result = Tb{1}; 47822702cc7SSlava Zakharin // Tb origBase = Tb{b}; 47922702cc7SSlava Zakharin // while (true) { 48022702cc7SSlava Zakharin // if (p & Tp{1}) 48122702cc7SSlava Zakharin // result *= b; 48222702cc7SSlava Zakharin // p >>= Tp{1}; 48322702cc7SSlava Zakharin // if (p == Tp{0}) 48422702cc7SSlava Zakharin // break; 48522702cc7SSlava Zakharin // b *= b; 48622702cc7SSlava Zakharin // } 48722702cc7SSlava Zakharin Block *loopHeader = builder.createBlock( 48822702cc7SSlava Zakharin funcBody, funcBody->end(), {baseType, baseType, powType}, 48922702cc7SSlava Zakharin {builder.getLoc(), builder.getLoc(), builder.getLoc()}); 49022702cc7SSlava Zakharin // Set initial values of 'result', 'b' and 'p' for the loop. 49122702cc7SSlava Zakharin builder.setInsertionPointToEnd(pInit->getBlock()); 49222702cc7SSlava Zakharin builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit}); 49322702cc7SSlava Zakharin 49422702cc7SSlava Zakharin // Create loop body. 49522702cc7SSlava Zakharin Value resultTmp = loopHeader->getArgument(0); 49622702cc7SSlava Zakharin Value baseTmp = loopHeader->getArgument(1); 49722702cc7SSlava Zakharin Value powerTmp = loopHeader->getArgument(2); 49822702cc7SSlava Zakharin builder.setInsertionPointToEnd(loopHeader); 49922702cc7SSlava Zakharin 50022702cc7SSlava Zakharin // if (p & Tp{1}) 50122702cc7SSlava Zakharin auto powerTmpIsOdd = builder.create<arith::CmpIOp>( 50222702cc7SSlava Zakharin arith::CmpIPredicate::ne, 50322702cc7SSlava Zakharin builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue); 50422702cc7SSlava Zakharin thenBlock = builder.createBlock(funcBody); 50522702cc7SSlava Zakharin // result *= b; 50622702cc7SSlava Zakharin Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp); 50722702cc7SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, 50822702cc7SSlava Zakharin builder.getLoc()); 50922702cc7SSlava Zakharin builder.setInsertionPointToEnd(thenBlock); 51022702cc7SSlava Zakharin builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); 51122702cc7SSlava Zakharin // Set up conditional branch for (p & Tp{1}). 51222702cc7SSlava Zakharin builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); 51322702cc7SSlava Zakharin builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock, 51422702cc7SSlava Zakharin resultTmp); 51522702cc7SSlava Zakharin // Merged 'result'. 51622702cc7SSlava Zakharin newResultTmp = fallthroughBlock->getArgument(0); 51722702cc7SSlava Zakharin 51822702cc7SSlava Zakharin // p >>= Tp{1}; 51922702cc7SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 52022702cc7SSlava Zakharin Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue); 52122702cc7SSlava Zakharin 52222702cc7SSlava Zakharin // if (p == Tp{0}) 52322702cc7SSlava Zakharin auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 52422702cc7SSlava Zakharin newPowerTmp, zeroPValue); 52522702cc7SSlava Zakharin // break; 52622702cc7SSlava Zakharin // 52722702cc7SSlava Zakharin // The conditional branch is finalized below with a jump to 52822702cc7SSlava Zakharin // the loop exit block. 52922702cc7SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody); 53022702cc7SSlava Zakharin 53122702cc7SSlava Zakharin // b *= b; 53222702cc7SSlava Zakharin // } 53322702cc7SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 53422702cc7SSlava Zakharin Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp); 53522702cc7SSlava Zakharin // Pass new values for 'result', 'b' and 'p' to the loop header. 53622702cc7SSlava Zakharin builder.create<cf::BranchOp>( 53722702cc7SSlava Zakharin ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); 53822702cc7SSlava Zakharin 53922702cc7SSlava Zakharin // Set up conditional branch for early loop exit: 54022702cc7SSlava Zakharin // if (p == Tp{0}) 54122702cc7SSlava Zakharin // break; 54222702cc7SSlava Zakharin Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType, 54322702cc7SSlava Zakharin builder.getLoc()); 54422702cc7SSlava Zakharin builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); 54522702cc7SSlava Zakharin builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp, 54622702cc7SSlava Zakharin fallthroughBlock, ValueRange{}); 54722702cc7SSlava Zakharin 54822702cc7SSlava Zakharin // if (isMin) { 54922702cc7SSlava Zakharin // result *= origBase; 55022702cc7SSlava Zakharin // } 55122702cc7SSlava Zakharin newResultTmp = loopExit->getArgument(0); 55222702cc7SSlava Zakharin thenBlock = builder.createBlock(funcBody); 55322702cc7SSlava Zakharin fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, 55422702cc7SSlava Zakharin builder.getLoc()); 55522702cc7SSlava Zakharin builder.setInsertionPointToEnd(loopExit); 55622702cc7SSlava Zakharin builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock, 55722702cc7SSlava Zakharin newResultTmp); 55822702cc7SSlava Zakharin builder.setInsertionPointToEnd(thenBlock); 55922702cc7SSlava Zakharin newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg); 56022702cc7SSlava Zakharin builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); 56122702cc7SSlava Zakharin 56222702cc7SSlava Zakharin /// if (isNegativePower) { 56322702cc7SSlava Zakharin /// result = Tb{1} / result; 56422702cc7SSlava Zakharin /// } 56522702cc7SSlava Zakharin newResultTmp = fallthroughBlock->getArgument(0); 56622702cc7SSlava Zakharin thenBlock = builder.createBlock(funcBody); 56722702cc7SSlava Zakharin Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType, 56822702cc7SSlava Zakharin builder.getLoc()); 56922702cc7SSlava Zakharin builder.setInsertionPointToEnd(fallthroughBlock); 57022702cc7SSlava Zakharin builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock, 57122702cc7SSlava Zakharin newResultTmp); 57222702cc7SSlava Zakharin builder.setInsertionPointToEnd(thenBlock); 57322702cc7SSlava Zakharin newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp); 57422702cc7SSlava Zakharin builder.create<cf::BranchOp>(newResultTmp, returnBlock); 57522702cc7SSlava Zakharin 57622702cc7SSlava Zakharin // return result; 57722702cc7SSlava Zakharin builder.setInsertionPointToEnd(returnBlock); 57822702cc7SSlava Zakharin builder.create<func::ReturnOp>(returnBlock->getArgument(0)); 57922702cc7SSlava Zakharin 58022702cc7SSlava Zakharin return funcOp; 58122702cc7SSlava Zakharin } 58222702cc7SSlava Zakharin 58322702cc7SSlava Zakharin /// Convert FPowI into a call to a local function implementing 58422702cc7SSlava Zakharin /// the power operation. The local function computes a scalar result, 58522702cc7SSlava Zakharin /// so vector forms of FPowI are linearized. 58622702cc7SSlava Zakharin LogicalResult 58722702cc7SSlava Zakharin FPowIOpLowering::matchAndRewrite(math::FPowIOp op, 58822702cc7SSlava Zakharin PatternRewriter &rewriter) const { 5895550c821STres Popp if (dyn_cast<VectorType>(op.getType())) 59022702cc7SSlava Zakharin return rewriter.notifyMatchFailure(op, "non-scalar operation"); 59122702cc7SSlava Zakharin 59222702cc7SSlava Zakharin FunctionType funcType = getElementalFuncTypeForOp(op); 59322702cc7SSlava Zakharin 59422702cc7SSlava Zakharin // The outlined software implementation must have been already 59522702cc7SSlava Zakharin // generated. 596bfbccfa1SJeremy Kun func::FuncOp elementFunc = getFuncOpCallback(op, funcType); 59722702cc7SSlava Zakharin if (!elementFunc) 59822702cc7SSlava Zakharin return rewriter.notifyMatchFailure(op, "missing software implementation"); 59922702cc7SSlava Zakharin 60022702cc7SSlava Zakharin rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands()); 60122702cc7SSlava Zakharin return success(); 60222702cc7SSlava Zakharin } 60322702cc7SSlava Zakharin 604bfbccfa1SJeremy Kun /// Create function to implement the ctlz function the given \p elementType type 605bfbccfa1SJeremy Kun /// inside \p module. The \p elementType must be IntegerType, an the created 606bfbccfa1SJeremy Kun /// function has 'IntegerType (*)(IntegerType)' function type. 607bfbccfa1SJeremy Kun /// 608bfbccfa1SJeremy Kun /// template <typename T> 609bfbccfa1SJeremy Kun /// T __mlir_math_ctlz_*(T x) { 610bfbccfa1SJeremy Kun /// bits = sizeof(x) * 8; 611bfbccfa1SJeremy Kun /// if (x == 0) 612bfbccfa1SJeremy Kun /// return bits; 613bfbccfa1SJeremy Kun /// 614bfbccfa1SJeremy Kun /// uint32_t n = 0; 615bfbccfa1SJeremy Kun /// for (int i = 1; i < bits; ++i) { 616bfbccfa1SJeremy Kun /// if (x < 0) continue; 617bfbccfa1SJeremy Kun /// n++; 618bfbccfa1SJeremy Kun /// x <<= 1; 619bfbccfa1SJeremy Kun /// } 620bfbccfa1SJeremy Kun /// return n; 621bfbccfa1SJeremy Kun /// } 622bfbccfa1SJeremy Kun /// 623bfbccfa1SJeremy Kun /// Converts to (for i32): 624bfbccfa1SJeremy Kun /// 625bfbccfa1SJeremy Kun /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 { 626bfbccfa1SJeremy Kun /// %c_32 = arith.constant 32 : index 627bfbccfa1SJeremy Kun /// %c_0 = arith.constant 0 : i32 628bfbccfa1SJeremy Kun /// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1 629bfbccfa1SJeremy Kun /// %out = scf.if %arg_eq_zero { 630bfbccfa1SJeremy Kun /// scf.yield %c_32 : i32 631bfbccfa1SJeremy Kun /// } else { 632bfbccfa1SJeremy Kun /// %c_1index = arith.constant 1 : index 633bfbccfa1SJeremy Kun /// %c_1i32 = arith.constant 1 : i32 634bfbccfa1SJeremy Kun /// %n = arith.constant 0 : i32 635bfbccfa1SJeremy Kun /// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index 636bfbccfa1SJeremy Kun /// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) { 637bfbccfa1SJeremy Kun /// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32 638bfbccfa1SJeremy Kun /// %yield_val = scf.if %cond { 639bfbccfa1SJeremy Kun /// scf.yield %arg_iter, %n_iter : i32, i32 640bfbccfa1SJeremy Kun /// } else { 641bfbccfa1SJeremy Kun /// %arg_next = arith.shli %arg_iter, %c_1i32 : i32 642bfbccfa1SJeremy Kun /// %n_next = arith.addi %n_iter, %c_1i32 : i32 643bfbccfa1SJeremy Kun /// scf.yield %arg_next, %n_next : i32, i32 644bfbccfa1SJeremy Kun /// } 645bfbccfa1SJeremy Kun /// scf.yield %yield_val: i32, i32 646bfbccfa1SJeremy Kun /// } 647bfbccfa1SJeremy Kun /// scf.yield %n_out : i32 648bfbccfa1SJeremy Kun /// } 649bfbccfa1SJeremy Kun /// return %out: i32 650bfbccfa1SJeremy Kun /// } 651bfbccfa1SJeremy Kun static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { 6525550c821STres Popp if (!isa<IntegerType>(elementType)) { 653bfbccfa1SJeremy Kun LLVM_DEBUG({ 654bfbccfa1SJeremy Kun DBGS() << "non-integer element type for CtlzFunc; type was: "; 655bfbccfa1SJeremy Kun elementType.print(llvm::dbgs()); 656bfbccfa1SJeremy Kun }); 657bfbccfa1SJeremy Kun llvm_unreachable("non-integer element type"); 658bfbccfa1SJeremy Kun } 659bfbccfa1SJeremy Kun int64_t bitWidth = elementType.getIntOrFloatBitWidth(); 660bfbccfa1SJeremy Kun 661bfbccfa1SJeremy Kun Location loc = module->getLoc(); 662bfbccfa1SJeremy Kun ImplicitLocOpBuilder builder = 663bfbccfa1SJeremy Kun ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody()); 664bfbccfa1SJeremy Kun 665bfbccfa1SJeremy Kun std::string funcName("__mlir_math_ctlz"); 666bfbccfa1SJeremy Kun llvm::raw_string_ostream nameOS(funcName); 667bfbccfa1SJeremy Kun nameOS << '_' << elementType; 668bfbccfa1SJeremy Kun FunctionType funcType = 669bfbccfa1SJeremy Kun FunctionType::get(builder.getContext(), {elementType}, elementType); 670bfbccfa1SJeremy Kun auto funcOp = builder.create<func::FuncOp>(funcName, funcType); 671bfbccfa1SJeremy Kun 672bfbccfa1SJeremy Kun // LinkonceODR ensures that there is only one implementation of this function 673bfbccfa1SJeremy Kun // across all math.ctlz functions that are lowered in this way. 674bfbccfa1SJeremy Kun LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; 675bfbccfa1SJeremy Kun Attribute linkage = 676bfbccfa1SJeremy Kun LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 677bfbccfa1SJeremy Kun funcOp->setAttr("llvm.linkage", linkage); 678bfbccfa1SJeremy Kun funcOp.setPrivate(); 679bfbccfa1SJeremy Kun 680bfbccfa1SJeremy Kun // set the insertion point to the start of the function 681bfbccfa1SJeremy Kun Block *funcBody = funcOp.addEntryBlock(); 682bfbccfa1SJeremy Kun builder.setInsertionPointToStart(funcBody); 683bfbccfa1SJeremy Kun 684bfbccfa1SJeremy Kun Value arg = funcOp.getArgument(0); 685bfbccfa1SJeremy Kun Type indexType = builder.getIndexType(); 686bfbccfa1SJeremy Kun Value bitWidthValue = builder.create<arith::ConstantOp>( 687bfbccfa1SJeremy Kun elementType, builder.getIntegerAttr(elementType, bitWidth)); 688bfbccfa1SJeremy Kun Value zeroValue = builder.create<arith::ConstantOp>( 689bfbccfa1SJeremy Kun elementType, builder.getIntegerAttr(elementType, 0)); 690bfbccfa1SJeremy Kun 691bfbccfa1SJeremy Kun Value inputEqZero = 692bfbccfa1SJeremy Kun builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue); 693bfbccfa1SJeremy Kun 694bfbccfa1SJeremy Kun // if input == 0, return bit width, else enter loop. 695bfbccfa1SJeremy Kun scf::IfOp ifOp = builder.create<scf::IfOp>( 696bfbccfa1SJeremy Kun elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); 697bfbccfa1SJeremy Kun ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue); 698bfbccfa1SJeremy Kun 699bfbccfa1SJeremy Kun auto elseBuilder = 700bfbccfa1SJeremy Kun ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); 701bfbccfa1SJeremy Kun 702bfbccfa1SJeremy Kun Value oneIndex = elseBuilder.create<arith::ConstantOp>( 703bfbccfa1SJeremy Kun indexType, elseBuilder.getIndexAttr(1)); 704bfbccfa1SJeremy Kun Value oneValue = elseBuilder.create<arith::ConstantOp>( 705bfbccfa1SJeremy Kun elementType, elseBuilder.getIntegerAttr(elementType, 1)); 706bfbccfa1SJeremy Kun Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>( 707bfbccfa1SJeremy Kun indexType, elseBuilder.getIndexAttr(bitWidth)); 708bfbccfa1SJeremy Kun Value nValue = elseBuilder.create<arith::ConstantOp>( 709bfbccfa1SJeremy Kun elementType, elseBuilder.getIntegerAttr(elementType, 0)); 710bfbccfa1SJeremy Kun 711bfbccfa1SJeremy Kun auto loop = elseBuilder.create<scf::ForOp>( 712bfbccfa1SJeremy Kun oneIndex, bitWidthIndex, oneIndex, 713bfbccfa1SJeremy Kun // Initial values for two loop induction variables, the arg which is being 714bfbccfa1SJeremy Kun // shifted left in each iteration, and the n value which tracks the count 715bfbccfa1SJeremy Kun // of leading zeros. 716bfbccfa1SJeremy Kun ValueRange{arg, nValue}, 717bfbccfa1SJeremy Kun // Callback to build the body of the for loop 718bfbccfa1SJeremy Kun // if (arg < 0) { 719bfbccfa1SJeremy Kun // continue; 720bfbccfa1SJeremy Kun // } else { 721bfbccfa1SJeremy Kun // n++; 722bfbccfa1SJeremy Kun // arg <<= 1; 723bfbccfa1SJeremy Kun // } 724bfbccfa1SJeremy Kun [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 725bfbccfa1SJeremy Kun Value argIter = args[0]; 726bfbccfa1SJeremy Kun Value nIter = args[1]; 727bfbccfa1SJeremy Kun 728bfbccfa1SJeremy Kun Value argIsNonNegative = b.create<arith::CmpIOp>( 729bfbccfa1SJeremy Kun loc, arith::CmpIPredicate::slt, argIter, zeroValue); 730bfbccfa1SJeremy Kun scf::IfOp ifOp = b.create<scf::IfOp>( 731bfbccfa1SJeremy Kun loc, argIsNonNegative, 732bfbccfa1SJeremy Kun [&](OpBuilder &b, Location loc) { 733bfbccfa1SJeremy Kun // If arg is negative, continue (effectively, break) 734bfbccfa1SJeremy Kun b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter}); 735bfbccfa1SJeremy Kun }, 736bfbccfa1SJeremy Kun [&](OpBuilder &b, Location loc) { 737bfbccfa1SJeremy Kun // Otherwise, increment n and shift arg left. 738bfbccfa1SJeremy Kun Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue); 739bfbccfa1SJeremy Kun Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue); 740bfbccfa1SJeremy Kun b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext}); 741bfbccfa1SJeremy Kun }); 742bfbccfa1SJeremy Kun b.create<scf::YieldOp>(loc, ifOp.getResults()); 743bfbccfa1SJeremy Kun }); 744bfbccfa1SJeremy Kun elseBuilder.create<scf::YieldOp>(loop.getResult(1)); 745bfbccfa1SJeremy Kun 746bfbccfa1SJeremy Kun builder.create<func::ReturnOp>(ifOp.getResult(0)); 747bfbccfa1SJeremy Kun return funcOp; 748bfbccfa1SJeremy Kun } 749bfbccfa1SJeremy Kun 750bfbccfa1SJeremy Kun /// Convert ctlz into a call to a local function implementing the ctlz 751bfbccfa1SJeremy Kun /// operation. 752bfbccfa1SJeremy Kun LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op, 753bfbccfa1SJeremy Kun PatternRewriter &rewriter) const { 7545550c821STres Popp if (dyn_cast<VectorType>(op.getType())) 755bfbccfa1SJeremy Kun return rewriter.notifyMatchFailure(op, "non-scalar operation"); 756bfbccfa1SJeremy Kun 757bfbccfa1SJeremy Kun Type type = getElementTypeOrSelf(op.getResult().getType()); 758bfbccfa1SJeremy Kun func::FuncOp elementFunc = getFuncOpCallback(op, type); 759bfbccfa1SJeremy Kun if (!elementFunc) 760bfbccfa1SJeremy Kun return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { 761bfbccfa1SJeremy Kun diag << "Missing software implementation for op " << op->getName() 762bfbccfa1SJeremy Kun << " and type " << type; 763bfbccfa1SJeremy Kun }); 764bfbccfa1SJeremy Kun 765bfbccfa1SJeremy Kun rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand()); 766bfbccfa1SJeremy Kun return success(); 767bfbccfa1SJeremy Kun } 768bfbccfa1SJeremy Kun 769223c54c4SSlava Zakharin namespace { 770223c54c4SSlava Zakharin struct ConvertMathToFuncsPass 77167d0d7acSMichele Scuttari : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> { 772039b969bSMichele Scuttari ConvertMathToFuncsPass() = default; 77322702cc7SSlava Zakharin ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options) 77422702cc7SSlava Zakharin : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {} 775223c54c4SSlava Zakharin 776223c54c4SSlava Zakharin void runOnOperation() override; 777223c54c4SSlava Zakharin 778223c54c4SSlava Zakharin private: 77922702cc7SSlava Zakharin // Return true, if this FPowI operation must be converted 78022702cc7SSlava Zakharin // because the width of its exponent's type is greater than 78122702cc7SSlava Zakharin // or equal to minWidthOfFPowIExponent option value. 78222702cc7SSlava Zakharin bool isFPowIConvertible(math::FPowIOp op); 78322702cc7SSlava Zakharin 784*7ad63c0eSLongsheng Mou // Reture true, if operation is integer type. 785*7ad63c0eSLongsheng Mou bool isConvertible(Operation *op); 786*7ad63c0eSLongsheng Mou 787223c54c4SSlava Zakharin // Generate outlined implementations for power operations 788bfbccfa1SJeremy Kun // and store them in funcImpls map. 789bfbccfa1SJeremy Kun void generateOpImplementations(); 790223c54c4SSlava Zakharin 791bfbccfa1SJeremy Kun // A map between pairs of (operation, type) deduced from operations that this 792bfbccfa1SJeremy Kun // pass will convert, and the corresponding outlined software implementations 793bfbccfa1SJeremy Kun // of these operations for the given type. 794bfbccfa1SJeremy Kun DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls; 795223c54c4SSlava Zakharin }; 796223c54c4SSlava Zakharin } // namespace 797223c54c4SSlava Zakharin 79822702cc7SSlava Zakharin bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) { 79922702cc7SSlava Zakharin auto expTy = 8005550c821STres Popp dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType())); 80122702cc7SSlava Zakharin return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent); 80222702cc7SSlava Zakharin } 80322702cc7SSlava Zakharin 804*7ad63c0eSLongsheng Mou bool ConvertMathToFuncsPass::isConvertible(Operation *op) { 805*7ad63c0eSLongsheng Mou return isa<IntegerType>(getElementTypeOrSelf(op->getResult(0).getType())); 806*7ad63c0eSLongsheng Mou } 807*7ad63c0eSLongsheng Mou 808bfbccfa1SJeremy Kun void ConvertMathToFuncsPass::generateOpImplementations() { 809223c54c4SSlava Zakharin ModuleOp module = getOperation(); 810223c54c4SSlava Zakharin 811223c54c4SSlava Zakharin module.walk([&](Operation *op) { 81222702cc7SSlava Zakharin TypeSwitch<Operation *>(op) 813bfbccfa1SJeremy Kun .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) { 814*7ad63c0eSLongsheng Mou if (!convertCtlz || !isConvertible(op)) 815ce47090dSSlava Zakharin return; 816bfbccfa1SJeremy Kun Type resultType = getElementTypeOrSelf(op.getResult().getType()); 817bfbccfa1SJeremy Kun 818bfbccfa1SJeremy Kun // Generate the software implementation of this operation, 819bfbccfa1SJeremy Kun // if it has not been generated yet. 820bfbccfa1SJeremy Kun auto key = std::pair(op->getName(), resultType); 821bfbccfa1SJeremy Kun auto entry = funcImpls.try_emplace(key, func::FuncOp{}); 822bfbccfa1SJeremy Kun if (entry.second) 823bfbccfa1SJeremy Kun entry.first->second = createCtlzFunc(&module, resultType); 824bfbccfa1SJeremy Kun }) 82522702cc7SSlava Zakharin .Case<math::IPowIOp>([&](math::IPowIOp op) { 826*7ad63c0eSLongsheng Mou if (!isConvertible(op)) 827*7ad63c0eSLongsheng Mou return; 828*7ad63c0eSLongsheng Mou 829223c54c4SSlava Zakharin Type resultType = getElementTypeOrSelf(op.getResult().getType()); 830223c54c4SSlava Zakharin 831223c54c4SSlava Zakharin // Generate the software implementation of this operation, 832223c54c4SSlava Zakharin // if it has not been generated yet. 833bfbccfa1SJeremy Kun auto key = std::pair(op->getName(), resultType); 834bfbccfa1SJeremy Kun auto entry = funcImpls.try_emplace(key, func::FuncOp{}); 835223c54c4SSlava Zakharin if (entry.second) 836223c54c4SSlava Zakharin entry.first->second = createElementIPowIFunc(&module, resultType); 83722702cc7SSlava Zakharin }) 83822702cc7SSlava Zakharin .Case<math::FPowIOp>([&](math::FPowIOp op) { 83922702cc7SSlava Zakharin if (!isFPowIConvertible(op)) 84022702cc7SSlava Zakharin return; 84122702cc7SSlava Zakharin 84222702cc7SSlava Zakharin FunctionType funcType = getElementalFuncTypeForOp(op); 84322702cc7SSlava Zakharin 84422702cc7SSlava Zakharin // Generate the software implementation of this operation, 84522702cc7SSlava Zakharin // if it has not been generated yet. 84622702cc7SSlava Zakharin // FPowI implementations are mapped via the FunctionType 84722702cc7SSlava Zakharin // created from the operation's result and operands. 848bfbccfa1SJeremy Kun auto key = std::pair(op->getName(), funcType); 849bfbccfa1SJeremy Kun auto entry = funcImpls.try_emplace(key, func::FuncOp{}); 85022702cc7SSlava Zakharin if (entry.second) 85122702cc7SSlava Zakharin entry.first->second = createElementFPowIFunc(&module, funcType); 852223c54c4SSlava Zakharin }); 853223c54c4SSlava Zakharin }); 854223c54c4SSlava Zakharin } 855223c54c4SSlava Zakharin 856223c54c4SSlava Zakharin void ConvertMathToFuncsPass::runOnOperation() { 857223c54c4SSlava Zakharin ModuleOp module = getOperation(); 858223c54c4SSlava Zakharin 859223c54c4SSlava Zakharin // Create outlined implementations for power operations. 860bfbccfa1SJeremy Kun generateOpImplementations(); 861223c54c4SSlava Zakharin 862223c54c4SSlava Zakharin RewritePatternSet patterns(&getContext()); 863bfbccfa1SJeremy Kun patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>, 864bfbccfa1SJeremy Kun VecOpToScalarOp<math::CountLeadingZerosOp>>( 86522702cc7SSlava Zakharin patterns.getContext()); 866223c54c4SSlava Zakharin 867bfbccfa1SJeremy Kun // For the given Type Returns FuncOp stored in funcImpls map. 868bfbccfa1SJeremy Kun auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp { 869bfbccfa1SJeremy Kun auto it = funcImpls.find(std::pair(op->getName(), type)); 870bfbccfa1SJeremy Kun if (it == funcImpls.end()) 871223c54c4SSlava Zakharin return {}; 872223c54c4SSlava Zakharin 873223c54c4SSlava Zakharin return it->second; 874223c54c4SSlava Zakharin }; 87522702cc7SSlava Zakharin patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(), 876bfbccfa1SJeremy Kun getFuncOpByType); 877bfbccfa1SJeremy Kun 878bfbccfa1SJeremy Kun if (convertCtlz) 879bfbccfa1SJeremy Kun patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType); 880223c54c4SSlava Zakharin 881223c54c4SSlava Zakharin ConversionTarget target(getContext()); 882abc362a1SJakub Kuderski target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect, 883bfbccfa1SJeremy Kun func::FuncDialect, scf::SCFDialect, 884bfbccfa1SJeremy Kun vector::VectorDialect>(); 885bfbccfa1SJeremy Kun 886*7ad63c0eSLongsheng Mou target.addDynamicallyLegalOp<math::IPowIOp>( 887*7ad63c0eSLongsheng Mou [this](math::IPowIOp op) { return !isConvertible(op); }); 888*7ad63c0eSLongsheng Mou if (convertCtlz) { 889*7ad63c0eSLongsheng Mou target.addDynamicallyLegalOp<math::CountLeadingZerosOp>( 890*7ad63c0eSLongsheng Mou [this](math::CountLeadingZerosOp op) { return !isConvertible(op); }); 891*7ad63c0eSLongsheng Mou } 89222702cc7SSlava Zakharin target.addDynamicallyLegalOp<math::FPowIOp>( 89322702cc7SSlava Zakharin [this](math::FPowIOp op) { return !isFPowIConvertible(op); }); 894223c54c4SSlava Zakharin if (failed(applyPartialConversion(module, target, std::move(patterns)))) 895223c54c4SSlava Zakharin signalPassFailure(); 896223c54c4SSlava Zakharin } 897