1119545f4SAlex Zinenko //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// 2119545f4SAlex Zinenko // 3119545f4SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4119545f4SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5119545f4SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6119545f4SAlex Zinenko // 7119545f4SAlex Zinenko //===----------------------------------------------------------------------===// 8119545f4SAlex Zinenko // 9119545f4SAlex Zinenko // This file implements a pass to convert scf.parallel operations into OpenMP 10119545f4SAlex Zinenko // parallel loops. 11119545f4SAlex Zinenko // 12119545f4SAlex Zinenko //===----------------------------------------------------------------------===// 13119545f4SAlex Zinenko 14119545f4SAlex Zinenko #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" 1567d0d7acSMichele Scuttari 16755dc07dSRiver Riddle #include "mlir/Analysis/SliceAnalysis.h" 17755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 191ce752b7SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 2078fb4f9dSWilliam S. Moses #include "mlir/Dialect/MemRef/IR/MemRef.h" 21119545f4SAlex Zinenko #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 228b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 231ce752b7SAlex Zinenko #include "mlir/IR/ImplicitLocOpBuilder.h" 241ce752b7SAlex Zinenko #include "mlir/IR/SymbolTable.h" 2567d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 26119545f4SAlex Zinenko #include "mlir/Transforms/DialectConversion.h" 27119545f4SAlex Zinenko 2867d0d7acSMichele Scuttari namespace mlir { 297c4e45ecSMarkus Böck #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS 3067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 3167d0d7acSMichele Scuttari } // namespace mlir 3267d0d7acSMichele Scuttari 33119545f4SAlex Zinenko using namespace mlir; 34119545f4SAlex Zinenko 351ce752b7SAlex Zinenko /// Matches a block containing a "simple" reduction. The expected shape of the 361ce752b7SAlex Zinenko /// block is as follows. 371ce752b7SAlex Zinenko /// 381ce752b7SAlex Zinenko /// ^bb(%arg0, %arg1): 391ce752b7SAlex Zinenko /// %0 = OpTy(%arg0, %arg1) 401ce752b7SAlex Zinenko /// scf.reduce.return %0 411ce752b7SAlex Zinenko template <typename... OpTy> 421ce752b7SAlex Zinenko static bool matchSimpleReduction(Block &block) { 431ce752b7SAlex Zinenko if (block.empty() || llvm::hasSingleElement(block) || 441ce752b7SAlex Zinenko std::next(block.begin(), 2) != block.end()) 451ce752b7SAlex Zinenko return false; 462a876a71SDiego Caballero 472a876a71SDiego Caballero if (block.getNumArguments() != 2) 482a876a71SDiego Caballero return false; 492a876a71SDiego Caballero 502a876a71SDiego Caballero SmallVector<Operation *, 4> combinerOps; 512a876a71SDiego Caballero Value reducedVal = matchReduction({block.getArguments()[1]}, 522a876a71SDiego Caballero /*redPos=*/0, combinerOps); 532a876a71SDiego Caballero 545550c821STres Popp if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1) 552a876a71SDiego Caballero return false; 562a876a71SDiego Caballero 572a876a71SDiego Caballero return isa<OpTy...>(combinerOps[0]) && 581ce752b7SAlex Zinenko isa<scf::ReduceReturnOp>(block.back()) && 592a876a71SDiego Caballero block.front().getOperands() == block.getArguments(); 601ce752b7SAlex Zinenko } 611ce752b7SAlex Zinenko 621ce752b7SAlex Zinenko /// Matches a block containing a select-based min/max reduction. The types of 631ce752b7SAlex Zinenko /// select and compare operations are provided as template arguments. The 641ce752b7SAlex Zinenko /// comparison predicates suitable for min and max are provided as function 651ce752b7SAlex Zinenko /// arguments. If a reduction is matched, `ifMin` will be set if the reduction 661ce752b7SAlex Zinenko /// compute the minimum and unset if it computes the maximum, otherwise it 671ce752b7SAlex Zinenko /// remains unmodified. The expected shape of the block is as follows. 681ce752b7SAlex Zinenko /// 691ce752b7SAlex Zinenko /// ^bb(%arg0, %arg1): 701ce752b7SAlex Zinenko /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1) 711ce752b7SAlex Zinenko /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here. 721ce752b7SAlex Zinenko /// scf.reduce.return %1 731ce752b7SAlex Zinenko template < 741ce752b7SAlex Zinenko typename CompareOpTy, typename SelectOpTy, 7562fea88bSJacques Pienaar typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())> 761ce752b7SAlex Zinenko static bool 771ce752b7SAlex Zinenko matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates, 781ce752b7SAlex Zinenko ArrayRef<Predicate> greaterThanPredicates, bool &isMin) { 79dec8af70SRiver Riddle static_assert( 80dec8af70SRiver Riddle llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value, 81dec8af70SRiver Riddle "only arithmetic and llvm select ops are supported"); 821ce752b7SAlex Zinenko 831ce752b7SAlex Zinenko // Expect exactly three operations in the block. 841ce752b7SAlex Zinenko if (block.empty() || llvm::hasSingleElement(block) || 851ce752b7SAlex Zinenko std::next(block.begin(), 2) == block.end() || 861ce752b7SAlex Zinenko std::next(block.begin(), 3) != block.end()) 871ce752b7SAlex Zinenko return false; 881ce752b7SAlex Zinenko 891ce752b7SAlex Zinenko // Check op kinds. 901ce752b7SAlex Zinenko auto compare = dyn_cast<CompareOpTy>(block.front()); 911ce752b7SAlex Zinenko auto select = dyn_cast<SelectOpTy>(block.front().getNextNode()); 921ce752b7SAlex Zinenko auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back()); 931ce752b7SAlex Zinenko if (!compare || !select || !terminator) 941ce752b7SAlex Zinenko return false; 951ce752b7SAlex Zinenko 961ce752b7SAlex Zinenko // Block arguments must be compared. 971ce752b7SAlex Zinenko if (compare->getOperands() != block.getArguments()) 981ce752b7SAlex Zinenko return false; 991ce752b7SAlex Zinenko 1001ce752b7SAlex Zinenko // Detect whether the comparison is less-than or greater-than, otherwise bail. 1011ce752b7SAlex Zinenko bool isLess; 102ad44495aSjacquesguan if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) { 1031ce752b7SAlex Zinenko isLess = true; 104ad44495aSjacquesguan } else if (llvm::is_contained(greaterThanPredicates, 105ad44495aSjacquesguan compare.getPredicate())) { 1061ce752b7SAlex Zinenko isLess = false; 1071ce752b7SAlex Zinenko } else { 1081ce752b7SAlex Zinenko return false; 1091ce752b7SAlex Zinenko } 1101ce752b7SAlex Zinenko 111cfb72fd3SJacques Pienaar if (select.getCondition() != compare.getResult()) 1121ce752b7SAlex Zinenko return false; 1131ce752b7SAlex Zinenko 1141ce752b7SAlex Zinenko // Detect if the operands are swapped between cmpf and select. Match the 1151ce752b7SAlex Zinenko // comparison type with the requested type or with the opposite of the 1161ce752b7SAlex Zinenko // requested type if the operands are swapped. Use generic accessors because 1171ce752b7SAlex Zinenko // std and LLVM versions of select have different operand names but identical 1181ce752b7SAlex Zinenko // positions. 1191ce752b7SAlex Zinenko constexpr unsigned kTrueValue = 1; 1201ce752b7SAlex Zinenko constexpr unsigned kFalseValue = 2; 121cfb72fd3SJacques Pienaar bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() && 122cfb72fd3SJacques Pienaar select.getOperand(kFalseValue) == compare.getRhs(); 123cfb72fd3SJacques Pienaar bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() && 124cfb72fd3SJacques Pienaar select.getOperand(kFalseValue) == compare.getLhs(); 1251ce752b7SAlex Zinenko if (!sameOperands && !swappedOperands) 1261ce752b7SAlex Zinenko return false; 1271ce752b7SAlex Zinenko 128c0342a2dSJacques Pienaar if (select.getResult() != terminator.getResult()) 1291ce752b7SAlex Zinenko return false; 1301ce752b7SAlex Zinenko 1311ce752b7SAlex Zinenko // The reduction is a min if it uses less-than predicates with same operands 1321ce752b7SAlex Zinenko // or greather-than predicates with swapped operands. Similarly for max. 1331ce752b7SAlex Zinenko isMin = (isLess && sameOperands) || (!isLess && swappedOperands); 1341ce752b7SAlex Zinenko return isMin || (isLess & swappedOperands) || (!isLess && sameOperands); 1351ce752b7SAlex Zinenko } 1361ce752b7SAlex Zinenko 1371ce752b7SAlex Zinenko /// Returns the float semantics for the given float type. 1381ce752b7SAlex Zinenko static const llvm::fltSemantics &fltSemanticsForType(FloatType type) { 1391ce752b7SAlex Zinenko if (type.isF16()) 1401ce752b7SAlex Zinenko return llvm::APFloat::IEEEhalf(); 1411ce752b7SAlex Zinenko if (type.isF32()) 1421ce752b7SAlex Zinenko return llvm::APFloat::IEEEsingle(); 1431ce752b7SAlex Zinenko if (type.isF64()) 1441ce752b7SAlex Zinenko return llvm::APFloat::IEEEdouble(); 1451ce752b7SAlex Zinenko if (type.isF128()) 1461ce752b7SAlex Zinenko return llvm::APFloat::IEEEquad(); 1471ce752b7SAlex Zinenko if (type.isBF16()) 1481ce752b7SAlex Zinenko return llvm::APFloat::BFloat(); 1491ce752b7SAlex Zinenko if (type.isF80()) 1501ce752b7SAlex Zinenko return llvm::APFloat::x87DoubleExtended(); 1511ce752b7SAlex Zinenko llvm_unreachable("unknown float type"); 1521ce752b7SAlex Zinenko } 1531ce752b7SAlex Zinenko 1541ce752b7SAlex Zinenko /// Returns an attribute with the minimum (if `min` is set) or the maximum value 1551ce752b7SAlex Zinenko /// (otherwise) for the given float type. 1561ce752b7SAlex Zinenko static Attribute minMaxValueForFloat(Type type, bool min) { 1575550c821STres Popp auto fltType = cast<FloatType>(type); 1581ce752b7SAlex Zinenko return FloatAttr::get( 1591ce752b7SAlex Zinenko type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); 1601ce752b7SAlex Zinenko } 1611ce752b7SAlex Zinenko 1621ce752b7SAlex Zinenko /// Returns an attribute with the signed integer minimum (if `min` is set) or 1631ce752b7SAlex Zinenko /// the maximum value (otherwise) for the given integer type, regardless of its 1641ce752b7SAlex Zinenko /// signedness semantics (only the width is considered). 1651ce752b7SAlex Zinenko static Attribute minMaxValueForSignedInt(Type type, bool min) { 1665550c821STres Popp auto intType = cast<IntegerType>(type); 1671ce752b7SAlex Zinenko unsigned bitwidth = intType.getWidth(); 1681ce752b7SAlex Zinenko return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) 1691ce752b7SAlex Zinenko : llvm::APInt::getSignedMaxValue(bitwidth)); 1701ce752b7SAlex Zinenko } 1711ce752b7SAlex Zinenko 1721ce752b7SAlex Zinenko /// Returns an attribute with the unsigned integer minimum (if `min` is set) or 1731ce752b7SAlex Zinenko /// the maximum value (otherwise) for the given integer type, regardless of its 1741ce752b7SAlex Zinenko /// signedness semantics (only the width is considered). 1751ce752b7SAlex Zinenko static Attribute minMaxValueForUnsignedInt(Type type, bool min) { 1765550c821STres Popp auto intType = cast<IntegerType>(type); 1771ce752b7SAlex Zinenko unsigned bitwidth = intType.getWidth(); 1784a05edd4SKazu Hirata return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth) 179b7ffd968SKazu Hirata : llvm::APInt::getAllOnes(bitwidth)); 1801ce752b7SAlex Zinenko } 1811ce752b7SAlex Zinenko 1821ce752b7SAlex Zinenko /// Creates an OpenMP reduction declaration and inserts it into the provided 1831ce752b7SAlex Zinenko /// symbol table. The declaration has a constant initializer with the neutral 18410056c82SMatthias Springer /// value `initValue`, and the `reductionIndex`-th reduction combiner carried 18510056c82SMatthias Springer /// over from `reduce`. 186d84252e0SSergio Afonso static omp::DeclareReductionOp 18710056c82SMatthias Springer createDecl(PatternRewriter &builder, SymbolTable &symbolTable, 18810056c82SMatthias Springer scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) { 1891ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 19010056c82SMatthias Springer Type type = reduce.getOperands()[reductionIndex].getType(); 191d84252e0SSergio Afonso auto decl = builder.create<omp::DeclareReductionOp>(reduce.getLoc(), 19210056c82SMatthias Springer "__scf_reduction", type); 1931ce752b7SAlex Zinenko symbolTable.insert(decl); 1941ce752b7SAlex Zinenko 1954fb4e12bSRiver Riddle builder.createBlock(&decl.getInitializerRegion(), 1964fb4e12bSRiver Riddle decl.getInitializerRegion().end(), {type}, 19710056c82SMatthias Springer {reduce.getOperands()[reductionIndex].getLoc()}); 1984fb4e12bSRiver Riddle builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); 1991ce752b7SAlex Zinenko Value init = 2001ce752b7SAlex Zinenko builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue); 2011ce752b7SAlex Zinenko builder.create<omp::YieldOp>(reduce.getLoc(), init); 2021ce752b7SAlex Zinenko 20310056c82SMatthias Springer Operation *terminator = 20410056c82SMatthias Springer &reduce.getReductions()[reductionIndex].front().back(); 2051ce752b7SAlex Zinenko assert(isa<scf::ReduceReturnOp>(terminator) && 2061ce752b7SAlex Zinenko "expected reduce op to be terminated by redure return"); 2071ce752b7SAlex Zinenko builder.setInsertionPoint(terminator); 2081ce752b7SAlex Zinenko builder.replaceOpWithNewOp<omp::YieldOp>(terminator, 2091ce752b7SAlex Zinenko terminator->getOperands()); 21010056c82SMatthias Springer builder.inlineRegionBefore(reduce.getReductions()[reductionIndex], 21110056c82SMatthias Springer decl.getReductionRegion(), 2124fb4e12bSRiver Riddle decl.getReductionRegion().end()); 2131ce752b7SAlex Zinenko return decl; 2141ce752b7SAlex Zinenko } 2151ce752b7SAlex Zinenko 2161ce752b7SAlex Zinenko /// Adds an atomic reduction combiner to the given OpenMP reduction declaration 2171ce752b7SAlex Zinenko /// using llvm.atomicrmw of the given kind. 218d84252e0SSergio Afonso static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, 2191ce752b7SAlex Zinenko LLVM::AtomicBinOp atomicKind, 220d84252e0SSergio Afonso omp::DeclareReductionOp decl, 22110056c82SMatthias Springer scf::ReduceOp reduce, 22210056c82SMatthias Springer int64_t reductionIndex) { 2231ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 224e2564b27SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); 22510056c82SMatthias Springer Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc(); 2264fb4e12bSRiver Riddle builder.createBlock(&decl.getAtomicReductionRegion(), 2274fb4e12bSRiver Riddle decl.getAtomicReductionRegion().end(), {ptrType, ptrType}, 228e084679fSRiver Riddle {reduceOperandLoc, reduceOperandLoc}); 2294fb4e12bSRiver Riddle Block *atomicBlock = &decl.getAtomicReductionRegion().back(); 2301ce752b7SAlex Zinenko builder.setInsertionPointToEnd(atomicBlock); 2317c4e45ecSMarkus Böck Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(), 2321ce752b7SAlex Zinenko atomicBlock->getArgument(1)); 2337f97895fSTobias Gysi builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind, 2341ce752b7SAlex Zinenko atomicBlock->getArgument(0), loaded, 2351ce752b7SAlex Zinenko LLVM::AtomicOrdering::monotonic); 2361ce752b7SAlex Zinenko builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>()); 2371ce752b7SAlex Zinenko return decl; 2381ce752b7SAlex Zinenko } 2391ce752b7SAlex Zinenko 2401ce752b7SAlex Zinenko /// Creates an OpenMP reduction declaration that corresponds to the given SCF 2411ce752b7SAlex Zinenko /// reduction and returns it. Recognizes common reductions in order to identify 2421ce752b7SAlex Zinenko /// the neutral value, necessary for the OpenMP declaration. If the reduction 2431ce752b7SAlex Zinenko /// cannot be recognized, returns null. 244d84252e0SSergio Afonso static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, 24510056c82SMatthias Springer scf::ReduceOp reduce, 24610056c82SMatthias Springer int64_t reductionIndex) { 2471ce752b7SAlex Zinenko Operation *container = SymbolTable::getNearestSymbolTable(reduce); 2481ce752b7SAlex Zinenko SymbolTable symbolTable(container); 2491ce752b7SAlex Zinenko 2501ce752b7SAlex Zinenko // Insert reduction declarations in the symbol-table ancestor before the 2511ce752b7SAlex Zinenko // ancestor of the current insertion point. 2521ce752b7SAlex Zinenko Operation *insertionPoint = reduce; 2531ce752b7SAlex Zinenko while (insertionPoint->getParentOp() != container) 2541ce752b7SAlex Zinenko insertionPoint = insertionPoint->getParentOp(); 2551ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(builder); 2561ce752b7SAlex Zinenko builder.setInsertionPoint(insertionPoint); 2571ce752b7SAlex Zinenko 25810056c82SMatthias Springer assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) && 2591ce752b7SAlex Zinenko "expected reduction region to have a single element"); 2601ce752b7SAlex Zinenko 2611ce752b7SAlex Zinenko // Match simple binary reductions that can be expressed with atomicrmw. 26210056c82SMatthias Springer Type type = reduce.getOperands()[reductionIndex].getType(); 26310056c82SMatthias Springer Block &reduction = reduce.getReductions()[reductionIndex].front(); 264a54f4eaeSMogball if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) { 265d84252e0SSergio Afonso omp::DeclareReductionOp decl = 26610056c82SMatthias Springer createDecl(builder, symbolTable, reduce, reductionIndex, 2671ce752b7SAlex Zinenko builder.getFloatAttr(type, 0.0)); 26810056c82SMatthias Springer return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce, 26910056c82SMatthias Springer reductionIndex); 2701ce752b7SAlex Zinenko } 271a54f4eaeSMogball if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) { 272d84252e0SSergio Afonso omp::DeclareReductionOp decl = 27310056c82SMatthias Springer createDecl(builder, symbolTable, reduce, reductionIndex, 2741ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 27510056c82SMatthias Springer return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce, 27610056c82SMatthias Springer reductionIndex); 2771ce752b7SAlex Zinenko } 278a54f4eaeSMogball if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) { 279d84252e0SSergio Afonso omp::DeclareReductionOp decl = 28010056c82SMatthias Springer createDecl(builder, symbolTable, reduce, reductionIndex, 2811ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 28210056c82SMatthias Springer return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce, 28310056c82SMatthias Springer reductionIndex); 2841ce752b7SAlex Zinenko } 285a54f4eaeSMogball if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) { 286d84252e0SSergio Afonso omp::DeclareReductionOp decl = 28710056c82SMatthias Springer createDecl(builder, symbolTable, reduce, reductionIndex, 2881ce752b7SAlex Zinenko builder.getIntegerAttr(type, 0)); 28910056c82SMatthias Springer return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce, 29010056c82SMatthias Springer reductionIndex); 2911ce752b7SAlex Zinenko } 292a54f4eaeSMogball if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) { 293d84252e0SSergio Afonso omp::DeclareReductionOp decl = createDecl( 29410056c82SMatthias Springer builder, symbolTable, reduce, reductionIndex, 2951ce752b7SAlex Zinenko builder.getIntegerAttr( 296b7ffd968SKazu Hirata type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth()))); 29710056c82SMatthias Springer return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce, 29810056c82SMatthias Springer reductionIndex); 2991ce752b7SAlex Zinenko } 3001ce752b7SAlex Zinenko 3011ce752b7SAlex Zinenko // Match simple binary reductions that cannot be expressed with atomicrmw. 3021ce752b7SAlex Zinenko // TODO: add atomic region using cmpxchg (which needs atomic load to be 3031ce752b7SAlex Zinenko // available as an op). 304a54f4eaeSMogball if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) { 30510056c82SMatthias Springer return createDecl(builder, symbolTable, reduce, reductionIndex, 3061ce752b7SAlex Zinenko builder.getFloatAttr(type, 1.0)); 3071ce752b7SAlex Zinenko } 308c1125ae5SKiran Chandramohan if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) { 30910056c82SMatthias Springer return createDecl(builder, symbolTable, reduce, reductionIndex, 310c1125ae5SKiran Chandramohan builder.getIntegerAttr(type, 1)); 311c1125ae5SKiran Chandramohan } 3121ce752b7SAlex Zinenko 3131ce752b7SAlex Zinenko // Match select-based min/max reductions. 3141ce752b7SAlex Zinenko bool isMin; 315dec8af70SRiver Riddle if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>( 316a54f4eaeSMogball reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, 317a54f4eaeSMogball {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || 3181ce752b7SAlex Zinenko matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>( 3191ce752b7SAlex Zinenko reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, 3201ce752b7SAlex Zinenko {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { 32110056c82SMatthias Springer return createDecl(builder, symbolTable, reduce, reductionIndex, 3221ce752b7SAlex Zinenko minMaxValueForFloat(type, !isMin)); 3231ce752b7SAlex Zinenko } 324dec8af70SRiver Riddle if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>( 325a54f4eaeSMogball reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, 326a54f4eaeSMogball {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || 3271ce752b7SAlex Zinenko matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( 3281ce752b7SAlex Zinenko reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, 3291ce752b7SAlex Zinenko {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { 330d84252e0SSergio Afonso omp::DeclareReductionOp decl = 33110056c82SMatthias Springer createDecl(builder, symbolTable, reduce, reductionIndex, 33210056c82SMatthias Springer minMaxValueForSignedInt(type, !isMin)); 3331ce752b7SAlex Zinenko return addAtomicRMW(builder, 3341ce752b7SAlex Zinenko isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, 33510056c82SMatthias Springer decl, reduce, reductionIndex); 3361ce752b7SAlex Zinenko } 337dec8af70SRiver Riddle if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>( 338a54f4eaeSMogball reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, 339a54f4eaeSMogball {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || 3401ce752b7SAlex Zinenko matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( 3411ce752b7SAlex Zinenko reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, 3421ce752b7SAlex Zinenko {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { 343d84252e0SSergio Afonso omp::DeclareReductionOp decl = 34410056c82SMatthias Springer createDecl(builder, symbolTable, reduce, reductionIndex, 34510056c82SMatthias Springer minMaxValueForUnsignedInt(type, !isMin)); 3461ce752b7SAlex Zinenko return addAtomicRMW( 3471ce752b7SAlex Zinenko builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, 34810056c82SMatthias Springer decl, reduce, reductionIndex); 3491ce752b7SAlex Zinenko } 3501ce752b7SAlex Zinenko 3511ce752b7SAlex Zinenko return nullptr; 3521ce752b7SAlex Zinenko } 3531ce752b7SAlex Zinenko 354119545f4SAlex Zinenko namespace { 355119545f4SAlex Zinenko 356119545f4SAlex Zinenko struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { 3577f4f75c1SPablo Antonio Martinez static constexpr unsigned kUseOpenMPDefaultNumThreads = 0; 3587f4f75c1SPablo Antonio Martinez unsigned numThreads; 3597c4e45ecSMarkus Böck 3607f4f75c1SPablo Antonio Martinez ParallelOpLowering(MLIRContext *context, 3617f4f75c1SPablo Antonio Martinez unsigned numThreads = kUseOpenMPDefaultNumThreads) 3627f4f75c1SPablo Antonio Martinez : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {} 363119545f4SAlex Zinenko 364119545f4SAlex Zinenko LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, 365119545f4SAlex Zinenko PatternRewriter &rewriter) const override { 3661ce752b7SAlex Zinenko // Declare reductions. 3671ce752b7SAlex Zinenko // TODO: consider checking it here is already a compatible reduction 3681ce752b7SAlex Zinenko // declaration and use it instead of redeclaring. 369fdfeea5bSSergio Afonso SmallVector<Attribute> reductionSyms; 370d84252e0SSergio Afonso SmallVector<omp::DeclareReductionOp> ompReductionDecls; 37110056c82SMatthias Springer auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator()); 37210056c82SMatthias Springer for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) { 373d84252e0SSergio Afonso omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i); 374be9f8ffdSDavid Truby ompReductionDecls.push_back(decl); 3751ce752b7SAlex Zinenko if (!decl) 3761ce752b7SAlex Zinenko return failure(); 377fdfeea5bSSergio Afonso reductionSyms.push_back( 3784fb4e12bSRiver Riddle SymbolRefAttr::get(rewriter.getContext(), decl.getSymName())); 3791ce752b7SAlex Zinenko } 3801ce752b7SAlex Zinenko 3811ce752b7SAlex Zinenko // Allocate reduction variables. Make sure the we don't overflow the stack 3821ce752b7SAlex Zinenko // with local `alloca`s by saving and restoring the stack pointer. 3831ce752b7SAlex Zinenko Location loc = parallelOp.getLoc(); 3841ce752b7SAlex Zinenko Value one = rewriter.create<LLVM::ConstantOp>( 3851ce752b7SAlex Zinenko loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); 3861ce752b7SAlex Zinenko SmallVector<Value> reductionVariables; 3871ce752b7SAlex Zinenko reductionVariables.reserve(parallelOp.getNumReductions()); 388e2564b27SChristian Ulmann auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext()); 389c0342a2dSJacques Pienaar for (Value init : parallelOp.getInitVals()) { 3901ce752b7SAlex Zinenko assert((LLVM::isCompatibleType(init.getType()) || 3915550c821STres Popp isa<LLVM::PointerElementTypeInterface>(init.getType())) && 3921ce752b7SAlex Zinenko "cannot create a reduction variable if the type is not an LLVM " 3931ce752b7SAlex Zinenko "pointer element"); 394e2564b27SChristian Ulmann Value storage = 395e2564b27SChristian Ulmann rewriter.create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0); 3961ce752b7SAlex Zinenko rewriter.create<LLVM::StoreOp>(loc, init, storage); 3971ce752b7SAlex Zinenko reductionVariables.push_back(storage); 3981ce752b7SAlex Zinenko } 3991ce752b7SAlex Zinenko 4001ce752b7SAlex Zinenko // Replace the reduction operations contained in this loop. Must be done 4011ce752b7SAlex Zinenko // here rather than in a separate pattern to have access to the list of 4021ce752b7SAlex Zinenko // reduction variables. 403be9f8ffdSDavid Truby for (auto [x, y, rD] : llvm::zip_equal( 404be9f8ffdSDavid Truby reductionVariables, reduce.getOperands(), ompReductionDecls)) { 4051ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 40610056c82SMatthias Springer rewriter.setInsertionPoint(reduce); 407be9f8ffdSDavid Truby Region &redRegion = rD.getReductionRegion(); 408be9f8ffdSDavid Truby // The SCF dialect by definition contains only structured operations 409be9f8ffdSDavid Truby // and hence the SCF reduction region will contain a single block. 410be9f8ffdSDavid Truby // The ompReductionDecls region is a copy of the SCF reduction region 411be9f8ffdSDavid Truby // and hence has the same property. 412be9f8ffdSDavid Truby assert(redRegion.hasOneBlock() && 413be9f8ffdSDavid Truby "expect reduction region to have one block"); 414be9f8ffdSDavid Truby Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc); 415be9f8ffdSDavid Truby Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(), 416be9f8ffdSDavid Truby rD.getType(), pvtRedVar); 417be9f8ffdSDavid Truby // Make a copy of the reduction combiner region in the body 418be9f8ffdSDavid Truby mlir::OpBuilder builder(rewriter.getContext()); 419be9f8ffdSDavid Truby builder.setInsertionPoint(reduce); 420be9f8ffdSDavid Truby mlir::IRMapping mapper; 421be9f8ffdSDavid Truby assert(redRegion.getNumArguments() == 2 && 422be9f8ffdSDavid Truby "expect reduction region to have two arguments"); 423be9f8ffdSDavid Truby mapper.map(redRegion.getArgument(0), pvtRedVal); 424be9f8ffdSDavid Truby mapper.map(redRegion.getArgument(1), y); 425be9f8ffdSDavid Truby for (auto &op : redRegion.getOps()) { 426be9f8ffdSDavid Truby Operation *cloneOp = builder.clone(op, mapper); 427be9f8ffdSDavid Truby if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) { 428be9f8ffdSDavid Truby assert(yieldOp && yieldOp.getResults().size() == 1 && 429be9f8ffdSDavid Truby "expect YieldOp in reduction region to return one result"); 430be9f8ffdSDavid Truby Value redVal = yieldOp.getResults()[0]; 431be9f8ffdSDavid Truby rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar); 432be9f8ffdSDavid Truby rewriter.eraseOp(yieldOp); 433be9f8ffdSDavid Truby break; 434be9f8ffdSDavid Truby } 435be9f8ffdSDavid Truby } 4361ce752b7SAlex Zinenko } 43710056c82SMatthias Springer rewriter.eraseOp(reduce); 4381ce752b7SAlex Zinenko 4397f4f75c1SPablo Antonio Martinez Value numThreadsVar; 4407f4f75c1SPablo Antonio Martinez if (numThreads > 0) { 4417f4f75c1SPablo Antonio Martinez numThreadsVar = rewriter.create<LLVM::ConstantOp>( 4427f4f75c1SPablo Antonio Martinez loc, rewriter.getI32IntegerAttr(numThreads)); 4437f4f75c1SPablo Antonio Martinez } 4441ce752b7SAlex Zinenko // Create the parallel wrapper. 4457f4f75c1SPablo Antonio Martinez auto ompParallel = rewriter.create<omp::ParallelOp>( 4467f4f75c1SPablo Antonio Martinez loc, 4477f4f75c1SPablo Antonio Martinez /* allocate_vars = */ llvm::SmallVector<Value>{}, 448fdfeea5bSSergio Afonso /* allocator_vars = */ llvm::SmallVector<Value>{}, 449b3b46963SSergio Afonso /* if_expr = */ Value{}, 450b3b46963SSergio Afonso /* num_threads = */ numThreadsVar, 451b3b46963SSergio Afonso /* private_vars = */ ValueRange(), 452b3b46963SSergio Afonso /* private_syms = */ nullptr, 453b3b46963SSergio Afonso /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{}, 454*afcbcae6SAnchu Rajendran S /* reduction_mod = */ nullptr, 4557f4f75c1SPablo Antonio Martinez /* reduction_vars = */ llvm::SmallVector<Value>{}, 456fdfeea5bSSergio Afonso /* reduction_byref = */ DenseBoolArrayAttr{}, 457b3b46963SSergio Afonso /* reduction_syms = */ ArrayAttr{}); 4581ce752b7SAlex Zinenko { 45978fb4f9dSWilliam S. Moses 4601ce752b7SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter); 4614fb4e12bSRiver Riddle rewriter.createBlock(&ompParallel.getRegion()); 4621ce752b7SAlex Zinenko 463119545f4SAlex Zinenko // Replace the loop. 464bf6477ebSWilliam S. Moses { 465bf6477ebSWilliam S. Moses OpBuilder::InsertionGuard allocaGuard(rewriter); 4668843d541SSergio Afonso // Create worksharing loop wrapper. 4678843d541SSergio Afonso auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc()); 4688843d541SSergio Afonso if (!reductionVariables.empty()) { 469fdfeea5bSSergio Afonso wsloopOp.setReductionSymsAttr( 470fdfeea5bSSergio Afonso ArrayAttr::get(rewriter.getContext(), reductionSyms)); 4718843d541SSergio Afonso wsloopOp.getReductionVarsMutable().append(reductionVariables); 472fdfeea5bSSergio Afonso llvm::SmallVector<bool> reductionByRef; 47374a87548STom Eccles // false because these reductions always reduce scalars and so do 47474a87548STom Eccles // not need to pass by reference 475fdfeea5bSSergio Afonso reductionByRef.resize(reductionVariables.size(), false); 476fdfeea5bSSergio Afonso wsloopOp.setReductionByref( 477fdfeea5bSSergio Afonso DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef)); 4788843d541SSergio Afonso } 4798843d541SSergio Afonso rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator. 4808843d541SSergio Afonso 4818843d541SSergio Afonso // The wrapper's entry block arguments will define the reduction 4828843d541SSergio Afonso // variables. 4838843d541SSergio Afonso llvm::SmallVector<mlir::Type> reductionTypes; 4848843d541SSergio Afonso reductionTypes.reserve(reductionVariables.size()); 4858843d541SSergio Afonso llvm::transform(reductionVariables, std::back_inserter(reductionTypes), 4868843d541SSergio Afonso [](mlir::Value v) { return v.getType(); }); 4878843d541SSergio Afonso rewriter.createBlock( 4888843d541SSergio Afonso &wsloopOp.getRegion(), {}, reductionTypes, 4898843d541SSergio Afonso llvm::SmallVector<mlir::Location>(reductionVariables.size(), 4908843d541SSergio Afonso parallelOp.getLoc())); 4918843d541SSergio Afonso 4928843d541SSergio Afonso // Create loop nest and populate region with contents of scf.parallel. 4938843d541SSergio Afonso auto loopOp = rewriter.create<omp::LoopNestOp>( 494c0342a2dSJacques Pienaar parallelOp.getLoc(), parallelOp.getLowerBound(), 495c0342a2dSJacques Pienaar parallelOp.getUpperBound(), parallelOp.getStep()); 4961ce752b7SAlex Zinenko 4978843d541SSergio Afonso rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(), 4988843d541SSergio Afonso loopOp.getRegion().begin()); 499bf6477ebSWilliam S. Moses 5008843d541SSergio Afonso // Remove reduction-related block arguments from omp.loop_nest and 5018843d541SSergio Afonso // redirect uses to the corresponding omp.wsloop block argument. 5028843d541SSergio Afonso mlir::Block &loopOpEntryBlock = loopOp.getRegion().front(); 5038843d541SSergio Afonso unsigned numLoops = parallelOp.getNumLoops(); 5048843d541SSergio Afonso rewriter.replaceAllUsesWith( 5058843d541SSergio Afonso loopOpEntryBlock.getArguments().drop_front(numLoops), 5068843d541SSergio Afonso wsloopOp.getRegion().getArguments()); 5078843d541SSergio Afonso loopOpEntryBlock.eraseArguments( 5088843d541SSergio Afonso numLoops, loopOpEntryBlock.getNumArguments() - numLoops); 509bf6477ebSWilliam S. Moses 5108843d541SSergio Afonso Block *ops = 5118843d541SSergio Afonso rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin()); 5128843d541SSergio Afonso rewriter.setInsertionPointToStart(&loopOpEntryBlock); 513bf6477ebSWilliam S. Moses 514bf6477ebSWilliam S. Moses auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(), 515bf6477ebSWilliam S. Moses TypeRange()); 516bf6477ebSWilliam S. Moses rewriter.create<omp::YieldOp>(loc, ValueRange()); 517bf6477ebSWilliam S. Moses Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion()); 518bf6477ebSWilliam S. Moses rewriter.mergeBlocks(ops, scopeBlock); 519bf6477ebSWilliam S. Moses rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); 52010056c82SMatthias Springer rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange()); 5211ce752b7SAlex Zinenko } 52278fb4f9dSWilliam S. Moses } 523973cb2c3SWilliam S. Moses 5241ce752b7SAlex Zinenko // Load loop results. 5251ce752b7SAlex Zinenko SmallVector<Value> results; 5261ce752b7SAlex Zinenko results.reserve(reductionVariables.size()); 5277c4e45ecSMarkus Böck for (auto [variable, type] : 5287c4e45ecSMarkus Böck llvm::zip(reductionVariables, parallelOp.getResultTypes())) { 5297c4e45ecSMarkus Böck Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable); 5301ce752b7SAlex Zinenko results.push_back(res); 5311ce752b7SAlex Zinenko } 5321ce752b7SAlex Zinenko rewriter.replaceOp(parallelOp, results); 5331ce752b7SAlex Zinenko 534119545f4SAlex Zinenko return success(); 535119545f4SAlex Zinenko } 536119545f4SAlex Zinenko }; 537119545f4SAlex Zinenko 538119545f4SAlex Zinenko /// Applies the conversion patterns in the given function. 5397f4f75c1SPablo Antonio Martinez static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) { 5401ce752b7SAlex Zinenko ConversionTarget target(*module.getContext()); 5411ce752b7SAlex Zinenko target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); 54278fb4f9dSWilliam S. Moses target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect, 54378fb4f9dSWilliam S. Moses memref::MemRefDialect>(); 544119545f4SAlex Zinenko 5451ce752b7SAlex Zinenko RewritePatternSet patterns(module.getContext()); 5467f4f75c1SPablo Antonio Martinez patterns.add<ParallelOpLowering>(module.getContext(), numThreads); 54779d7f618SChris Lattner FrozenRewritePatternSet frozen(std::move(patterns)); 5481ce752b7SAlex Zinenko return applyPartialConversion(module, target, frozen); 549119545f4SAlex Zinenko } 550119545f4SAlex Zinenko 551119545f4SAlex Zinenko /// A pass converting SCF operations to OpenMP operations. 5527c4e45ecSMarkus Böck struct SCFToOpenMPPass 5537c4e45ecSMarkus Böck : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> { 5547c4e45ecSMarkus Böck 5557c4e45ecSMarkus Böck using Base::Base; 5567c4e45ecSMarkus Böck 557119545f4SAlex Zinenko /// Pass entry point. 5581ce752b7SAlex Zinenko void runOnOperation() override { 5597f4f75c1SPablo Antonio Martinez if (failed(applyPatterns(getOperation(), numThreads))) 560119545f4SAlex Zinenko signalPassFailure(); 561119545f4SAlex Zinenko } 562119545f4SAlex Zinenko }; 563119545f4SAlex Zinenko 564be0a7e9fSMehdi Amini } // namespace 565