1a70aa7bbSRiver Riddle //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle //
9a70aa7bbSRiver Riddle // This file implements a pass to test various loop fusion utility functions.
10a70aa7bbSRiver Riddle //
11a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
12a70aa7bbSRiver Riddle
13a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Analysis/Utils.h"
14a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
15a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/LoopUtils.h"
1736550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
18a70aa7bbSRiver Riddle #include "mlir/Pass/Pass.h"
19a70aa7bbSRiver Riddle
20a70aa7bbSRiver Riddle #define DEBUG_TYPE "test-loop-fusion"
21a70aa7bbSRiver Riddle
22a70aa7bbSRiver Riddle using namespace mlir;
234c48f016SMatthias Springer using namespace mlir::affine;
24a70aa7bbSRiver Riddle
25a70aa7bbSRiver Riddle namespace {
26a70aa7bbSRiver Riddle
27a70aa7bbSRiver Riddle struct TestLoopFusion
2858ceae95SRiver Riddle : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon472513560111::TestLoopFusion295e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
305e50dd04SRiver Riddle
31a70aa7bbSRiver Riddle StringRef getArgument() const final { return "test-loop-fusion"; }
getDescription__anon472513560111::TestLoopFusion32a70aa7bbSRiver Riddle StringRef getDescription() const final {
33a70aa7bbSRiver Riddle return "Tests loop fusion utility functions.";
34a70aa7bbSRiver Riddle }
35a70aa7bbSRiver Riddle void runOnOperation() override;
36*608a663cSPhilip Lassen
37*608a663cSPhilip Lassen TestLoopFusion() = default;
TestLoopFusion__anon472513560111::TestLoopFusion38*608a663cSPhilip Lassen TestLoopFusion(const TestLoopFusion &pass) : PassWrapper(pass){};
39*608a663cSPhilip Lassen
40*608a663cSPhilip Lassen Option<bool> clTestDependenceCheck{
41*608a663cSPhilip Lassen *this, "test-loop-fusion-dependence-check",
42*608a663cSPhilip Lassen llvm::cl::desc("Enable testing of loop fusion dependence check"),
43*608a663cSPhilip Lassen llvm::cl::init(false)};
44*608a663cSPhilip Lassen
45*608a663cSPhilip Lassen Option<bool> clTestSliceComputation{
46*608a663cSPhilip Lassen *this, "test-loop-fusion-slice-computation",
47*608a663cSPhilip Lassen llvm::cl::desc("Enable testing of loop fusion slice computation"),
48*608a663cSPhilip Lassen llvm::cl::init(false)};
49*608a663cSPhilip Lassen
50*608a663cSPhilip Lassen Option<bool> clTestLoopFusionTransformation{
51*608a663cSPhilip Lassen *this, "test-loop-fusion-transformation",
52*608a663cSPhilip Lassen llvm::cl::desc("Enable testing of loop fusion transformation"),
53*608a663cSPhilip Lassen llvm::cl::init(false)};
54a70aa7bbSRiver Riddle };
55a70aa7bbSRiver Riddle
56a70aa7bbSRiver Riddle } // namespace
57a70aa7bbSRiver Riddle
58a70aa7bbSRiver Riddle // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
59a70aa7bbSRiver Riddle // in range ['loopDepth' + 1, 'maxLoopDepth'].
60a70aa7bbSRiver Riddle // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
61a70aa7bbSRiver Riddle // Returns false as IR is not transformed.
testDependenceCheck(AffineForOp srcForOp,AffineForOp dstForOp,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)62a70aa7bbSRiver Riddle static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
63a70aa7bbSRiver Riddle unsigned i, unsigned j, unsigned loopDepth,
64a70aa7bbSRiver Riddle unsigned maxLoopDepth) {
654c48f016SMatthias Springer affine::ComputationSliceState sliceUnion;
66a70aa7bbSRiver Riddle for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
67a70aa7bbSRiver Riddle FusionResult result =
684c48f016SMatthias Springer affine::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
69a70aa7bbSRiver Riddle if (result.value == FusionResult::FailBlockDependence) {
70a70aa7bbSRiver Riddle srcForOp->emitRemark("block-level dependence preventing"
71a70aa7bbSRiver Riddle " fusion of loop nest ")
72a70aa7bbSRiver Riddle << i << " into loop nest " << j << " at depth " << loopDepth;
73a70aa7bbSRiver Riddle }
74a70aa7bbSRiver Riddle }
75a70aa7bbSRiver Riddle return false;
76a70aa7bbSRiver Riddle }
77a70aa7bbSRiver Riddle
78a70aa7bbSRiver Riddle // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)79a70aa7bbSRiver Riddle static unsigned getBlockIndex(Operation &op) {
80a70aa7bbSRiver Riddle unsigned index = 0;
81a70aa7bbSRiver Riddle for (auto &opX : *op.getBlock()) {
82a70aa7bbSRiver Riddle if (&op == &opX)
83a70aa7bbSRiver Riddle break;
84a70aa7bbSRiver Riddle ++index;
85a70aa7bbSRiver Riddle }
86a70aa7bbSRiver Riddle return index;
87a70aa7bbSRiver Riddle }
88a70aa7bbSRiver Riddle
89a70aa7bbSRiver Riddle // Returns a string representation of 'sliceUnion'.
904c48f016SMatthias Springer static std::string
getSliceStr(const affine::ComputationSliceState & sliceUnion)914c48f016SMatthias Springer getSliceStr(const affine::ComputationSliceState &sliceUnion) {
92a70aa7bbSRiver Riddle std::string result;
93a70aa7bbSRiver Riddle llvm::raw_string_ostream os(result);
94a70aa7bbSRiver Riddle // Slice insertion point format [loop-depth, operation-block-index]
95a70aa7bbSRiver Riddle unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
96a70aa7bbSRiver Riddle unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
97a70aa7bbSRiver Riddle os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
98a70aa7bbSRiver Riddle << ")";
99a70aa7bbSRiver Riddle assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
100a70aa7bbSRiver Riddle os << " loop bounds: ";
101a70aa7bbSRiver Riddle for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
102a70aa7bbSRiver Riddle os << '[';
103a70aa7bbSRiver Riddle sliceUnion.lbs[k].print(os);
104a70aa7bbSRiver Riddle os << ", ";
105a70aa7bbSRiver Riddle sliceUnion.ubs[k].print(os);
106a70aa7bbSRiver Riddle os << "] ";
107a70aa7bbSRiver Riddle }
108a70aa7bbSRiver Riddle return os.str();
109a70aa7bbSRiver Riddle }
110a70aa7bbSRiver Riddle
111a70aa7bbSRiver Riddle /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
112a70aa7bbSRiver Riddle /// in range ['loopDepth' + 1, 'maxLoopDepth'].
113a70aa7bbSRiver Riddle /// Emits a string representation of the slice union as a remark on 'loops[j]'
114a70aa7bbSRiver Riddle /// and marks this as incorrect slice if the slice is invalid. Returns false as
115a70aa7bbSRiver Riddle /// IR is not transformed.
testSliceComputation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)116a70aa7bbSRiver Riddle static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
117a70aa7bbSRiver Riddle unsigned i, unsigned j, unsigned loopDepth,
118a70aa7bbSRiver Riddle unsigned maxLoopDepth) {
119a70aa7bbSRiver Riddle for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
1204c48f016SMatthias Springer affine::ComputationSliceState sliceUnion;
1214c48f016SMatthias Springer FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
122a70aa7bbSRiver Riddle if (result.value == FusionResult::Success) {
123a70aa7bbSRiver Riddle forOpB->emitRemark("slice (")
124a70aa7bbSRiver Riddle << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
125a70aa7bbSRiver Riddle << " : " << getSliceStr(sliceUnion) << ")";
126a70aa7bbSRiver Riddle } else if (result.value == FusionResult::FailIncorrectSlice) {
127a70aa7bbSRiver Riddle forOpB->emitRemark("Incorrect slice (")
128a70aa7bbSRiver Riddle << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
129a70aa7bbSRiver Riddle << " : " << getSliceStr(sliceUnion) << ")";
130a70aa7bbSRiver Riddle }
131a70aa7bbSRiver Riddle }
132a70aa7bbSRiver Riddle return false;
133a70aa7bbSRiver Riddle }
134a70aa7bbSRiver Riddle
135a70aa7bbSRiver Riddle // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
136a70aa7bbSRiver Riddle // ['loopDepth' + 1, 'maxLoopDepth'].
137a70aa7bbSRiver Riddle // Returns true if loops were successfully fused, false otherwise.
testLoopFusionTransformation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)138a70aa7bbSRiver Riddle static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
139a70aa7bbSRiver Riddle unsigned i, unsigned j,
140a70aa7bbSRiver Riddle unsigned loopDepth,
141a70aa7bbSRiver Riddle unsigned maxLoopDepth) {
142a70aa7bbSRiver Riddle for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
1434c48f016SMatthias Springer affine::ComputationSliceState sliceUnion;
1444c48f016SMatthias Springer FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
145a70aa7bbSRiver Riddle if (result.value == FusionResult::Success) {
1464c48f016SMatthias Springer affine::fuseLoops(forOpA, forOpB, sliceUnion);
147a70aa7bbSRiver Riddle // Note: 'forOpA' is removed to simplify test output. A proper loop
148a70aa7bbSRiver Riddle // fusion pass should check the data dependence graph and run memref
149a70aa7bbSRiver Riddle // region analysis to ensure removing 'forOpA' is safe.
150a70aa7bbSRiver Riddle forOpA.erase();
151a70aa7bbSRiver Riddle return true;
152a70aa7bbSRiver Riddle }
153a70aa7bbSRiver Riddle }
154a70aa7bbSRiver Riddle return false;
155a70aa7bbSRiver Riddle }
156a70aa7bbSRiver Riddle
157a70aa7bbSRiver Riddle using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
158a70aa7bbSRiver Riddle unsigned, unsigned)>;
159a70aa7bbSRiver Riddle
160a70aa7bbSRiver Riddle // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
161a70aa7bbSRiver Riddle // If 'return_on_change' is true, returns on first invocation of 'fn' which
162a70aa7bbSRiver Riddle // returns true.
iterateLoops(ArrayRef<SmallVector<AffineForOp,2>> depthToLoops,LoopFunc fn,bool returnOnChange=false)163a70aa7bbSRiver Riddle static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
164a70aa7bbSRiver Riddle LoopFunc fn, bool returnOnChange = false) {
165a70aa7bbSRiver Riddle bool changed = false;
166a70aa7bbSRiver Riddle for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
167a70aa7bbSRiver Riddle ++loopDepth) {
168a70aa7bbSRiver Riddle auto &loops = depthToLoops[loopDepth];
169a70aa7bbSRiver Riddle unsigned numLoops = loops.size();
170a70aa7bbSRiver Riddle for (unsigned j = 0; j < numLoops; ++j) {
171a70aa7bbSRiver Riddle for (unsigned k = 0; k < numLoops; ++k) {
172a70aa7bbSRiver Riddle if (j != k)
173a70aa7bbSRiver Riddle changed |=
174a70aa7bbSRiver Riddle fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
175a70aa7bbSRiver Riddle if (changed && returnOnChange)
176a70aa7bbSRiver Riddle return true;
177a70aa7bbSRiver Riddle }
178a70aa7bbSRiver Riddle }
179a70aa7bbSRiver Riddle }
180a70aa7bbSRiver Riddle return changed;
181a70aa7bbSRiver Riddle }
182a70aa7bbSRiver Riddle
runOnOperation()183a70aa7bbSRiver Riddle void TestLoopFusion::runOnOperation() {
184a70aa7bbSRiver Riddle std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
185a70aa7bbSRiver Riddle if (clTestLoopFusionTransformation) {
186a70aa7bbSRiver Riddle // Run loop fusion until a fixed point is reached.
187a70aa7bbSRiver Riddle do {
188a70aa7bbSRiver Riddle depthToLoops.clear();
189a70aa7bbSRiver Riddle // Gather all AffineForOps by loop depth.
190a70aa7bbSRiver Riddle gatherLoops(getOperation(), depthToLoops);
191a70aa7bbSRiver Riddle
192a70aa7bbSRiver Riddle // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
193a70aa7bbSRiver Riddle } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
194a70aa7bbSRiver Riddle /*returnOnChange=*/true));
195a70aa7bbSRiver Riddle return;
196a70aa7bbSRiver Riddle }
197a70aa7bbSRiver Riddle
198a70aa7bbSRiver Riddle // Gather all AffineForOps by loop depth.
199a70aa7bbSRiver Riddle gatherLoops(getOperation(), depthToLoops);
200a70aa7bbSRiver Riddle
201a70aa7bbSRiver Riddle // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
202a70aa7bbSRiver Riddle if (clTestDependenceCheck)
203a70aa7bbSRiver Riddle iterateLoops(depthToLoops, testDependenceCheck);
204a70aa7bbSRiver Riddle if (clTestSliceComputation)
205a70aa7bbSRiver Riddle iterateLoops(depthToLoops, testSliceComputation);
206a70aa7bbSRiver Riddle }
207a70aa7bbSRiver Riddle
208a70aa7bbSRiver Riddle namespace mlir {
209a70aa7bbSRiver Riddle namespace test {
registerTestLoopFusion()210a70aa7bbSRiver Riddle void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
211a70aa7bbSRiver Riddle } // namespace test
212a70aa7bbSRiver Riddle } // namespace mlir
213