xref: /llvm-project/mlir/test/lib/Analysis/TestMemRefDependenceCheck.cpp (revision 4c48f016effde67d500fc95290096aec9f3bdb70)
1 //===- TestMemRefDependenceCheck.cpp - Test dep analysis ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to run pair-wise memref access dependence checks.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
14 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
15 #include "mlir/Dialect/Affine/Analysis/Utils.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Pass/Pass.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "test-memref-dependence-check"
22 
23 using namespace mlir;
24 using namespace mlir::affine;
25 
26 namespace {
27 
28 // TODO: Add common surrounding loop depth-wise dependence checks.
29 /// Checks dependences between all pairs of memref accesses in a Function.
30 struct TestMemRefDependenceCheck
31     : public PassWrapper<TestMemRefDependenceCheck, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonec729be30111::TestMemRefDependenceCheck32   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefDependenceCheck)
33 
34   StringRef getArgument() const final { return "test-memref-dependence-check"; }
getDescription__anonec729be30111::TestMemRefDependenceCheck35   StringRef getDescription() const final {
36     return "Checks dependences between all pairs of memref accesses.";
37   }
38   SmallVector<Operation *, 4> loadsAndStores;
39   void runOnOperation() override;
40 };
41 
42 } // namespace
43 
44 // Returns a result string which represents the direction vector (if there was
45 // a dependence), returns the string "false" otherwise.
46 static std::string
getDirectionVectorStr(bool ret,unsigned numCommonLoops,unsigned loopNestDepth,ArrayRef<DependenceComponent> dependenceComponents)47 getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth,
48                       ArrayRef<DependenceComponent> dependenceComponents) {
49   if (!ret)
50     return "false";
51   if (dependenceComponents.empty() || loopNestDepth > numCommonLoops)
52     return "true";
53   std::string result;
54   for (const auto &dependenceComponent : dependenceComponents) {
55     std::string lbStr = "-inf";
56     if (dependenceComponent.lb.has_value() &&
57         *dependenceComponent.lb != std::numeric_limits<int64_t>::min())
58       lbStr = std::to_string(*dependenceComponent.lb);
59 
60     std::string ubStr = "+inf";
61     if (dependenceComponent.ub.has_value() &&
62         *dependenceComponent.ub != std::numeric_limits<int64_t>::max())
63       ubStr = std::to_string(*dependenceComponent.ub);
64 
65     result += "[" + lbStr + ", " + ubStr + "]";
66   }
67   return result;
68 }
69 
70 // For each access in 'loadsAndStores', runs a dependence check between this
71 // "source" access and all subsequent "destination" accesses in
72 // 'loadsAndStores'. Emits the result of the dependence check as a note with
73 // the source access.
checkDependences(ArrayRef<Operation * > loadsAndStores)74 static void checkDependences(ArrayRef<Operation *> loadsAndStores) {
75   for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) {
76     auto *srcOpInst = loadsAndStores[i];
77     MemRefAccess srcAccess(srcOpInst);
78     for (unsigned j = 0; j < e; ++j) {
79       auto *dstOpInst = loadsAndStores[j];
80       MemRefAccess dstAccess(dstOpInst);
81 
82       unsigned numCommonLoops =
83           getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
84       for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
85         SmallVector<DependenceComponent, 2> dependenceComponents;
86         DependenceResult result = checkMemrefAccessDependence(
87             srcAccess, dstAccess, d, /*dependenceConstraints=*/nullptr,
88             &dependenceComponents);
89         if (result.value == DependenceResult::Failure) {
90           srcOpInst->emitError("dependence check failed");
91         } else {
92           bool ret = hasDependence(result);
93           // TODO: Print dependence type (i.e. RAW, etc) and print
94           // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
95           // vectors from ([1, 1], [3, 3]) to (1, 3).
96           srcOpInst->emitRemark("dependence from ")
97               << i << " to " << j << " at depth " << d << " = "
98               << getDirectionVectorStr(ret, numCommonLoops, d,
99                                        dependenceComponents);
100         }
101       }
102     }
103   }
104 }
105 
106 /// Walks the operation adding load and store ops to 'loadsAndStores'. Runs
107 /// pair-wise dependence checks.
runOnOperation()108 void TestMemRefDependenceCheck::runOnOperation() {
109   // Collect the loads and stores within the function.
110   loadsAndStores.clear();
111   getOperation()->walk([&](Operation *op) {
112     if (isa<AffineLoadOp, AffineStoreOp>(op))
113       loadsAndStores.push_back(op);
114   });
115 
116   checkDependences(loadsAndStores);
117 }
118 
119 namespace mlir {
120 namespace test {
registerTestMemRefDependenceCheck()121 void registerTestMemRefDependenceCheck() {
122   PassRegistration<TestMemRefDependenceCheck>();
123 }
124 } // namespace test
125 } // namespace mlir
126