1755dc07dSRiver Riddle //===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===// 2755dc07dSRiver Riddle // 3755dc07dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4755dc07dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5755dc07dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6755dc07dSRiver Riddle // 7755dc07dSRiver Riddle //===----------------------------------------------------------------------===// 8755dc07dSRiver Riddle // 9755dc07dSRiver Riddle // This file implements miscellaneous loop analysis routines. 10755dc07dSRiver Riddle // 11755dc07dSRiver Riddle //===----------------------------------------------------------------------===// 12755dc07dSRiver Riddle 13755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" 14755dc07dSRiver Riddle 15d1e9c7b6Smax #include "mlir/Analysis/SliceAnalysis.h" 16755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" 17755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 18755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" 19755dc07dSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 20755dc07dSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 210fb216fbSRamkumar Ramachandra #include "llvm/Support/MathExtras.h" 22755dc07dSRiver Riddle 23d1e9c7b6Smax #include "llvm/ADT/DenseSet.h" 24755dc07dSRiver Riddle #include "llvm/ADT/SmallPtrSet.h" 25755dc07dSRiver Riddle #include "llvm/ADT/SmallString.h" 26*fe04aafeSUday Bondhugula #include "llvm/Support/Debug.h" 277a617fdfSKazu Hirata #include <numeric> 28e3915e6bSKazu Hirata #include <optional> 29755dc07dSRiver Riddle #include <type_traits> 30755dc07dSRiver Riddle 31755dc07dSRiver Riddle using namespace mlir; 324c48f016SMatthias Springer using namespace mlir::affine; 33755dc07dSRiver Riddle 34*fe04aafeSUday Bondhugula #define DEBUG_TYPE "affine-loop-analysis" 35*fe04aafeSUday Bondhugula 36d1e9c7b6Smax /// Returns the trip count of the loop as an affine expression if the latter is 37d1e9c7b6Smax /// expressible as an affine expression, and nullptr otherwise. The trip count 38d1e9c7b6Smax /// expression is simplified before returning. This method only utilizes map 39d1e9c7b6Smax /// composition to construct lower and upper bounds before computing the trip 40d1e9c7b6Smax /// count expressions. 41d1e9c7b6Smax void mlir::affine::getTripCountMapAndOperands( 42d1e9c7b6Smax AffineForOp forOp, AffineMap *tripCountMap, 43d1e9c7b6Smax SmallVectorImpl<Value> *tripCountOperands) { 44d1e9c7b6Smax MLIRContext *context = forOp.getContext(); 45d1e9c7b6Smax int64_t step = forOp.getStepAsInt(); 46d1e9c7b6Smax int64_t loopSpan; 47d1e9c7b6Smax if (forOp.hasConstantBounds()) { 48d1e9c7b6Smax int64_t lb = forOp.getConstantLowerBound(); 49d1e9c7b6Smax int64_t ub = forOp.getConstantUpperBound(); 50d1e9c7b6Smax loopSpan = ub - lb; 51d1e9c7b6Smax if (loopSpan < 0) 52d1e9c7b6Smax loopSpan = 0; 530fb216fbSRamkumar Ramachandra *tripCountMap = AffineMap::getConstantMap( 540fb216fbSRamkumar Ramachandra llvm::divideCeilSigned(loopSpan, step), context); 55d1e9c7b6Smax tripCountOperands->clear(); 56d1e9c7b6Smax return; 57d1e9c7b6Smax } 58d1e9c7b6Smax auto lbMap = forOp.getLowerBoundMap(); 59d1e9c7b6Smax auto ubMap = forOp.getUpperBoundMap(); 60d1e9c7b6Smax if (lbMap.getNumResults() != 1) { 61d1e9c7b6Smax *tripCountMap = AffineMap(); 62d1e9c7b6Smax return; 63d1e9c7b6Smax } 64d1e9c7b6Smax 65d1e9c7b6Smax // Difference of each upper bound expression from the single lower bound 66d1e9c7b6Smax // expression (divided by the step) provides the expressions for the trip 67d1e9c7b6Smax // count map. 68d1e9c7b6Smax AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands()); 69d1e9c7b6Smax 70d1e9c7b6Smax SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(), 71d1e9c7b6Smax lbMap.getResult(0)); 72d1e9c7b6Smax auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), 73d1e9c7b6Smax lbSplatExpr, context); 74d1e9c7b6Smax AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands()); 75d1e9c7b6Smax 76d1e9c7b6Smax AffineValueMap tripCountValueMap; 77d1e9c7b6Smax AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap); 78d1e9c7b6Smax for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i) 79d1e9c7b6Smax tripCountValueMap.setResult(i, 80d1e9c7b6Smax tripCountValueMap.getResult(i).ceilDiv(step)); 81d1e9c7b6Smax 82d1e9c7b6Smax *tripCountMap = tripCountValueMap.getAffineMap(); 83d1e9c7b6Smax tripCountOperands->assign(tripCountValueMap.getOperands().begin(), 84d1e9c7b6Smax tripCountValueMap.getOperands().end()); 85d1e9c7b6Smax } 86d1e9c7b6Smax 87d1e9c7b6Smax /// Returns the trip count of the loop if it's a constant, std::nullopt 88d1e9c7b6Smax /// otherwise. This method uses affine expression analysis (in turn using 89d1e9c7b6Smax /// getTripCount) and is able to determine constant trip count in non-trivial 90d1e9c7b6Smax /// cases. 91d1e9c7b6Smax std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) { 92d1e9c7b6Smax SmallVector<Value, 4> operands; 93d1e9c7b6Smax AffineMap map; 94d1e9c7b6Smax getTripCountMapAndOperands(forOp, &map, &operands); 95d1e9c7b6Smax 96d1e9c7b6Smax if (!map) 97d1e9c7b6Smax return std::nullopt; 98d1e9c7b6Smax 99d1e9c7b6Smax // Take the min if all trip counts are constant. 100d1e9c7b6Smax std::optional<uint64_t> tripCount; 101d1e9c7b6Smax for (auto resultExpr : map.getResults()) { 102d1e9c7b6Smax if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) { 103d1e9c7b6Smax if (tripCount.has_value()) 104d1e9c7b6Smax tripCount = 105d1e9c7b6Smax std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue())); 106d1e9c7b6Smax else 107d1e9c7b6Smax tripCount = constExpr.getValue(); 108d1e9c7b6Smax } else 109d1e9c7b6Smax return std::nullopt; 110d1e9c7b6Smax } 111d1e9c7b6Smax return tripCount; 112d1e9c7b6Smax } 113d1e9c7b6Smax 114755dc07dSRiver Riddle /// Returns the greatest known integral divisor of the trip count. Affine 115755dc07dSRiver Riddle /// expression analysis is used (indirectly through getTripCount), and 116755dc07dSRiver Riddle /// this method is thus able to determine non-trivial divisors. 1174c48f016SMatthias Springer uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) { 118755dc07dSRiver Riddle SmallVector<Value, 4> operands; 119755dc07dSRiver Riddle AffineMap map; 120755dc07dSRiver Riddle getTripCountMapAndOperands(forOp, &map, &operands); 121755dc07dSRiver Riddle 122755dc07dSRiver Riddle if (!map) 123755dc07dSRiver Riddle return 1; 124755dc07dSRiver Riddle 125755dc07dSRiver Riddle // The largest divisor of the trip count is the GCD of the individual largest 126755dc07dSRiver Riddle // divisors. 127755dc07dSRiver Riddle assert(map.getNumResults() >= 1 && "expected one or more results"); 128e3915e6bSKazu Hirata std::optional<uint64_t> gcd; 129755dc07dSRiver Riddle for (auto resultExpr : map.getResults()) { 130755dc07dSRiver Riddle uint64_t thisGcd; 1311609f1c2Slong.chen if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) { 132755dc07dSRiver Riddle uint64_t tripCount = constExpr.getValue(); 133755dc07dSRiver Riddle // 0 iteration loops (greatest divisor is 2^64 - 1). 134755dc07dSRiver Riddle if (tripCount == 0) 135755dc07dSRiver Riddle thisGcd = std::numeric_limits<uint64_t>::max(); 136755dc07dSRiver Riddle else 137755dc07dSRiver Riddle // The greatest divisor is the trip count. 138755dc07dSRiver Riddle thisGcd = tripCount; 139755dc07dSRiver Riddle } else { 140755dc07dSRiver Riddle // Trip count is not a known constant; return its largest known divisor. 141755dc07dSRiver Riddle thisGcd = resultExpr.getLargestKnownDivisor(); 142755dc07dSRiver Riddle } 143491d2701SKazu Hirata if (gcd.has_value()) 1444913e5daSFangrui Song gcd = std::gcd(*gcd, thisGcd); 145755dc07dSRiver Riddle else 146755dc07dSRiver Riddle gcd = thisGcd; 147755dc07dSRiver Riddle } 148491d2701SKazu Hirata assert(gcd.has_value() && "value expected per above logic"); 1494913e5daSFangrui Song return *gcd; 150755dc07dSRiver Riddle } 151755dc07dSRiver Riddle 1521e9bfcd9SUday Bondhugula /// Given an affine.for `iv` and an access `index` of type index, returns `true` 1531e9bfcd9SUday Bondhugula /// if `index` is independent of `iv` and false otherwise. 154755dc07dSRiver Riddle /// 1551e9bfcd9SUday Bondhugula /// Prerequisites: `iv` and `index` of the proper type; 156755dc07dSRiver Riddle static bool isAccessIndexInvariant(Value iv, Value index) { 1571e9bfcd9SUday Bondhugula assert(isAffineForInductionVar(iv) && "iv must be an affine.for iv"); 1581e9bfcd9SUday Bondhugula assert(isa<IndexType>(index.getType()) && "index must be of 'index' type"); 1591e9bfcd9SUday Bondhugula auto map = AffineMap::getMultiDimIdentityMap(/*numDims=*/1, iv.getContext()); 1601e9bfcd9SUday Bondhugula SmallVector<Value> operands = {index}; 1611e9bfcd9SUday Bondhugula AffineValueMap avm(map, operands); 1621e9bfcd9SUday Bondhugula avm.composeSimplifyAndCanonicalize(); 1631e9bfcd9SUday Bondhugula return !avm.isFunctionOf(0, iv); 164755dc07dSRiver Riddle } 165755dc07dSRiver Riddle 1661e9bfcd9SUday Bondhugula // Pre-requisite: Loop bounds should be in canonical form. 1671e9bfcd9SUday Bondhugula template <typename LoadOrStoreOp> 1681e9bfcd9SUday Bondhugula bool mlir::affine::isInvariantAccess(LoadOrStoreOp memOp, AffineForOp forOp) { 1691e9bfcd9SUday Bondhugula AffineValueMap avm(memOp.getAffineMap(), memOp.getMapOperands()); 1701e9bfcd9SUday Bondhugula avm.composeSimplifyAndCanonicalize(); 1711e9bfcd9SUday Bondhugula return !llvm::is_contained(avm.getOperands(), forOp.getInductionVar()); 172755dc07dSRiver Riddle } 173755dc07dSRiver Riddle 1741e9bfcd9SUday Bondhugula // Explicitly instantiate the template so that the compiler knows we need them. 1751e9bfcd9SUday Bondhugula template bool mlir::affine::isInvariantAccess(AffineReadOpInterface, 1761e9bfcd9SUday Bondhugula AffineForOp); 1771e9bfcd9SUday Bondhugula template bool mlir::affine::isInvariantAccess(AffineWriteOpInterface, 1781e9bfcd9SUday Bondhugula AffineForOp); 1791e9bfcd9SUday Bondhugula template bool mlir::affine::isInvariantAccess(AffineLoadOp, AffineForOp); 1801e9bfcd9SUday Bondhugula template bool mlir::affine::isInvariantAccess(AffineStoreOp, AffineForOp); 181755dc07dSRiver Riddle 1824c48f016SMatthias Springer DenseSet<Value> mlir::affine::getInvariantAccesses(Value iv, 1834c48f016SMatthias Springer ArrayRef<Value> indices) { 184755dc07dSRiver Riddle DenseSet<Value> res; 185755dc07dSRiver Riddle for (auto val : indices) { 186755dc07dSRiver Riddle if (isAccessIndexInvariant(iv, val)) { 187755dc07dSRiver Riddle res.insert(val); 188755dc07dSRiver Riddle } 189755dc07dSRiver Riddle } 190755dc07dSRiver Riddle return res; 191755dc07dSRiver Riddle } 192755dc07dSRiver Riddle 1932679d379SUday Bondhugula // TODO: check access stride. 194755dc07dSRiver Riddle template <typename LoadOrStoreOp> 1952679d379SUday Bondhugula bool mlir::affine::isContiguousAccess(Value iv, LoadOrStoreOp memoryOp, 196755dc07dSRiver Riddle int *memRefDim) { 1972679d379SUday Bondhugula static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface, 1982679d379SUday Bondhugula AffineWriteOpInterface>::value, 1992679d379SUday Bondhugula "Must be called on either an affine read or write op"); 200755dc07dSRiver Riddle assert(memRefDim && "memRefDim == nullptr"); 201755dc07dSRiver Riddle auto memRefType = memoryOp.getMemRefType(); 202755dc07dSRiver Riddle 203755dc07dSRiver Riddle if (!memRefType.getLayout().isIdentity()) 2042679d379SUday Bondhugula return memoryOp.emitError("NYI: non-trivial layout map"), false; 205755dc07dSRiver Riddle 206755dc07dSRiver Riddle int uniqueVaryingIndexAlongIv = -1; 207755dc07dSRiver Riddle auto accessMap = memoryOp.getAffineMap(); 208755dc07dSRiver Riddle SmallVector<Value, 4> mapOperands(memoryOp.getMapOperands()); 209755dc07dSRiver Riddle unsigned numDims = accessMap.getNumDims(); 210755dc07dSRiver Riddle for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) { 2112679d379SUday Bondhugula // Gather map operands used in result expr 'i' in 'exprOperands'. 212755dc07dSRiver Riddle SmallVector<Value, 4> exprOperands; 213755dc07dSRiver Riddle auto resultExpr = accessMap.getResult(i); 214755dc07dSRiver Riddle resultExpr.walk([&](AffineExpr expr) { 2151609f1c2Slong.chen if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 216755dc07dSRiver Riddle exprOperands.push_back(mapOperands[dimExpr.getPosition()]); 2171609f1c2Slong.chen else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) 218755dc07dSRiver Riddle exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]); 219755dc07dSRiver Riddle }); 220755dc07dSRiver Riddle // Check access invariance of each operand in 'exprOperands'. 2212679d379SUday Bondhugula for (Value exprOperand : exprOperands) { 222755dc07dSRiver Riddle if (!isAccessIndexInvariant(iv, exprOperand)) { 223755dc07dSRiver Riddle if (uniqueVaryingIndexAlongIv != -1) { 224755dc07dSRiver Riddle // 2+ varying indices -> do not vectorize along iv. 225755dc07dSRiver Riddle return false; 226755dc07dSRiver Riddle } 227755dc07dSRiver Riddle uniqueVaryingIndexAlongIv = i; 228755dc07dSRiver Riddle } 229755dc07dSRiver Riddle } 230755dc07dSRiver Riddle } 231755dc07dSRiver Riddle 232755dc07dSRiver Riddle if (uniqueVaryingIndexAlongIv == -1) 233755dc07dSRiver Riddle *memRefDim = -1; 234755dc07dSRiver Riddle else 235755dc07dSRiver Riddle *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1); 236755dc07dSRiver Riddle return true; 237755dc07dSRiver Riddle } 238755dc07dSRiver Riddle 2392679d379SUday Bondhugula template bool mlir::affine::isContiguousAccess(Value iv, 2402679d379SUday Bondhugula AffineReadOpInterface loadOp, 2412679d379SUday Bondhugula int *memRefDim); 2422679d379SUday Bondhugula template bool mlir::affine::isContiguousAccess(Value iv, 2432679d379SUday Bondhugula AffineWriteOpInterface loadOp, 2442679d379SUday Bondhugula int *memRefDim); 2452679d379SUday Bondhugula 246755dc07dSRiver Riddle template <typename LoadOrStoreOp> 247755dc07dSRiver Riddle static bool isVectorElement(LoadOrStoreOp memoryOp) { 248755dc07dSRiver Riddle auto memRefType = memoryOp.getMemRefType(); 2495550c821STres Popp return isa<VectorType>(memRefType.getElementType()); 250755dc07dSRiver Riddle } 251755dc07dSRiver Riddle 252755dc07dSRiver Riddle using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>; 253755dc07dSRiver Riddle 254755dc07dSRiver Riddle static bool 255755dc07dSRiver Riddle isVectorizableLoopBodyWithOpCond(AffineForOp loop, 256755dc07dSRiver Riddle const VectorizableOpFun &isVectorizableOp, 257755dc07dSRiver Riddle NestedPattern &vectorTransferMatcher) { 258755dc07dSRiver Riddle auto *forOp = loop.getOperation(); 259755dc07dSRiver Riddle 260755dc07dSRiver Riddle // No vectorization across conditionals for now. 261755dc07dSRiver Riddle auto conditionals = matcher::If(); 262755dc07dSRiver Riddle SmallVector<NestedMatch, 8> conditionalsMatched; 263755dc07dSRiver Riddle conditionals.match(forOp, &conditionalsMatched); 264755dc07dSRiver Riddle if (!conditionalsMatched.empty()) { 265755dc07dSRiver Riddle return false; 266755dc07dSRiver Riddle } 267755dc07dSRiver Riddle 2683974ecb7SJoshua Cao // No vectorization for ops with operand or result types that are not 2693974ecb7SJoshua Cao // vectorizable. 2703974ecb7SJoshua Cao auto types = matcher::Op([](Operation &op) -> bool { 2713974ecb7SJoshua Cao if (llvm::any_of(op.getOperandTypes(), [](Type type) { 2723974ecb7SJoshua Cao if (MemRefType t = dyn_cast<MemRefType>(type)) 2733974ecb7SJoshua Cao return !VectorType::isValidElementType(t.getElementType()); 2743974ecb7SJoshua Cao return !VectorType::isValidElementType(type); 2753974ecb7SJoshua Cao })) 2763974ecb7SJoshua Cao return true; 2773974ecb7SJoshua Cao return llvm::any_of(op.getResultTypes(), [](Type type) { 2783974ecb7SJoshua Cao return !VectorType::isValidElementType(type); 2793974ecb7SJoshua Cao }); 2803974ecb7SJoshua Cao }); 2813974ecb7SJoshua Cao SmallVector<NestedMatch, 8> opsMatched; 2823974ecb7SJoshua Cao types.match(forOp, &opsMatched); 2833974ecb7SJoshua Cao if (!opsMatched.empty()) { 2843974ecb7SJoshua Cao return false; 2853974ecb7SJoshua Cao } 2863974ecb7SJoshua Cao 287755dc07dSRiver Riddle // No vectorization across unknown regions. 288755dc07dSRiver Riddle auto regions = matcher::Op([](Operation &op) -> bool { 289755dc07dSRiver Riddle return op.getNumRegions() != 0 && !isa<AffineIfOp, AffineForOp>(op); 290755dc07dSRiver Riddle }); 291755dc07dSRiver Riddle SmallVector<NestedMatch, 8> regionsMatched; 292755dc07dSRiver Riddle regions.match(forOp, ®ionsMatched); 293755dc07dSRiver Riddle if (!regionsMatched.empty()) { 294755dc07dSRiver Riddle return false; 295755dc07dSRiver Riddle } 296755dc07dSRiver Riddle 297755dc07dSRiver Riddle SmallVector<NestedMatch, 8> vectorTransfersMatched; 298755dc07dSRiver Riddle vectorTransferMatcher.match(forOp, &vectorTransfersMatched); 299755dc07dSRiver Riddle if (!vectorTransfersMatched.empty()) { 300755dc07dSRiver Riddle return false; 301755dc07dSRiver Riddle } 302755dc07dSRiver Riddle 303755dc07dSRiver Riddle auto loadAndStores = matcher::Op(matcher::isLoadOrStore); 304755dc07dSRiver Riddle SmallVector<NestedMatch, 8> loadAndStoresMatched; 305755dc07dSRiver Riddle loadAndStores.match(forOp, &loadAndStoresMatched); 306755dc07dSRiver Riddle for (auto ls : loadAndStoresMatched) { 307755dc07dSRiver Riddle auto *op = ls.getMatchedOperation(); 308755dc07dSRiver Riddle auto load = dyn_cast<AffineLoadOp>(op); 309755dc07dSRiver Riddle auto store = dyn_cast<AffineStoreOp>(op); 310755dc07dSRiver Riddle // Only scalar types are considered vectorizable, all load/store must be 311755dc07dSRiver Riddle // vectorizable for a loop to qualify as vectorizable. 312755dc07dSRiver Riddle // TODO: ponder whether we want to be more general here. 313755dc07dSRiver Riddle bool vector = load ? isVectorElement(load) : isVectorElement(store); 314755dc07dSRiver Riddle if (vector) { 315755dc07dSRiver Riddle return false; 316755dc07dSRiver Riddle } 317755dc07dSRiver Riddle if (isVectorizableOp && !isVectorizableOp(loop, *op)) { 318755dc07dSRiver Riddle return false; 319755dc07dSRiver Riddle } 320755dc07dSRiver Riddle } 321755dc07dSRiver Riddle return true; 322755dc07dSRiver Riddle } 323755dc07dSRiver Riddle 3244c48f016SMatthias Springer bool mlir::affine::isVectorizableLoopBody( 3254c48f016SMatthias Springer AffineForOp loop, int *memRefDim, NestedPattern &vectorTransferMatcher) { 326755dc07dSRiver Riddle *memRefDim = -1; 327755dc07dSRiver Riddle VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { 328755dc07dSRiver Riddle auto load = dyn_cast<AffineLoadOp>(op); 329755dc07dSRiver Riddle auto store = dyn_cast<AffineStoreOp>(op); 330755dc07dSRiver Riddle int thisOpMemRefDim = -1; 3312679d379SUday Bondhugula bool isContiguous = 3322679d379SUday Bondhugula load ? isContiguousAccess(loop.getInductionVar(), 3332679d379SUday Bondhugula cast<AffineReadOpInterface>(*load), 334755dc07dSRiver Riddle &thisOpMemRefDim) 3352679d379SUday Bondhugula : isContiguousAccess(loop.getInductionVar(), 3362679d379SUday Bondhugula cast<AffineWriteOpInterface>(*store), 337755dc07dSRiver Riddle &thisOpMemRefDim); 338755dc07dSRiver Riddle if (thisOpMemRefDim != -1) { 339755dc07dSRiver Riddle // If memory accesses vary across different dimensions then the loop is 340755dc07dSRiver Riddle // not vectorizable. 341755dc07dSRiver Riddle if (*memRefDim != -1 && *memRefDim != thisOpMemRefDim) 342755dc07dSRiver Riddle return false; 343755dc07dSRiver Riddle *memRefDim = thisOpMemRefDim; 344755dc07dSRiver Riddle } 345755dc07dSRiver Riddle return isContiguous; 346755dc07dSRiver Riddle }); 347755dc07dSRiver Riddle return isVectorizableLoopBodyWithOpCond(loop, fun, vectorTransferMatcher); 348755dc07dSRiver Riddle } 349755dc07dSRiver Riddle 3504c48f016SMatthias Springer bool mlir::affine::isVectorizableLoopBody( 3514c48f016SMatthias Springer AffineForOp loop, NestedPattern &vectorTransferMatcher) { 352755dc07dSRiver Riddle return isVectorizableLoopBodyWithOpCond(loop, nullptr, vectorTransferMatcher); 353755dc07dSRiver Riddle } 354755dc07dSRiver Riddle 355755dc07dSRiver Riddle /// Checks whether SSA dominance would be violated if a for op's body 356755dc07dSRiver Riddle /// operations are shifted by the specified shifts. This method checks if a 357755dc07dSRiver Riddle /// 'def' and all its uses have the same shift factor. 358755dc07dSRiver Riddle // TODO: extend this to check for memory-based dependence violation when we have 359755dc07dSRiver Riddle // the support. 3604c48f016SMatthias Springer bool mlir::affine::isOpwiseShiftValid(AffineForOp forOp, 3614c48f016SMatthias Springer ArrayRef<uint64_t> shifts) { 362755dc07dSRiver Riddle auto *forBody = forOp.getBody(); 363755dc07dSRiver Riddle assert(shifts.size() == forBody->getOperations().size()); 364755dc07dSRiver Riddle 365755dc07dSRiver Riddle // Work backwards over the body of the block so that the shift of a use's 366755dc07dSRiver Riddle // ancestor operation in the block gets recorded before it's looked up. 367755dc07dSRiver Riddle DenseMap<Operation *, uint64_t> forBodyShift; 368755dc07dSRiver Riddle for (const auto &it : 369755dc07dSRiver Riddle llvm::enumerate(llvm::reverse(forBody->getOperations()))) { 370755dc07dSRiver Riddle auto &op = it.value(); 371755dc07dSRiver Riddle 372755dc07dSRiver Riddle // Get the index of the current operation, note that we are iterating in 373755dc07dSRiver Riddle // reverse so we need to fix it up. 374755dc07dSRiver Riddle size_t index = shifts.size() - it.index() - 1; 375755dc07dSRiver Riddle 376755dc07dSRiver Riddle // Remember the shift of this operation. 377755dc07dSRiver Riddle uint64_t shift = shifts[index]; 378755dc07dSRiver Riddle forBodyShift.try_emplace(&op, shift); 379755dc07dSRiver Riddle 380755dc07dSRiver Riddle // Validate the results of this operation if it were to be shifted. 381755dc07dSRiver Riddle for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { 382755dc07dSRiver Riddle Value result = op.getResult(i); 383755dc07dSRiver Riddle for (auto *user : result.getUsers()) { 384755dc07dSRiver Riddle // If an ancestor operation doesn't lie in the block of forOp, 385755dc07dSRiver Riddle // there is no shift to check. 386755dc07dSRiver Riddle if (auto *ancOp = forBody->findAncestorOpInBlock(*user)) { 387755dc07dSRiver Riddle assert(forBodyShift.count(ancOp) > 0 && "ancestor expected in map"); 388755dc07dSRiver Riddle if (shift != forBodyShift[ancOp]) 389755dc07dSRiver Riddle return false; 390755dc07dSRiver Riddle } 391755dc07dSRiver Riddle } 392755dc07dSRiver Riddle } 393755dc07dSRiver Riddle } 394755dc07dSRiver Riddle return true; 395755dc07dSRiver Riddle } 396*fe04aafeSUday Bondhugula 397*fe04aafeSUday Bondhugula bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) { 398*fe04aafeSUday Bondhugula assert(!loops.empty() && "no original loops provided"); 399*fe04aafeSUday Bondhugula 400*fe04aafeSUday Bondhugula // We first find out all dependences we intend to check. 401*fe04aafeSUday Bondhugula SmallVector<Operation *, 8> loadAndStoreOps; 402*fe04aafeSUday Bondhugula loops[0]->walk([&](Operation *op) { 403*fe04aafeSUday Bondhugula if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) 404*fe04aafeSUday Bondhugula loadAndStoreOps.push_back(op); 405*fe04aafeSUday Bondhugula }); 406*fe04aafeSUday Bondhugula 407*fe04aafeSUday Bondhugula unsigned numOps = loadAndStoreOps.size(); 408*fe04aafeSUday Bondhugula unsigned numLoops = loops.size(); 409*fe04aafeSUday Bondhugula for (unsigned d = 1; d <= numLoops + 1; ++d) { 410*fe04aafeSUday Bondhugula for (unsigned i = 0; i < numOps; ++i) { 411*fe04aafeSUday Bondhugula Operation *srcOp = loadAndStoreOps[i]; 412*fe04aafeSUday Bondhugula MemRefAccess srcAccess(srcOp); 413*fe04aafeSUday Bondhugula for (unsigned j = 0; j < numOps; ++j) { 414*fe04aafeSUday Bondhugula Operation *dstOp = loadAndStoreOps[j]; 415*fe04aafeSUday Bondhugula MemRefAccess dstAccess(dstOp); 416*fe04aafeSUday Bondhugula 417*fe04aafeSUday Bondhugula SmallVector<DependenceComponent, 2> depComps; 418*fe04aafeSUday Bondhugula DependenceResult result = checkMemrefAccessDependence( 419*fe04aafeSUday Bondhugula srcAccess, dstAccess, d, /*dependenceConstraints=*/nullptr, 420*fe04aafeSUday Bondhugula &depComps); 421*fe04aafeSUday Bondhugula 422*fe04aafeSUday Bondhugula // Skip if there is no dependence in this case. 423*fe04aafeSUday Bondhugula if (!hasDependence(result)) 424*fe04aafeSUday Bondhugula continue; 425*fe04aafeSUday Bondhugula 426*fe04aafeSUday Bondhugula // Check whether there is any negative direction vector in the 427*fe04aafeSUday Bondhugula // dependence components found above, which means that dependence is 428*fe04aafeSUday Bondhugula // violated by the default hyper-rect tiling method. 429*fe04aafeSUday Bondhugula LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated " 430*fe04aafeSUday Bondhugula "for dependence at depth: " 431*fe04aafeSUday Bondhugula << Twine(d) << " between:\n";); 432*fe04aafeSUday Bondhugula LLVM_DEBUG(srcAccess.opInst->dump()); 433*fe04aafeSUday Bondhugula LLVM_DEBUG(dstAccess.opInst->dump()); 434*fe04aafeSUday Bondhugula for (const DependenceComponent &depComp : depComps) { 435*fe04aafeSUday Bondhugula if (depComp.lb.has_value() && depComp.ub.has_value() && 436*fe04aafeSUday Bondhugula *depComp.lb < *depComp.ub && *depComp.ub < 0) { 437*fe04aafeSUday Bondhugula LLVM_DEBUG(llvm::dbgs() 438*fe04aafeSUday Bondhugula << "Dependence component lb = " << Twine(*depComp.lb) 439*fe04aafeSUday Bondhugula << " ub = " << Twine(*depComp.ub) 440*fe04aafeSUday Bondhugula << " is negative at depth: " << Twine(d) 441*fe04aafeSUday Bondhugula << " and thus violates the legality rule.\n"); 442*fe04aafeSUday Bondhugula return false; 443*fe04aafeSUday Bondhugula } 444*fe04aafeSUday Bondhugula } 445*fe04aafeSUday Bondhugula } 446*fe04aafeSUday Bondhugula } 447*fe04aafeSUday Bondhugula } 448*fe04aafeSUday Bondhugula 449*fe04aafeSUday Bondhugula return true; 450*fe04aafeSUday Bondhugula } 451