xref: /llvm-project/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (revision b7ffd9686ddf6bc4d8006f76acd55ac0a43a8ec7)
1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
15 
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/SymbolTable.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 /// Matches a block containing a "simple" reduction. The expected shape of the
36 /// block is as follows.
37 ///
38 ///   ^bb(%arg0, %arg1):
39 ///     %0 = OpTy(%arg0, %arg1)
40 ///     scf.reduce.return %0
41 template <typename... OpTy>
42 static bool matchSimpleReduction(Block &block) {
43   if (block.empty() || llvm::hasSingleElement(block) ||
44       std::next(block.begin(), 2) != block.end())
45     return false;
46 
47   if (block.getNumArguments() != 2)
48     return false;
49 
50   SmallVector<Operation *, 4> combinerOps;
51   Value reducedVal = matchReduction({block.getArguments()[1]},
52                                     /*redPos=*/0, combinerOps);
53 
54   if (!reducedVal || !reducedVal.isa<BlockArgument>() ||
55       combinerOps.size() != 1)
56     return false;
57 
58   return isa<OpTy...>(combinerOps[0]) &&
59          isa<scf::ReduceReturnOp>(block.back()) &&
60          block.front().getOperands() == block.getArguments();
61 }
62 
63 /// Matches a block containing a select-based min/max reduction. The types of
64 /// select and compare operations are provided as template arguments. The
65 /// comparison predicates suitable for min and max are provided as function
66 /// arguments. If a reduction is matched, `ifMin` will be set if the reduction
67 /// compute the minimum and unset if it computes the maximum, otherwise it
68 /// remains unmodified. The expected shape of the block is as follows.
69 ///
70 ///   ^bb(%arg0, %arg1):
71 ///     %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
72 ///     %1 = SelectOpTy(%0, %arg0, %arg1)  // %arg0, %arg1 may be swapped here.
73 ///     scf.reduce.return %1
74 template <
75     typename CompareOpTy, typename SelectOpTy,
76     typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
77 static bool
78 matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
79                      ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
80   static_assert(
81       llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
82       "only arithmetic and llvm select ops are supported");
83 
84   // Expect exactly three operations in the block.
85   if (block.empty() || llvm::hasSingleElement(block) ||
86       std::next(block.begin(), 2) == block.end() ||
87       std::next(block.begin(), 3) != block.end())
88     return false;
89 
90   // Check op kinds.
91   auto compare = dyn_cast<CompareOpTy>(block.front());
92   auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
93   auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back());
94   if (!compare || !select || !terminator)
95     return false;
96 
97   // Block arguments must be compared.
98   if (compare->getOperands() != block.getArguments())
99     return false;
100 
101   // Detect whether the comparison is less-than or greater-than, otherwise bail.
102   bool isLess;
103   if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
104     isLess = true;
105   } else if (llvm::is_contained(greaterThanPredicates,
106                                 compare.getPredicate())) {
107     isLess = false;
108   } else {
109     return false;
110   }
111 
112   if (select.getCondition() != compare.getResult())
113     return false;
114 
115   // Detect if the operands are swapped between cmpf and select. Match the
116   // comparison type with the requested type or with the opposite of the
117   // requested type if the operands are swapped. Use generic accessors because
118   // std and LLVM versions of select have different operand names but identical
119   // positions.
120   constexpr unsigned kTrueValue = 1;
121   constexpr unsigned kFalseValue = 2;
122   bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
123                       select.getOperand(kFalseValue) == compare.getRhs();
124   bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
125                          select.getOperand(kFalseValue) == compare.getLhs();
126   if (!sameOperands && !swappedOperands)
127     return false;
128 
129   if (select.getResult() != terminator.getResult())
130     return false;
131 
132   // The reduction is a min if it uses less-than predicates with same operands
133   // or greather-than predicates with swapped operands. Similarly for max.
134   isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
135   return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
136 }
137 
138 /// Returns the float semantics for the given float type.
139 static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
140   if (type.isF16())
141     return llvm::APFloat::IEEEhalf();
142   if (type.isF32())
143     return llvm::APFloat::IEEEsingle();
144   if (type.isF64())
145     return llvm::APFloat::IEEEdouble();
146   if (type.isF128())
147     return llvm::APFloat::IEEEquad();
148   if (type.isBF16())
149     return llvm::APFloat::BFloat();
150   if (type.isF80())
151     return llvm::APFloat::x87DoubleExtended();
152   llvm_unreachable("unknown float type");
153 }
154 
155 /// Returns an attribute with the minimum (if `min` is set) or the maximum value
156 /// (otherwise) for the given float type.
157 static Attribute minMaxValueForFloat(Type type, bool min) {
158   auto fltType = type.cast<FloatType>();
159   return FloatAttr::get(
160       type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
161 }
162 
163 /// Returns an attribute with the signed integer minimum (if `min` is set) or
164 /// the maximum value (otherwise) for the given integer type, regardless of its
165 /// signedness semantics (only the width is considered).
166 static Attribute minMaxValueForSignedInt(Type type, bool min) {
167   auto intType = type.cast<IntegerType>();
168   unsigned bitwidth = intType.getWidth();
169   return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
170                                     : llvm::APInt::getSignedMaxValue(bitwidth));
171 }
172 
173 /// Returns an attribute with the unsigned integer minimum (if `min` is set) or
174 /// the maximum value (otherwise) for the given integer type, regardless of its
175 /// signedness semantics (only the width is considered).
176 static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
177   auto intType = type.cast<IntegerType>();
178   unsigned bitwidth = intType.getWidth();
179   return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
180                                     : llvm::APInt::getAllOnes(bitwidth));
181 }
182 
183 /// Creates an OpenMP reduction declaration and inserts it into the provided
184 /// symbol table. The declaration has a constant initializer with the neutral
185 /// value `initValue`, and the reduction combiner carried over from `reduce`.
186 static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
187                                           SymbolTable &symbolTable,
188                                           scf::ReduceOp reduce,
189                                           Attribute initValue) {
190   OpBuilder::InsertionGuard guard(builder);
191   auto decl = builder.create<omp::ReductionDeclareOp>(
192       reduce.getLoc(), "__scf_reduction", reduce.getOperand().getType());
193   symbolTable.insert(decl);
194 
195   Type type = reduce.getOperand().getType();
196   builder.createBlock(&decl.getInitializerRegion(),
197                       decl.getInitializerRegion().end(), {type},
198                       {reduce.getOperand().getLoc()});
199   builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
200   Value init =
201       builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue);
202   builder.create<omp::YieldOp>(reduce.getLoc(), init);
203 
204   Operation *terminator = &reduce.getRegion().front().back();
205   assert(isa<scf::ReduceReturnOp>(terminator) &&
206          "expected reduce op to be terminated by redure return");
207   builder.setInsertionPoint(terminator);
208   builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
209                                            terminator->getOperands());
210   builder.inlineRegionBefore(reduce.getRegion(), decl.getReductionRegion(),
211                              decl.getReductionRegion().end());
212   return decl;
213 }
214 
215 /// Returns an LLVM pointer type with the given element type, or an opaque
216 /// pointer if 'useOpaquePointers' is true.
217 static LLVM::LLVMPointerType getPointerType(Type elementType,
218                                             bool useOpaquePointers) {
219   if (useOpaquePointers)
220     return LLVM::LLVMPointerType::get(elementType.getContext());
221   return LLVM::LLVMPointerType::get(elementType);
222 }
223 
224 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
225 /// using llvm.atomicrmw of the given kind.
226 static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
227                                             LLVM::AtomicBinOp atomicKind,
228                                             omp::ReductionDeclareOp decl,
229                                             scf::ReduceOp reduce,
230                                             bool useOpaquePointers) {
231   OpBuilder::InsertionGuard guard(builder);
232   Type type = reduce.getOperand().getType();
233   Type ptrType = getPointerType(type, useOpaquePointers);
234   Location reduceOperandLoc = reduce.getOperand().getLoc();
235   builder.createBlock(&decl.getAtomicReductionRegion(),
236                       decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
237                       {reduceOperandLoc, reduceOperandLoc});
238   Block *atomicBlock = &decl.getAtomicReductionRegion().back();
239   builder.setInsertionPointToEnd(atomicBlock);
240   Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
241                                               atomicBlock->getArgument(1));
242   builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
243                                     atomicBlock->getArgument(0), loaded,
244                                     LLVM::AtomicOrdering::monotonic);
245   builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>());
246   return decl;
247 }
248 
249 /// Creates an OpenMP reduction declaration that corresponds to the given SCF
250 /// reduction and returns it. Recognizes common reductions in order to identify
251 /// the neutral value, necessary for the OpenMP declaration. If the reduction
252 /// cannot be recognized, returns null.
253 static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
254                                                 scf::ReduceOp reduce,
255                                                 bool useOpaquePointers) {
256   Operation *container = SymbolTable::getNearestSymbolTable(reduce);
257   SymbolTable symbolTable(container);
258 
259   // Insert reduction declarations in the symbol-table ancestor before the
260   // ancestor of the current insertion point.
261   Operation *insertionPoint = reduce;
262   while (insertionPoint->getParentOp() != container)
263     insertionPoint = insertionPoint->getParentOp();
264   OpBuilder::InsertionGuard guard(builder);
265   builder.setInsertionPoint(insertionPoint);
266 
267   assert(llvm::hasSingleElement(reduce.getRegion()) &&
268          "expected reduction region to have a single element");
269 
270   // Match simple binary reductions that can be expressed with atomicrmw.
271   Type type = reduce.getOperand().getType();
272   Block &reduction = reduce.getRegion().front();
273   if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
274     omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
275                                               builder.getFloatAttr(type, 0.0));
276     return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
277                         useOpaquePointers);
278   }
279   if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
280     omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
281                                               builder.getIntegerAttr(type, 0));
282     return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
283                         useOpaquePointers);
284   }
285   if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
286     omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
287                                               builder.getIntegerAttr(type, 0));
288     return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
289                         useOpaquePointers);
290   }
291   if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
292     omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
293                                               builder.getIntegerAttr(type, 0));
294     return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
295                         useOpaquePointers);
296   }
297   if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
298     omp::ReductionDeclareOp decl = createDecl(
299         builder, symbolTable, reduce,
300         builder.getIntegerAttr(
301             type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
302     return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
303                         useOpaquePointers);
304   }
305 
306   // Match simple binary reductions that cannot be expressed with atomicrmw.
307   // TODO: add atomic region using cmpxchg (which needs atomic load to be
308   // available as an op).
309   if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
310     return createDecl(builder, symbolTable, reduce,
311                       builder.getFloatAttr(type, 1.0));
312   }
313 
314   // Match select-based min/max reductions.
315   bool isMin;
316   if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
317           reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
318           {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
319       matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
320           reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
321           {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
322     return createDecl(builder, symbolTable, reduce,
323                       minMaxValueForFloat(type, !isMin));
324   }
325   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
326           reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
327           {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
328       matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
329           reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
330           {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
331     omp::ReductionDeclareOp decl = createDecl(
332         builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
333     return addAtomicRMW(builder,
334                         isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
335                         decl, reduce, useOpaquePointers);
336   }
337   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
338           reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339           {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
340       matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
341           reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
342           {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
343     omp::ReductionDeclareOp decl = createDecl(
344         builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
345     return addAtomicRMW(
346         builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
347         decl, reduce, useOpaquePointers);
348   }
349 
350   return nullptr;
351 }
352 
353 namespace {
354 
355 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
356 
357   bool useOpaquePointers;
358 
359   ParallelOpLowering(MLIRContext *context, bool useOpaquePointers)
360       : OpRewritePattern<scf::ParallelOp>(context),
361         useOpaquePointers(useOpaquePointers) {}
362 
363   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
364                                 PatternRewriter &rewriter) const override {
365     // Declare reductions.
366     // TODO: consider checking it here is already a compatible reduction
367     // declaration and use it instead of redeclaring.
368     SmallVector<Attribute> reductionDeclSymbols;
369     for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
370       omp::ReductionDeclareOp decl =
371           declareReduction(rewriter, reduce, useOpaquePointers);
372       if (!decl)
373         return failure();
374       reductionDeclSymbols.push_back(
375           SymbolRefAttr::get(rewriter.getContext(), decl.getSymName()));
376     }
377 
378     // Allocate reduction variables. Make sure the we don't overflow the stack
379     // with local `alloca`s by saving and restoring the stack pointer.
380     Location loc = parallelOp.getLoc();
381     Value one = rewriter.create<LLVM::ConstantOp>(
382         loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1));
383     SmallVector<Value> reductionVariables;
384     reductionVariables.reserve(parallelOp.getNumReductions());
385     for (Value init : parallelOp.getInitVals()) {
386       assert((LLVM::isCompatibleType(init.getType()) ||
387               init.getType().isa<LLVM::PointerElementTypeInterface>()) &&
388              "cannot create a reduction variable if the type is not an LLVM "
389              "pointer element");
390       Value storage = rewriter.create<LLVM::AllocaOp>(
391           loc, getPointerType(init.getType(), useOpaquePointers),
392           init.getType(), one, 0);
393       rewriter.create<LLVM::StoreOp>(loc, init, storage);
394       reductionVariables.push_back(storage);
395     }
396 
397     // Replace the reduction operations contained in this loop. Must be done
398     // here rather than in a separate pattern to have access to the list of
399     // reduction variables.
400     for (auto pair :
401          llvm::zip(parallelOp.getOps<scf::ReduceOp>(), reductionVariables)) {
402       OpBuilder::InsertionGuard guard(rewriter);
403       scf::ReduceOp reduceOp = std::get<0>(pair);
404       rewriter.setInsertionPoint(reduceOp);
405       rewriter.replaceOpWithNewOp<omp::ReductionOp>(
406           reduceOp, reduceOp.getOperand(), std::get<1>(pair));
407     }
408 
409     // Create the parallel wrapper.
410     auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
411     {
412 
413       OpBuilder::InsertionGuard guard(rewriter);
414       rewriter.createBlock(&ompParallel.getRegion());
415 
416       // Replace the loop.
417       {
418         OpBuilder::InsertionGuard allocaGuard(rewriter);
419         auto loop = rewriter.create<omp::WsLoopOp>(
420             parallelOp.getLoc(), parallelOp.getLowerBound(),
421             parallelOp.getUpperBound(), parallelOp.getStep());
422         rewriter.create<omp::TerminatorOp>(loc);
423 
424         rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
425                                     loop.getRegion().begin());
426 
427         Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
428                                          loop.getRegion().begin()->begin());
429 
430         rewriter.setInsertionPointToStart(&*loop.getRegion().begin());
431 
432         auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
433                                                             TypeRange());
434         rewriter.create<omp::YieldOp>(loc, ValueRange());
435         Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
436         rewriter.mergeBlocks(ops, scopeBlock);
437         auto oldYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
438         rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
439         rewriter.replaceOpWithNewOp<memref::AllocaScopeReturnOp>(
440             oldYield, oldYield->getOperands());
441         if (!reductionVariables.empty()) {
442           loop.setReductionsAttr(
443               ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
444           loop.getReductionVarsMutable().append(reductionVariables);
445         }
446       }
447     }
448 
449     // Load loop results.
450     SmallVector<Value> results;
451     results.reserve(reductionVariables.size());
452     for (auto [variable, type] :
453          llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
454       Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
455       results.push_back(res);
456     }
457     rewriter.replaceOp(parallelOp, results);
458 
459     return success();
460   }
461 };
462 
463 /// Applies the conversion patterns in the given function.
464 static LogicalResult applyPatterns(ModuleOp module, bool useOpaquePointers) {
465   ConversionTarget target(*module.getContext());
466   target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
467   target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
468                          memref::MemRefDialect>();
469 
470   RewritePatternSet patterns(module.getContext());
471   patterns.add<ParallelOpLowering>(module.getContext(), useOpaquePointers);
472   FrozenRewritePatternSet frozen(std::move(patterns));
473   return applyPartialConversion(module, target, frozen);
474 }
475 
476 /// A pass converting SCF operations to OpenMP operations.
477 struct SCFToOpenMPPass
478     : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
479 
480   using Base::Base;
481 
482   /// Pass entry point.
483   void runOnOperation() override {
484     if (failed(applyPatterns(getOperation(), useOpaquePointers)))
485       signalPassFailure();
486   }
487 };
488 
489 } // namespace
490