1 //===- MathToFuncs.cpp - Math to outlined implementation conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" 10 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Utils/IndexingUtils.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 20 #include "mlir/IR/ImplicitLocOpBuilder.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 #include "llvm/ADT/DenseMap.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/Support/Debug.h" 27 28 namespace mlir { 29 #define GEN_PASS_DEF_CONVERTMATHTOFUNCS 30 #include "mlir/Conversion/Passes.h.inc" 31 } // namespace mlir 32 33 using namespace mlir; 34 35 #define DEBUG_TYPE "math-to-funcs" 36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 37 38 namespace { 39 // Pattern to convert vector operations to scalar operations. 40 template <typename Op> 41 struct VecOpToScalarOp : public OpRewritePattern<Op> { 42 public: 43 using OpRewritePattern<Op>::OpRewritePattern; 44 45 LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; 46 }; 47 48 // Callback type for getting pre-generated FuncOp implementing 49 // an operation of the given type. 50 using GetFuncCallbackTy = function_ref<func::FuncOp(Operation *, Type)>; 51 52 // Pattern to convert scalar IPowIOp into a call of outlined 53 // software implementation. 54 class IPowIOpLowering : public OpRewritePattern<math::IPowIOp> { 55 public: 56 IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb) 57 : OpRewritePattern<math::IPowIOp>(context), getFuncOpCallback(cb) {} 58 59 /// Convert IPowI into a call to a local function implementing 60 /// the power operation. The local function computes a scalar result, 61 /// so vector forms of IPowI are linearized. 62 LogicalResult matchAndRewrite(math::IPowIOp op, 63 PatternRewriter &rewriter) const final; 64 65 private: 66 GetFuncCallbackTy getFuncOpCallback; 67 }; 68 69 // Pattern to convert scalar FPowIOp into a call of outlined 70 // software implementation. 71 class FPowIOpLowering : public OpRewritePattern<math::FPowIOp> { 72 public: 73 FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb) 74 : OpRewritePattern<math::FPowIOp>(context), getFuncOpCallback(cb) {} 75 76 /// Convert FPowI into a call to a local function implementing 77 /// the power operation. The local function computes a scalar result, 78 /// so vector forms of FPowI are linearized. 79 LogicalResult matchAndRewrite(math::FPowIOp op, 80 PatternRewriter &rewriter) const final; 81 82 private: 83 GetFuncCallbackTy getFuncOpCallback; 84 }; 85 86 // Pattern to convert scalar ctlz into a call of outlined software 87 // implementation. 88 class CtlzOpLowering : public OpRewritePattern<math::CountLeadingZerosOp> { 89 public: 90 CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb) 91 : OpRewritePattern<math::CountLeadingZerosOp>(context), 92 getFuncOpCallback(cb) {} 93 94 /// Convert ctlz into a call to a local function implementing 95 /// the count leading zeros operation. 96 LogicalResult matchAndRewrite(math::CountLeadingZerosOp op, 97 PatternRewriter &rewriter) const final; 98 99 private: 100 GetFuncCallbackTy getFuncOpCallback; 101 }; 102 } // namespace 103 104 template <typename Op> 105 LogicalResult 106 VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const { 107 Type opType = op.getType(); 108 Location loc = op.getLoc(); 109 auto vecType = dyn_cast<VectorType>(opType); 110 111 if (!vecType) 112 return rewriter.notifyMatchFailure(op, "not a vector operation"); 113 if (!vecType.hasRank()) 114 return rewriter.notifyMatchFailure(op, "unknown vector rank"); 115 ArrayRef<int64_t> shape = vecType.getShape(); 116 int64_t numElements = vecType.getNumElements(); 117 118 Type resultElementType = vecType.getElementType(); 119 Attribute initValueAttr; 120 if (isa<FloatType>(resultElementType)) 121 initValueAttr = FloatAttr::get(resultElementType, 0.0); 122 else 123 initValueAttr = IntegerAttr::get(resultElementType, 0); 124 Value result = rewriter.create<arith::ConstantOp>( 125 loc, DenseElementsAttr::get(vecType, initValueAttr)); 126 SmallVector<int64_t> strides = computeStrides(shape); 127 for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { 128 SmallVector<int64_t> positions = delinearize(linearIndex, strides); 129 SmallVector<Value> operands; 130 for (Value input : op->getOperands()) 131 operands.push_back( 132 rewriter.create<vector::ExtractOp>(loc, input, positions)); 133 Value scalarOp = 134 rewriter.create<Op>(loc, vecType.getElementType(), operands); 135 result = 136 rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions); 137 } 138 rewriter.replaceOp(op, result); 139 return success(); 140 } 141 142 static FunctionType getElementalFuncTypeForOp(Operation *op) { 143 SmallVector<Type, 1> resultTys(op->getNumResults()); 144 SmallVector<Type, 2> inputTys(op->getNumOperands()); 145 std::transform(op->result_type_begin(), op->result_type_end(), 146 resultTys.begin(), 147 [](Type ty) { return getElementTypeOrSelf(ty); }); 148 std::transform(op->operand_type_begin(), op->operand_type_end(), 149 inputTys.begin(), 150 [](Type ty) { return getElementTypeOrSelf(ty); }); 151 return FunctionType::get(op->getContext(), inputTys, resultTys); 152 } 153 154 /// Create linkonce_odr function to implement the power function with 155 /// the given \p elementType type inside \p module. The \p elementType 156 /// must be IntegerType, an the created function has 157 /// 'IntegerType (*)(IntegerType, IntegerType)' function type. 158 /// 159 /// template <typename T> 160 /// T __mlir_math_ipowi_*(T b, T p) { 161 /// if (p == T(0)) 162 /// return T(1); 163 /// if (p < T(0)) { 164 /// if (b == T(0)) 165 /// return T(1) / T(0); // trigger div-by-zero 166 /// if (b == T(1)) 167 /// return T(1); 168 /// if (b == T(-1)) { 169 /// if (p & T(1)) 170 /// return T(-1); 171 /// return T(1); 172 /// } 173 /// return T(0); 174 /// } 175 /// T result = T(1); 176 /// while (true) { 177 /// if (p & T(1)) 178 /// result *= b; 179 /// p >>= T(1); 180 /// if (p == T(0)) 181 /// return result; 182 /// b *= b; 183 /// } 184 /// } 185 static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { 186 assert(isa<IntegerType>(elementType) && 187 "non-integer element type for IPowIOp"); 188 189 ImplicitLocOpBuilder builder = 190 ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); 191 192 std::string funcName("__mlir_math_ipowi"); 193 llvm::raw_string_ostream nameOS(funcName); 194 nameOS << '_' << elementType; 195 196 FunctionType funcType = FunctionType::get( 197 builder.getContext(), {elementType, elementType}, elementType); 198 auto funcOp = builder.create<func::FuncOp>(funcName, funcType); 199 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; 200 Attribute linkage = 201 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 202 funcOp->setAttr("llvm.linkage", linkage); 203 funcOp.setPrivate(); 204 205 Block *entryBlock = funcOp.addEntryBlock(); 206 Region *funcBody = entryBlock->getParent(); 207 208 Value bArg = funcOp.getArgument(0); 209 Value pArg = funcOp.getArgument(1); 210 builder.setInsertionPointToEnd(entryBlock); 211 Value zeroValue = builder.create<arith::ConstantOp>( 212 elementType, builder.getIntegerAttr(elementType, 0)); 213 Value oneValue = builder.create<arith::ConstantOp>( 214 elementType, builder.getIntegerAttr(elementType, 1)); 215 Value minusOneValue = builder.create<arith::ConstantOp>( 216 elementType, 217 builder.getIntegerAttr(elementType, 218 APInt(elementType.getIntOrFloatBitWidth(), -1ULL, 219 /*isSigned=*/true))); 220 221 // if (p == T(0)) 222 // return T(1); 223 auto pIsZero = 224 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroValue); 225 Block *thenBlock = builder.createBlock(funcBody); 226 builder.create<func::ReturnOp>(oneValue); 227 Block *fallthroughBlock = builder.createBlock(funcBody); 228 // Set up conditional branch for (p == T(0)). 229 builder.setInsertionPointToEnd(pIsZero->getBlock()); 230 builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock); 231 232 // if (p < T(0)) { 233 builder.setInsertionPointToEnd(fallthroughBlock); 234 auto pIsNeg = 235 builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, zeroValue); 236 // if (b == T(0)) 237 builder.createBlock(funcBody); 238 auto bIsZero = 239 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, zeroValue); 240 // return T(1) / T(0); 241 thenBlock = builder.createBlock(funcBody); 242 builder.create<func::ReturnOp>( 243 builder.create<arith::DivSIOp>(oneValue, zeroValue).getResult()); 244 fallthroughBlock = builder.createBlock(funcBody); 245 // Set up conditional branch for (b == T(0)). 246 builder.setInsertionPointToEnd(bIsZero->getBlock()); 247 builder.create<cf::CondBranchOp>(bIsZero, thenBlock, fallthroughBlock); 248 249 // if (b == T(1)) 250 builder.setInsertionPointToEnd(fallthroughBlock); 251 auto bIsOne = 252 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bArg, oneValue); 253 // return T(1); 254 thenBlock = builder.createBlock(funcBody); 255 builder.create<func::ReturnOp>(oneValue); 256 fallthroughBlock = builder.createBlock(funcBody); 257 // Set up conditional branch for (b == T(1)). 258 builder.setInsertionPointToEnd(bIsOne->getBlock()); 259 builder.create<cf::CondBranchOp>(bIsOne, thenBlock, fallthroughBlock); 260 261 // if (b == T(-1)) { 262 builder.setInsertionPointToEnd(fallthroughBlock); 263 auto bIsMinusOne = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 264 bArg, minusOneValue); 265 // if (p & T(1)) 266 builder.createBlock(funcBody); 267 auto pIsOdd = builder.create<arith::CmpIOp>( 268 arith::CmpIPredicate::ne, builder.create<arith::AndIOp>(pArg, oneValue), 269 zeroValue); 270 // return T(-1); 271 thenBlock = builder.createBlock(funcBody); 272 builder.create<func::ReturnOp>(minusOneValue); 273 fallthroughBlock = builder.createBlock(funcBody); 274 // Set up conditional branch for (p & T(1)). 275 builder.setInsertionPointToEnd(pIsOdd->getBlock()); 276 builder.create<cf::CondBranchOp>(pIsOdd, thenBlock, fallthroughBlock); 277 278 // return T(1); 279 // } // b == T(-1) 280 builder.setInsertionPointToEnd(fallthroughBlock); 281 builder.create<func::ReturnOp>(oneValue); 282 fallthroughBlock = builder.createBlock(funcBody); 283 // Set up conditional branch for (b == T(-1)). 284 builder.setInsertionPointToEnd(bIsMinusOne->getBlock()); 285 builder.create<cf::CondBranchOp>(bIsMinusOne, pIsOdd->getBlock(), 286 fallthroughBlock); 287 288 // return T(0); 289 // } // (p < T(0)) 290 builder.setInsertionPointToEnd(fallthroughBlock); 291 builder.create<func::ReturnOp>(zeroValue); 292 Block *loopHeader = builder.createBlock( 293 funcBody, funcBody->end(), {elementType, elementType, elementType}, 294 {builder.getLoc(), builder.getLoc(), builder.getLoc()}); 295 // Set up conditional branch for (p < T(0)). 296 builder.setInsertionPointToEnd(pIsNeg->getBlock()); 297 // Set initial values of 'result', 'b' and 'p' for the loop. 298 builder.create<cf::CondBranchOp>(pIsNeg, bIsZero->getBlock(), loopHeader, 299 ValueRange{oneValue, bArg, pArg}); 300 301 // T result = T(1); 302 // while (true) { 303 // if (p & T(1)) 304 // result *= b; 305 // p >>= T(1); 306 // if (p == T(0)) 307 // return result; 308 // b *= b; 309 // } 310 Value resultTmp = loopHeader->getArgument(0); 311 Value baseTmp = loopHeader->getArgument(1); 312 Value powerTmp = loopHeader->getArgument(2); 313 builder.setInsertionPointToEnd(loopHeader); 314 315 // if (p & T(1)) 316 auto powerTmpIsOdd = builder.create<arith::CmpIOp>( 317 arith::CmpIPredicate::ne, 318 builder.create<arith::AndIOp>(powerTmp, oneValue), zeroValue); 319 thenBlock = builder.createBlock(funcBody); 320 // result *= b; 321 Value newResultTmp = builder.create<arith::MulIOp>(resultTmp, baseTmp); 322 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType, 323 builder.getLoc()); 324 builder.setInsertionPointToEnd(thenBlock); 325 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); 326 // Set up conditional branch for (p & T(1)). 327 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); 328 builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock, 329 resultTmp); 330 // Merged 'result'. 331 newResultTmp = fallthroughBlock->getArgument(0); 332 333 // p >>= T(1); 334 builder.setInsertionPointToEnd(fallthroughBlock); 335 Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, oneValue); 336 337 // if (p == T(0)) 338 auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 339 newPowerTmp, zeroValue); 340 // return result; 341 thenBlock = builder.createBlock(funcBody); 342 builder.create<func::ReturnOp>(newResultTmp); 343 fallthroughBlock = builder.createBlock(funcBody); 344 // Set up conditional branch for (p == T(0)). 345 builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); 346 builder.create<cf::CondBranchOp>(newPowerIsZero, thenBlock, fallthroughBlock); 347 348 // b *= b; 349 // } 350 builder.setInsertionPointToEnd(fallthroughBlock); 351 Value newBaseTmp = builder.create<arith::MulIOp>(baseTmp, baseTmp); 352 // Pass new values for 'result', 'b' and 'p' to the loop header. 353 builder.create<cf::BranchOp>( 354 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); 355 return funcOp; 356 } 357 358 /// Convert IPowI into a call to a local function implementing 359 /// the power operation. The local function computes a scalar result, 360 /// so vector forms of IPowI are linearized. 361 LogicalResult 362 IPowIOpLowering::matchAndRewrite(math::IPowIOp op, 363 PatternRewriter &rewriter) const { 364 auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType()); 365 366 if (!baseType) 367 return rewriter.notifyMatchFailure(op, "non-integer base operand"); 368 369 // The outlined software implementation must have been already 370 // generated. 371 func::FuncOp elementFunc = getFuncOpCallback(op, baseType); 372 if (!elementFunc) 373 return rewriter.notifyMatchFailure(op, "missing software implementation"); 374 375 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands()); 376 return success(); 377 } 378 379 /// Create linkonce_odr function to implement the power function with 380 /// the given \p funcType type inside \p module. The \p funcType must be 381 /// 'FloatType (*)(FloatType, IntegerType)' function type. 382 /// 383 /// template <typename T> 384 /// Tb __mlir_math_fpowi_*(Tb b, Tp p) { 385 /// if (p == Tp{0}) 386 /// return Tb{1}; 387 /// bool isNegativePower{p < Tp{0}} 388 /// bool isMin{p == std::numeric_limits<Tp>::min()}; 389 /// if (isMin) { 390 /// p = std::numeric_limits<Tp>::max(); 391 /// } else if (isNegativePower) { 392 /// p = -p; 393 /// } 394 /// Tb result = Tb{1}; 395 /// Tb origBase = Tb{b}; 396 /// while (true) { 397 /// if (p & Tp{1}) 398 /// result *= b; 399 /// p >>= Tp{1}; 400 /// if (p == Tp{0}) 401 /// break; 402 /// b *= b; 403 /// } 404 /// if (isMin) { 405 /// result *= origBase; 406 /// } 407 /// if (isNegativePower) { 408 /// result = Tb{1} / result; 409 /// } 410 /// return result; 411 /// } 412 static func::FuncOp createElementFPowIFunc(ModuleOp *module, 413 FunctionType funcType) { 414 auto baseType = cast<FloatType>(funcType.getInput(0)); 415 auto powType = cast<IntegerType>(funcType.getInput(1)); 416 ImplicitLocOpBuilder builder = 417 ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); 418 419 std::string funcName("__mlir_math_fpowi"); 420 llvm::raw_string_ostream nameOS(funcName); 421 nameOS << '_' << baseType; 422 nameOS << '_' << powType; 423 auto funcOp = builder.create<func::FuncOp>(funcName, funcType); 424 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; 425 Attribute linkage = 426 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 427 funcOp->setAttr("llvm.linkage", linkage); 428 funcOp.setPrivate(); 429 430 Block *entryBlock = funcOp.addEntryBlock(); 431 Region *funcBody = entryBlock->getParent(); 432 433 Value bArg = funcOp.getArgument(0); 434 Value pArg = funcOp.getArgument(1); 435 builder.setInsertionPointToEnd(entryBlock); 436 Value oneBValue = builder.create<arith::ConstantOp>( 437 baseType, builder.getFloatAttr(baseType, 1.0)); 438 Value zeroPValue = builder.create<arith::ConstantOp>( 439 powType, builder.getIntegerAttr(powType, 0)); 440 Value onePValue = builder.create<arith::ConstantOp>( 441 powType, builder.getIntegerAttr(powType, 1)); 442 Value minPValue = builder.create<arith::ConstantOp>( 443 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMinValue( 444 powType.getWidth()))); 445 Value maxPValue = builder.create<arith::ConstantOp>( 446 powType, builder.getIntegerAttr(powType, llvm::APInt::getSignedMaxValue( 447 powType.getWidth()))); 448 449 // if (p == Tp{0}) 450 // return Tb{1}; 451 auto pIsZero = 452 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, zeroPValue); 453 Block *thenBlock = builder.createBlock(funcBody); 454 builder.create<func::ReturnOp>(oneBValue); 455 Block *fallthroughBlock = builder.createBlock(funcBody); 456 // Set up conditional branch for (p == Tp{0}). 457 builder.setInsertionPointToEnd(pIsZero->getBlock()); 458 builder.create<cf::CondBranchOp>(pIsZero, thenBlock, fallthroughBlock); 459 460 builder.setInsertionPointToEnd(fallthroughBlock); 461 // bool isNegativePower{p < Tp{0}} 462 auto pIsNeg = builder.create<arith::CmpIOp>(arith::CmpIPredicate::sle, pArg, 463 zeroPValue); 464 // bool isMin{p == std::numeric_limits<Tp>::min()}; 465 auto pIsMin = 466 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, pArg, minPValue); 467 468 // if (isMin) { 469 // p = std::numeric_limits<Tp>::max(); 470 // } else if (isNegativePower) { 471 // p = -p; 472 // } 473 Value negP = builder.create<arith::SubIOp>(zeroPValue, pArg); 474 auto pInit = builder.create<arith::SelectOp>(pIsNeg, negP, pArg); 475 pInit = builder.create<arith::SelectOp>(pIsMin, maxPValue, pInit); 476 477 // Tb result = Tb{1}; 478 // Tb origBase = Tb{b}; 479 // while (true) { 480 // if (p & Tp{1}) 481 // result *= b; 482 // p >>= Tp{1}; 483 // if (p == Tp{0}) 484 // break; 485 // b *= b; 486 // } 487 Block *loopHeader = builder.createBlock( 488 funcBody, funcBody->end(), {baseType, baseType, powType}, 489 {builder.getLoc(), builder.getLoc(), builder.getLoc()}); 490 // Set initial values of 'result', 'b' and 'p' for the loop. 491 builder.setInsertionPointToEnd(pInit->getBlock()); 492 builder.create<cf::BranchOp>(loopHeader, ValueRange{oneBValue, bArg, pInit}); 493 494 // Create loop body. 495 Value resultTmp = loopHeader->getArgument(0); 496 Value baseTmp = loopHeader->getArgument(1); 497 Value powerTmp = loopHeader->getArgument(2); 498 builder.setInsertionPointToEnd(loopHeader); 499 500 // if (p & Tp{1}) 501 auto powerTmpIsOdd = builder.create<arith::CmpIOp>( 502 arith::CmpIPredicate::ne, 503 builder.create<arith::AndIOp>(powerTmp, onePValue), zeroPValue); 504 thenBlock = builder.createBlock(funcBody); 505 // result *= b; 506 Value newResultTmp = builder.create<arith::MulFOp>(resultTmp, baseTmp); 507 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, 508 builder.getLoc()); 509 builder.setInsertionPointToEnd(thenBlock); 510 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); 511 // Set up conditional branch for (p & Tp{1}). 512 builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); 513 builder.create<cf::CondBranchOp>(powerTmpIsOdd, thenBlock, fallthroughBlock, 514 resultTmp); 515 // Merged 'result'. 516 newResultTmp = fallthroughBlock->getArgument(0); 517 518 // p >>= Tp{1}; 519 builder.setInsertionPointToEnd(fallthroughBlock); 520 Value newPowerTmp = builder.create<arith::ShRUIOp>(powerTmp, onePValue); 521 522 // if (p == Tp{0}) 523 auto newPowerIsZero = builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 524 newPowerTmp, zeroPValue); 525 // break; 526 // 527 // The conditional branch is finalized below with a jump to 528 // the loop exit block. 529 fallthroughBlock = builder.createBlock(funcBody); 530 531 // b *= b; 532 // } 533 builder.setInsertionPointToEnd(fallthroughBlock); 534 Value newBaseTmp = builder.create<arith::MulFOp>(baseTmp, baseTmp); 535 // Pass new values for 'result', 'b' and 'p' to the loop header. 536 builder.create<cf::BranchOp>( 537 ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); 538 539 // Set up conditional branch for early loop exit: 540 // if (p == Tp{0}) 541 // break; 542 Block *loopExit = builder.createBlock(funcBody, funcBody->end(), baseType, 543 builder.getLoc()); 544 builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); 545 builder.create<cf::CondBranchOp>(newPowerIsZero, loopExit, newResultTmp, 546 fallthroughBlock, ValueRange{}); 547 548 // if (isMin) { 549 // result *= origBase; 550 // } 551 newResultTmp = loopExit->getArgument(0); 552 thenBlock = builder.createBlock(funcBody); 553 fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), baseType, 554 builder.getLoc()); 555 builder.setInsertionPointToEnd(loopExit); 556 builder.create<cf::CondBranchOp>(pIsMin, thenBlock, fallthroughBlock, 557 newResultTmp); 558 builder.setInsertionPointToEnd(thenBlock); 559 newResultTmp = builder.create<arith::MulFOp>(newResultTmp, bArg); 560 builder.create<cf::BranchOp>(newResultTmp, fallthroughBlock); 561 562 /// if (isNegativePower) { 563 /// result = Tb{1} / result; 564 /// } 565 newResultTmp = fallthroughBlock->getArgument(0); 566 thenBlock = builder.createBlock(funcBody); 567 Block *returnBlock = builder.createBlock(funcBody, funcBody->end(), baseType, 568 builder.getLoc()); 569 builder.setInsertionPointToEnd(fallthroughBlock); 570 builder.create<cf::CondBranchOp>(pIsNeg, thenBlock, returnBlock, 571 newResultTmp); 572 builder.setInsertionPointToEnd(thenBlock); 573 newResultTmp = builder.create<arith::DivFOp>(oneBValue, newResultTmp); 574 builder.create<cf::BranchOp>(newResultTmp, returnBlock); 575 576 // return result; 577 builder.setInsertionPointToEnd(returnBlock); 578 builder.create<func::ReturnOp>(returnBlock->getArgument(0)); 579 580 return funcOp; 581 } 582 583 /// Convert FPowI into a call to a local function implementing 584 /// the power operation. The local function computes a scalar result, 585 /// so vector forms of FPowI are linearized. 586 LogicalResult 587 FPowIOpLowering::matchAndRewrite(math::FPowIOp op, 588 PatternRewriter &rewriter) const { 589 if (dyn_cast<VectorType>(op.getType())) 590 return rewriter.notifyMatchFailure(op, "non-scalar operation"); 591 592 FunctionType funcType = getElementalFuncTypeForOp(op); 593 594 // The outlined software implementation must have been already 595 // generated. 596 func::FuncOp elementFunc = getFuncOpCallback(op, funcType); 597 if (!elementFunc) 598 return rewriter.notifyMatchFailure(op, "missing software implementation"); 599 600 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperands()); 601 return success(); 602 } 603 604 /// Create function to implement the ctlz function the given \p elementType type 605 /// inside \p module. The \p elementType must be IntegerType, an the created 606 /// function has 'IntegerType (*)(IntegerType)' function type. 607 /// 608 /// template <typename T> 609 /// T __mlir_math_ctlz_*(T x) { 610 /// bits = sizeof(x) * 8; 611 /// if (x == 0) 612 /// return bits; 613 /// 614 /// uint32_t n = 0; 615 /// for (int i = 1; i < bits; ++i) { 616 /// if (x < 0) continue; 617 /// n++; 618 /// x <<= 1; 619 /// } 620 /// return n; 621 /// } 622 /// 623 /// Converts to (for i32): 624 /// 625 /// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 { 626 /// %c_32 = arith.constant 32 : index 627 /// %c_0 = arith.constant 0 : i32 628 /// %arg_eq_zero = arith.cmpi eq, %arg, %c_0 : i1 629 /// %out = scf.if %arg_eq_zero { 630 /// scf.yield %c_32 : i32 631 /// } else { 632 /// %c_1index = arith.constant 1 : index 633 /// %c_1i32 = arith.constant 1 : i32 634 /// %n = arith.constant 0 : i32 635 /// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index 636 /// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) { 637 /// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32 638 /// %yield_val = scf.if %cond { 639 /// scf.yield %arg_iter, %n_iter : i32, i32 640 /// } else { 641 /// %arg_next = arith.shli %arg_iter, %c_1i32 : i32 642 /// %n_next = arith.addi %n_iter, %c_1i32 : i32 643 /// scf.yield %arg_next, %n_next : i32, i32 644 /// } 645 /// scf.yield %yield_val: i32, i32 646 /// } 647 /// scf.yield %n_out : i32 648 /// } 649 /// return %out: i32 650 /// } 651 static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { 652 if (!isa<IntegerType>(elementType)) { 653 LLVM_DEBUG({ 654 DBGS() << "non-integer element type for CtlzFunc; type was: "; 655 elementType.print(llvm::dbgs()); 656 }); 657 llvm_unreachable("non-integer element type"); 658 } 659 int64_t bitWidth = elementType.getIntOrFloatBitWidth(); 660 661 Location loc = module->getLoc(); 662 ImplicitLocOpBuilder builder = 663 ImplicitLocOpBuilder::atBlockEnd(loc, module->getBody()); 664 665 std::string funcName("__mlir_math_ctlz"); 666 llvm::raw_string_ostream nameOS(funcName); 667 nameOS << '_' << elementType; 668 FunctionType funcType = 669 FunctionType::get(builder.getContext(), {elementType}, elementType); 670 auto funcOp = builder.create<func::FuncOp>(funcName, funcType); 671 672 // LinkonceODR ensures that there is only one implementation of this function 673 // across all math.ctlz functions that are lowered in this way. 674 LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; 675 Attribute linkage = 676 LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); 677 funcOp->setAttr("llvm.linkage", linkage); 678 funcOp.setPrivate(); 679 680 // set the insertion point to the start of the function 681 Block *funcBody = funcOp.addEntryBlock(); 682 builder.setInsertionPointToStart(funcBody); 683 684 Value arg = funcOp.getArgument(0); 685 Type indexType = builder.getIndexType(); 686 Value bitWidthValue = builder.create<arith::ConstantOp>( 687 elementType, builder.getIntegerAttr(elementType, bitWidth)); 688 Value zeroValue = builder.create<arith::ConstantOp>( 689 elementType, builder.getIntegerAttr(elementType, 0)); 690 691 Value inputEqZero = 692 builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, arg, zeroValue); 693 694 // if input == 0, return bit width, else enter loop. 695 scf::IfOp ifOp = builder.create<scf::IfOp>( 696 elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); 697 ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue); 698 699 auto elseBuilder = 700 ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); 701 702 Value oneIndex = elseBuilder.create<arith::ConstantOp>( 703 indexType, elseBuilder.getIndexAttr(1)); 704 Value oneValue = elseBuilder.create<arith::ConstantOp>( 705 elementType, elseBuilder.getIntegerAttr(elementType, 1)); 706 Value bitWidthIndex = elseBuilder.create<arith::ConstantOp>( 707 indexType, elseBuilder.getIndexAttr(bitWidth)); 708 Value nValue = elseBuilder.create<arith::ConstantOp>( 709 elementType, elseBuilder.getIntegerAttr(elementType, 0)); 710 711 auto loop = elseBuilder.create<scf::ForOp>( 712 oneIndex, bitWidthIndex, oneIndex, 713 // Initial values for two loop induction variables, the arg which is being 714 // shifted left in each iteration, and the n value which tracks the count 715 // of leading zeros. 716 ValueRange{arg, nValue}, 717 // Callback to build the body of the for loop 718 // if (arg < 0) { 719 // continue; 720 // } else { 721 // n++; 722 // arg <<= 1; 723 // } 724 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { 725 Value argIter = args[0]; 726 Value nIter = args[1]; 727 728 Value argIsNonNegative = b.create<arith::CmpIOp>( 729 loc, arith::CmpIPredicate::slt, argIter, zeroValue); 730 scf::IfOp ifOp = b.create<scf::IfOp>( 731 loc, argIsNonNegative, 732 [&](OpBuilder &b, Location loc) { 733 // If arg is negative, continue (effectively, break) 734 b.create<scf::YieldOp>(loc, ValueRange{argIter, nIter}); 735 }, 736 [&](OpBuilder &b, Location loc) { 737 // Otherwise, increment n and shift arg left. 738 Value nNext = b.create<arith::AddIOp>(loc, nIter, oneValue); 739 Value argNext = b.create<arith::ShLIOp>(loc, argIter, oneValue); 740 b.create<scf::YieldOp>(loc, ValueRange{argNext, nNext}); 741 }); 742 b.create<scf::YieldOp>(loc, ifOp.getResults()); 743 }); 744 elseBuilder.create<scf::YieldOp>(loop.getResult(1)); 745 746 builder.create<func::ReturnOp>(ifOp.getResult(0)); 747 return funcOp; 748 } 749 750 /// Convert ctlz into a call to a local function implementing the ctlz 751 /// operation. 752 LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op, 753 PatternRewriter &rewriter) const { 754 if (dyn_cast<VectorType>(op.getType())) 755 return rewriter.notifyMatchFailure(op, "non-scalar operation"); 756 757 Type type = getElementTypeOrSelf(op.getResult().getType()); 758 func::FuncOp elementFunc = getFuncOpCallback(op, type); 759 if (!elementFunc) 760 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { 761 diag << "Missing software implementation for op " << op->getName() 762 << " and type " << type; 763 }); 764 765 rewriter.replaceOpWithNewOp<func::CallOp>(op, elementFunc, op.getOperand()); 766 return success(); 767 } 768 769 namespace { 770 struct ConvertMathToFuncsPass 771 : public impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass> { 772 ConvertMathToFuncsPass() = default; 773 ConvertMathToFuncsPass(const ConvertMathToFuncsOptions &options) 774 : impl::ConvertMathToFuncsBase<ConvertMathToFuncsPass>(options) {} 775 776 void runOnOperation() override; 777 778 private: 779 // Return true, if this FPowI operation must be converted 780 // because the width of its exponent's type is greater than 781 // or equal to minWidthOfFPowIExponent option value. 782 bool isFPowIConvertible(math::FPowIOp op); 783 784 // Reture true, if operation is integer type. 785 bool isConvertible(Operation *op); 786 787 // Generate outlined implementations for power operations 788 // and store them in funcImpls map. 789 void generateOpImplementations(); 790 791 // A map between pairs of (operation, type) deduced from operations that this 792 // pass will convert, and the corresponding outlined software implementations 793 // of these operations for the given type. 794 DenseMap<std::pair<OperationName, Type>, func::FuncOp> funcImpls; 795 }; 796 } // namespace 797 798 bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) { 799 auto expTy = 800 dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType())); 801 return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent); 802 } 803 804 bool ConvertMathToFuncsPass::isConvertible(Operation *op) { 805 return isa<IntegerType>(getElementTypeOrSelf(op->getResult(0).getType())); 806 } 807 808 void ConvertMathToFuncsPass::generateOpImplementations() { 809 ModuleOp module = getOperation(); 810 811 module.walk([&](Operation *op) { 812 TypeSwitch<Operation *>(op) 813 .Case<math::CountLeadingZerosOp>([&](math::CountLeadingZerosOp op) { 814 if (!convertCtlz || !isConvertible(op)) 815 return; 816 Type resultType = getElementTypeOrSelf(op.getResult().getType()); 817 818 // Generate the software implementation of this operation, 819 // if it has not been generated yet. 820 auto key = std::pair(op->getName(), resultType); 821 auto entry = funcImpls.try_emplace(key, func::FuncOp{}); 822 if (entry.second) 823 entry.first->second = createCtlzFunc(&module, resultType); 824 }) 825 .Case<math::IPowIOp>([&](math::IPowIOp op) { 826 if (!isConvertible(op)) 827 return; 828 829 Type resultType = getElementTypeOrSelf(op.getResult().getType()); 830 831 // Generate the software implementation of this operation, 832 // if it has not been generated yet. 833 auto key = std::pair(op->getName(), resultType); 834 auto entry = funcImpls.try_emplace(key, func::FuncOp{}); 835 if (entry.second) 836 entry.first->second = createElementIPowIFunc(&module, resultType); 837 }) 838 .Case<math::FPowIOp>([&](math::FPowIOp op) { 839 if (!isFPowIConvertible(op)) 840 return; 841 842 FunctionType funcType = getElementalFuncTypeForOp(op); 843 844 // Generate the software implementation of this operation, 845 // if it has not been generated yet. 846 // FPowI implementations are mapped via the FunctionType 847 // created from the operation's result and operands. 848 auto key = std::pair(op->getName(), funcType); 849 auto entry = funcImpls.try_emplace(key, func::FuncOp{}); 850 if (entry.second) 851 entry.first->second = createElementFPowIFunc(&module, funcType); 852 }); 853 }); 854 } 855 856 void ConvertMathToFuncsPass::runOnOperation() { 857 ModuleOp module = getOperation(); 858 859 // Create outlined implementations for power operations. 860 generateOpImplementations(); 861 862 RewritePatternSet patterns(&getContext()); 863 patterns.add<VecOpToScalarOp<math::IPowIOp>, VecOpToScalarOp<math::FPowIOp>, 864 VecOpToScalarOp<math::CountLeadingZerosOp>>( 865 patterns.getContext()); 866 867 // For the given Type Returns FuncOp stored in funcImpls map. 868 auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp { 869 auto it = funcImpls.find(std::pair(op->getName(), type)); 870 if (it == funcImpls.end()) 871 return {}; 872 873 return it->second; 874 }; 875 patterns.add<IPowIOpLowering, FPowIOpLowering>(patterns.getContext(), 876 getFuncOpByType); 877 878 if (convertCtlz) 879 patterns.add<CtlzOpLowering>(patterns.getContext(), getFuncOpByType); 880 881 ConversionTarget target(getContext()); 882 target.addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect, 883 func::FuncDialect, scf::SCFDialect, 884 vector::VectorDialect>(); 885 886 target.addDynamicallyLegalOp<math::IPowIOp>( 887 [this](math::IPowIOp op) { return !isConvertible(op); }); 888 if (convertCtlz) { 889 target.addDynamicallyLegalOp<math::CountLeadingZerosOp>( 890 [this](math::CountLeadingZerosOp op) { return !isConvertible(op); }); 891 } 892 target.addDynamicallyLegalOp<math::FPowIOp>( 893 [this](math::FPowIOp op) { return !isFPowIConvertible(op); }); 894 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 895 signalPassFailure(); 896 } 897