xref: /llvm-project/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (revision afcbcae668f1d8061974247f2828190173aef742)
1119545f4SAlex Zinenko //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2119545f4SAlex Zinenko //
3119545f4SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4119545f4SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5119545f4SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6119545f4SAlex Zinenko //
7119545f4SAlex Zinenko //===----------------------------------------------------------------------===//
8119545f4SAlex Zinenko //
9119545f4SAlex Zinenko // This file implements a pass to convert scf.parallel operations into OpenMP
10119545f4SAlex Zinenko // parallel loops.
11119545f4SAlex Zinenko //
12119545f4SAlex Zinenko //===----------------------------------------------------------------------===//
13119545f4SAlex Zinenko 
14119545f4SAlex Zinenko #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
1567d0d7acSMichele Scuttari 
16755dc07dSRiver Riddle #include "mlir/Analysis/SliceAnalysis.h"
17755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
191ce752b7SAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2078fb4f9dSWilliam S. Moses #include "mlir/Dialect/MemRef/IR/MemRef.h"
21119545f4SAlex Zinenko #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
228b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
231ce752b7SAlex Zinenko #include "mlir/IR/ImplicitLocOpBuilder.h"
241ce752b7SAlex Zinenko #include "mlir/IR/SymbolTable.h"
2567d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
26119545f4SAlex Zinenko #include "mlir/Transforms/DialectConversion.h"
27119545f4SAlex Zinenko 
2867d0d7acSMichele Scuttari namespace mlir {
297c4e45ecSMarkus Böck #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
3067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3167d0d7acSMichele Scuttari } // namespace mlir
3267d0d7acSMichele Scuttari 
33119545f4SAlex Zinenko using namespace mlir;
34119545f4SAlex Zinenko 
351ce752b7SAlex Zinenko /// Matches a block containing a "simple" reduction. The expected shape of the
361ce752b7SAlex Zinenko /// block is as follows.
371ce752b7SAlex Zinenko ///
381ce752b7SAlex Zinenko ///   ^bb(%arg0, %arg1):
391ce752b7SAlex Zinenko ///     %0 = OpTy(%arg0, %arg1)
401ce752b7SAlex Zinenko ///     scf.reduce.return %0
411ce752b7SAlex Zinenko template <typename... OpTy>
421ce752b7SAlex Zinenko static bool matchSimpleReduction(Block &block) {
431ce752b7SAlex Zinenko   if (block.empty() || llvm::hasSingleElement(block) ||
441ce752b7SAlex Zinenko       std::next(block.begin(), 2) != block.end())
451ce752b7SAlex Zinenko     return false;
462a876a71SDiego Caballero 
472a876a71SDiego Caballero   if (block.getNumArguments() != 2)
482a876a71SDiego Caballero     return false;
492a876a71SDiego Caballero 
502a876a71SDiego Caballero   SmallVector<Operation *, 4> combinerOps;
512a876a71SDiego Caballero   Value reducedVal = matchReduction({block.getArguments()[1]},
522a876a71SDiego Caballero                                     /*redPos=*/0, combinerOps);
532a876a71SDiego Caballero 
545550c821STres Popp   if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
552a876a71SDiego Caballero     return false;
562a876a71SDiego Caballero 
572a876a71SDiego Caballero   return isa<OpTy...>(combinerOps[0]) &&
581ce752b7SAlex Zinenko          isa<scf::ReduceReturnOp>(block.back()) &&
592a876a71SDiego Caballero          block.front().getOperands() == block.getArguments();
601ce752b7SAlex Zinenko }
611ce752b7SAlex Zinenko 
621ce752b7SAlex Zinenko /// Matches a block containing a select-based min/max reduction. The types of
631ce752b7SAlex Zinenko /// select and compare operations are provided as template arguments. The
641ce752b7SAlex Zinenko /// comparison predicates suitable for min and max are provided as function
651ce752b7SAlex Zinenko /// arguments. If a reduction is matched, `ifMin` will be set if the reduction
661ce752b7SAlex Zinenko /// compute the minimum and unset if it computes the maximum, otherwise it
671ce752b7SAlex Zinenko /// remains unmodified. The expected shape of the block is as follows.
681ce752b7SAlex Zinenko ///
691ce752b7SAlex Zinenko ///   ^bb(%arg0, %arg1):
701ce752b7SAlex Zinenko ///     %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
711ce752b7SAlex Zinenko ///     %1 = SelectOpTy(%0, %arg0, %arg1)  // %arg0, %arg1 may be swapped here.
721ce752b7SAlex Zinenko ///     scf.reduce.return %1
731ce752b7SAlex Zinenko template <
741ce752b7SAlex Zinenko     typename CompareOpTy, typename SelectOpTy,
7562fea88bSJacques Pienaar     typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
761ce752b7SAlex Zinenko static bool
771ce752b7SAlex Zinenko matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
781ce752b7SAlex Zinenko                      ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
79dec8af70SRiver Riddle   static_assert(
80dec8af70SRiver Riddle       llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
81dec8af70SRiver Riddle       "only arithmetic and llvm select ops are supported");
821ce752b7SAlex Zinenko 
831ce752b7SAlex Zinenko   // Expect exactly three operations in the block.
841ce752b7SAlex Zinenko   if (block.empty() || llvm::hasSingleElement(block) ||
851ce752b7SAlex Zinenko       std::next(block.begin(), 2) == block.end() ||
861ce752b7SAlex Zinenko       std::next(block.begin(), 3) != block.end())
871ce752b7SAlex Zinenko     return false;
881ce752b7SAlex Zinenko 
891ce752b7SAlex Zinenko   // Check op kinds.
901ce752b7SAlex Zinenko   auto compare = dyn_cast<CompareOpTy>(block.front());
911ce752b7SAlex Zinenko   auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
921ce752b7SAlex Zinenko   auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
931ce752b7SAlex Zinenko   if (!compare || !select || !terminator)
941ce752b7SAlex Zinenko     return false;
951ce752b7SAlex Zinenko 
961ce752b7SAlex Zinenko   // Block arguments must be compared.
971ce752b7SAlex Zinenko   if (compare->getOperands() != block.getArguments())
981ce752b7SAlex Zinenko     return false;
991ce752b7SAlex Zinenko 
1001ce752b7SAlex Zinenko   // Detect whether the comparison is less-than or greater-than, otherwise bail.
1011ce752b7SAlex Zinenko   bool isLess;
102ad44495aSjacquesguan   if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
1031ce752b7SAlex Zinenko     isLess = true;
104ad44495aSjacquesguan   } else if (llvm::is_contained(greaterThanPredicates,
105ad44495aSjacquesguan                                 compare.getPredicate())) {
1061ce752b7SAlex Zinenko     isLess = false;
1071ce752b7SAlex Zinenko   } else {
1081ce752b7SAlex Zinenko     return false;
1091ce752b7SAlex Zinenko   }
1101ce752b7SAlex Zinenko 
111cfb72fd3SJacques Pienaar   if (select.getCondition() != compare.getResult())
1121ce752b7SAlex Zinenko     return false;
1131ce752b7SAlex Zinenko 
1141ce752b7SAlex Zinenko   // Detect if the operands are swapped between cmpf and select. Match the
1151ce752b7SAlex Zinenko   // comparison type with the requested type or with the opposite of the
1161ce752b7SAlex Zinenko   // requested type if the operands are swapped. Use generic accessors because
1171ce752b7SAlex Zinenko   // std and LLVM versions of select have different operand names but identical
1181ce752b7SAlex Zinenko   // positions.
1191ce752b7SAlex Zinenko   constexpr unsigned kTrueValue = 1;
1201ce752b7SAlex Zinenko   constexpr unsigned kFalseValue = 2;
121cfb72fd3SJacques Pienaar   bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
122cfb72fd3SJacques Pienaar                       select.getOperand(kFalseValue) == compare.getRhs();
123cfb72fd3SJacques Pienaar   bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
124cfb72fd3SJacques Pienaar                          select.getOperand(kFalseValue) == compare.getLhs();
1251ce752b7SAlex Zinenko   if (!sameOperands && !swappedOperands)
1261ce752b7SAlex Zinenko     return false;
1271ce752b7SAlex Zinenko 
128c0342a2dSJacques Pienaar   if (select.getResult() != terminator.getResult())
1291ce752b7SAlex Zinenko     return false;
1301ce752b7SAlex Zinenko 
1311ce752b7SAlex Zinenko   // The reduction is a min if it uses less-than predicates with same operands
1321ce752b7SAlex Zinenko   // or greather-than predicates with swapped operands. Similarly for max.
1331ce752b7SAlex Zinenko   isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
1341ce752b7SAlex Zinenko   return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
1351ce752b7SAlex Zinenko }
1361ce752b7SAlex Zinenko 
1371ce752b7SAlex Zinenko /// Returns the float semantics for the given float type.
1381ce752b7SAlex Zinenko static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
1391ce752b7SAlex Zinenko   if (type.isF16())
1401ce752b7SAlex Zinenko     return llvm::APFloat::IEEEhalf();
1411ce752b7SAlex Zinenko   if (type.isF32())
1421ce752b7SAlex Zinenko     return llvm::APFloat::IEEEsingle();
1431ce752b7SAlex Zinenko   if (type.isF64())
1441ce752b7SAlex Zinenko     return llvm::APFloat::IEEEdouble();
1451ce752b7SAlex Zinenko   if (type.isF128())
1461ce752b7SAlex Zinenko     return llvm::APFloat::IEEEquad();
1471ce752b7SAlex Zinenko   if (type.isBF16())
1481ce752b7SAlex Zinenko     return llvm::APFloat::BFloat();
1491ce752b7SAlex Zinenko   if (type.isF80())
1501ce752b7SAlex Zinenko     return llvm::APFloat::x87DoubleExtended();
1511ce752b7SAlex Zinenko   llvm_unreachable("unknown float type");
1521ce752b7SAlex Zinenko }
1531ce752b7SAlex Zinenko 
1541ce752b7SAlex Zinenko /// Returns an attribute with the minimum (if `min` is set) or the maximum value
1551ce752b7SAlex Zinenko /// (otherwise) for the given float type.
1561ce752b7SAlex Zinenko static Attribute minMaxValueForFloat(Type type, bool min) {
1575550c821STres Popp   auto fltType = cast<FloatType>(type);
1581ce752b7SAlex Zinenko   return FloatAttr::get(
1591ce752b7SAlex Zinenko       type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
1601ce752b7SAlex Zinenko }
1611ce752b7SAlex Zinenko 
1621ce752b7SAlex Zinenko /// Returns an attribute with the signed integer minimum (if `min` is set) or
1631ce752b7SAlex Zinenko /// the maximum value (otherwise) for the given integer type, regardless of its
1641ce752b7SAlex Zinenko /// signedness semantics (only the width is considered).
1651ce752b7SAlex Zinenko static Attribute minMaxValueForSignedInt(Type type, bool min) {
1665550c821STres Popp   auto intType = cast<IntegerType>(type);
1671ce752b7SAlex Zinenko   unsigned bitwidth = intType.getWidth();
1681ce752b7SAlex Zinenko   return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
1691ce752b7SAlex Zinenko                                     : llvm::APInt::getSignedMaxValue(bitwidth));
1701ce752b7SAlex Zinenko }
1711ce752b7SAlex Zinenko 
1721ce752b7SAlex Zinenko /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
1731ce752b7SAlex Zinenko /// the maximum value (otherwise) for the given integer type, regardless of its
1741ce752b7SAlex Zinenko /// signedness semantics (only the width is considered).
1751ce752b7SAlex Zinenko static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
1765550c821STres Popp   auto intType = cast<IntegerType>(type);
1771ce752b7SAlex Zinenko   unsigned bitwidth = intType.getWidth();
1784a05edd4SKazu Hirata   return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
179b7ffd968SKazu Hirata                                     : llvm::APInt::getAllOnes(bitwidth));
1801ce752b7SAlex Zinenko }
1811ce752b7SAlex Zinenko 
1821ce752b7SAlex Zinenko /// Creates an OpenMP reduction declaration and inserts it into the provided
1831ce752b7SAlex Zinenko /// symbol table. The declaration has a constant initializer with the neutral
18410056c82SMatthias Springer /// value `initValue`, and the `reductionIndex`-th reduction combiner carried
18510056c82SMatthias Springer /// over from `reduce`.
186d84252e0SSergio Afonso static omp::DeclareReductionOp
18710056c82SMatthias Springer createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
18810056c82SMatthias Springer            scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
1891ce752b7SAlex Zinenko   OpBuilder::InsertionGuard guard(builder);
19010056c82SMatthias Springer   Type type = reduce.getOperands()[reductionIndex].getType();
191d84252e0SSergio Afonso   auto decl = builder.create<omp::DeclareReductionOp>(reduce.getLoc(),
19210056c82SMatthias Springer                                                       "__scf_reduction", type);
1931ce752b7SAlex Zinenko   symbolTable.insert(decl);
1941ce752b7SAlex Zinenko 
1954fb4e12bSRiver Riddle   builder.createBlock(&decl.getInitializerRegion(),
1964fb4e12bSRiver Riddle                       decl.getInitializerRegion().end(), {type},
19710056c82SMatthias Springer                       {reduce.getOperands()[reductionIndex].getLoc()});
1984fb4e12bSRiver Riddle   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
1991ce752b7SAlex Zinenko   Value init =
2001ce752b7SAlex Zinenko       builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
2011ce752b7SAlex Zinenko   builder.create<omp::YieldOp>(reduce.getLoc(), init);
2021ce752b7SAlex Zinenko 
20310056c82SMatthias Springer   Operation *terminator =
20410056c82SMatthias Springer       &reduce.getReductions()[reductionIndex].front().back();
2051ce752b7SAlex Zinenko   assert(isa<scf::ReduceReturnOp>(terminator) &&
2061ce752b7SAlex Zinenko          "expected reduce op to be terminated by redure return");
2071ce752b7SAlex Zinenko   builder.setInsertionPoint(terminator);
2081ce752b7SAlex Zinenko   builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
2091ce752b7SAlex Zinenko                                            terminator->getOperands());
21010056c82SMatthias Springer   builder.inlineRegionBefore(reduce.getReductions()[reductionIndex],
21110056c82SMatthias Springer                              decl.getReductionRegion(),
2124fb4e12bSRiver Riddle                              decl.getReductionRegion().end());
2131ce752b7SAlex Zinenko   return decl;
2141ce752b7SAlex Zinenko }
2151ce752b7SAlex Zinenko 
2161ce752b7SAlex Zinenko /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
2171ce752b7SAlex Zinenko /// using llvm.atomicrmw of the given kind.
218d84252e0SSergio Afonso static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
2191ce752b7SAlex Zinenko                                             LLVM::AtomicBinOp atomicKind,
220d84252e0SSergio Afonso                                             omp::DeclareReductionOp decl,
22110056c82SMatthias Springer                                             scf::ReduceOp reduce,
22210056c82SMatthias Springer                                             int64_t reductionIndex) {
2231ce752b7SAlex Zinenko   OpBuilder::InsertionGuard guard(builder);
224e2564b27SChristian Ulmann   auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
22510056c82SMatthias Springer   Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc();
2264fb4e12bSRiver Riddle   builder.createBlock(&decl.getAtomicReductionRegion(),
2274fb4e12bSRiver Riddle                       decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
228e084679fSRiver Riddle                       {reduceOperandLoc, reduceOperandLoc});
2294fb4e12bSRiver Riddle   Block *atomicBlock = &decl.getAtomicReductionRegion().back();
2301ce752b7SAlex Zinenko   builder.setInsertionPointToEnd(atomicBlock);
2317c4e45ecSMarkus Böck   Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
2321ce752b7SAlex Zinenko                                               atomicBlock->getArgument(1));
2337f97895fSTobias Gysi   builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
2341ce752b7SAlex Zinenko                                     atomicBlock->getArgument(0), loaded,
2351ce752b7SAlex Zinenko                                     LLVM::AtomicOrdering::monotonic);
2361ce752b7SAlex Zinenko   builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
2371ce752b7SAlex Zinenko   return decl;
2381ce752b7SAlex Zinenko }
2391ce752b7SAlex Zinenko 
2401ce752b7SAlex Zinenko /// Creates an OpenMP reduction declaration that corresponds to the given SCF
2411ce752b7SAlex Zinenko /// reduction and returns it. Recognizes common reductions in order to identify
2421ce752b7SAlex Zinenko /// the neutral value, necessary for the OpenMP declaration. If the reduction
2431ce752b7SAlex Zinenko /// cannot be recognized, returns null.
244d84252e0SSergio Afonso static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
24510056c82SMatthias Springer                                                 scf::ReduceOp reduce,
24610056c82SMatthias Springer                                                 int64_t reductionIndex) {
2471ce752b7SAlex Zinenko   Operation *container = SymbolTable::getNearestSymbolTable(reduce);
2481ce752b7SAlex Zinenko   SymbolTable symbolTable(container);
2491ce752b7SAlex Zinenko 
2501ce752b7SAlex Zinenko   // Insert reduction declarations in the symbol-table ancestor before the
2511ce752b7SAlex Zinenko   // ancestor of the current insertion point.
2521ce752b7SAlex Zinenko   Operation *insertionPoint = reduce;
2531ce752b7SAlex Zinenko   while (insertionPoint->getParentOp() != container)
2541ce752b7SAlex Zinenko     insertionPoint = insertionPoint->getParentOp();
2551ce752b7SAlex Zinenko   OpBuilder::InsertionGuard guard(builder);
2561ce752b7SAlex Zinenko   builder.setInsertionPoint(insertionPoint);
2571ce752b7SAlex Zinenko 
25810056c82SMatthias Springer   assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
2591ce752b7SAlex Zinenko          "expected reduction region to have a single element");
2601ce752b7SAlex Zinenko 
2611ce752b7SAlex Zinenko   // Match simple binary reductions that can be expressed with atomicrmw.
26210056c82SMatthias Springer   Type type = reduce.getOperands()[reductionIndex].getType();
26310056c82SMatthias Springer   Block &reduction = reduce.getReductions()[reductionIndex].front();
264a54f4eaeSMogball   if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
265d84252e0SSergio Afonso     omp::DeclareReductionOp decl =
26610056c82SMatthias Springer         createDecl(builder, symbolTable, reduce, reductionIndex,
2671ce752b7SAlex Zinenko                    builder.getFloatAttr(type, 0.0));
26810056c82SMatthias Springer     return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
26910056c82SMatthias Springer                         reductionIndex);
2701ce752b7SAlex Zinenko   }
271a54f4eaeSMogball   if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
272d84252e0SSergio Afonso     omp::DeclareReductionOp decl =
27310056c82SMatthias Springer         createDecl(builder, symbolTable, reduce, reductionIndex,
2741ce752b7SAlex Zinenko                    builder.getIntegerAttr(type, 0));
27510056c82SMatthias Springer     return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
27610056c82SMatthias Springer                         reductionIndex);
2771ce752b7SAlex Zinenko   }
278a54f4eaeSMogball   if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
279d84252e0SSergio Afonso     omp::DeclareReductionOp decl =
28010056c82SMatthias Springer         createDecl(builder, symbolTable, reduce, reductionIndex,
2811ce752b7SAlex Zinenko                    builder.getIntegerAttr(type, 0));
28210056c82SMatthias Springer     return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
28310056c82SMatthias Springer                         reductionIndex);
2841ce752b7SAlex Zinenko   }
285a54f4eaeSMogball   if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
286d84252e0SSergio Afonso     omp::DeclareReductionOp decl =
28710056c82SMatthias Springer         createDecl(builder, symbolTable, reduce, reductionIndex,
2881ce752b7SAlex Zinenko                    builder.getIntegerAttr(type, 0));
28910056c82SMatthias Springer     return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
29010056c82SMatthias Springer                         reductionIndex);
2911ce752b7SAlex Zinenko   }
292a54f4eaeSMogball   if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
293d84252e0SSergio Afonso     omp::DeclareReductionOp decl = createDecl(
29410056c82SMatthias Springer         builder, symbolTable, reduce, reductionIndex,
2951ce752b7SAlex Zinenko         builder.getIntegerAttr(
296b7ffd968SKazu Hirata             type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
29710056c82SMatthias Springer     return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
29810056c82SMatthias Springer                         reductionIndex);
2991ce752b7SAlex Zinenko   }
3001ce752b7SAlex Zinenko 
3011ce752b7SAlex Zinenko   // Match simple binary reductions that cannot be expressed with atomicrmw.
3021ce752b7SAlex Zinenko   // TODO: add atomic region using cmpxchg (which needs atomic load to be
3031ce752b7SAlex Zinenko   // available as an op).
304a54f4eaeSMogball   if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
30510056c82SMatthias Springer     return createDecl(builder, symbolTable, reduce, reductionIndex,
3061ce752b7SAlex Zinenko                       builder.getFloatAttr(type, 1.0));
3071ce752b7SAlex Zinenko   }
308c1125ae5SKiran Chandramohan   if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
30910056c82SMatthias Springer     return createDecl(builder, symbolTable, reduce, reductionIndex,
310c1125ae5SKiran Chandramohan                       builder.getIntegerAttr(type, 1));
311c1125ae5SKiran Chandramohan   }
3121ce752b7SAlex Zinenko 
3131ce752b7SAlex Zinenko   // Match select-based min/max reductions.
3141ce752b7SAlex Zinenko   bool isMin;
315dec8af70SRiver Riddle   if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
316a54f4eaeSMogball           reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
317a54f4eaeSMogball           {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
3181ce752b7SAlex Zinenko       matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
3191ce752b7SAlex Zinenko           reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
3201ce752b7SAlex Zinenko           {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
32110056c82SMatthias Springer     return createDecl(builder, symbolTable, reduce, reductionIndex,
3221ce752b7SAlex Zinenko                       minMaxValueForFloat(type, !isMin));
3231ce752b7SAlex Zinenko   }
324dec8af70SRiver Riddle   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
325a54f4eaeSMogball           reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
326a54f4eaeSMogball           {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
3271ce752b7SAlex Zinenko       matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
3281ce752b7SAlex Zinenko           reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
3291ce752b7SAlex Zinenko           {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
330d84252e0SSergio Afonso     omp::DeclareReductionOp decl =
33110056c82SMatthias Springer         createDecl(builder, symbolTable, reduce, reductionIndex,
33210056c82SMatthias Springer                    minMaxValueForSignedInt(type, !isMin));
3331ce752b7SAlex Zinenko     return addAtomicRMW(builder,
3341ce752b7SAlex Zinenko                         isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
33510056c82SMatthias Springer                         decl, reduce, reductionIndex);
3361ce752b7SAlex Zinenko   }
337dec8af70SRiver Riddle   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
338a54f4eaeSMogball           reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339a54f4eaeSMogball           {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
3401ce752b7SAlex Zinenko       matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
3411ce752b7SAlex Zinenko           reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
3421ce752b7SAlex Zinenko           {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
343d84252e0SSergio Afonso     omp::DeclareReductionOp decl =
34410056c82SMatthias Springer         createDecl(builder, symbolTable, reduce, reductionIndex,
34510056c82SMatthias Springer                    minMaxValueForUnsignedInt(type, !isMin));
3461ce752b7SAlex Zinenko     return addAtomicRMW(
3471ce752b7SAlex Zinenko         builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
34810056c82SMatthias Springer         decl, reduce, reductionIndex);
3491ce752b7SAlex Zinenko   }
3501ce752b7SAlex Zinenko 
3511ce752b7SAlex Zinenko   return nullptr;
3521ce752b7SAlex Zinenko }
3531ce752b7SAlex Zinenko 
354119545f4SAlex Zinenko namespace {
355119545f4SAlex Zinenko 
356119545f4SAlex Zinenko struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
3577f4f75c1SPablo Antonio Martinez   static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
3587f4f75c1SPablo Antonio Martinez   unsigned numThreads;
3597c4e45ecSMarkus Böck 
3607f4f75c1SPablo Antonio Martinez   ParallelOpLowering(MLIRContext *context,
3617f4f75c1SPablo Antonio Martinez                      unsigned numThreads = kUseOpenMPDefaultNumThreads)
3627f4f75c1SPablo Antonio Martinez       : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
363119545f4SAlex Zinenko 
364119545f4SAlex Zinenko   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
365119545f4SAlex Zinenko                                 PatternRewriter &rewriter) const override {
3661ce752b7SAlex Zinenko     // Declare reductions.
3671ce752b7SAlex Zinenko     // TODO: consider checking it here is already a compatible reduction
3681ce752b7SAlex Zinenko     // declaration and use it instead of redeclaring.
369fdfeea5bSSergio Afonso     SmallVector<Attribute> reductionSyms;
370d84252e0SSergio Afonso     SmallVector<omp::DeclareReductionOp> ompReductionDecls;
37110056c82SMatthias Springer     auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
37210056c82SMatthias Springer     for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
373d84252e0SSergio Afonso       omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
374be9f8ffdSDavid Truby       ompReductionDecls.push_back(decl);
3751ce752b7SAlex Zinenko       if (!decl)
3761ce752b7SAlex Zinenko         return failure();
377fdfeea5bSSergio Afonso       reductionSyms.push_back(
3784fb4e12bSRiver Riddle           SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
3791ce752b7SAlex Zinenko     }
3801ce752b7SAlex Zinenko 
3811ce752b7SAlex Zinenko     // Allocate reduction variables. Make sure the we don't overflow the stack
3821ce752b7SAlex Zinenko     // with local `alloca`s by saving and restoring the stack pointer.
3831ce752b7SAlex Zinenko     Location loc = parallelOp.getLoc();
3841ce752b7SAlex Zinenko     Value one = rewriter.create<LLVM::ConstantOp>(
3851ce752b7SAlex Zinenko         loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
3861ce752b7SAlex Zinenko     SmallVector<Value> reductionVariables;
3871ce752b7SAlex Zinenko     reductionVariables.reserve(parallelOp.getNumReductions());
388e2564b27SChristian Ulmann     auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext());
389c0342a2dSJacques Pienaar     for (Value init : parallelOp.getInitVals()) {
3901ce752b7SAlex Zinenko       assert((LLVM::isCompatibleType(init.getType()) ||
3915550c821STres Popp               isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
3921ce752b7SAlex Zinenko              "cannot create a reduction variable if the type is not an LLVM "
3931ce752b7SAlex Zinenko              "pointer element");
394e2564b27SChristian Ulmann       Value storage =
395e2564b27SChristian Ulmann           rewriter.create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0);
3961ce752b7SAlex Zinenko       rewriter.create<LLVM::StoreOp>(loc, init, storage);
3971ce752b7SAlex Zinenko       reductionVariables.push_back(storage);
3981ce752b7SAlex Zinenko     }
3991ce752b7SAlex Zinenko 
4001ce752b7SAlex Zinenko     // Replace the reduction operations contained in this loop. Must be done
4011ce752b7SAlex Zinenko     // here rather than in a separate pattern to have access to the list of
4021ce752b7SAlex Zinenko     // reduction variables.
403be9f8ffdSDavid Truby     for (auto [x, y, rD] : llvm::zip_equal(
404be9f8ffdSDavid Truby              reductionVariables, reduce.getOperands(), ompReductionDecls)) {
4051ce752b7SAlex Zinenko       OpBuilder::InsertionGuard guard(rewriter);
40610056c82SMatthias Springer       rewriter.setInsertionPoint(reduce);
407be9f8ffdSDavid Truby       Region &redRegion = rD.getReductionRegion();
408be9f8ffdSDavid Truby       // The SCF dialect by definition contains only structured operations
409be9f8ffdSDavid Truby       // and hence the SCF reduction region will contain a single block.
410be9f8ffdSDavid Truby       // The ompReductionDecls region is a copy of the SCF reduction region
411be9f8ffdSDavid Truby       // and hence has the same property.
412be9f8ffdSDavid Truby       assert(redRegion.hasOneBlock() &&
413be9f8ffdSDavid Truby              "expect reduction region to have one block");
414be9f8ffdSDavid Truby       Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc);
415be9f8ffdSDavid Truby       Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(),
416be9f8ffdSDavid Truby                                                       rD.getType(), pvtRedVar);
417be9f8ffdSDavid Truby       // Make a copy of the reduction combiner region in the body
418be9f8ffdSDavid Truby       mlir::OpBuilder builder(rewriter.getContext());
419be9f8ffdSDavid Truby       builder.setInsertionPoint(reduce);
420be9f8ffdSDavid Truby       mlir::IRMapping mapper;
421be9f8ffdSDavid Truby       assert(redRegion.getNumArguments() == 2 &&
422be9f8ffdSDavid Truby              "expect reduction region to have two arguments");
423be9f8ffdSDavid Truby       mapper.map(redRegion.getArgument(0), pvtRedVal);
424be9f8ffdSDavid Truby       mapper.map(redRegion.getArgument(1), y);
425be9f8ffdSDavid Truby       for (auto &op : redRegion.getOps()) {
426be9f8ffdSDavid Truby         Operation *cloneOp = builder.clone(op, mapper);
427be9f8ffdSDavid Truby         if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
428be9f8ffdSDavid Truby           assert(yieldOp && yieldOp.getResults().size() == 1 &&
429be9f8ffdSDavid Truby                  "expect YieldOp in reduction region to return one result");
430be9f8ffdSDavid Truby           Value redVal = yieldOp.getResults()[0];
431be9f8ffdSDavid Truby           rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar);
432be9f8ffdSDavid Truby           rewriter.eraseOp(yieldOp);
433be9f8ffdSDavid Truby           break;
434be9f8ffdSDavid Truby         }
435be9f8ffdSDavid Truby       }
4361ce752b7SAlex Zinenko     }
43710056c82SMatthias Springer     rewriter.eraseOp(reduce);
4381ce752b7SAlex Zinenko 
4397f4f75c1SPablo Antonio Martinez     Value numThreadsVar;
4407f4f75c1SPablo Antonio Martinez     if (numThreads > 0) {
4417f4f75c1SPablo Antonio Martinez       numThreadsVar = rewriter.create<LLVM::ConstantOp>(
4427f4f75c1SPablo Antonio Martinez           loc, rewriter.getI32IntegerAttr(numThreads));
4437f4f75c1SPablo Antonio Martinez     }
4441ce752b7SAlex Zinenko     // Create the parallel wrapper.
4457f4f75c1SPablo Antonio Martinez     auto ompParallel = rewriter.create<omp::ParallelOp>(
4467f4f75c1SPablo Antonio Martinez         loc,
4477f4f75c1SPablo Antonio Martinez         /* allocate_vars = */ llvm::SmallVector<Value>{},
448fdfeea5bSSergio Afonso         /* allocator_vars = */ llvm::SmallVector<Value>{},
449b3b46963SSergio Afonso         /* if_expr = */ Value{},
450b3b46963SSergio Afonso         /* num_threads = */ numThreadsVar,
451b3b46963SSergio Afonso         /* private_vars = */ ValueRange(),
452b3b46963SSergio Afonso         /* private_syms = */ nullptr,
453b3b46963SSergio Afonso         /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
454*afcbcae6SAnchu Rajendran S         /* reduction_mod = */ nullptr,
4557f4f75c1SPablo Antonio Martinez         /* reduction_vars = */ llvm::SmallVector<Value>{},
456fdfeea5bSSergio Afonso         /* reduction_byref = */ DenseBoolArrayAttr{},
457b3b46963SSergio Afonso         /* reduction_syms = */ ArrayAttr{});
4581ce752b7SAlex Zinenko     {
45978fb4f9dSWilliam S. Moses 
4601ce752b7SAlex Zinenko       OpBuilder::InsertionGuard guard(rewriter);
4614fb4e12bSRiver Riddle       rewriter.createBlock(&ompParallel.getRegion());
4621ce752b7SAlex Zinenko 
463119545f4SAlex Zinenko       // Replace the loop.
464bf6477ebSWilliam S. Moses       {
465bf6477ebSWilliam S. Moses         OpBuilder::InsertionGuard allocaGuard(rewriter);
4668843d541SSergio Afonso         // Create worksharing loop wrapper.
4678843d541SSergio Afonso         auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
4688843d541SSergio Afonso         if (!reductionVariables.empty()) {
469fdfeea5bSSergio Afonso           wsloopOp.setReductionSymsAttr(
470fdfeea5bSSergio Afonso               ArrayAttr::get(rewriter.getContext(), reductionSyms));
4718843d541SSergio Afonso           wsloopOp.getReductionVarsMutable().append(reductionVariables);
472fdfeea5bSSergio Afonso           llvm::SmallVector<bool> reductionByRef;
47374a87548STom Eccles           // false because these reductions always reduce scalars and so do
47474a87548STom Eccles           // not need to pass by reference
475fdfeea5bSSergio Afonso           reductionByRef.resize(reductionVariables.size(), false);
476fdfeea5bSSergio Afonso           wsloopOp.setReductionByref(
477fdfeea5bSSergio Afonso               DenseBoolArrayAttr::get(rewriter.getContext(), reductionByRef));
4788843d541SSergio Afonso         }
4798843d541SSergio Afonso         rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
4808843d541SSergio Afonso 
4818843d541SSergio Afonso         // The wrapper's entry block arguments will define the reduction
4828843d541SSergio Afonso         // variables.
4838843d541SSergio Afonso         llvm::SmallVector<mlir::Type> reductionTypes;
4848843d541SSergio Afonso         reductionTypes.reserve(reductionVariables.size());
4858843d541SSergio Afonso         llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
4868843d541SSergio Afonso                         [](mlir::Value v) { return v.getType(); });
4878843d541SSergio Afonso         rewriter.createBlock(
4888843d541SSergio Afonso             &wsloopOp.getRegion(), {}, reductionTypes,
4898843d541SSergio Afonso             llvm::SmallVector<mlir::Location>(reductionVariables.size(),
4908843d541SSergio Afonso                                               parallelOp.getLoc()));
4918843d541SSergio Afonso 
4928843d541SSergio Afonso         // Create loop nest and populate region with contents of scf.parallel.
4938843d541SSergio Afonso         auto loopOp = rewriter.create<omp::LoopNestOp>(
494c0342a2dSJacques Pienaar             parallelOp.getLoc(), parallelOp.getLowerBound(),
495c0342a2dSJacques Pienaar             parallelOp.getUpperBound(), parallelOp.getStep());
4961ce752b7SAlex Zinenko 
4978843d541SSergio Afonso         rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
4988843d541SSergio Afonso                                     loopOp.getRegion().begin());
499bf6477ebSWilliam S. Moses 
5008843d541SSergio Afonso         // Remove reduction-related block arguments from omp.loop_nest and
5018843d541SSergio Afonso         // redirect uses to the corresponding omp.wsloop block argument.
5028843d541SSergio Afonso         mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
5038843d541SSergio Afonso         unsigned numLoops = parallelOp.getNumLoops();
5048843d541SSergio Afonso         rewriter.replaceAllUsesWith(
5058843d541SSergio Afonso             loopOpEntryBlock.getArguments().drop_front(numLoops),
5068843d541SSergio Afonso             wsloopOp.getRegion().getArguments());
5078843d541SSergio Afonso         loopOpEntryBlock.eraseArguments(
5088843d541SSergio Afonso             numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
509bf6477ebSWilliam S. Moses 
5108843d541SSergio Afonso         Block *ops =
5118843d541SSergio Afonso             rewriter.splitBlock(&loopOpEntryBlock, loopOpEntryBlock.begin());
5128843d541SSergio Afonso         rewriter.setInsertionPointToStart(&loopOpEntryBlock);
513bf6477ebSWilliam S. Moses 
514bf6477ebSWilliam S. Moses         auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
515bf6477ebSWilliam S. Moses                                                             TypeRange());
516bf6477ebSWilliam S. Moses         rewriter.create<omp::YieldOp>(loc, ValueRange());
517bf6477ebSWilliam S. Moses         Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
518bf6477ebSWilliam S. Moses         rewriter.mergeBlocks(ops, scopeBlock);
519bf6477ebSWilliam S. Moses         rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
52010056c82SMatthias Springer         rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
5211ce752b7SAlex Zinenko       }
52278fb4f9dSWilliam S. Moses     }
523973cb2c3SWilliam S. Moses 
5241ce752b7SAlex Zinenko     // Load loop results.
5251ce752b7SAlex Zinenko     SmallVector<Value> results;
5261ce752b7SAlex Zinenko     results.reserve(reductionVariables.size());
5277c4e45ecSMarkus Böck     for (auto [variable, type] :
5287c4e45ecSMarkus Böck          llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
5297c4e45ecSMarkus Böck       Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
5301ce752b7SAlex Zinenko       results.push_back(res);
5311ce752b7SAlex Zinenko     }
5321ce752b7SAlex Zinenko     rewriter.replaceOp(parallelOp, results);
5331ce752b7SAlex Zinenko 
534119545f4SAlex Zinenko     return success();
535119545f4SAlex Zinenko   }
536119545f4SAlex Zinenko };
537119545f4SAlex Zinenko 
538119545f4SAlex Zinenko /// Applies the conversion patterns in the given function.
5397f4f75c1SPablo Antonio Martinez static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
5401ce752b7SAlex Zinenko   ConversionTarget target(*module.getContext());
5411ce752b7SAlex Zinenko   target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
54278fb4f9dSWilliam S. Moses   target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
54378fb4f9dSWilliam S. Moses                          memref::MemRefDialect>();
544119545f4SAlex Zinenko 
5451ce752b7SAlex Zinenko   RewritePatternSet patterns(module.getContext());
5467f4f75c1SPablo Antonio Martinez   patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
54779d7f618SChris Lattner   FrozenRewritePatternSet frozen(std::move(patterns));
5481ce752b7SAlex Zinenko   return applyPartialConversion(module, target, frozen);
549119545f4SAlex Zinenko }
550119545f4SAlex Zinenko 
551119545f4SAlex Zinenko /// A pass converting SCF operations to OpenMP operations.
5527c4e45ecSMarkus Böck struct SCFToOpenMPPass
5537c4e45ecSMarkus Böck     : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
5547c4e45ecSMarkus Böck 
5557c4e45ecSMarkus Böck   using Base::Base;
5567c4e45ecSMarkus Böck 
557119545f4SAlex Zinenko   /// Pass entry point.
5581ce752b7SAlex Zinenko   void runOnOperation() override {
5597f4f75c1SPablo Antonio Martinez     if (failed(applyPatterns(getOperation(), numThreads)))
560119545f4SAlex Zinenko       signalPassFailure();
561119545f4SAlex Zinenko   }
562119545f4SAlex Zinenko };
563119545f4SAlex Zinenko 
564be0a7e9fSMehdi Amini } // namespace
565