1 //===- Utils.h - SCF dialect utilities --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines prototypes for various SCF utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SCF_UTILS_UTILS_H_ 14 #define MLIR_DIALECT_SCF_UTILS_UTILS_H_ 15 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/IR/PatternMatch.h" 18 #include "mlir/Support/LLVM.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include <optional> 21 22 namespace mlir { 23 class Location; 24 class Operation; 25 class OpBuilder; 26 class Region; 27 class RewriterBase; 28 class ValueRange; 29 class Value; 30 31 namespace func { 32 class CallOp; 33 class FuncOp; 34 } // namespace func 35 36 /// Update a perfectly nested loop nest to yield new values from the innermost 37 /// loop and propagating it up through the loop nest. This function 38 /// - Expects `loopNest` to be a perfectly nested loop with outer most loop 39 /// first and innermost loop last. 40 /// - `newIterOperands` are the initialization values to be used for the 41 /// outermost loop 42 /// - `newYielValueFn` is the callback that generates the new values to be 43 /// yielded from within the innermost loop. 44 /// - The original loops are not erased, but are left in a "no-op" state where 45 /// the body of the loop just yields the basic block arguments that correspond 46 /// to the initialization values of a loop. The original loops are dead after 47 /// this method. 48 /// - If `replaceIterOperandsUsesInLoop` is true, all uses of the 49 /// `newIterOperands` within the generated new loop are replaced with the 50 /// corresponding `BlockArgument` in the loop body. 51 SmallVector<scf::ForOp> replaceLoopNestWithNewYields( 52 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest, 53 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, 54 bool replaceIterOperandsUsesInLoop = true); 55 56 /// Outline a region with a single block into a new FuncOp. 57 /// Assumes the FuncOp result types is the type of the yielded operands of the 58 /// single block. This constraint makes it easy to determine the result. 59 /// This method also clones the `arith::ConstantIndexOp` at the start of 60 /// `outlinedFuncBody` to alloc simple canonicalizations. 61 /// Creates a new FuncOp and thus cannot be used in a FuncOp pass. 62 /// The client is responsible for providing a unique `funcName` that will not 63 /// collide with another FuncOp name. If `callOp` is provided, it will be set 64 /// to point to the operation that calls the outlined function. 65 // TODO: support more than single-block regions. 66 // TODO: more flexible constant handling. 67 FailureOr<func::FuncOp> 68 outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, 69 StringRef funcName, func::CallOp *callOp = nullptr); 70 71 /// Outline the then and/or else regions of `ifOp` as follows: 72 /// - if `thenFn` is not null, `thenFnName` must be specified and the `then` 73 /// region is inlined into a new FuncOp that is captured by the pointer. 74 /// - if `elseFn` is not null, `elseFnName` must be specified and the `else` 75 /// region is inlined into a new FuncOp that is captured by the pointer. 76 /// Creates new FuncOps and thus cannot be used in a FuncOp pass. 77 /// The client is responsible for providing a unique `thenFnName`/`elseFnName` 78 /// that will not collide with another FuncOp name. 79 LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, 80 StringRef thenFnName, func::FuncOp *elseFn, 81 StringRef elseFnName); 82 83 /// Get a list of innermost parallel loops contained in `rootOp`. Innermost 84 /// parallel loops are those that do not contain further parallel loops 85 /// themselves. 86 bool getInnermostParallelLoops(Operation *rootOp, 87 SmallVectorImpl<scf::ParallelOp> &result); 88 89 /// Return the min/max expressions for `value` if it is an induction variable 90 /// from scf.for or scf.parallel loop. 91 /// if `loopFilter` is passed, the filter determines which loop to consider. 92 /// Other induction variables are ignored. 93 std::optional<std::pair<AffineExpr, AffineExpr>> 94 getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims, 95 SmallVectorImpl<Value> &symbols, 96 llvm::function_ref<bool(Operation *)> loopFilter = nullptr); 97 98 /// Replace a perfect nest of "for" loops with a single linearized loop. Assumes 99 /// `loops` contains a list of perfectly nested loops with bounds and steps 100 /// independent of any loop induction variable involved in the nest. 101 LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops); 102 LogicalResult coalesceLoops(RewriterBase &rewriter, 103 MutableArrayRef<scf::ForOp>); 104 105 /// Walk an affine.for to find a band to coalesce. 106 LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op); 107 108 /// Take the ParallelLoop and for each set of dimension indices, combine them 109 /// into a single dimension. combinedDimensions must contain each index into 110 /// loops exactly once. 111 void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, 112 ArrayRef<std::vector<unsigned>> combinedDimensions); 113 114 struct UnrolledLoopInfo { 115 std::optional<scf::ForOp> mainLoopOp = std::nullopt; 116 std::optional<scf::ForOp> epilogueLoopOp = std::nullopt; 117 }; 118 119 /// Unrolls this for operation by the specified unroll factor. Returns the 120 /// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise 121 /// returns failure if the loop cannot be unrolled either due to restrictions or 122 /// due to invalid unroll factors. Requires positive loop bounds and step. If 123 /// specified, annotates the Ops in each unrolled iteration by applying 124 /// `annotateFn`. 125 FailureOr<UnrolledLoopInfo> loopUnrollByFactor( 126 scf::ForOp forOp, uint64_t unrollFactor, 127 function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr); 128 129 /// Unrolls and jams this `scf.for` operation by the specified unroll factor. 130 /// Returns failure if the loop cannot be unrolled either due to restrictions or 131 /// due to invalid unroll factors. In case of unroll factor of 1, the function 132 /// bails out without doing anything (returns success). Currently, only constant 133 /// trip count that are divided by the unroll factor is supported. Currently, 134 /// for operations with results are not supported. 135 LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor); 136 137 /// Materialize bounds and step of a zero-based and unit-step loop derived by 138 /// normalizing the specified bounds and step. 139 Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, 140 OpFoldResult lb, OpFoldResult ub, 141 OpFoldResult step); 142 143 /// Get back the original induction variable values after loop normalization. 144 void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, 145 Value normalizedIv, OpFoldResult origLb, 146 OpFoldResult origStep); 147 148 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such 149 /// parametric tile sizes that the outer loops have a fixed number of iterations 150 /// as defined in `sizes`. 151 using Loops = SmallVector<scf::ForOp, 8>; 152 using TileLoops = std::pair<Loops, Loops>; 153 TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef<int64_t> sizes); 154 155 /// Performs tiling fo imperfectly nested loops (with interchange) by 156 /// strip-mining the `forOps` by `sizes` and sinking them, in their order of 157 /// occurrence in `forOps`, under each of the `targets`. 158 /// Returns the new AffineForOps, one per each of (`forOps`, `targets`) pair, 159 /// nested immediately under each of `targets`. 160 SmallVector<Loops, 8> tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes, 161 ArrayRef<scf::ForOp> targets); 162 163 /// Performs tiling (with interchange) by strip-mining the `forOps` by `sizes` 164 /// and sinking them, in their order of occurrence in `forOps`, under `target`. 165 /// Returns the new AffineForOps, one per `forOps`, nested immediately under 166 /// `target`. 167 Loops tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes, 168 scf::ForOp target); 169 170 /// Tile a nest of scf::ForOp loops rooted at `rootForOp` with the given 171 /// (parametric) sizes. Sizes are expected to be strictly positive values at 172 /// runtime. If more sizes than loops are provided, discard the trailing values 173 /// in sizes. Assumes the loop nest is permutable. 174 /// Returns the newly created intra-tile loops. 175 Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes); 176 177 /// Get perfectly nested sequence of loops starting at root of loop nest 178 /// (the first op being another AffineFor, and the second op - a terminator). 179 /// A loop is perfectly nested iff: the first op in the loop's body is another 180 /// AffineForOp, and the second op is a terminator). 181 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops, 182 scf::ForOp root); 183 184 /// Given two scf.forall loops, `target` and `source`, fuses `target` into 185 /// `source`. Assumes that the given loops are siblings and are independent of 186 /// each other. 187 /// 188 /// This function does not perform any legality checks and simply fuses the 189 /// loops. The caller is responsible for ensuring that the loops are legal to 190 /// fuse. 191 scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, 192 scf::ForallOp source, 193 RewriterBase &rewriter); 194 195 /// Given two scf.for loops, `target` and `source`, fuses `target` into 196 /// `source`. Assumes that the given loops are siblings and are independent of 197 /// each other. 198 /// 199 /// This function does not perform any legality checks and simply fuses the 200 /// loops. The caller is responsible for ensuring that the loops are legal to 201 /// fuse. 202 scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, 203 RewriterBase &rewriter); 204 205 /// Normalize an `scf.forall` operation. Returns `failure()`if normalization 206 /// fails. 207 // On `success()` returns the 208 /// newly created operation with all uses of the original operation replaced 209 /// with results of the new operation. 210 FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter, 211 scf::ForallOp forallOp); 212 213 } // namespace mlir 214 215 #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ 216