13fef2d26SRiver Riddle //===- TestMemRefDependenceCheck.cpp - Test dep analysis ------------------===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file implements a pass to run pair-wise memref access dependence checks.
103fef2d26SRiver Riddle //
113fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
123fef2d26SRiver Riddle
13755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
14755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
15755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/Utils.h"
163fef2d26SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
173fef2d26SRiver Riddle #include "mlir/IR/Builders.h"
183fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
193fef2d26SRiver Riddle #include "llvm/Support/Debug.h"
203fef2d26SRiver Riddle
213fef2d26SRiver Riddle #define DEBUG_TYPE "test-memref-dependence-check"
223fef2d26SRiver Riddle
233fef2d26SRiver Riddle using namespace mlir;
24*4c48f016SMatthias Springer using namespace mlir::affine;
253fef2d26SRiver Riddle
263fef2d26SRiver Riddle namespace {
273fef2d26SRiver Riddle
283fef2d26SRiver Riddle // TODO: Add common surrounding loop depth-wise dependence checks.
293fef2d26SRiver Riddle /// Checks dependences between all pairs of memref accesses in a Function.
303fef2d26SRiver Riddle struct TestMemRefDependenceCheck
3187d6bf37SRiver Riddle : public PassWrapper<TestMemRefDependenceCheck, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonec729be30111::TestMemRefDependenceCheck325e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefDependenceCheck)
335e50dd04SRiver Riddle
34b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-memref-dependence-check"; }
getDescription__anonec729be30111::TestMemRefDependenceCheck35b5e22e6dSMehdi Amini StringRef getDescription() const final {
36b5e22e6dSMehdi Amini return "Checks dependences between all pairs of memref accesses.";
37b5e22e6dSMehdi Amini }
383fef2d26SRiver Riddle SmallVector<Operation *, 4> loadsAndStores;
3941574554SRiver Riddle void runOnOperation() override;
403fef2d26SRiver Riddle };
413fef2d26SRiver Riddle
42be0a7e9fSMehdi Amini } // namespace
433fef2d26SRiver Riddle
443fef2d26SRiver Riddle // Returns a result string which represents the direction vector (if there was
453fef2d26SRiver Riddle // a dependence), returns the string "false" otherwise.
463fef2d26SRiver Riddle static std::string
getDirectionVectorStr(bool ret,unsigned numCommonLoops,unsigned loopNestDepth,ArrayRef<DependenceComponent> dependenceComponents)473fef2d26SRiver Riddle getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth,
483fef2d26SRiver Riddle ArrayRef<DependenceComponent> dependenceComponents) {
493fef2d26SRiver Riddle if (!ret)
503fef2d26SRiver Riddle return "false";
513fef2d26SRiver Riddle if (dependenceComponents.empty() || loopNestDepth > numCommonLoops)
523fef2d26SRiver Riddle return "true";
533fef2d26SRiver Riddle std::string result;
54e5639b3fSMehdi Amini for (const auto &dependenceComponent : dependenceComponents) {
553fef2d26SRiver Riddle std::string lbStr = "-inf";
56491d2701SKazu Hirata if (dependenceComponent.lb.has_value() &&
57cbb09813SFangrui Song *dependenceComponent.lb != std::numeric_limits<int64_t>::min())
58cbb09813SFangrui Song lbStr = std::to_string(*dependenceComponent.lb);
593fef2d26SRiver Riddle
603fef2d26SRiver Riddle std::string ubStr = "+inf";
61491d2701SKazu Hirata if (dependenceComponent.ub.has_value() &&
62cbb09813SFangrui Song *dependenceComponent.ub != std::numeric_limits<int64_t>::max())
63cbb09813SFangrui Song ubStr = std::to_string(*dependenceComponent.ub);
643fef2d26SRiver Riddle
653fef2d26SRiver Riddle result += "[" + lbStr + ", " + ubStr + "]";
663fef2d26SRiver Riddle }
673fef2d26SRiver Riddle return result;
683fef2d26SRiver Riddle }
693fef2d26SRiver Riddle
703fef2d26SRiver Riddle // For each access in 'loadsAndStores', runs a dependence check between this
713fef2d26SRiver Riddle // "source" access and all subsequent "destination" accesses in
723fef2d26SRiver Riddle // 'loadsAndStores'. Emits the result of the dependence check as a note with
733fef2d26SRiver Riddle // the source access.
checkDependences(ArrayRef<Operation * > loadsAndStores)743fef2d26SRiver Riddle static void checkDependences(ArrayRef<Operation *> loadsAndStores) {
753fef2d26SRiver Riddle for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) {
763fef2d26SRiver Riddle auto *srcOpInst = loadsAndStores[i];
773fef2d26SRiver Riddle MemRefAccess srcAccess(srcOpInst);
783fef2d26SRiver Riddle for (unsigned j = 0; j < e; ++j) {
793fef2d26SRiver Riddle auto *dstOpInst = loadsAndStores[j];
803fef2d26SRiver Riddle MemRefAccess dstAccess(dstOpInst);
813fef2d26SRiver Riddle
823fef2d26SRiver Riddle unsigned numCommonLoops =
833fef2d26SRiver Riddle getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
843fef2d26SRiver Riddle for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
853fef2d26SRiver Riddle SmallVector<DependenceComponent, 2> dependenceComponents;
863fef2d26SRiver Riddle DependenceResult result = checkMemrefAccessDependence(
87ee7c4741SMatthias Springer srcAccess, dstAccess, d, /*dependenceConstraints=*/nullptr,
883fef2d26SRiver Riddle &dependenceComponents);
891d541bd9SKai Sasaki if (result.value == DependenceResult::Failure) {
901d541bd9SKai Sasaki srcOpInst->emitError("dependence check failed");
911d541bd9SKai Sasaki } else {
923fef2d26SRiver Riddle bool ret = hasDependence(result);
933fef2d26SRiver Riddle // TODO: Print dependence type (i.e. RAW, etc) and print
943fef2d26SRiver Riddle // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
953fef2d26SRiver Riddle // vectors from ([1, 1], [3, 3]) to (1, 3).
963fef2d26SRiver Riddle srcOpInst->emitRemark("dependence from ")
973fef2d26SRiver Riddle << i << " to " << j << " at depth " << d << " = "
983fef2d26SRiver Riddle << getDirectionVectorStr(ret, numCommonLoops, d,
993fef2d26SRiver Riddle dependenceComponents);
1003fef2d26SRiver Riddle }
1013fef2d26SRiver Riddle }
1023fef2d26SRiver Riddle }
1033fef2d26SRiver Riddle }
1041d541bd9SKai Sasaki }
1053fef2d26SRiver Riddle
10687d6bf37SRiver Riddle /// Walks the operation adding load and store ops to 'loadsAndStores'. Runs
10787d6bf37SRiver Riddle /// pair-wise dependence checks.
runOnOperation()10841574554SRiver Riddle void TestMemRefDependenceCheck::runOnOperation() {
1093fef2d26SRiver Riddle // Collect the loads and stores within the function.
1103fef2d26SRiver Riddle loadsAndStores.clear();
11187d6bf37SRiver Riddle getOperation()->walk([&](Operation *op) {
1123fef2d26SRiver Riddle if (isa<AffineLoadOp, AffineStoreOp>(op))
1133fef2d26SRiver Riddle loadsAndStores.push_back(op);
1143fef2d26SRiver Riddle });
1153fef2d26SRiver Riddle
1163fef2d26SRiver Riddle checkDependences(loadsAndStores);
1173fef2d26SRiver Riddle }
1183fef2d26SRiver Riddle
1193fef2d26SRiver Riddle namespace mlir {
1203fef2d26SRiver Riddle namespace test {
registerTestMemRefDependenceCheck()1213fef2d26SRiver Riddle void registerTestMemRefDependenceCheck() {
122b5e22e6dSMehdi Amini PassRegistration<TestMemRefDependenceCheck>();
1233fef2d26SRiver Riddle }
1243fef2d26SRiver Riddle } // namespace test
1253fef2d26SRiver Riddle } // namespace mlir
126