xref: /llvm-project/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp (revision fe04aafe6c27f32ad4ba38e552d06d14431cb2de)
1755dc07dSRiver Riddle //===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===//
2755dc07dSRiver Riddle //
3755dc07dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4755dc07dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5755dc07dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6755dc07dSRiver Riddle //
7755dc07dSRiver Riddle //===----------------------------------------------------------------------===//
8755dc07dSRiver Riddle //
9755dc07dSRiver Riddle // This file implements miscellaneous loop analysis routines.
10755dc07dSRiver Riddle //
11755dc07dSRiver Riddle //===----------------------------------------------------------------------===//
12755dc07dSRiver Riddle 
13755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
14755dc07dSRiver Riddle 
15d1e9c7b6Smax #include "mlir/Analysis/SliceAnalysis.h"
16755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
18755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
19755dc07dSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
20755dc07dSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
210fb216fbSRamkumar Ramachandra #include "llvm/Support/MathExtras.h"
22755dc07dSRiver Riddle 
23d1e9c7b6Smax #include "llvm/ADT/DenseSet.h"
24755dc07dSRiver Riddle #include "llvm/ADT/SmallPtrSet.h"
25755dc07dSRiver Riddle #include "llvm/ADT/SmallString.h"
26*fe04aafeSUday Bondhugula #include "llvm/Support/Debug.h"
277a617fdfSKazu Hirata #include <numeric>
28e3915e6bSKazu Hirata #include <optional>
29755dc07dSRiver Riddle #include <type_traits>
30755dc07dSRiver Riddle 
31755dc07dSRiver Riddle using namespace mlir;
324c48f016SMatthias Springer using namespace mlir::affine;
33755dc07dSRiver Riddle 
34*fe04aafeSUday Bondhugula #define DEBUG_TYPE "affine-loop-analysis"
35*fe04aafeSUday Bondhugula 
36d1e9c7b6Smax /// Returns the trip count of the loop as an affine expression if the latter is
37d1e9c7b6Smax /// expressible as an affine expression, and nullptr otherwise. The trip count
38d1e9c7b6Smax /// expression is simplified before returning. This method only utilizes map
39d1e9c7b6Smax /// composition to construct lower and upper bounds before computing the trip
40d1e9c7b6Smax /// count expressions.
41d1e9c7b6Smax void mlir::affine::getTripCountMapAndOperands(
42d1e9c7b6Smax     AffineForOp forOp, AffineMap *tripCountMap,
43d1e9c7b6Smax     SmallVectorImpl<Value> *tripCountOperands) {
44d1e9c7b6Smax   MLIRContext *context = forOp.getContext();
45d1e9c7b6Smax   int64_t step = forOp.getStepAsInt();
46d1e9c7b6Smax   int64_t loopSpan;
47d1e9c7b6Smax   if (forOp.hasConstantBounds()) {
48d1e9c7b6Smax     int64_t lb = forOp.getConstantLowerBound();
49d1e9c7b6Smax     int64_t ub = forOp.getConstantUpperBound();
50d1e9c7b6Smax     loopSpan = ub - lb;
51d1e9c7b6Smax     if (loopSpan < 0)
52d1e9c7b6Smax       loopSpan = 0;
530fb216fbSRamkumar Ramachandra     *tripCountMap = AffineMap::getConstantMap(
540fb216fbSRamkumar Ramachandra         llvm::divideCeilSigned(loopSpan, step), context);
55d1e9c7b6Smax     tripCountOperands->clear();
56d1e9c7b6Smax     return;
57d1e9c7b6Smax   }
58d1e9c7b6Smax   auto lbMap = forOp.getLowerBoundMap();
59d1e9c7b6Smax   auto ubMap = forOp.getUpperBoundMap();
60d1e9c7b6Smax   if (lbMap.getNumResults() != 1) {
61d1e9c7b6Smax     *tripCountMap = AffineMap();
62d1e9c7b6Smax     return;
63d1e9c7b6Smax   }
64d1e9c7b6Smax 
65d1e9c7b6Smax   // Difference of each upper bound expression from the single lower bound
66d1e9c7b6Smax   // expression (divided by the step) provides the expressions for the trip
67d1e9c7b6Smax   // count map.
68d1e9c7b6Smax   AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
69d1e9c7b6Smax 
70d1e9c7b6Smax   SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
71d1e9c7b6Smax                                          lbMap.getResult(0));
72d1e9c7b6Smax   auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
73d1e9c7b6Smax                                    lbSplatExpr, context);
74d1e9c7b6Smax   AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
75d1e9c7b6Smax 
76d1e9c7b6Smax   AffineValueMap tripCountValueMap;
77d1e9c7b6Smax   AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
78d1e9c7b6Smax   for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
79d1e9c7b6Smax     tripCountValueMap.setResult(i,
80d1e9c7b6Smax                                 tripCountValueMap.getResult(i).ceilDiv(step));
81d1e9c7b6Smax 
82d1e9c7b6Smax   *tripCountMap = tripCountValueMap.getAffineMap();
83d1e9c7b6Smax   tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
84d1e9c7b6Smax                             tripCountValueMap.getOperands().end());
85d1e9c7b6Smax }
86d1e9c7b6Smax 
87d1e9c7b6Smax /// Returns the trip count of the loop if it's a constant, std::nullopt
88d1e9c7b6Smax /// otherwise. This method uses affine expression analysis (in turn using
89d1e9c7b6Smax /// getTripCount) and is able to determine constant trip count in non-trivial
90d1e9c7b6Smax /// cases.
91d1e9c7b6Smax std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
92d1e9c7b6Smax   SmallVector<Value, 4> operands;
93d1e9c7b6Smax   AffineMap map;
94d1e9c7b6Smax   getTripCountMapAndOperands(forOp, &map, &operands);
95d1e9c7b6Smax 
96d1e9c7b6Smax   if (!map)
97d1e9c7b6Smax     return std::nullopt;
98d1e9c7b6Smax 
99d1e9c7b6Smax   // Take the min if all trip counts are constant.
100d1e9c7b6Smax   std::optional<uint64_t> tripCount;
101d1e9c7b6Smax   for (auto resultExpr : map.getResults()) {
102d1e9c7b6Smax     if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
103d1e9c7b6Smax       if (tripCount.has_value())
104d1e9c7b6Smax         tripCount =
105d1e9c7b6Smax             std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
106d1e9c7b6Smax       else
107d1e9c7b6Smax         tripCount = constExpr.getValue();
108d1e9c7b6Smax     } else
109d1e9c7b6Smax       return std::nullopt;
110d1e9c7b6Smax   }
111d1e9c7b6Smax   return tripCount;
112d1e9c7b6Smax }
113d1e9c7b6Smax 
114755dc07dSRiver Riddle /// Returns the greatest known integral divisor of the trip count. Affine
115755dc07dSRiver Riddle /// expression analysis is used (indirectly through getTripCount), and
116755dc07dSRiver Riddle /// this method is thus able to determine non-trivial divisors.
1174c48f016SMatthias Springer uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
118755dc07dSRiver Riddle   SmallVector<Value, 4> operands;
119755dc07dSRiver Riddle   AffineMap map;
120755dc07dSRiver Riddle   getTripCountMapAndOperands(forOp, &map, &operands);
121755dc07dSRiver Riddle 
122755dc07dSRiver Riddle   if (!map)
123755dc07dSRiver Riddle     return 1;
124755dc07dSRiver Riddle 
125755dc07dSRiver Riddle   // The largest divisor of the trip count is the GCD of the individual largest
126755dc07dSRiver Riddle   // divisors.
127755dc07dSRiver Riddle   assert(map.getNumResults() >= 1 && "expected one or more results");
128e3915e6bSKazu Hirata   std::optional<uint64_t> gcd;
129755dc07dSRiver Riddle   for (auto resultExpr : map.getResults()) {
130755dc07dSRiver Riddle     uint64_t thisGcd;
1311609f1c2Slong.chen     if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
132755dc07dSRiver Riddle       uint64_t tripCount = constExpr.getValue();
133755dc07dSRiver Riddle       // 0 iteration loops (greatest divisor is 2^64 - 1).
134755dc07dSRiver Riddle       if (tripCount == 0)
135755dc07dSRiver Riddle         thisGcd = std::numeric_limits<uint64_t>::max();
136755dc07dSRiver Riddle       else
137755dc07dSRiver Riddle         // The greatest divisor is the trip count.
138755dc07dSRiver Riddle         thisGcd = tripCount;
139755dc07dSRiver Riddle     } else {
140755dc07dSRiver Riddle       // Trip count is not a known constant; return its largest known divisor.
141755dc07dSRiver Riddle       thisGcd = resultExpr.getLargestKnownDivisor();
142755dc07dSRiver Riddle     }
143491d2701SKazu Hirata     if (gcd.has_value())
1444913e5daSFangrui Song       gcd = std::gcd(*gcd, thisGcd);
145755dc07dSRiver Riddle     else
146755dc07dSRiver Riddle       gcd = thisGcd;
147755dc07dSRiver Riddle   }
148491d2701SKazu Hirata   assert(gcd.has_value() && "value expected per above logic");
1494913e5daSFangrui Song   return *gcd;
150755dc07dSRiver Riddle }
151755dc07dSRiver Riddle 
1521e9bfcd9SUday Bondhugula /// Given an affine.for `iv` and an access `index` of type index, returns `true`
1531e9bfcd9SUday Bondhugula /// if `index` is independent of `iv` and false otherwise.
154755dc07dSRiver Riddle ///
1551e9bfcd9SUday Bondhugula /// Prerequisites: `iv` and `index` of the proper type;
156755dc07dSRiver Riddle static bool isAccessIndexInvariant(Value iv, Value index) {
1571e9bfcd9SUday Bondhugula   assert(isAffineForInductionVar(iv) && "iv must be an affine.for iv");
1581e9bfcd9SUday Bondhugula   assert(isa<IndexType>(index.getType()) && "index must be of 'index' type");
1591e9bfcd9SUday Bondhugula   auto map = AffineMap::getMultiDimIdentityMap(/*numDims=*/1, iv.getContext());
1601e9bfcd9SUday Bondhugula   SmallVector<Value> operands = {index};
1611e9bfcd9SUday Bondhugula   AffineValueMap avm(map, operands);
1621e9bfcd9SUday Bondhugula   avm.composeSimplifyAndCanonicalize();
1631e9bfcd9SUday Bondhugula   return !avm.isFunctionOf(0, iv);
164755dc07dSRiver Riddle }
165755dc07dSRiver Riddle 
1661e9bfcd9SUday Bondhugula // Pre-requisite: Loop bounds should be in canonical form.
1671e9bfcd9SUday Bondhugula template <typename LoadOrStoreOp>
1681e9bfcd9SUday Bondhugula bool mlir::affine::isInvariantAccess(LoadOrStoreOp memOp, AffineForOp forOp) {
1691e9bfcd9SUday Bondhugula   AffineValueMap avm(memOp.getAffineMap(), memOp.getMapOperands());
1701e9bfcd9SUday Bondhugula   avm.composeSimplifyAndCanonicalize();
1711e9bfcd9SUday Bondhugula   return !llvm::is_contained(avm.getOperands(), forOp.getInductionVar());
172755dc07dSRiver Riddle }
173755dc07dSRiver Riddle 
1741e9bfcd9SUday Bondhugula // Explicitly instantiate the template so that the compiler knows we need them.
1751e9bfcd9SUday Bondhugula template bool mlir::affine::isInvariantAccess(AffineReadOpInterface,
1761e9bfcd9SUday Bondhugula                                               AffineForOp);
1771e9bfcd9SUday Bondhugula template bool mlir::affine::isInvariantAccess(AffineWriteOpInterface,
1781e9bfcd9SUday Bondhugula                                               AffineForOp);
1791e9bfcd9SUday Bondhugula template bool mlir::affine::isInvariantAccess(AffineLoadOp, AffineForOp);
1801e9bfcd9SUday Bondhugula template bool mlir::affine::isInvariantAccess(AffineStoreOp, AffineForOp);
181755dc07dSRiver Riddle 
1824c48f016SMatthias Springer DenseSet<Value> mlir::affine::getInvariantAccesses(Value iv,
1834c48f016SMatthias Springer                                                    ArrayRef<Value> indices) {
184755dc07dSRiver Riddle   DenseSet<Value> res;
185755dc07dSRiver Riddle   for (auto val : indices) {
186755dc07dSRiver Riddle     if (isAccessIndexInvariant(iv, val)) {
187755dc07dSRiver Riddle       res.insert(val);
188755dc07dSRiver Riddle     }
189755dc07dSRiver Riddle   }
190755dc07dSRiver Riddle   return res;
191755dc07dSRiver Riddle }
192755dc07dSRiver Riddle 
1932679d379SUday Bondhugula // TODO: check access stride.
194755dc07dSRiver Riddle template <typename LoadOrStoreOp>
1952679d379SUday Bondhugula bool mlir::affine::isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
196755dc07dSRiver Riddle                                       int *memRefDim) {
1972679d379SUday Bondhugula   static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1982679d379SUday Bondhugula                                 AffineWriteOpInterface>::value,
1992679d379SUday Bondhugula                 "Must be called on either an affine read or write op");
200755dc07dSRiver Riddle   assert(memRefDim && "memRefDim == nullptr");
201755dc07dSRiver Riddle   auto memRefType = memoryOp.getMemRefType();
202755dc07dSRiver Riddle 
203755dc07dSRiver Riddle   if (!memRefType.getLayout().isIdentity())
2042679d379SUday Bondhugula     return memoryOp.emitError("NYI: non-trivial layout map"), false;
205755dc07dSRiver Riddle 
206755dc07dSRiver Riddle   int uniqueVaryingIndexAlongIv = -1;
207755dc07dSRiver Riddle   auto accessMap = memoryOp.getAffineMap();
208755dc07dSRiver Riddle   SmallVector<Value, 4> mapOperands(memoryOp.getMapOperands());
209755dc07dSRiver Riddle   unsigned numDims = accessMap.getNumDims();
210755dc07dSRiver Riddle   for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
2112679d379SUday Bondhugula     // Gather map operands used in result expr 'i' in 'exprOperands'.
212755dc07dSRiver Riddle     SmallVector<Value, 4> exprOperands;
213755dc07dSRiver Riddle     auto resultExpr = accessMap.getResult(i);
214755dc07dSRiver Riddle     resultExpr.walk([&](AffineExpr expr) {
2151609f1c2Slong.chen       if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
216755dc07dSRiver Riddle         exprOperands.push_back(mapOperands[dimExpr.getPosition()]);
2171609f1c2Slong.chen       else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
218755dc07dSRiver Riddle         exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]);
219755dc07dSRiver Riddle     });
220755dc07dSRiver Riddle     // Check access invariance of each operand in 'exprOperands'.
2212679d379SUday Bondhugula     for (Value exprOperand : exprOperands) {
222755dc07dSRiver Riddle       if (!isAccessIndexInvariant(iv, exprOperand)) {
223755dc07dSRiver Riddle         if (uniqueVaryingIndexAlongIv != -1) {
224755dc07dSRiver Riddle           // 2+ varying indices -> do not vectorize along iv.
225755dc07dSRiver Riddle           return false;
226755dc07dSRiver Riddle         }
227755dc07dSRiver Riddle         uniqueVaryingIndexAlongIv = i;
228755dc07dSRiver Riddle       }
229755dc07dSRiver Riddle     }
230755dc07dSRiver Riddle   }
231755dc07dSRiver Riddle 
232755dc07dSRiver Riddle   if (uniqueVaryingIndexAlongIv == -1)
233755dc07dSRiver Riddle     *memRefDim = -1;
234755dc07dSRiver Riddle   else
235755dc07dSRiver Riddle     *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1);
236755dc07dSRiver Riddle   return true;
237755dc07dSRiver Riddle }
238755dc07dSRiver Riddle 
2392679d379SUday Bondhugula template bool mlir::affine::isContiguousAccess(Value iv,
2402679d379SUday Bondhugula                                                AffineReadOpInterface loadOp,
2412679d379SUday Bondhugula                                                int *memRefDim);
2422679d379SUday Bondhugula template bool mlir::affine::isContiguousAccess(Value iv,
2432679d379SUday Bondhugula                                                AffineWriteOpInterface loadOp,
2442679d379SUday Bondhugula                                                int *memRefDim);
2452679d379SUday Bondhugula 
246755dc07dSRiver Riddle template <typename LoadOrStoreOp>
247755dc07dSRiver Riddle static bool isVectorElement(LoadOrStoreOp memoryOp) {
248755dc07dSRiver Riddle   auto memRefType = memoryOp.getMemRefType();
2495550c821STres Popp   return isa<VectorType>(memRefType.getElementType());
250755dc07dSRiver Riddle }
251755dc07dSRiver Riddle 
252755dc07dSRiver Riddle using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
253755dc07dSRiver Riddle 
254755dc07dSRiver Riddle static bool
255755dc07dSRiver Riddle isVectorizableLoopBodyWithOpCond(AffineForOp loop,
256755dc07dSRiver Riddle                                  const VectorizableOpFun &isVectorizableOp,
257755dc07dSRiver Riddle                                  NestedPattern &vectorTransferMatcher) {
258755dc07dSRiver Riddle   auto *forOp = loop.getOperation();
259755dc07dSRiver Riddle 
260755dc07dSRiver Riddle   // No vectorization across conditionals for now.
261755dc07dSRiver Riddle   auto conditionals = matcher::If();
262755dc07dSRiver Riddle   SmallVector<NestedMatch, 8> conditionalsMatched;
263755dc07dSRiver Riddle   conditionals.match(forOp, &conditionalsMatched);
264755dc07dSRiver Riddle   if (!conditionalsMatched.empty()) {
265755dc07dSRiver Riddle     return false;
266755dc07dSRiver Riddle   }
267755dc07dSRiver Riddle 
2683974ecb7SJoshua Cao   // No vectorization for ops with operand or result types that are not
2693974ecb7SJoshua Cao   // vectorizable.
2703974ecb7SJoshua Cao   auto types = matcher::Op([](Operation &op) -> bool {
2713974ecb7SJoshua Cao     if (llvm::any_of(op.getOperandTypes(), [](Type type) {
2723974ecb7SJoshua Cao           if (MemRefType t = dyn_cast<MemRefType>(type))
2733974ecb7SJoshua Cao             return !VectorType::isValidElementType(t.getElementType());
2743974ecb7SJoshua Cao           return !VectorType::isValidElementType(type);
2753974ecb7SJoshua Cao         }))
2763974ecb7SJoshua Cao       return true;
2773974ecb7SJoshua Cao     return llvm::any_of(op.getResultTypes(), [](Type type) {
2783974ecb7SJoshua Cao       return !VectorType::isValidElementType(type);
2793974ecb7SJoshua Cao     });
2803974ecb7SJoshua Cao   });
2813974ecb7SJoshua Cao   SmallVector<NestedMatch, 8> opsMatched;
2823974ecb7SJoshua Cao   types.match(forOp, &opsMatched);
2833974ecb7SJoshua Cao   if (!opsMatched.empty()) {
2843974ecb7SJoshua Cao     return false;
2853974ecb7SJoshua Cao   }
2863974ecb7SJoshua Cao 
287755dc07dSRiver Riddle   // No vectorization across unknown regions.
288755dc07dSRiver Riddle   auto regions = matcher::Op([](Operation &op) -> bool {
289755dc07dSRiver Riddle     return op.getNumRegions() != 0 && !isa<AffineIfOp, AffineForOp>(op);
290755dc07dSRiver Riddle   });
291755dc07dSRiver Riddle   SmallVector<NestedMatch, 8> regionsMatched;
292755dc07dSRiver Riddle   regions.match(forOp, &regionsMatched);
293755dc07dSRiver Riddle   if (!regionsMatched.empty()) {
294755dc07dSRiver Riddle     return false;
295755dc07dSRiver Riddle   }
296755dc07dSRiver Riddle 
297755dc07dSRiver Riddle   SmallVector<NestedMatch, 8> vectorTransfersMatched;
298755dc07dSRiver Riddle   vectorTransferMatcher.match(forOp, &vectorTransfersMatched);
299755dc07dSRiver Riddle   if (!vectorTransfersMatched.empty()) {
300755dc07dSRiver Riddle     return false;
301755dc07dSRiver Riddle   }
302755dc07dSRiver Riddle 
303755dc07dSRiver Riddle   auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
304755dc07dSRiver Riddle   SmallVector<NestedMatch, 8> loadAndStoresMatched;
305755dc07dSRiver Riddle   loadAndStores.match(forOp, &loadAndStoresMatched);
306755dc07dSRiver Riddle   for (auto ls : loadAndStoresMatched) {
307755dc07dSRiver Riddle     auto *op = ls.getMatchedOperation();
308755dc07dSRiver Riddle     auto load = dyn_cast<AffineLoadOp>(op);
309755dc07dSRiver Riddle     auto store = dyn_cast<AffineStoreOp>(op);
310755dc07dSRiver Riddle     // Only scalar types are considered vectorizable, all load/store must be
311755dc07dSRiver Riddle     // vectorizable for a loop to qualify as vectorizable.
312755dc07dSRiver Riddle     // TODO: ponder whether we want to be more general here.
313755dc07dSRiver Riddle     bool vector = load ? isVectorElement(load) : isVectorElement(store);
314755dc07dSRiver Riddle     if (vector) {
315755dc07dSRiver Riddle       return false;
316755dc07dSRiver Riddle     }
317755dc07dSRiver Riddle     if (isVectorizableOp && !isVectorizableOp(loop, *op)) {
318755dc07dSRiver Riddle       return false;
319755dc07dSRiver Riddle     }
320755dc07dSRiver Riddle   }
321755dc07dSRiver Riddle   return true;
322755dc07dSRiver Riddle }
323755dc07dSRiver Riddle 
3244c48f016SMatthias Springer bool mlir::affine::isVectorizableLoopBody(
3254c48f016SMatthias Springer     AffineForOp loop, int *memRefDim, NestedPattern &vectorTransferMatcher) {
326755dc07dSRiver Riddle   *memRefDim = -1;
327755dc07dSRiver Riddle   VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) {
328755dc07dSRiver Riddle     auto load = dyn_cast<AffineLoadOp>(op);
329755dc07dSRiver Riddle     auto store = dyn_cast<AffineStoreOp>(op);
330755dc07dSRiver Riddle     int thisOpMemRefDim = -1;
3312679d379SUday Bondhugula     bool isContiguous =
3322679d379SUday Bondhugula         load ? isContiguousAccess(loop.getInductionVar(),
3332679d379SUday Bondhugula                                   cast<AffineReadOpInterface>(*load),
334755dc07dSRiver Riddle                                   &thisOpMemRefDim)
3352679d379SUday Bondhugula              : isContiguousAccess(loop.getInductionVar(),
3362679d379SUday Bondhugula                                   cast<AffineWriteOpInterface>(*store),
337755dc07dSRiver Riddle                                   &thisOpMemRefDim);
338755dc07dSRiver Riddle     if (thisOpMemRefDim != -1) {
339755dc07dSRiver Riddle       // If memory accesses vary across different dimensions then the loop is
340755dc07dSRiver Riddle       // not vectorizable.
341755dc07dSRiver Riddle       if (*memRefDim != -1 && *memRefDim != thisOpMemRefDim)
342755dc07dSRiver Riddle         return false;
343755dc07dSRiver Riddle       *memRefDim = thisOpMemRefDim;
344755dc07dSRiver Riddle     }
345755dc07dSRiver Riddle     return isContiguous;
346755dc07dSRiver Riddle   });
347755dc07dSRiver Riddle   return isVectorizableLoopBodyWithOpCond(loop, fun, vectorTransferMatcher);
348755dc07dSRiver Riddle }
349755dc07dSRiver Riddle 
3504c48f016SMatthias Springer bool mlir::affine::isVectorizableLoopBody(
3514c48f016SMatthias Springer     AffineForOp loop, NestedPattern &vectorTransferMatcher) {
352755dc07dSRiver Riddle   return isVectorizableLoopBodyWithOpCond(loop, nullptr, vectorTransferMatcher);
353755dc07dSRiver Riddle }
354755dc07dSRiver Riddle 
355755dc07dSRiver Riddle /// Checks whether SSA dominance would be violated if a for op's body
356755dc07dSRiver Riddle /// operations are shifted by the specified shifts. This method checks if a
357755dc07dSRiver Riddle /// 'def' and all its uses have the same shift factor.
358755dc07dSRiver Riddle // TODO: extend this to check for memory-based dependence violation when we have
359755dc07dSRiver Riddle // the support.
3604c48f016SMatthias Springer bool mlir::affine::isOpwiseShiftValid(AffineForOp forOp,
3614c48f016SMatthias Springer                                       ArrayRef<uint64_t> shifts) {
362755dc07dSRiver Riddle   auto *forBody = forOp.getBody();
363755dc07dSRiver Riddle   assert(shifts.size() == forBody->getOperations().size());
364755dc07dSRiver Riddle 
365755dc07dSRiver Riddle   // Work backwards over the body of the block so that the shift of a use's
366755dc07dSRiver Riddle   // ancestor operation in the block gets recorded before it's looked up.
367755dc07dSRiver Riddle   DenseMap<Operation *, uint64_t> forBodyShift;
368755dc07dSRiver Riddle   for (const auto &it :
369755dc07dSRiver Riddle        llvm::enumerate(llvm::reverse(forBody->getOperations()))) {
370755dc07dSRiver Riddle     auto &op = it.value();
371755dc07dSRiver Riddle 
372755dc07dSRiver Riddle     // Get the index of the current operation, note that we are iterating in
373755dc07dSRiver Riddle     // reverse so we need to fix it up.
374755dc07dSRiver Riddle     size_t index = shifts.size() - it.index() - 1;
375755dc07dSRiver Riddle 
376755dc07dSRiver Riddle     // Remember the shift of this operation.
377755dc07dSRiver Riddle     uint64_t shift = shifts[index];
378755dc07dSRiver Riddle     forBodyShift.try_emplace(&op, shift);
379755dc07dSRiver Riddle 
380755dc07dSRiver Riddle     // Validate the results of this operation if it were to be shifted.
381755dc07dSRiver Riddle     for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
382755dc07dSRiver Riddle       Value result = op.getResult(i);
383755dc07dSRiver Riddle       for (auto *user : result.getUsers()) {
384755dc07dSRiver Riddle         // If an ancestor operation doesn't lie in the block of forOp,
385755dc07dSRiver Riddle         // there is no shift to check.
386755dc07dSRiver Riddle         if (auto *ancOp = forBody->findAncestorOpInBlock(*user)) {
387755dc07dSRiver Riddle           assert(forBodyShift.count(ancOp) > 0 && "ancestor expected in map");
388755dc07dSRiver Riddle           if (shift != forBodyShift[ancOp])
389755dc07dSRiver Riddle             return false;
390755dc07dSRiver Riddle         }
391755dc07dSRiver Riddle       }
392755dc07dSRiver Riddle     }
393755dc07dSRiver Riddle   }
394755dc07dSRiver Riddle   return true;
395755dc07dSRiver Riddle }
396*fe04aafeSUday Bondhugula 
397*fe04aafeSUday Bondhugula bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) {
398*fe04aafeSUday Bondhugula   assert(!loops.empty() && "no original loops provided");
399*fe04aafeSUday Bondhugula 
400*fe04aafeSUday Bondhugula   // We first find out all dependences we intend to check.
401*fe04aafeSUday Bondhugula   SmallVector<Operation *, 8> loadAndStoreOps;
402*fe04aafeSUday Bondhugula   loops[0]->walk([&](Operation *op) {
403*fe04aafeSUday Bondhugula     if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
404*fe04aafeSUday Bondhugula       loadAndStoreOps.push_back(op);
405*fe04aafeSUday Bondhugula   });
406*fe04aafeSUday Bondhugula 
407*fe04aafeSUday Bondhugula   unsigned numOps = loadAndStoreOps.size();
408*fe04aafeSUday Bondhugula   unsigned numLoops = loops.size();
409*fe04aafeSUday Bondhugula   for (unsigned d = 1; d <= numLoops + 1; ++d) {
410*fe04aafeSUday Bondhugula     for (unsigned i = 0; i < numOps; ++i) {
411*fe04aafeSUday Bondhugula       Operation *srcOp = loadAndStoreOps[i];
412*fe04aafeSUday Bondhugula       MemRefAccess srcAccess(srcOp);
413*fe04aafeSUday Bondhugula       for (unsigned j = 0; j < numOps; ++j) {
414*fe04aafeSUday Bondhugula         Operation *dstOp = loadAndStoreOps[j];
415*fe04aafeSUday Bondhugula         MemRefAccess dstAccess(dstOp);
416*fe04aafeSUday Bondhugula 
417*fe04aafeSUday Bondhugula         SmallVector<DependenceComponent, 2> depComps;
418*fe04aafeSUday Bondhugula         DependenceResult result = checkMemrefAccessDependence(
419*fe04aafeSUday Bondhugula             srcAccess, dstAccess, d, /*dependenceConstraints=*/nullptr,
420*fe04aafeSUday Bondhugula             &depComps);
421*fe04aafeSUday Bondhugula 
422*fe04aafeSUday Bondhugula         // Skip if there is no dependence in this case.
423*fe04aafeSUday Bondhugula         if (!hasDependence(result))
424*fe04aafeSUday Bondhugula           continue;
425*fe04aafeSUday Bondhugula 
426*fe04aafeSUday Bondhugula         // Check whether there is any negative direction vector in the
427*fe04aafeSUday Bondhugula         // dependence components found above, which means that dependence is
428*fe04aafeSUday Bondhugula         // violated by the default hyper-rect tiling method.
429*fe04aafeSUday Bondhugula         LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated "
430*fe04aafeSUday Bondhugula                                    "for dependence at depth: "
431*fe04aafeSUday Bondhugula                                 << Twine(d) << " between:\n";);
432*fe04aafeSUday Bondhugula         LLVM_DEBUG(srcAccess.opInst->dump());
433*fe04aafeSUday Bondhugula         LLVM_DEBUG(dstAccess.opInst->dump());
434*fe04aafeSUday Bondhugula         for (const DependenceComponent &depComp : depComps) {
435*fe04aafeSUday Bondhugula           if (depComp.lb.has_value() && depComp.ub.has_value() &&
436*fe04aafeSUday Bondhugula               *depComp.lb < *depComp.ub && *depComp.ub < 0) {
437*fe04aafeSUday Bondhugula             LLVM_DEBUG(llvm::dbgs()
438*fe04aafeSUday Bondhugula                        << "Dependence component lb = " << Twine(*depComp.lb)
439*fe04aafeSUday Bondhugula                        << " ub = " << Twine(*depComp.ub)
440*fe04aafeSUday Bondhugula                        << " is negative  at depth: " << Twine(d)
441*fe04aafeSUday Bondhugula                        << " and thus violates the legality rule.\n");
442*fe04aafeSUday Bondhugula             return false;
443*fe04aafeSUday Bondhugula           }
444*fe04aafeSUday Bondhugula         }
445*fe04aafeSUday Bondhugula       }
446*fe04aafeSUday Bondhugula     }
447*fe04aafeSUday Bondhugula   }
448*fe04aafeSUday Bondhugula 
449*fe04aafeSUday Bondhugula   return true;
450*fe04aafeSUday Bondhugula }
451