1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert scf.parallel operations into OpenMP 10 // parallel loops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" 15 #include "../PassDetail.h" 16 #include "mlir/Analysis/LoopAnalysis.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/IR/ImplicitLocOpBuilder.h" 23 #include "mlir/IR/SymbolTable.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 26 using namespace mlir; 27 28 /// Matches a block containing a "simple" reduction. The expected shape of the 29 /// block is as follows. 30 /// 31 /// ^bb(%arg0, %arg1): 32 /// %0 = OpTy(%arg0, %arg1) 33 /// scf.reduce.return %0 34 template <typename... OpTy> 35 static bool matchSimpleReduction(Block &block) { 36 if (block.empty() || llvm::hasSingleElement(block) || 37 std::next(block.begin(), 2) != block.end()) 38 return false; 39 40 if (block.getNumArguments() != 2) 41 return false; 42 43 SmallVector<Operation *, 4> combinerOps; 44 Value reducedVal = matchReduction({block.getArguments()[1]}, 45 /*redPos=*/0, combinerOps); 46 47 if (!reducedVal || !reducedVal.isa<BlockArgument>() || 48 combinerOps.size() != 1) 49 return false; 50 51 return isa<OpTy...>(combinerOps[0]) && 52 isa<scf::ReduceReturnOp>(block.back()) && 53 block.front().getOperands() == block.getArguments(); 54 } 55 56 /// Matches a block containing a select-based min/max reduction. The types of 57 /// select and compare operations are provided as template arguments. The 58 /// comparison predicates suitable for min and max are provided as function 59 /// arguments. If a reduction is matched, `ifMin` will be set if the reduction 60 /// compute the minimum and unset if it computes the maximum, otherwise it 61 /// remains unmodified. The expected shape of the block is as follows. 62 /// 63 /// ^bb(%arg0, %arg1): 64 /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1) 65 /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here. 66 /// scf.reduce.return %1 67 template < 68 typename CompareOpTy, typename SelectOpTy, 69 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())> 70 static bool 71 matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates, 72 ArrayRef<Predicate> greaterThanPredicates, bool &isMin) { 73 static_assert(llvm::is_one_of<SelectOpTy, SelectOp, LLVM::SelectOp>::value, 74 "only std and llvm select ops are supported"); 75 76 // Expect exactly three operations in the block. 77 if (block.empty() || llvm::hasSingleElement(block) || 78 std::next(block.begin(), 2) == block.end() || 79 std::next(block.begin(), 3) != block.end()) 80 return false; 81 82 // Check op kinds. 83 auto compare = dyn_cast<CompareOpTy>(block.front()); 84 auto select = dyn_cast<SelectOpTy>(block.front().getNextNode()); 85 auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back()); 86 if (!compare || !select || !terminator) 87 return false; 88 89 // Block arguments must be compared. 90 if (compare->getOperands() != block.getArguments()) 91 return false; 92 93 // Detect whether the comparison is less-than or greater-than, otherwise bail. 94 bool isLess; 95 if (llvm::find(lessThanPredicates, compare.getPredicate()) != 96 lessThanPredicates.end()) { 97 isLess = true; 98 } else if (llvm::find(greaterThanPredicates, compare.getPredicate()) != 99 greaterThanPredicates.end()) { 100 isLess = false; 101 } else { 102 return false; 103 } 104 105 if (select.getCondition() != compare.getResult()) 106 return false; 107 108 // Detect if the operands are swapped between cmpf and select. Match the 109 // comparison type with the requested type or with the opposite of the 110 // requested type if the operands are swapped. Use generic accessors because 111 // std and LLVM versions of select have different operand names but identical 112 // positions. 113 constexpr unsigned kTrueValue = 1; 114 constexpr unsigned kFalseValue = 2; 115 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() && 116 select.getOperand(kFalseValue) == compare.getRhs(); 117 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() && 118 select.getOperand(kFalseValue) == compare.getLhs(); 119 if (!sameOperands && !swappedOperands) 120 return false; 121 122 if (select.getResult() != terminator.result()) 123 return false; 124 125 // The reduction is a min if it uses less-than predicates with same operands 126 // or greather-than predicates with swapped operands. Similarly for max. 127 isMin = (isLess && sameOperands) || (!isLess && swappedOperands); 128 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands); 129 } 130 131 /// Returns the float semantics for the given float type. 132 static const llvm::fltSemantics &fltSemanticsForType(FloatType type) { 133 if (type.isF16()) 134 return llvm::APFloat::IEEEhalf(); 135 if (type.isF32()) 136 return llvm::APFloat::IEEEsingle(); 137 if (type.isF64()) 138 return llvm::APFloat::IEEEdouble(); 139 if (type.isF128()) 140 return llvm::APFloat::IEEEquad(); 141 if (type.isBF16()) 142 return llvm::APFloat::BFloat(); 143 if (type.isF80()) 144 return llvm::APFloat::x87DoubleExtended(); 145 llvm_unreachable("unknown float type"); 146 } 147 148 /// Returns an attribute with the minimum (if `min` is set) or the maximum value 149 /// (otherwise) for the given float type. 150 static Attribute minMaxValueForFloat(Type type, bool min) { 151 auto fltType = type.cast<FloatType>(); 152 return FloatAttr::get( 153 type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); 154 } 155 156 /// Returns an attribute with the signed integer minimum (if `min` is set) or 157 /// the maximum value (otherwise) for the given integer type, regardless of its 158 /// signedness semantics (only the width is considered). 159 static Attribute minMaxValueForSignedInt(Type type, bool min) { 160 auto intType = type.cast<IntegerType>(); 161 unsigned bitwidth = intType.getWidth(); 162 return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) 163 : llvm::APInt::getSignedMaxValue(bitwidth)); 164 } 165 166 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or 167 /// the maximum value (otherwise) for the given integer type, regardless of its 168 /// signedness semantics (only the width is considered). 169 static Attribute minMaxValueForUnsignedInt(Type type, bool min) { 170 auto intType = type.cast<IntegerType>(); 171 unsigned bitwidth = intType.getWidth(); 172 return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth) 173 : llvm::APInt::getAllOnesValue(bitwidth)); 174 } 175 176 /// Creates an OpenMP reduction declaration and inserts it into the provided 177 /// symbol table. The declaration has a constant initializer with the neutral 178 /// value `initValue`, and the reduction combiner carried over from `reduce`. 179 static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, 180 SymbolTable &symbolTable, 181 scf::ReduceOp reduce, 182 Attribute initValue) { 183 OpBuilder::InsertionGuard guard(builder); 184 auto decl = builder.create<omp::ReductionDeclareOp>( 185 reduce.getLoc(), "__scf_reduction", reduce.operand().getType()); 186 symbolTable.insert(decl); 187 188 Type type = reduce.operand().getType(); 189 builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), 190 {type}); 191 builder.setInsertionPointToEnd(&decl.initializerRegion().back()); 192 Value init = 193 builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue); 194 builder.create<omp::YieldOp>(reduce.getLoc(), init); 195 196 Operation *terminator = &reduce.getRegion().front().back(); 197 assert(isa<scf::ReduceReturnOp>(terminator) && 198 "expected reduce op to be terminated by redure return"); 199 builder.setInsertionPoint(terminator); 200 builder.replaceOpWithNewOp<omp::YieldOp>(terminator, 201 terminator->getOperands()); 202 builder.inlineRegionBefore(reduce.getRegion(), decl.reductionRegion(), 203 decl.reductionRegion().end()); 204 return decl; 205 } 206 207 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration 208 /// using llvm.atomicrmw of the given kind. 209 static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, 210 LLVM::AtomicBinOp atomicKind, 211 omp::ReductionDeclareOp decl, 212 scf::ReduceOp reduce) { 213 OpBuilder::InsertionGuard guard(builder); 214 Type type = reduce.operand().getType(); 215 Type ptrType = LLVM::LLVMPointerType::get(type); 216 builder.createBlock(&decl.atomicReductionRegion(), 217 decl.atomicReductionRegion().end(), {ptrType, ptrType}); 218 Block *atomicBlock = &decl.atomicReductionRegion().back(); 219 builder.setInsertionPointToEnd(atomicBlock); 220 Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), 221 atomicBlock->getArgument(1)); 222 builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), type, atomicKind, 223 atomicBlock->getArgument(0), loaded, 224 LLVM::AtomicOrdering::monotonic); 225 builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>()); 226 return decl; 227 } 228 229 /// Creates an OpenMP reduction declaration that corresponds to the given SCF 230 /// reduction and returns it. Recognizes common reductions in order to identify 231 /// the neutral value, necessary for the OpenMP declaration. If the reduction 232 /// cannot be recognized, returns null. 233 static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, 234 scf::ReduceOp reduce) { 235 Operation *container = SymbolTable::getNearestSymbolTable(reduce); 236 SymbolTable symbolTable(container); 237 238 // Insert reduction declarations in the symbol-table ancestor before the 239 // ancestor of the current insertion point. 240 Operation *insertionPoint = reduce; 241 while (insertionPoint->getParentOp() != container) 242 insertionPoint = insertionPoint->getParentOp(); 243 OpBuilder::InsertionGuard guard(builder); 244 builder.setInsertionPoint(insertionPoint); 245 246 assert(llvm::hasSingleElement(reduce.getRegion()) && 247 "expected reduction region to have a single element"); 248 249 // Match simple binary reductions that can be expressed with atomicrmw. 250 Type type = reduce.operand().getType(); 251 Block &reduction = reduce.getRegion().front(); 252 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) { 253 omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 254 builder.getFloatAttr(type, 0.0)); 255 return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); 256 } 257 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) { 258 omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 259 builder.getIntegerAttr(type, 0)); 260 return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); 261 } 262 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) { 263 omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 264 builder.getIntegerAttr(type, 0)); 265 return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); 266 } 267 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) { 268 omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, 269 builder.getIntegerAttr(type, 0)); 270 return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); 271 } 272 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) { 273 omp::ReductionDeclareOp decl = createDecl( 274 builder, symbolTable, reduce, 275 builder.getIntegerAttr( 276 type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth()))); 277 return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce); 278 } 279 280 // Match simple binary reductions that cannot be expressed with atomicrmw. 281 // TODO: add atomic region using cmpxchg (which needs atomic load to be 282 // available as an op). 283 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) { 284 return createDecl(builder, symbolTable, reduce, 285 builder.getFloatAttr(type, 1.0)); 286 } 287 288 // Match select-based min/max reductions. 289 bool isMin; 290 if (matchSelectReduction<arith::CmpFOp, SelectOp>( 291 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, 292 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || 293 matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>( 294 reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, 295 {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { 296 return createDecl(builder, symbolTable, reduce, 297 minMaxValueForFloat(type, !isMin)); 298 } 299 if (matchSelectReduction<arith::CmpIOp, SelectOp>( 300 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, 301 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || 302 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( 303 reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, 304 {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { 305 omp::ReductionDeclareOp decl = createDecl( 306 builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin)); 307 return addAtomicRMW(builder, 308 isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, 309 decl, reduce); 310 } 311 if (matchSelectReduction<arith::CmpIOp, SelectOp>( 312 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, 313 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || 314 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( 315 reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, 316 {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { 317 omp::ReductionDeclareOp decl = createDecl( 318 builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin)); 319 return addAtomicRMW( 320 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, 321 decl, reduce); 322 } 323 324 return nullptr; 325 } 326 327 namespace { 328 329 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { 330 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 331 332 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, 333 PatternRewriter &rewriter) const override { 334 // Replace SCF yield with OpenMP yield. 335 { 336 OpBuilder::InsertionGuard guard(rewriter); 337 rewriter.setInsertionPointToEnd(parallelOp.getBody()); 338 assert(llvm::hasSingleElement(parallelOp.region()) && 339 "expected scf.parallel to have one block"); 340 rewriter.replaceOpWithNewOp<omp::YieldOp>( 341 parallelOp.getBody()->getTerminator(), ValueRange()); 342 } 343 344 // Declare reductions. 345 // TODO: consider checking it here is already a compatible reduction 346 // declaration and use it instead of redeclaring. 347 SmallVector<Attribute> reductionDeclSymbols; 348 for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) { 349 omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce); 350 if (!decl) 351 return failure(); 352 reductionDeclSymbols.push_back( 353 SymbolRefAttr::get(rewriter.getContext(), decl.sym_name())); 354 } 355 356 // Allocate reduction variables. Make sure the we don't overflow the stack 357 // with local `alloca`s by saving and restoring the stack pointer. 358 Location loc = parallelOp.getLoc(); 359 Value one = rewriter.create<LLVM::ConstantOp>( 360 loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); 361 SmallVector<Value> reductionVariables; 362 reductionVariables.reserve(parallelOp.getNumReductions()); 363 Value token = rewriter.create<LLVM::StackSaveOp>( 364 loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8))); 365 for (Value init : parallelOp.initVals()) { 366 assert((LLVM::isCompatibleType(init.getType()) || 367 init.getType().isa<LLVM::PointerElementTypeInterface>()) && 368 "cannot create a reduction variable if the type is not an LLVM " 369 "pointer element"); 370 Value storage = rewriter.create<LLVM::AllocaOp>( 371 loc, LLVM::LLVMPointerType::get(init.getType()), one, 0); 372 rewriter.create<LLVM::StoreOp>(loc, init, storage); 373 reductionVariables.push_back(storage); 374 } 375 376 // Replace the reduction operations contained in this loop. Must be done 377 // here rather than in a separate pattern to have access to the list of 378 // reduction variables. 379 for (auto pair : 380 llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) { 381 OpBuilder::InsertionGuard guard(rewriter); 382 scf::ReduceOp reduceOp = std::get<0>(pair); 383 rewriter.setInsertionPoint(reduceOp); 384 rewriter.replaceOpWithNewOp<omp::ReductionOp>( 385 reduceOp, reduceOp.operand(), std::get<1>(pair)); 386 } 387 388 // Create the parallel wrapper. 389 auto ompParallel = rewriter.create<omp::ParallelOp>(loc); 390 { 391 OpBuilder::InsertionGuard guard(rewriter); 392 rewriter.createBlock(&ompParallel.region()); 393 394 // Replace SCF yield with OpenMP yield. 395 { 396 OpBuilder::InsertionGuard innerGuard(rewriter); 397 rewriter.setInsertionPointToEnd(parallelOp.getBody()); 398 assert(llvm::hasSingleElement(parallelOp.region()) && 399 "expected scf.parallel to have one block"); 400 rewriter.replaceOpWithNewOp<omp::YieldOp>( 401 parallelOp.getBody()->getTerminator(), ValueRange()); 402 } 403 404 // Replace the loop. 405 auto loop = rewriter.create<omp::WsLoopOp>( 406 parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), 407 parallelOp.step()); 408 rewriter.create<omp::TerminatorOp>(loc); 409 410 rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), 411 loop.region().begin()); 412 if (!reductionVariables.empty()) { 413 loop.reductionsAttr( 414 ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); 415 loop.reduction_varsMutable().append(reductionVariables); 416 } 417 } 418 419 // Load loop results. 420 SmallVector<Value> results; 421 results.reserve(reductionVariables.size()); 422 for (Value variable : reductionVariables) { 423 Value res = rewriter.create<LLVM::LoadOp>(loc, variable); 424 results.push_back(res); 425 } 426 rewriter.replaceOp(parallelOp, results); 427 428 rewriter.create<LLVM::StackRestoreOp>(loc, token); 429 return success(); 430 } 431 }; 432 433 /// Applies the conversion patterns in the given function. 434 static LogicalResult applyPatterns(ModuleOp module) { 435 ConversionTarget target(*module.getContext()); 436 target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); 437 target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect>(); 438 439 RewritePatternSet patterns(module.getContext()); 440 patterns.add<ParallelOpLowering>(module.getContext()); 441 FrozenRewritePatternSet frozen(std::move(patterns)); 442 return applyPartialConversion(module, target, frozen); 443 } 444 445 /// A pass converting SCF operations to OpenMP operations. 446 struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> { 447 /// Pass entry point. 448 void runOnOperation() override { 449 if (failed(applyPatterns(getOperation()))) 450 signalPassFailure(); 451 } 452 }; 453 454 } // namespace 455 456 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() { 457 return std::make_unique<SCFToOpenMPPass>(); 458 } 459