xref: /llvm-project/mlir/test/lib/Analysis/TestAliasAnalysis.cpp (revision 830b9b072d8458ee89c48f00d4de59456c9f467f)
1 //===- TestAliasAnalysis.cpp - Test alias analysis results ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains test passes for constructing and testing alias analysis
10 // results.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestAliasAnalysis.h"
15 #include "mlir/Analysis/AliasAnalysis.h"
16 #include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
17 #include "mlir/Interfaces/FunctionInterfaces.h"
18 #include "mlir/Pass/Pass.h"
19 
20 using namespace mlir;
21 
22 /// Print a value that is used as an operand of an alias query.
printAliasOperand(Operation * op)23 static void printAliasOperand(Operation *op) {
24   llvm::errs() << op->getAttrOfType<StringAttr>("test.ptr").getValue();
25 }
printAliasOperand(Value value)26 static void printAliasOperand(Value value) {
27   if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
28     Region *region = arg.getParentRegion();
29     unsigned parentBlockNumber =
30         std::distance(region->begin(), arg.getOwner()->getIterator());
31     llvm::errs() << region->getParentOp()
32                         ->getAttrOfType<StringAttr>("test.ptr")
33                         .getValue()
34                  << ".region" << region->getRegionNumber();
35     if (parentBlockNumber != 0)
36       llvm::errs() << ".block" << parentBlockNumber;
37     llvm::errs() << "#" << arg.getArgNumber();
38     return;
39   }
40   OpResult result = cast<OpResult>(value);
41   printAliasOperand(result.getOwner());
42   llvm::errs() << "#" << result.getResultNumber();
43 }
44 
45 namespace mlir {
46 namespace test {
printAliasResult(AliasResult result,Value lhs,Value rhs)47 void printAliasResult(AliasResult result, Value lhs, Value rhs) {
48   printAliasOperand(lhs);
49   llvm::errs() << " <-> ";
50   printAliasOperand(rhs);
51   llvm::errs() << ": " << result << "\n";
52 }
53 
54 /// Print the result of an alias query.
printModRefResult(ModRefResult result,Operation * op,Value location)55 void printModRefResult(ModRefResult result, Operation *op, Value location) {
56   printAliasOperand(op);
57   llvm::errs() << " -> ";
58   printAliasOperand(location);
59   llvm::errs() << ": " << result << "\n";
60 }
61 
runAliasAnalysisOnOperation(Operation * op,AliasAnalysis & aliasAnalysis)62 void TestAliasAnalysisBase::runAliasAnalysisOnOperation(
63     Operation *op, AliasAnalysis &aliasAnalysis) {
64   llvm::errs() << "Testing : " << *op->getInherentAttr("sym_name") << "\n";
65 
66   // Collect all of the values to check for aliasing behavior.
67   SmallVector<Value, 32> valsToCheck;
68   op->walk([&](Operation *op) {
69     if (!op->getDiscardableAttr("test.ptr"))
70       return;
71     valsToCheck.append(op->result_begin(), op->result_end());
72     for (Region &region : op->getRegions())
73       for (Block &block : region)
74         valsToCheck.append(block.args_begin(), block.args_end());
75   });
76 
77   // Check for aliasing behavior between each of the values.
78   for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it)
79     for (auto *innerIt = valsToCheck.begin(); innerIt != it; ++innerIt)
80       printAliasResult(aliasAnalysis.alias(*innerIt, *it), *innerIt, *it);
81 }
82 
runAliasAnalysisOnOperation(Operation * op,AliasAnalysis & aliasAnalysis)83 void TestAliasAnalysisModRefBase::runAliasAnalysisOnOperation(
84     Operation *op, AliasAnalysis &aliasAnalysis) {
85   llvm::errs() << "Testing : " << *op->getInherentAttr("sym_name") << "\n";
86 
87   // Collect all of the values to check for aliasing behavior.
88   SmallVector<Value, 32> valsToCheck;
89   op->walk([&](Operation *op) {
90     if (!op->getDiscardableAttr("test.ptr"))
91       return;
92     valsToCheck.append(op->result_begin(), op->result_end());
93     for (Region &region : op->getRegions())
94       for (Block &block : region)
95         valsToCheck.append(block.args_begin(), block.args_end());
96   });
97 
98   // Check for aliasing behavior between each of the values.
99   for (auto &it : valsToCheck) {
100     op->walk([&](Operation *op) {
101       if (!op->getDiscardableAttr("test.ptr"))
102         return;
103       printModRefResult(aliasAnalysis.getModRef(op, it), op, it);
104     });
105   }
106 }
107 
108 } // namespace test
109 } // namespace mlir
110 
111 //===----------------------------------------------------------------------===//
112 // Testing AliasResult
113 //===----------------------------------------------------------------------===//
114 
115 namespace {
116 struct TestAliasAnalysisPass
117     : public test::TestAliasAnalysisBase,
118       PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonabeb314c0411::TestAliasAnalysisPass119   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisPass)
120 
121   StringRef getArgument() const final { return "test-alias-analysis"; }
getDescription__anonabeb314c0411::TestAliasAnalysisPass122   StringRef getDescription() const final {
123     return "Test alias analysis results.";
124   }
runOnOperation__anonabeb314c0411::TestAliasAnalysisPass125   void runOnOperation() override {
126     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
127     runAliasAnalysisOnOperation(getOperation(), aliasAnalysis);
128   }
129 };
130 } // namespace
131 
132 //===----------------------------------------------------------------------===//
133 // Testing ModRefResult
134 //===----------------------------------------------------------------------===//
135 
136 namespace {
137 struct TestAliasAnalysisModRefPass
138     : public test::TestAliasAnalysisModRefBase,
139       PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonabeb314c0511::TestAliasAnalysisModRefPass140   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisModRefPass)
141 
142   StringRef getArgument() const final { return "test-alias-analysis-modref"; }
getDescription__anonabeb314c0511::TestAliasAnalysisModRefPass143   StringRef getDescription() const final {
144     return "Test alias analysis ModRef results.";
145   }
runOnOperation__anonabeb314c0511::TestAliasAnalysisModRefPass146   void runOnOperation() override {
147     AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
148     runAliasAnalysisOnOperation(getOperation(), aliasAnalysis);
149   }
150 };
151 } // namespace
152 
153 //===----------------------------------------------------------------------===//
154 // Testing LocalAliasAnalysis extending
155 //===----------------------------------------------------------------------===//
156 
157 /// Check if value is function argument.
isFuncArg(Value val)158 static bool isFuncArg(Value val) {
159   auto blockArg = dyn_cast<BlockArgument>(val);
160   if (!blockArg)
161     return false;
162 
163   return mlir::isa_and_nonnull<FunctionOpInterface>(
164       blockArg.getOwner()->getParentOp());
165 }
166 
167 /// Check if value has "restrict" attribute. Value must be a function argument.
isRestrict(Value val)168 static bool isRestrict(Value val) {
169   auto blockArg = cast<BlockArgument>(val);
170   auto func =
171       mlir::cast<FunctionOpInterface>(blockArg.getOwner()->getParentOp());
172   return !!func.getArgAttr(blockArg.getArgNumber(),
173                            "local_alias_analysis.restrict");
174 }
175 
176 namespace {
177 /// LocalAliasAnalysis extended to support "restrict" attreibute.
178 class LocalAliasAnalysisRestrict : public LocalAliasAnalysis {
179 protected:
aliasImpl(Value lhs,Value rhs)180   AliasResult aliasImpl(Value lhs, Value rhs) override {
181     if (lhs == rhs)
182       return AliasResult::MustAlias;
183 
184     // Assume no aliasing if both values are function arguments and any of them
185     // have restrict attr.
186     if (isFuncArg(lhs) && isFuncArg(rhs))
187       if (isRestrict(lhs) || isRestrict(rhs))
188         return AliasResult::NoAlias;
189 
190     return LocalAliasAnalysis::aliasImpl(lhs, rhs);
191   }
192 };
193 
194 /// This pass tests adding additional analysis impls to the AliasAnalysis.
195 struct TestAliasAnalysisExtendingPass
196     : public test::TestAliasAnalysisBase,
197       PassWrapper<TestAliasAnalysisExtendingPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonabeb314c0611::TestAliasAnalysisExtendingPass198   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisExtendingPass)
199 
200   StringRef getArgument() const final {
201     return "test-alias-analysis-extending";
202   }
getDescription__anonabeb314c0611::TestAliasAnalysisExtendingPass203   StringRef getDescription() const final {
204     return "Test alias analysis extending.";
205   }
runOnOperation__anonabeb314c0611::TestAliasAnalysisExtendingPass206   void runOnOperation() override {
207     AliasAnalysis aliasAnalysis(getOperation());
208     aliasAnalysis.addAnalysisImplementation(LocalAliasAnalysisRestrict());
209     runAliasAnalysisOnOperation(getOperation(), aliasAnalysis);
210   }
211 };
212 } // namespace
213 
214 //===----------------------------------------------------------------------===//
215 // Pass Registration
216 //===----------------------------------------------------------------------===//
217 
218 namespace mlir {
219 namespace test {
registerTestAliasAnalysisPass()220 void registerTestAliasAnalysisPass() {
221   PassRegistration<TestAliasAnalysisExtendingPass>();
222   PassRegistration<TestAliasAnalysisModRefPass>();
223   PassRegistration<TestAliasAnalysisPass>();
224 }
225 } // namespace test
226 } // namespace mlir
227