1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to test various loop fusion utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Analysis/Utils.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/LoopFusionUtils.h" 16 #include "mlir/Dialect/Affine/LoopUtils.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Pass/Pass.h" 19 20 #define DEBUG_TYPE "test-loop-fusion" 21 22 using namespace mlir; 23 using namespace mlir::affine; 24 25 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 26 27 static llvm::cl::opt<bool> clTestDependenceCheck( 28 "test-loop-fusion-dependence-check", 29 llvm::cl::desc("Enable testing of loop fusion dependence check"), 30 llvm::cl::cat(clOptionsCategory)); 31 32 static llvm::cl::opt<bool> clTestSliceComputation( 33 "test-loop-fusion-slice-computation", 34 llvm::cl::desc("Enable testing of loop fusion slice computation"), 35 llvm::cl::cat(clOptionsCategory)); 36 37 static llvm::cl::opt<bool> clTestLoopFusionTransformation( 38 "test-loop-fusion-transformation", 39 llvm::cl::desc("Enable testing of loop fusion transformation"), 40 llvm::cl::cat(clOptionsCategory)); 41 42 namespace { 43 44 struct TestLoopFusion 45 : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> { 46 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion) 47 48 StringRef getArgument() const final { return "test-loop-fusion"; } 49 StringRef getDescription() const final { 50 return "Tests loop fusion utility functions."; 51 } 52 void runOnOperation() override; 53 }; 54 55 } // namespace 56 57 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths 58 // in range ['loopDepth' + 1, 'maxLoopDepth']. 59 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists. 60 // Returns false as IR is not transformed. 61 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp, 62 unsigned i, unsigned j, unsigned loopDepth, 63 unsigned maxLoopDepth) { 64 affine::ComputationSliceState sliceUnion; 65 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 66 FusionResult result = 67 affine::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion); 68 if (result.value == FusionResult::FailBlockDependence) { 69 srcForOp->emitRemark("block-level dependence preventing" 70 " fusion of loop nest ") 71 << i << " into loop nest " << j << " at depth " << loopDepth; 72 } 73 } 74 return false; 75 } 76 77 // Returns the index of 'op' in its block. 78 static unsigned getBlockIndex(Operation &op) { 79 unsigned index = 0; 80 for (auto &opX : *op.getBlock()) { 81 if (&op == &opX) 82 break; 83 ++index; 84 } 85 return index; 86 } 87 88 // Returns a string representation of 'sliceUnion'. 89 static std::string 90 getSliceStr(const affine::ComputationSliceState &sliceUnion) { 91 std::string result; 92 llvm::raw_string_ostream os(result); 93 // Slice insertion point format [loop-depth, operation-block-index] 94 unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint); 95 unsigned ipb = getBlockIndex(*sliceUnion.insertPoint); 96 os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb) 97 << ")"; 98 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size()); 99 os << " loop bounds: "; 100 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) { 101 os << '['; 102 sliceUnion.lbs[k].print(os); 103 os << ", "; 104 sliceUnion.ubs[k].print(os); 105 os << "] "; 106 } 107 return os.str(); 108 } 109 110 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths 111 /// in range ['loopDepth' + 1, 'maxLoopDepth']. 112 /// Emits a string representation of the slice union as a remark on 'loops[j]' 113 /// and marks this as incorrect slice if the slice is invalid. Returns false as 114 /// IR is not transformed. 115 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB, 116 unsigned i, unsigned j, unsigned loopDepth, 117 unsigned maxLoopDepth) { 118 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 119 affine::ComputationSliceState sliceUnion; 120 FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion); 121 if (result.value == FusionResult::Success) { 122 forOpB->emitRemark("slice (") 123 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d 124 << " : " << getSliceStr(sliceUnion) << ")"; 125 } else if (result.value == FusionResult::FailIncorrectSlice) { 126 forOpB->emitRemark("Incorrect slice (") 127 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d 128 << " : " << getSliceStr(sliceUnion) << ")"; 129 } 130 } 131 return false; 132 } 133 134 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range 135 // ['loopDepth' + 1, 'maxLoopDepth']. 136 // Returns true if loops were successfully fused, false otherwise. 137 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB, 138 unsigned i, unsigned j, 139 unsigned loopDepth, 140 unsigned maxLoopDepth) { 141 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) { 142 affine::ComputationSliceState sliceUnion; 143 FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion); 144 if (result.value == FusionResult::Success) { 145 affine::fuseLoops(forOpA, forOpB, sliceUnion); 146 // Note: 'forOpA' is removed to simplify test output. A proper loop 147 // fusion pass should check the data dependence graph and run memref 148 // region analysis to ensure removing 'forOpA' is safe. 149 forOpA.erase(); 150 return true; 151 } 152 } 153 return false; 154 } 155 156 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned, 157 unsigned, unsigned)>; 158 159 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. 160 // If 'return_on_change' is true, returns on first invocation of 'fn' which 161 // returns true. 162 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops, 163 LoopFunc fn, bool returnOnChange = false) { 164 bool changed = false; 165 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end; 166 ++loopDepth) { 167 auto &loops = depthToLoops[loopDepth]; 168 unsigned numLoops = loops.size(); 169 for (unsigned j = 0; j < numLoops; ++j) { 170 for (unsigned k = 0; k < numLoops; ++k) { 171 if (j != k) 172 changed |= 173 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size()); 174 if (changed && returnOnChange) 175 return true; 176 } 177 } 178 } 179 return changed; 180 } 181 182 void TestLoopFusion::runOnOperation() { 183 std::vector<SmallVector<AffineForOp, 2>> depthToLoops; 184 if (clTestLoopFusionTransformation) { 185 // Run loop fusion until a fixed point is reached. 186 do { 187 depthToLoops.clear(); 188 // Gather all AffineForOps by loop depth. 189 gatherLoops(getOperation(), depthToLoops); 190 191 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'. 192 } while (iterateLoops(depthToLoops, testLoopFusionTransformation, 193 /*returnOnChange=*/true)); 194 return; 195 } 196 197 // Gather all AffineForOps by loop depth. 198 gatherLoops(getOperation(), depthToLoops); 199 200 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'. 201 if (clTestDependenceCheck) 202 iterateLoops(depthToLoops, testDependenceCheck); 203 if (clTestSliceComputation) 204 iterateLoops(depthToLoops, testSliceComputation); 205 } 206 207 namespace mlir { 208 namespace test { 209 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); } 210 } // namespace test 211 } // namespace mlir 212