1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/AffineMap.h" 10 #include "AffineMapDetail.h" 11 #include "mlir/IR/AffineExpr.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinAttributes.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/ADT/SmallBitVector.h" 17 #include "llvm/ADT/SmallSet.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/StringRef.h" 20 #include "llvm/Support/MathExtras.h" 21 #include "llvm/Support/raw_ostream.h" 22 #include <iterator> 23 #include <numeric> 24 #include <optional> 25 #include <type_traits> 26 27 using namespace mlir; 28 29 using llvm::divideCeilSigned; 30 using llvm::divideFloorSigned; 31 using llvm::mod; 32 33 namespace { 34 35 // AffineExprConstantFolder evaluates an affine expression using constant 36 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute 37 // representing the constant value of the affine expression evaluated on 38 // constant 'operandConsts', or nullptr if it can't be folded. 39 class AffineExprConstantFolder { 40 public: 41 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts) 42 : numDims(numDims), operandConsts(operandConsts) {} 43 44 /// Attempt to constant fold the specified affine expr, or return null on 45 /// failure. 46 IntegerAttr constantFold(AffineExpr expr) { 47 if (auto result = constantFoldImpl(expr)) 48 return IntegerAttr::get(IndexType::get(expr.getContext()), *result); 49 return nullptr; 50 } 51 52 bool hasPoison() const { return hasPoison_; } 53 54 private: 55 std::optional<int64_t> constantFoldImpl(AffineExpr expr) { 56 switch (expr.getKind()) { 57 case AffineExprKind::Add: 58 return constantFoldBinExpr( 59 expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; }); 60 case AffineExprKind::Mul: 61 return constantFoldBinExpr( 62 expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; }); 63 case AffineExprKind::Mod: 64 return constantFoldBinExpr( 65 expr, [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { 66 if (rhs < 1) { 67 hasPoison_ = true; 68 return std::nullopt; 69 } 70 return mod(lhs, rhs); 71 }); 72 case AffineExprKind::FloorDiv: 73 return constantFoldBinExpr( 74 expr, [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { 75 if (rhs == 0) { 76 hasPoison_ = true; 77 return std::nullopt; 78 } 79 return divideFloorSigned(lhs, rhs); 80 }); 81 case AffineExprKind::CeilDiv: 82 return constantFoldBinExpr( 83 expr, [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { 84 if (rhs == 0) { 85 hasPoison_ = true; 86 return std::nullopt; 87 } 88 return divideCeilSigned(lhs, rhs); 89 }); 90 case AffineExprKind::Constant: 91 return cast<AffineConstantExpr>(expr).getValue(); 92 case AffineExprKind::DimId: 93 if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>( 94 operandConsts[cast<AffineDimExpr>(expr).getPosition()])) 95 return attr.getInt(); 96 return std::nullopt; 97 case AffineExprKind::SymbolId: 98 if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>( 99 operandConsts[numDims + 100 cast<AffineSymbolExpr>(expr).getPosition()])) 101 return attr.getInt(); 102 return std::nullopt; 103 } 104 llvm_unreachable("Unknown AffineExpr"); 105 } 106 107 // TODO: Change these to operate on APInts too. 108 std::optional<int64_t> constantFoldBinExpr( 109 AffineExpr expr, 110 llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) { 111 auto binOpExpr = cast<AffineBinaryOpExpr>(expr); 112 if (auto lhs = constantFoldImpl(binOpExpr.getLHS())) 113 if (auto rhs = constantFoldImpl(binOpExpr.getRHS())) 114 return op(*lhs, *rhs); 115 return std::nullopt; 116 } 117 118 // The number of dimension operands in AffineMap containing this expression. 119 unsigned numDims; 120 // The constant valued operands used to evaluate this AffineExpr. 121 ArrayRef<Attribute> operandConsts; 122 bool hasPoison_{false}; 123 }; 124 125 } // namespace 126 127 /// Returns a single constant result affine map. 128 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { 129 return get(/*dimCount=*/0, /*symbolCount=*/0, 130 {getAffineConstantExpr(val, context)}); 131 } 132 133 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most 134 /// minor dimensions. 135 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results, 136 MLIRContext *context) { 137 assert(dims >= results && "Dimension mismatch"); 138 auto id = AffineMap::getMultiDimIdentityMap(dims, context); 139 return AffineMap::get(dims, 0, id.getResults().take_back(results), context); 140 } 141 142 AffineMap AffineMap::getFilteredIdentityMap( 143 MLIRContext *ctx, unsigned numDims, 144 llvm::function_ref<bool(AffineDimExpr)> keepDimFilter) { 145 auto identityMap = getMultiDimIdentityMap(numDims, ctx); 146 147 // Apply filter to results. 148 llvm::SmallBitVector dropDimResults(numDims); 149 for (auto [idx, resultExpr] : llvm::enumerate(identityMap.getResults())) 150 dropDimResults[idx] = !keepDimFilter(cast<AffineDimExpr>(resultExpr)); 151 152 return identityMap.dropResults(dropDimResults); 153 } 154 155 bool AffineMap::isMinorIdentity() const { 156 return getNumDims() >= getNumResults() && 157 *this == 158 getMinorIdentityMap(getNumDims(), getNumResults(), getContext()); 159 } 160 161 SmallVector<unsigned> AffineMap::getBroadcastDims() const { 162 SmallVector<unsigned> broadcastedDims; 163 for (const auto &[resIdx, expr] : llvm::enumerate(getResults())) { 164 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) { 165 if (constExpr.getValue() != 0) 166 continue; 167 broadcastedDims.push_back(resIdx); 168 } 169 } 170 171 return broadcastedDims; 172 } 173 174 /// Returns true if this affine map is a minor identity up to broadcasted 175 /// dimensions which are indicated by value 0 in the result. 176 bool AffineMap::isMinorIdentityWithBroadcasting( 177 SmallVectorImpl<unsigned> *broadcastedDims) const { 178 if (broadcastedDims) 179 broadcastedDims->clear(); 180 if (getNumDims() < getNumResults()) 181 return false; 182 unsigned suffixStart = getNumDims() - getNumResults(); 183 for (const auto &idxAndExpr : llvm::enumerate(getResults())) { 184 unsigned resIdx = idxAndExpr.index(); 185 AffineExpr expr = idxAndExpr.value(); 186 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) { 187 // Each result may be either a constant 0 (broadcasted dimension). 188 if (constExpr.getValue() != 0) 189 return false; 190 if (broadcastedDims) 191 broadcastedDims->push_back(resIdx); 192 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 193 // Or it may be the input dimension corresponding to this result position. 194 if (dimExpr.getPosition() != suffixStart + resIdx) 195 return false; 196 } else { 197 return false; 198 } 199 } 200 return true; 201 } 202 203 /// Return true if this affine map can be converted to a minor identity with 204 /// broadcast by doing a permute. Return a permutation (there may be 205 /// several) to apply to get to a minor identity with broadcasts. 206 /// Ex: 207 /// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with 208 /// perm = [1, 0] and broadcast d2 209 /// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by 210 /// permutation + broadcast 211 /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3) 212 /// with perm = [1, 0, 2] and broadcast d2 213 /// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra 214 /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with 215 /// perm = [3, 0, 1, 2] 216 bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting( 217 SmallVectorImpl<unsigned> &permutedDims) const { 218 unsigned projectionStart = 219 getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0; 220 permutedDims.clear(); 221 SmallVector<unsigned> broadcastDims; 222 permutedDims.resize(getNumResults(), 0); 223 // If there are more results than input dimensions we want the new map to 224 // start with broadcast dimensions in order to be a minor identity with 225 // broadcasting. 226 unsigned leadingBroadcast = 227 getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0; 228 llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()), 229 false); 230 for (const auto &idxAndExpr : llvm::enumerate(getResults())) { 231 unsigned resIdx = idxAndExpr.index(); 232 AffineExpr expr = idxAndExpr.value(); 233 // Each result may be either a constant 0 (broadcast dimension) or a 234 // dimension. 235 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) { 236 if (constExpr.getValue() != 0) 237 return false; 238 broadcastDims.push_back(resIdx); 239 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 240 if (dimExpr.getPosition() < projectionStart) 241 return false; 242 unsigned newPosition = 243 dimExpr.getPosition() - projectionStart + leadingBroadcast; 244 permutedDims[resIdx] = newPosition; 245 dimFound[newPosition] = true; 246 } else { 247 return false; 248 } 249 } 250 // Find a permuation for the broadcast dimension. Since they are broadcasted 251 // any valid permutation is acceptable. We just permute the dim into a slot 252 // without an existing dimension. 253 unsigned pos = 0; 254 for (auto dim : broadcastDims) { 255 while (pos < dimFound.size() && dimFound[pos]) { 256 pos++; 257 } 258 permutedDims[dim] = pos++; 259 } 260 return true; 261 } 262 263 /// Returns an AffineMap representing a permutation. 264 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation, 265 MLIRContext *context) { 266 assert(!permutation.empty() && 267 "Cannot create permutation map from empty permutation vector"); 268 const auto *m = llvm::max_element(permutation); 269 auto permutationMap = getMultiDimMapWithTargets(*m + 1, permutation, context); 270 assert(permutationMap.isPermutation() && "Invalid permutation vector"); 271 return permutationMap; 272 } 273 AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation, 274 MLIRContext *context) { 275 SmallVector<unsigned> perm = llvm::map_to_vector( 276 permutation, [](int64_t i) { return static_cast<unsigned>(i); }); 277 return AffineMap::getPermutationMap(perm, context); 278 } 279 280 AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims, 281 ArrayRef<unsigned> targets, 282 MLIRContext *context) { 283 SmallVector<AffineExpr, 4> affExprs; 284 for (unsigned t : targets) 285 affExprs.push_back(getAffineDimExpr(t, context)); 286 AffineMap result = AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0, 287 affExprs, context); 288 return result; 289 } 290 291 /// Creates an affine map each for each list of AffineExpr's in `exprsList` 292 /// while inferring the right number of dimensional and symbolic inputs needed 293 /// based on the maximum dimensional and symbolic identifier appearing in the 294 /// expressions. 295 template <typename AffineExprContainer> 296 static SmallVector<AffineMap, 4> 297 inferFromExprList(ArrayRef<AffineExprContainer> exprsList, 298 MLIRContext *context) { 299 if (exprsList.empty()) 300 return {}; 301 int64_t maxDim = -1, maxSym = -1; 302 getMaxDimAndSymbol(exprsList, maxDim, maxSym); 303 SmallVector<AffineMap, 4> maps; 304 maps.reserve(exprsList.size()); 305 for (const auto &exprs : exprsList) 306 maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1, 307 /*symbolCount=*/maxSym + 1, exprs, context)); 308 return maps; 309 } 310 311 SmallVector<AffineMap, 4> 312 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList, 313 MLIRContext *context) { 314 return ::inferFromExprList(exprsList, context); 315 } 316 317 SmallVector<AffineMap, 4> 318 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList, 319 MLIRContext *context) { 320 return ::inferFromExprList(exprsList, context); 321 } 322 323 uint64_t AffineMap::getLargestKnownDivisorOfMapExprs() { 324 uint64_t gcd = 0; 325 for (AffineExpr resultExpr : getResults()) { 326 uint64_t thisGcd = resultExpr.getLargestKnownDivisor(); 327 gcd = std::gcd(gcd, thisGcd); 328 } 329 if (gcd == 0) 330 gcd = std::numeric_limits<uint64_t>::max(); 331 return gcd; 332 } 333 334 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, 335 MLIRContext *context) { 336 SmallVector<AffineExpr, 4> dimExprs; 337 dimExprs.reserve(numDims); 338 for (unsigned i = 0; i < numDims; ++i) 339 dimExprs.push_back(mlir::getAffineDimExpr(i, context)); 340 return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context); 341 } 342 343 MLIRContext *AffineMap::getContext() const { return map->context; } 344 345 bool AffineMap::isIdentity() const { 346 if (getNumDims() != getNumResults()) 347 return false; 348 ArrayRef<AffineExpr> results = getResults(); 349 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) { 350 auto expr = dyn_cast<AffineDimExpr>(results[i]); 351 if (!expr || expr.getPosition() != i) 352 return false; 353 } 354 return true; 355 } 356 357 bool AffineMap::isSymbolIdentity() const { 358 if (getNumSymbols() != getNumResults()) 359 return false; 360 ArrayRef<AffineExpr> results = getResults(); 361 for (unsigned i = 0, numSymbols = getNumSymbols(); i < numSymbols; ++i) { 362 auto expr = dyn_cast<AffineDimExpr>(results[i]); 363 if (!expr || expr.getPosition() != i) 364 return false; 365 } 366 return true; 367 } 368 369 bool AffineMap::isEmpty() const { 370 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0; 371 } 372 373 bool AffineMap::isSingleConstant() const { 374 return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0)); 375 } 376 377 bool AffineMap::isConstant() const { 378 return llvm::all_of(getResults(), llvm::IsaPred<AffineConstantExpr>); 379 } 380 381 int64_t AffineMap::getSingleConstantResult() const { 382 assert(isSingleConstant() && "map must have a single constant result"); 383 return cast<AffineConstantExpr>(getResult(0)).getValue(); 384 } 385 386 SmallVector<int64_t> AffineMap::getConstantResults() const { 387 assert(isConstant() && "map must have only constant results"); 388 SmallVector<int64_t> result; 389 for (auto expr : getResults()) 390 result.emplace_back(cast<AffineConstantExpr>(expr).getValue()); 391 return result; 392 } 393 394 unsigned AffineMap::getNumDims() const { 395 assert(map && "uninitialized map storage"); 396 return map->numDims; 397 } 398 unsigned AffineMap::getNumSymbols() const { 399 assert(map && "uninitialized map storage"); 400 return map->numSymbols; 401 } 402 unsigned AffineMap::getNumResults() const { return getResults().size(); } 403 unsigned AffineMap::getNumInputs() const { 404 assert(map && "uninitialized map storage"); 405 return map->numDims + map->numSymbols; 406 } 407 ArrayRef<AffineExpr> AffineMap::getResults() const { 408 assert(map && "uninitialized map storage"); 409 return map->results(); 410 } 411 AffineExpr AffineMap::getResult(unsigned idx) const { 412 return getResults()[idx]; 413 } 414 415 unsigned AffineMap::getDimPosition(unsigned idx) const { 416 return cast<AffineDimExpr>(getResult(idx)).getPosition(); 417 } 418 419 std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const { 420 if (!isa<AffineDimExpr>(input)) 421 return std::nullopt; 422 423 for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) { 424 if (getResult(i) == input) 425 return i; 426 } 427 428 return std::nullopt; 429 } 430 431 /// Folds the results of the application of an affine map on the provided 432 /// operands to a constant if possible. Returns false if the folding happens, 433 /// true otherwise. 434 LogicalResult AffineMap::constantFold(ArrayRef<Attribute> operandConstants, 435 SmallVectorImpl<Attribute> &results, 436 bool *hasPoison) const { 437 // Attempt partial folding. 438 SmallVector<int64_t, 2> integers; 439 partialConstantFold(operandConstants, &integers, hasPoison); 440 441 // If all expressions folded to a constant, populate results with attributes 442 // containing those constants. 443 if (integers.empty()) 444 return failure(); 445 446 auto range = llvm::map_range(integers, [this](int64_t i) { 447 return IntegerAttr::get(IndexType::get(getContext()), i); 448 }); 449 results.append(range.begin(), range.end()); 450 return success(); 451 } 452 453 AffineMap AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, 454 SmallVectorImpl<int64_t> *results, 455 bool *hasPoison) const { 456 assert(getNumInputs() == operandConstants.size()); 457 458 // Fold each of the result expressions. 459 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); 460 SmallVector<AffineExpr, 4> exprs; 461 exprs.reserve(getNumResults()); 462 463 for (auto expr : getResults()) { 464 auto folded = exprFolder.constantFold(expr); 465 if (exprFolder.hasPoison() && hasPoison) { 466 *hasPoison = true; 467 return {}; 468 } 469 // If did not fold to a constant, keep the original expression, and clear 470 // the integer results vector. 471 if (folded) { 472 exprs.push_back( 473 getAffineConstantExpr(folded.getInt(), folded.getContext())); 474 if (results) 475 results->push_back(folded.getInt()); 476 } else { 477 exprs.push_back(expr); 478 if (results) { 479 results->clear(); 480 results = nullptr; 481 } 482 } 483 } 484 485 return get(getNumDims(), getNumSymbols(), exprs, getContext()); 486 } 487 488 /// Walk all of the AffineExpr's in this mapping. Each node in an expression 489 /// tree is visited in postorder. 490 void AffineMap::walkExprs(llvm::function_ref<void(AffineExpr)> callback) const { 491 for (auto expr : getResults()) 492 expr.walk(callback); 493 } 494 495 /// This method substitutes any uses of dimensions and symbols (e.g. 496 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified 497 /// expression mapping. Because this can be used to eliminate dims and 498 /// symbols, the client needs to specify the number of dims and symbols in 499 /// the result. The returned map always has the same number of results. 500 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 501 ArrayRef<AffineExpr> symReplacements, 502 unsigned numResultDims, 503 unsigned numResultSyms) const { 504 SmallVector<AffineExpr, 8> results; 505 results.reserve(getNumResults()); 506 for (auto expr : getResults()) 507 results.push_back( 508 expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); 509 return get(numResultDims, numResultSyms, results, getContext()); 510 } 511 512 /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to 513 /// each of the results and return a new AffineMap with the new results and 514 /// with the specified number of dims and symbols. 515 AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement, 516 unsigned numResultDims, 517 unsigned numResultSyms) const { 518 SmallVector<AffineExpr, 4> newResults; 519 newResults.reserve(getNumResults()); 520 for (AffineExpr e : getResults()) 521 newResults.push_back(e.replace(expr, replacement)); 522 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext()); 523 } 524 525 /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the 526 /// results and return a new AffineMap with the new results and with the 527 /// specified number of dims and symbols. 528 AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map, 529 unsigned numResultDims, 530 unsigned numResultSyms) const { 531 SmallVector<AffineExpr, 4> newResults; 532 newResults.reserve(getNumResults()); 533 for (AffineExpr e : getResults()) 534 newResults.push_back(e.replace(map)); 535 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext()); 536 } 537 538 AffineMap 539 AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const { 540 SmallVector<AffineExpr, 4> newResults; 541 newResults.reserve(getNumResults()); 542 for (AffineExpr e : getResults()) 543 newResults.push_back(e.replace(map)); 544 return AffineMap::inferFromExprList(newResults, getContext()).front(); 545 } 546 547 AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const { 548 auto exprs = llvm::to_vector<4>(getResults()); 549 // TODO: this is a pretty terrible API .. is there anything better? 550 for (auto pos = positions.find_last(); pos != -1; 551 pos = positions.find_prev(pos)) 552 exprs.erase(exprs.begin() + pos); 553 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); 554 } 555 556 AffineMap AffineMap::compose(AffineMap map) const { 557 assert(getNumDims() == map.getNumResults() && "Number of results mismatch"); 558 // Prepare `map` by concatenating the symbols and rewriting its exprs. 559 unsigned numDims = map.getNumDims(); 560 unsigned numSymbolsThisMap = getNumSymbols(); 561 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols(); 562 SmallVector<AffineExpr, 8> newDims(numDims); 563 for (unsigned idx = 0; idx < numDims; ++idx) { 564 newDims[idx] = getAffineDimExpr(idx, getContext()); 565 } 566 SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap); 567 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) { 568 newSymbols[idx - numSymbolsThisMap] = 569 getAffineSymbolExpr(idx, getContext()); 570 } 571 auto newMap = 572 map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols); 573 SmallVector<AffineExpr, 8> exprs; 574 exprs.reserve(getResults().size()); 575 for (auto expr : getResults()) 576 exprs.push_back(expr.compose(newMap)); 577 return AffineMap::get(numDims, numSymbols, exprs, map.getContext()); 578 } 579 580 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const { 581 assert(getNumSymbols() == 0 && "Expected symbol-less map"); 582 SmallVector<AffineExpr, 4> exprs; 583 exprs.reserve(values.size()); 584 MLIRContext *ctx = getContext(); 585 for (auto v : values) 586 exprs.push_back(getAffineConstantExpr(v, ctx)); 587 auto resMap = compose(AffineMap::get(0, 0, exprs, ctx)); 588 SmallVector<int64_t, 4> res; 589 res.reserve(resMap.getNumResults()); 590 for (auto e : resMap.getResults()) 591 res.push_back(cast<AffineConstantExpr>(e).getValue()); 592 return res; 593 } 594 595 size_t AffineMap::getNumOfZeroResults() const { 596 size_t res = 0; 597 for (auto expr : getResults()) { 598 auto constExpr = dyn_cast<AffineConstantExpr>(expr); 599 if (constExpr && constExpr.getValue() == 0) 600 res++; 601 } 602 603 return res; 604 } 605 606 AffineMap AffineMap::dropZeroResults() { 607 auto exprs = llvm::to_vector(getResults()); 608 SmallVector<AffineExpr> newExprs; 609 610 for (auto expr : getResults()) { 611 auto constExpr = dyn_cast<AffineConstantExpr>(expr); 612 if (!constExpr || constExpr.getValue() != 0) 613 newExprs.push_back(expr); 614 } 615 return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext()); 616 } 617 618 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const { 619 if (getNumSymbols() > 0) 620 return false; 621 622 // Having more results than inputs means that results have duplicated dims or 623 // zeros that can't be mapped to input dims. 624 if (getNumResults() > getNumInputs()) 625 return false; 626 627 SmallVector<bool, 8> seen(getNumInputs(), false); 628 // A projected permutation can have, at most, only one instance of each input 629 // dimension in the result expressions. Zeros are allowed as long as the 630 // number of result expressions is lower or equal than the number of input 631 // expressions. 632 for (auto expr : getResults()) { 633 if (auto dim = dyn_cast<AffineDimExpr>(expr)) { 634 if (seen[dim.getPosition()]) 635 return false; 636 seen[dim.getPosition()] = true; 637 } else { 638 auto constExpr = dyn_cast<AffineConstantExpr>(expr); 639 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0) 640 return false; 641 } 642 } 643 644 // Results are either dims or zeros and zeros can be mapped to input dims. 645 return true; 646 } 647 648 bool AffineMap::isPermutation() const { 649 if (getNumDims() != getNumResults()) 650 return false; 651 return isProjectedPermutation(); 652 } 653 654 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const { 655 SmallVector<AffineExpr, 4> exprs; 656 exprs.reserve(resultPos.size()); 657 for (auto idx : resultPos) 658 exprs.push_back(getResult(idx)); 659 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); 660 } 661 662 AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const { 663 return AffineMap::get(getNumDims(), getNumSymbols(), 664 getResults().slice(start, length), getContext()); 665 } 666 667 AffineMap AffineMap::getMajorSubMap(unsigned numResults) const { 668 if (numResults == 0) 669 return AffineMap(); 670 if (numResults > getNumResults()) 671 return *this; 672 return getSliceMap(0, numResults); 673 } 674 675 AffineMap AffineMap::getMinorSubMap(unsigned numResults) const { 676 if (numResults == 0) 677 return AffineMap(); 678 if (numResults > getNumResults()) 679 return *this; 680 return getSliceMap(getNumResults() - numResults, numResults); 681 } 682 683 /// Implementation detail to compress multiple affine maps with a compressionFun 684 /// that is expected to be either compressUnusedDims or compressUnusedSymbols. 685 /// The implementation keeps track of num dims and symbols across the different 686 /// affine maps. 687 static SmallVector<AffineMap> compressUnusedListImpl( 688 ArrayRef<AffineMap> maps, 689 llvm::function_ref<AffineMap(AffineMap)> compressionFun) { 690 if (maps.empty()) 691 return SmallVector<AffineMap>(); 692 SmallVector<AffineExpr> allExprs; 693 allExprs.reserve(maps.size() * maps.front().getNumResults()); 694 unsigned numDims = maps.front().getNumDims(), 695 numSymbols = maps.front().getNumSymbols(); 696 for (auto m : maps) { 697 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() && 698 "expected maps with same num dims and symbols"); 699 llvm::append_range(allExprs, m.getResults()); 700 } 701 AffineMap unifiedMap = compressionFun( 702 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext())); 703 unsigned unifiedNumDims = unifiedMap.getNumDims(), 704 unifiedNumSymbols = unifiedMap.getNumSymbols(); 705 ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults(); 706 SmallVector<AffineMap> res; 707 res.reserve(maps.size()); 708 for (auto m : maps) { 709 res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols, 710 unifiedResults.take_front(m.getNumResults()), 711 m.getContext())); 712 unifiedResults = unifiedResults.drop_front(m.getNumResults()); 713 } 714 return res; 715 } 716 717 AffineMap mlir::compressDims(AffineMap map, 718 const llvm::SmallBitVector &unusedDims) { 719 return projectDims(map, unusedDims, /*compressDimsFlag=*/true); 720 } 721 722 AffineMap mlir::compressUnusedDims(AffineMap map) { 723 return compressDims(map, getUnusedDimsBitVector({map})); 724 } 725 726 SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) { 727 return compressUnusedListImpl( 728 maps, [](AffineMap m) { return compressUnusedDims(m); }); 729 } 730 731 AffineMap mlir::compressSymbols(AffineMap map, 732 const llvm::SmallBitVector &unusedSymbols) { 733 return projectSymbols(map, unusedSymbols, /*compressSymbolsFlag=*/true); 734 } 735 736 AffineMap mlir::compressUnusedSymbols(AffineMap map) { 737 return compressSymbols(map, getUnusedSymbolsBitVector({map})); 738 } 739 740 SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) { 741 return compressUnusedListImpl( 742 maps, [](AffineMap m) { return compressUnusedSymbols(m); }); 743 } 744 745 AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map, 746 ArrayRef<OpFoldResult> operands, 747 SmallVector<Value> &remainingValues) { 748 SmallVector<AffineExpr> dimReplacements, symReplacements; 749 int64_t numDims = 0; 750 for (int64_t i = 0; i < map.getNumDims(); ++i) { 751 if (auto attr = operands[i].dyn_cast<Attribute>()) { 752 dimReplacements.push_back( 753 b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt())); 754 } else { 755 dimReplacements.push_back(b.getAffineDimExpr(numDims++)); 756 remainingValues.push_back(cast<Value>(operands[i])); 757 } 758 } 759 int64_t numSymbols = 0; 760 for (int64_t i = 0; i < map.getNumSymbols(); ++i) { 761 if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) { 762 symReplacements.push_back( 763 b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt())); 764 } else { 765 symReplacements.push_back(b.getAffineSymbolExpr(numSymbols++)); 766 remainingValues.push_back(cast<Value>(operands[i + map.getNumDims()])); 767 } 768 } 769 return map.replaceDimsAndSymbols(dimReplacements, symReplacements, numDims, 770 numSymbols); 771 } 772 773 AffineMap mlir::simplifyAffineMap(AffineMap map) { 774 SmallVector<AffineExpr, 8> exprs; 775 for (auto e : map.getResults()) { 776 exprs.push_back( 777 simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); 778 } 779 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, 780 map.getContext()); 781 } 782 783 AffineMap mlir::removeDuplicateExprs(AffineMap map) { 784 auto results = map.getResults(); 785 SmallVector<AffineExpr, 4> uniqueExprs(results); 786 uniqueExprs.erase(llvm::unique(uniqueExprs), uniqueExprs.end()); 787 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs, 788 map.getContext()); 789 } 790 791 AffineMap mlir::inversePermutation(AffineMap map) { 792 if (map.isEmpty()) 793 return map; 794 assert(map.getNumSymbols() == 0 && "expected map without symbols"); 795 SmallVector<AffineExpr, 4> exprs(map.getNumDims()); 796 for (const auto &en : llvm::enumerate(map.getResults())) { 797 auto expr = en.value(); 798 // Skip non-permutations. 799 if (auto d = dyn_cast<AffineDimExpr>(expr)) { 800 if (exprs[d.getPosition()]) 801 continue; 802 exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext()); 803 } 804 } 805 SmallVector<AffineExpr, 4> seenExprs; 806 seenExprs.reserve(map.getNumDims()); 807 for (auto expr : exprs) 808 if (expr) 809 seenExprs.push_back(expr); 810 if (seenExprs.size() != map.getNumInputs()) 811 return AffineMap(); 812 return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); 813 } 814 815 AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) { 816 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true)); 817 MLIRContext *context = map.getContext(); 818 AffineExpr zero = mlir::getAffineConstantExpr(0, context); 819 // Start with all the results as 0. 820 SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero); 821 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { 822 // Skip zeros from input map. 'exprs' is already initialized to zero. 823 if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(i))) { 824 assert(constExpr.getValue() == 0 && 825 "Unexpected constant in projected permutation"); 826 (void)constExpr; 827 continue; 828 } 829 830 // Reverse each dimension existing in the original map result. 831 exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context); 832 } 833 return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); 834 } 835 836 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps, 837 MLIRContext *context) { 838 if (maps.empty()) 839 return AffineMap::get(context); 840 unsigned numResults = 0, numDims = 0, numSymbols = 0; 841 for (auto m : maps) 842 numResults += m.getNumResults(); 843 SmallVector<AffineExpr, 8> results; 844 results.reserve(numResults); 845 for (auto m : maps) { 846 for (auto res : m.getResults()) 847 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols)); 848 849 numSymbols += m.getNumSymbols(); 850 numDims = std::max(m.getNumDims(), numDims); 851 } 852 return AffineMap::get(numDims, numSymbols, results, context); 853 } 854 855 /// Common implementation to project out dimensions or symbols from an affine 856 /// map based on the template type. 857 /// Additionally, if 'compress' is true, the projected out dimensions or symbols 858 /// are also dropped from the resulting map. 859 template <typename AffineDimOrSymExpr> 860 static AffineMap projectCommonImpl(AffineMap map, 861 const llvm::SmallBitVector &toProject, 862 bool compress) { 863 static_assert(llvm::is_one_of<AffineDimOrSymExpr, AffineDimExpr, 864 AffineSymbolExpr>::value, 865 "expected AffineDimExpr or AffineSymbolExpr"); 866 867 constexpr bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value; 868 int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols(); 869 SmallVector<AffineExpr> replacements; 870 replacements.reserve(numDimOrSym); 871 872 auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr; 873 874 using replace_fn_ty = 875 std::function<AffineExpr(AffineExpr, ArrayRef<AffineExpr>)>; 876 replace_fn_ty replaceDims = [](AffineExpr e, 877 ArrayRef<AffineExpr> replacements) { 878 return e.replaceDims(replacements); 879 }; 880 replace_fn_ty replaceSymbols = [](AffineExpr e, 881 ArrayRef<AffineExpr> replacements) { 882 return e.replaceSymbols(replacements); 883 }; 884 replace_fn_ty replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols; 885 886 MLIRContext *context = map.getContext(); 887 int64_t newNumDimOrSym = 0; 888 for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) { 889 if (toProject.test(dimOrSym)) { 890 replacements.push_back(getAffineConstantExpr(0, context)); 891 continue; 892 } 893 int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym; 894 replacements.push_back(createNewDimOrSym(newPos, context)); 895 } 896 SmallVector<AffineExpr> resultExprs; 897 resultExprs.reserve(map.getNumResults()); 898 for (auto e : map.getResults()) 899 resultExprs.push_back(replaceNewDimOrSym(e, replacements)); 900 901 int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims(); 902 int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols(); 903 return AffineMap::get(numDims, numSyms, resultExprs, context); 904 } 905 906 AffineMap mlir::projectDims(AffineMap map, 907 const llvm::SmallBitVector &projectedDimensions, 908 bool compressDimsFlag) { 909 return projectCommonImpl<AffineDimExpr>(map, projectedDimensions, 910 compressDimsFlag); 911 } 912 913 AffineMap mlir::projectSymbols(AffineMap map, 914 const llvm::SmallBitVector &projectedSymbols, 915 bool compressSymbolsFlag) { 916 return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols, 917 compressSymbolsFlag); 918 } 919 920 AffineMap mlir::getProjectedMap(AffineMap map, 921 const llvm::SmallBitVector &projectedDimensions, 922 bool compressDimsFlag, 923 bool compressSymbolsFlag) { 924 map = projectDims(map, projectedDimensions, compressDimsFlag); 925 if (compressSymbolsFlag) 926 map = compressUnusedSymbols(map); 927 return map; 928 } 929 930 llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) { 931 unsigned numDims = maps[0].getNumDims(); 932 llvm::SmallBitVector numDimsBitVector(numDims, true); 933 for (AffineMap m : maps) { 934 for (unsigned i = 0; i < numDims; ++i) { 935 if (m.isFunctionOfDim(i)) 936 numDimsBitVector.reset(i); 937 } 938 } 939 return numDimsBitVector; 940 } 941 942 llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps) { 943 unsigned numSymbols = maps[0].getNumSymbols(); 944 llvm::SmallBitVector numSymbolsBitVector(numSymbols, true); 945 for (AffineMap m : maps) { 946 for (unsigned i = 0; i < numSymbols; ++i) { 947 if (m.isFunctionOfSymbol(i)) 948 numSymbolsBitVector.reset(i); 949 } 950 } 951 return numSymbolsBitVector; 952 } 953 954 AffineMap 955 mlir::expandDimsToRank(AffineMap map, int64_t rank, 956 const llvm::SmallBitVector &projectedDimensions) { 957 auto id = AffineMap::getMultiDimIdentityMap(rank, map.getContext()); 958 AffineMap proj = id.dropResults(projectedDimensions); 959 return map.compose(proj); 960 } 961 962 //===----------------------------------------------------------------------===// 963 // MutableAffineMap. 964 //===----------------------------------------------------------------------===// 965 966 MutableAffineMap::MutableAffineMap(AffineMap map) 967 : results(map.getResults()), numDims(map.getNumDims()), 968 numSymbols(map.getNumSymbols()), context(map.getContext()) {} 969 970 void MutableAffineMap::reset(AffineMap map) { 971 results.clear(); 972 numDims = map.getNumDims(); 973 numSymbols = map.getNumSymbols(); 974 context = map.getContext(); 975 llvm::append_range(results, map.getResults()); 976 } 977 978 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { 979 return results[idx].isMultipleOf(factor); 980 } 981 982 // Simplifies the result affine expressions of this map. The expressions 983 // have to be pure for the simplification implemented. 984 void MutableAffineMap::simplify() { 985 // Simplify each of the results if possible. 986 // TODO: functional-style map 987 for (unsigned i = 0, e = getNumResults(); i < e; i++) { 988 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); 989 } 990 } 991 992 AffineMap MutableAffineMap::getAffineMap() const { 993 return AffineMap::get(numDims, numSymbols, results, context); 994 } 995