xref: /llvm-project/mlir/lib/IR/AffineMap.cpp (revision fcb1591b46f12b8908a8cdb252611708820102f8)
1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineMap.h"
10 #include "AffineMapDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinAttributes.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallBitVector.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Support/MathExtras.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include <iterator>
23 #include <numeric>
24 #include <optional>
25 #include <type_traits>
26 
27 using namespace mlir;
28 
29 using llvm::divideCeilSigned;
30 using llvm::divideFloorSigned;
31 using llvm::mod;
32 
33 namespace {
34 
35 // AffineExprConstantFolder evaluates an affine expression using constant
36 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute
37 // representing the constant value of the affine expression evaluated on
38 // constant 'operandConsts', or nullptr if it can't be folded.
39 class AffineExprConstantFolder {
40 public:
41   AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
42       : numDims(numDims), operandConsts(operandConsts) {}
43 
44   /// Attempt to constant fold the specified affine expr, or return null on
45   /// failure.
46   IntegerAttr constantFold(AffineExpr expr) {
47     if (auto result = constantFoldImpl(expr))
48       return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
49     return nullptr;
50   }
51 
52   bool hasPoison() const { return hasPoison_; }
53 
54 private:
55   std::optional<int64_t> constantFoldImpl(AffineExpr expr) {
56     switch (expr.getKind()) {
57     case AffineExprKind::Add:
58       return constantFoldBinExpr(
59           expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
60     case AffineExprKind::Mul:
61       return constantFoldBinExpr(
62           expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
63     case AffineExprKind::Mod:
64       return constantFoldBinExpr(
65           expr, [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
66             if (rhs < 1) {
67               hasPoison_ = true;
68               return std::nullopt;
69             }
70             return mod(lhs, rhs);
71           });
72     case AffineExprKind::FloorDiv:
73       return constantFoldBinExpr(
74           expr, [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
75             if (rhs == 0) {
76               hasPoison_ = true;
77               return std::nullopt;
78             }
79             return divideFloorSigned(lhs, rhs);
80           });
81     case AffineExprKind::CeilDiv:
82       return constantFoldBinExpr(
83           expr, [this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
84             if (rhs == 0) {
85               hasPoison_ = true;
86               return std::nullopt;
87             }
88             return divideCeilSigned(lhs, rhs);
89           });
90     case AffineExprKind::Constant:
91       return cast<AffineConstantExpr>(expr).getValue();
92     case AffineExprKind::DimId:
93       if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
94               operandConsts[cast<AffineDimExpr>(expr).getPosition()]))
95         return attr.getInt();
96       return std::nullopt;
97     case AffineExprKind::SymbolId:
98       if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
99               operandConsts[numDims +
100                             cast<AffineSymbolExpr>(expr).getPosition()]))
101         return attr.getInt();
102       return std::nullopt;
103     }
104     llvm_unreachable("Unknown AffineExpr");
105   }
106 
107   // TODO: Change these to operate on APInts too.
108   std::optional<int64_t> constantFoldBinExpr(
109       AffineExpr expr,
110       llvm::function_ref<std::optional<int64_t>(int64_t, int64_t)> op) {
111     auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
112     if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
113       if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
114         return op(*lhs, *rhs);
115     return std::nullopt;
116   }
117 
118   // The number of dimension operands in AffineMap containing this expression.
119   unsigned numDims;
120   // The constant valued operands used to evaluate this AffineExpr.
121   ArrayRef<Attribute> operandConsts;
122   bool hasPoison_{false};
123 };
124 
125 } // namespace
126 
127 /// Returns a single constant result affine map.
128 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
129   return get(/*dimCount=*/0, /*symbolCount=*/0,
130              {getAffineConstantExpr(val, context)});
131 }
132 
133 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most
134 /// minor dimensions.
135 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
136                                          MLIRContext *context) {
137   assert(dims >= results && "Dimension mismatch");
138   auto id = AffineMap::getMultiDimIdentityMap(dims, context);
139   return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
140 }
141 
142 AffineMap AffineMap::getFilteredIdentityMap(
143     MLIRContext *ctx, unsigned numDims,
144     llvm::function_ref<bool(AffineDimExpr)> keepDimFilter) {
145   auto identityMap = getMultiDimIdentityMap(numDims, ctx);
146 
147   // Apply filter to results.
148   llvm::SmallBitVector dropDimResults(numDims);
149   for (auto [idx, resultExpr] : llvm::enumerate(identityMap.getResults()))
150     dropDimResults[idx] = !keepDimFilter(cast<AffineDimExpr>(resultExpr));
151 
152   return identityMap.dropResults(dropDimResults);
153 }
154 
155 bool AffineMap::isMinorIdentity() const {
156   return getNumDims() >= getNumResults() &&
157          *this ==
158              getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
159 }
160 
161 SmallVector<unsigned> AffineMap::getBroadcastDims() const {
162   SmallVector<unsigned> broadcastedDims;
163   for (const auto &[resIdx, expr] : llvm::enumerate(getResults())) {
164     if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
165       if (constExpr.getValue() != 0)
166         continue;
167       broadcastedDims.push_back(resIdx);
168     }
169   }
170 
171   return broadcastedDims;
172 }
173 
174 /// Returns true if this affine map is a minor identity up to broadcasted
175 /// dimensions which are indicated by value 0 in the result.
176 bool AffineMap::isMinorIdentityWithBroadcasting(
177     SmallVectorImpl<unsigned> *broadcastedDims) const {
178   if (broadcastedDims)
179     broadcastedDims->clear();
180   if (getNumDims() < getNumResults())
181     return false;
182   unsigned suffixStart = getNumDims() - getNumResults();
183   for (const auto &idxAndExpr : llvm::enumerate(getResults())) {
184     unsigned resIdx = idxAndExpr.index();
185     AffineExpr expr = idxAndExpr.value();
186     if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
187       // Each result may be either a constant 0 (broadcasted dimension).
188       if (constExpr.getValue() != 0)
189         return false;
190       if (broadcastedDims)
191         broadcastedDims->push_back(resIdx);
192     } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
193       // Or it may be the input dimension corresponding to this result position.
194       if (dimExpr.getPosition() != suffixStart + resIdx)
195         return false;
196     } else {
197       return false;
198     }
199   }
200   return true;
201 }
202 
203 /// Return true if this affine map can be converted to a minor identity with
204 /// broadcast by doing a permute. Return a permutation (there may be
205 /// several) to apply to get to a minor identity with broadcasts.
206 /// Ex:
207 ///  * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with
208 ///  perm = [1, 0] and broadcast d2
209 ///  * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by
210 ///  permutation + broadcast
211 ///  * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3)
212 ///  with perm = [1, 0, 2] and broadcast d2
213 ///  * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra
214 ///  leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with
215 ///  perm = [3, 0, 1, 2]
216 bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting(
217     SmallVectorImpl<unsigned> &permutedDims) const {
218   unsigned projectionStart =
219       getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0;
220   permutedDims.clear();
221   SmallVector<unsigned> broadcastDims;
222   permutedDims.resize(getNumResults(), 0);
223   // If there are more results than input dimensions we want the new map to
224   // start with broadcast dimensions in order to be a minor identity with
225   // broadcasting.
226   unsigned leadingBroadcast =
227       getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0;
228   llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()),
229                                 false);
230   for (const auto &idxAndExpr : llvm::enumerate(getResults())) {
231     unsigned resIdx = idxAndExpr.index();
232     AffineExpr expr = idxAndExpr.value();
233     // Each result may be either a constant 0 (broadcast dimension) or a
234     // dimension.
235     if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
236       if (constExpr.getValue() != 0)
237         return false;
238       broadcastDims.push_back(resIdx);
239     } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
240       if (dimExpr.getPosition() < projectionStart)
241         return false;
242       unsigned newPosition =
243           dimExpr.getPosition() - projectionStart + leadingBroadcast;
244       permutedDims[resIdx] = newPosition;
245       dimFound[newPosition] = true;
246     } else {
247       return false;
248     }
249   }
250   // Find a permuation for the broadcast dimension. Since they are broadcasted
251   // any valid permutation is acceptable. We just permute the dim into a slot
252   // without an existing dimension.
253   unsigned pos = 0;
254   for (auto dim : broadcastDims) {
255     while (pos < dimFound.size() && dimFound[pos]) {
256       pos++;
257     }
258     permutedDims[dim] = pos++;
259   }
260   return true;
261 }
262 
263 /// Returns an AffineMap representing a permutation.
264 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
265                                        MLIRContext *context) {
266   assert(!permutation.empty() &&
267          "Cannot create permutation map from empty permutation vector");
268   const auto *m = llvm::max_element(permutation);
269   auto permutationMap = getMultiDimMapWithTargets(*m + 1, permutation, context);
270   assert(permutationMap.isPermutation() && "Invalid permutation vector");
271   return permutationMap;
272 }
273 AffineMap AffineMap::getPermutationMap(ArrayRef<int64_t> permutation,
274                                        MLIRContext *context) {
275   SmallVector<unsigned> perm = llvm::map_to_vector(
276       permutation, [](int64_t i) { return static_cast<unsigned>(i); });
277   return AffineMap::getPermutationMap(perm, context);
278 }
279 
280 AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
281                                                ArrayRef<unsigned> targets,
282                                                MLIRContext *context) {
283   SmallVector<AffineExpr, 4> affExprs;
284   for (unsigned t : targets)
285     affExprs.push_back(getAffineDimExpr(t, context));
286   AffineMap result = AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0,
287                                     affExprs, context);
288   return result;
289 }
290 
291 /// Creates an affine map each for each list of AffineExpr's in `exprsList`
292 /// while inferring the right number of dimensional and symbolic inputs needed
293 /// based on the maximum dimensional and symbolic identifier appearing in the
294 /// expressions.
295 template <typename AffineExprContainer>
296 static SmallVector<AffineMap, 4>
297 inferFromExprList(ArrayRef<AffineExprContainer> exprsList,
298                   MLIRContext *context) {
299   if (exprsList.empty())
300     return {};
301   int64_t maxDim = -1, maxSym = -1;
302   getMaxDimAndSymbol(exprsList, maxDim, maxSym);
303   SmallVector<AffineMap, 4> maps;
304   maps.reserve(exprsList.size());
305   for (const auto &exprs : exprsList)
306     maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
307                                   /*symbolCount=*/maxSym + 1, exprs, context));
308   return maps;
309 }
310 
311 SmallVector<AffineMap, 4>
312 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList,
313                              MLIRContext *context) {
314   return ::inferFromExprList(exprsList, context);
315 }
316 
317 SmallVector<AffineMap, 4>
318 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList,
319                              MLIRContext *context) {
320   return ::inferFromExprList(exprsList, context);
321 }
322 
323 uint64_t AffineMap::getLargestKnownDivisorOfMapExprs() {
324   uint64_t gcd = 0;
325   for (AffineExpr resultExpr : getResults()) {
326     uint64_t thisGcd = resultExpr.getLargestKnownDivisor();
327     gcd = std::gcd(gcd, thisGcd);
328   }
329   if (gcd == 0)
330     gcd = std::numeric_limits<uint64_t>::max();
331   return gcd;
332 }
333 
334 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
335                                             MLIRContext *context) {
336   SmallVector<AffineExpr, 4> dimExprs;
337   dimExprs.reserve(numDims);
338   for (unsigned i = 0; i < numDims; ++i)
339     dimExprs.push_back(mlir::getAffineDimExpr(i, context));
340   return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context);
341 }
342 
343 MLIRContext *AffineMap::getContext() const { return map->context; }
344 
345 bool AffineMap::isIdentity() const {
346   if (getNumDims() != getNumResults())
347     return false;
348   ArrayRef<AffineExpr> results = getResults();
349   for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
350     auto expr = dyn_cast<AffineDimExpr>(results[i]);
351     if (!expr || expr.getPosition() != i)
352       return false;
353   }
354   return true;
355 }
356 
357 bool AffineMap::isSymbolIdentity() const {
358   if (getNumSymbols() != getNumResults())
359     return false;
360   ArrayRef<AffineExpr> results = getResults();
361   for (unsigned i = 0, numSymbols = getNumSymbols(); i < numSymbols; ++i) {
362     auto expr = dyn_cast<AffineDimExpr>(results[i]);
363     if (!expr || expr.getPosition() != i)
364       return false;
365   }
366   return true;
367 }
368 
369 bool AffineMap::isEmpty() const {
370   return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
371 }
372 
373 bool AffineMap::isSingleConstant() const {
374   return getNumResults() == 1 && isa<AffineConstantExpr>(getResult(0));
375 }
376 
377 bool AffineMap::isConstant() const {
378   return llvm::all_of(getResults(), llvm::IsaPred<AffineConstantExpr>);
379 }
380 
381 int64_t AffineMap::getSingleConstantResult() const {
382   assert(isSingleConstant() && "map must have a single constant result");
383   return cast<AffineConstantExpr>(getResult(0)).getValue();
384 }
385 
386 SmallVector<int64_t> AffineMap::getConstantResults() const {
387   assert(isConstant() && "map must have only constant results");
388   SmallVector<int64_t> result;
389   for (auto expr : getResults())
390     result.emplace_back(cast<AffineConstantExpr>(expr).getValue());
391   return result;
392 }
393 
394 unsigned AffineMap::getNumDims() const {
395   assert(map && "uninitialized map storage");
396   return map->numDims;
397 }
398 unsigned AffineMap::getNumSymbols() const {
399   assert(map && "uninitialized map storage");
400   return map->numSymbols;
401 }
402 unsigned AffineMap::getNumResults() const { return getResults().size(); }
403 unsigned AffineMap::getNumInputs() const {
404   assert(map && "uninitialized map storage");
405   return map->numDims + map->numSymbols;
406 }
407 ArrayRef<AffineExpr> AffineMap::getResults() const {
408   assert(map && "uninitialized map storage");
409   return map->results();
410 }
411 AffineExpr AffineMap::getResult(unsigned idx) const {
412   return getResults()[idx];
413 }
414 
415 unsigned AffineMap::getDimPosition(unsigned idx) const {
416   return cast<AffineDimExpr>(getResult(idx)).getPosition();
417 }
418 
419 std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const {
420   if (!isa<AffineDimExpr>(input))
421     return std::nullopt;
422 
423   for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) {
424     if (getResult(i) == input)
425       return i;
426   }
427 
428   return std::nullopt;
429 }
430 
431 /// Folds the results of the application of an affine map on the provided
432 /// operands to a constant if possible. Returns false if the folding happens,
433 /// true otherwise.
434 LogicalResult AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
435                                       SmallVectorImpl<Attribute> &results,
436                                       bool *hasPoison) const {
437   // Attempt partial folding.
438   SmallVector<int64_t, 2> integers;
439   partialConstantFold(operandConstants, &integers, hasPoison);
440 
441   // If all expressions folded to a constant, populate results with attributes
442   // containing those constants.
443   if (integers.empty())
444     return failure();
445 
446   auto range = llvm::map_range(integers, [this](int64_t i) {
447     return IntegerAttr::get(IndexType::get(getContext()), i);
448   });
449   results.append(range.begin(), range.end());
450   return success();
451 }
452 
453 AffineMap AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
454                                          SmallVectorImpl<int64_t> *results,
455                                          bool *hasPoison) const {
456   assert(getNumInputs() == operandConstants.size());
457 
458   // Fold each of the result expressions.
459   AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
460   SmallVector<AffineExpr, 4> exprs;
461   exprs.reserve(getNumResults());
462 
463   for (auto expr : getResults()) {
464     auto folded = exprFolder.constantFold(expr);
465     if (exprFolder.hasPoison() && hasPoison) {
466       *hasPoison = true;
467       return {};
468     }
469     // If did not fold to a constant, keep the original expression, and clear
470     // the integer results vector.
471     if (folded) {
472       exprs.push_back(
473           getAffineConstantExpr(folded.getInt(), folded.getContext()));
474       if (results)
475         results->push_back(folded.getInt());
476     } else {
477       exprs.push_back(expr);
478       if (results) {
479         results->clear();
480         results = nullptr;
481       }
482     }
483   }
484 
485   return get(getNumDims(), getNumSymbols(), exprs, getContext());
486 }
487 
488 /// Walk all of the AffineExpr's in this mapping. Each node in an expression
489 /// tree is visited in postorder.
490 void AffineMap::walkExprs(llvm::function_ref<void(AffineExpr)> callback) const {
491   for (auto expr : getResults())
492     expr.walk(callback);
493 }
494 
495 /// This method substitutes any uses of dimensions and symbols (e.g.
496 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
497 /// expression mapping.  Because this can be used to eliminate dims and
498 /// symbols, the client needs to specify the number of dims and symbols in
499 /// the result.  The returned map always has the same number of results.
500 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
501                                            ArrayRef<AffineExpr> symReplacements,
502                                            unsigned numResultDims,
503                                            unsigned numResultSyms) const {
504   SmallVector<AffineExpr, 8> results;
505   results.reserve(getNumResults());
506   for (auto expr : getResults())
507     results.push_back(
508         expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
509   return get(numResultDims, numResultSyms, results, getContext());
510 }
511 
512 /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to
513 /// each of the results and return a new AffineMap with the new results and
514 /// with the specified number of dims and symbols.
515 AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement,
516                              unsigned numResultDims,
517                              unsigned numResultSyms) const {
518   SmallVector<AffineExpr, 4> newResults;
519   newResults.reserve(getNumResults());
520   for (AffineExpr e : getResults())
521     newResults.push_back(e.replace(expr, replacement));
522   return AffineMap::get(numResultDims, numResultSyms, newResults, getContext());
523 }
524 
525 /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the
526 /// results and return a new AffineMap with the new results and with the
527 /// specified number of dims and symbols.
528 AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map,
529                              unsigned numResultDims,
530                              unsigned numResultSyms) const {
531   SmallVector<AffineExpr, 4> newResults;
532   newResults.reserve(getNumResults());
533   for (AffineExpr e : getResults())
534     newResults.push_back(e.replace(map));
535   return AffineMap::get(numResultDims, numResultSyms, newResults, getContext());
536 }
537 
538 AffineMap
539 AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
540   SmallVector<AffineExpr, 4> newResults;
541   newResults.reserve(getNumResults());
542   for (AffineExpr e : getResults())
543     newResults.push_back(e.replace(map));
544   return AffineMap::inferFromExprList(newResults, getContext()).front();
545 }
546 
547 AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {
548   auto exprs = llvm::to_vector<4>(getResults());
549   // TODO: this is a pretty terrible API .. is there anything better?
550   for (auto pos = positions.find_last(); pos != -1;
551        pos = positions.find_prev(pos))
552     exprs.erase(exprs.begin() + pos);
553   return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
554 }
555 
556 AffineMap AffineMap::compose(AffineMap map) const {
557   assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
558   // Prepare `map` by concatenating the symbols and rewriting its exprs.
559   unsigned numDims = map.getNumDims();
560   unsigned numSymbolsThisMap = getNumSymbols();
561   unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
562   SmallVector<AffineExpr, 8> newDims(numDims);
563   for (unsigned idx = 0; idx < numDims; ++idx) {
564     newDims[idx] = getAffineDimExpr(idx, getContext());
565   }
566   SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap);
567   for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
568     newSymbols[idx - numSymbolsThisMap] =
569         getAffineSymbolExpr(idx, getContext());
570   }
571   auto newMap =
572       map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
573   SmallVector<AffineExpr, 8> exprs;
574   exprs.reserve(getResults().size());
575   for (auto expr : getResults())
576     exprs.push_back(expr.compose(newMap));
577   return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
578 }
579 
580 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
581   assert(getNumSymbols() == 0 && "Expected symbol-less map");
582   SmallVector<AffineExpr, 4> exprs;
583   exprs.reserve(values.size());
584   MLIRContext *ctx = getContext();
585   for (auto v : values)
586     exprs.push_back(getAffineConstantExpr(v, ctx));
587   auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
588   SmallVector<int64_t, 4> res;
589   res.reserve(resMap.getNumResults());
590   for (auto e : resMap.getResults())
591     res.push_back(cast<AffineConstantExpr>(e).getValue());
592   return res;
593 }
594 
595 size_t AffineMap::getNumOfZeroResults() const {
596   size_t res = 0;
597   for (auto expr : getResults()) {
598     auto constExpr = dyn_cast<AffineConstantExpr>(expr);
599     if (constExpr && constExpr.getValue() == 0)
600       res++;
601   }
602 
603   return res;
604 }
605 
606 AffineMap AffineMap::dropZeroResults() {
607   auto exprs = llvm::to_vector(getResults());
608   SmallVector<AffineExpr> newExprs;
609 
610   for (auto expr : getResults()) {
611     auto constExpr = dyn_cast<AffineConstantExpr>(expr);
612     if (!constExpr || constExpr.getValue() != 0)
613       newExprs.push_back(expr);
614   }
615   return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext());
616 }
617 
618 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
619   if (getNumSymbols() > 0)
620     return false;
621 
622   // Having more results than inputs means that results have duplicated dims or
623   // zeros that can't be mapped to input dims.
624   if (getNumResults() > getNumInputs())
625     return false;
626 
627   SmallVector<bool, 8> seen(getNumInputs(), false);
628   // A projected permutation can have, at most, only one instance of each input
629   // dimension in the result expressions. Zeros are allowed as long as the
630   // number of result expressions is lower or equal than the number of input
631   // expressions.
632   for (auto expr : getResults()) {
633     if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
634       if (seen[dim.getPosition()])
635         return false;
636       seen[dim.getPosition()] = true;
637     } else {
638       auto constExpr = dyn_cast<AffineConstantExpr>(expr);
639       if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
640         return false;
641     }
642   }
643 
644   // Results are either dims or zeros and zeros can be mapped to input dims.
645   return true;
646 }
647 
648 bool AffineMap::isPermutation() const {
649   if (getNumDims() != getNumResults())
650     return false;
651   return isProjectedPermutation();
652 }
653 
654 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const {
655   SmallVector<AffineExpr, 4> exprs;
656   exprs.reserve(resultPos.size());
657   for (auto idx : resultPos)
658     exprs.push_back(getResult(idx));
659   return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
660 }
661 
662 AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const {
663   return AffineMap::get(getNumDims(), getNumSymbols(),
664                         getResults().slice(start, length), getContext());
665 }
666 
667 AffineMap AffineMap::getMajorSubMap(unsigned numResults) const {
668   if (numResults == 0)
669     return AffineMap();
670   if (numResults > getNumResults())
671     return *this;
672   return getSliceMap(0, numResults);
673 }
674 
675 AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
676   if (numResults == 0)
677     return AffineMap();
678   if (numResults > getNumResults())
679     return *this;
680   return getSliceMap(getNumResults() - numResults, numResults);
681 }
682 
683 /// Implementation detail to compress multiple affine maps with a compressionFun
684 /// that is expected to be either compressUnusedDims or compressUnusedSymbols.
685 /// The implementation keeps track of num dims and symbols across the different
686 /// affine maps.
687 static SmallVector<AffineMap> compressUnusedListImpl(
688     ArrayRef<AffineMap> maps,
689     llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
690   if (maps.empty())
691     return SmallVector<AffineMap>();
692   SmallVector<AffineExpr> allExprs;
693   allExprs.reserve(maps.size() * maps.front().getNumResults());
694   unsigned numDims = maps.front().getNumDims(),
695            numSymbols = maps.front().getNumSymbols();
696   for (auto m : maps) {
697     assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
698            "expected maps with same num dims and symbols");
699     llvm::append_range(allExprs, m.getResults());
700   }
701   AffineMap unifiedMap = compressionFun(
702       AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
703   unsigned unifiedNumDims = unifiedMap.getNumDims(),
704            unifiedNumSymbols = unifiedMap.getNumSymbols();
705   ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults();
706   SmallVector<AffineMap> res;
707   res.reserve(maps.size());
708   for (auto m : maps) {
709     res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols,
710                                  unifiedResults.take_front(m.getNumResults()),
711                                  m.getContext()));
712     unifiedResults = unifiedResults.drop_front(m.getNumResults());
713   }
714   return res;
715 }
716 
717 AffineMap mlir::compressDims(AffineMap map,
718                              const llvm::SmallBitVector &unusedDims) {
719   return projectDims(map, unusedDims, /*compressDimsFlag=*/true);
720 }
721 
722 AffineMap mlir::compressUnusedDims(AffineMap map) {
723   return compressDims(map, getUnusedDimsBitVector({map}));
724 }
725 
726 SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
727   return compressUnusedListImpl(
728       maps, [](AffineMap m) { return compressUnusedDims(m); });
729 }
730 
731 AffineMap mlir::compressSymbols(AffineMap map,
732                                 const llvm::SmallBitVector &unusedSymbols) {
733   return projectSymbols(map, unusedSymbols, /*compressSymbolsFlag=*/true);
734 }
735 
736 AffineMap mlir::compressUnusedSymbols(AffineMap map) {
737   return compressSymbols(map, getUnusedSymbolsBitVector({map}));
738 }
739 
740 SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
741   return compressUnusedListImpl(
742       maps, [](AffineMap m) { return compressUnusedSymbols(m); });
743 }
744 
745 AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
746                                       ArrayRef<OpFoldResult> operands,
747                                       SmallVector<Value> &remainingValues) {
748   SmallVector<AffineExpr> dimReplacements, symReplacements;
749   int64_t numDims = 0;
750   for (int64_t i = 0; i < map.getNumDims(); ++i) {
751     if (auto attr = operands[i].dyn_cast<Attribute>()) {
752       dimReplacements.push_back(
753           b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
754     } else {
755       dimReplacements.push_back(b.getAffineDimExpr(numDims++));
756       remainingValues.push_back(cast<Value>(operands[i]));
757     }
758   }
759   int64_t numSymbols = 0;
760   for (int64_t i = 0; i < map.getNumSymbols(); ++i) {
761     if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) {
762       symReplacements.push_back(
763           b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
764     } else {
765       symReplacements.push_back(b.getAffineSymbolExpr(numSymbols++));
766       remainingValues.push_back(cast<Value>(operands[i + map.getNumDims()]));
767     }
768   }
769   return map.replaceDimsAndSymbols(dimReplacements, symReplacements, numDims,
770                                    numSymbols);
771 }
772 
773 AffineMap mlir::simplifyAffineMap(AffineMap map) {
774   SmallVector<AffineExpr, 8> exprs;
775   for (auto e : map.getResults()) {
776     exprs.push_back(
777         simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
778   }
779   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
780                         map.getContext());
781 }
782 
783 AffineMap mlir::removeDuplicateExprs(AffineMap map) {
784   auto results = map.getResults();
785   SmallVector<AffineExpr, 4> uniqueExprs(results);
786   uniqueExprs.erase(llvm::unique(uniqueExprs), uniqueExprs.end());
787   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs,
788                         map.getContext());
789 }
790 
791 AffineMap mlir::inversePermutation(AffineMap map) {
792   if (map.isEmpty())
793     return map;
794   assert(map.getNumSymbols() == 0 && "expected map without symbols");
795   SmallVector<AffineExpr, 4> exprs(map.getNumDims());
796   for (const auto &en : llvm::enumerate(map.getResults())) {
797     auto expr = en.value();
798     // Skip non-permutations.
799     if (auto d = dyn_cast<AffineDimExpr>(expr)) {
800       if (exprs[d.getPosition()])
801         continue;
802       exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
803     }
804   }
805   SmallVector<AffineExpr, 4> seenExprs;
806   seenExprs.reserve(map.getNumDims());
807   for (auto expr : exprs)
808     if (expr)
809       seenExprs.push_back(expr);
810   if (seenExprs.size() != map.getNumInputs())
811     return AffineMap();
812   return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
813 }
814 
815 AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
816   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true));
817   MLIRContext *context = map.getContext();
818   AffineExpr zero = mlir::getAffineConstantExpr(0, context);
819   // Start with all the results as 0.
820   SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
821   for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
822     // Skip zeros from input map. 'exprs' is already initialized to zero.
823     if (auto constExpr = dyn_cast<AffineConstantExpr>(map.getResult(i))) {
824       assert(constExpr.getValue() == 0 &&
825              "Unexpected constant in projected permutation");
826       (void)constExpr;
827       continue;
828     }
829 
830     // Reverse each dimension existing in the original map result.
831     exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context);
832   }
833   return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
834 }
835 
836 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps,
837                                  MLIRContext *context) {
838   if (maps.empty())
839     return AffineMap::get(context);
840   unsigned numResults = 0, numDims = 0, numSymbols = 0;
841   for (auto m : maps)
842     numResults += m.getNumResults();
843   SmallVector<AffineExpr, 8> results;
844   results.reserve(numResults);
845   for (auto m : maps) {
846     for (auto res : m.getResults())
847       results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
848 
849     numSymbols += m.getNumSymbols();
850     numDims = std::max(m.getNumDims(), numDims);
851   }
852   return AffineMap::get(numDims, numSymbols, results, context);
853 }
854 
855 /// Common implementation to project out dimensions or symbols from an affine
856 /// map based on the template type.
857 /// Additionally, if 'compress' is true, the projected out dimensions or symbols
858 /// are also dropped from the resulting map.
859 template <typename AffineDimOrSymExpr>
860 static AffineMap projectCommonImpl(AffineMap map,
861                                    const llvm::SmallBitVector &toProject,
862                                    bool compress) {
863   static_assert(llvm::is_one_of<AffineDimOrSymExpr, AffineDimExpr,
864                                 AffineSymbolExpr>::value,
865                 "expected AffineDimExpr or AffineSymbolExpr");
866 
867   constexpr bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
868   int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols();
869   SmallVector<AffineExpr> replacements;
870   replacements.reserve(numDimOrSym);
871 
872   auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr;
873 
874   using replace_fn_ty =
875       std::function<AffineExpr(AffineExpr, ArrayRef<AffineExpr>)>;
876   replace_fn_ty replaceDims = [](AffineExpr e,
877                                  ArrayRef<AffineExpr> replacements) {
878     return e.replaceDims(replacements);
879   };
880   replace_fn_ty replaceSymbols = [](AffineExpr e,
881                                     ArrayRef<AffineExpr> replacements) {
882     return e.replaceSymbols(replacements);
883   };
884   replace_fn_ty replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
885 
886   MLIRContext *context = map.getContext();
887   int64_t newNumDimOrSym = 0;
888   for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) {
889     if (toProject.test(dimOrSym)) {
890       replacements.push_back(getAffineConstantExpr(0, context));
891       continue;
892     }
893     int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
894     replacements.push_back(createNewDimOrSym(newPos, context));
895   }
896   SmallVector<AffineExpr> resultExprs;
897   resultExprs.reserve(map.getNumResults());
898   for (auto e : map.getResults())
899     resultExprs.push_back(replaceNewDimOrSym(e, replacements));
900 
901   int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims();
902   int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols();
903   return AffineMap::get(numDims, numSyms, resultExprs, context);
904 }
905 
906 AffineMap mlir::projectDims(AffineMap map,
907                             const llvm::SmallBitVector &projectedDimensions,
908                             bool compressDimsFlag) {
909   return projectCommonImpl<AffineDimExpr>(map, projectedDimensions,
910                                           compressDimsFlag);
911 }
912 
913 AffineMap mlir::projectSymbols(AffineMap map,
914                                const llvm::SmallBitVector &projectedSymbols,
915                                bool compressSymbolsFlag) {
916   return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols,
917                                              compressSymbolsFlag);
918 }
919 
920 AffineMap mlir::getProjectedMap(AffineMap map,
921                                 const llvm::SmallBitVector &projectedDimensions,
922                                 bool compressDimsFlag,
923                                 bool compressSymbolsFlag) {
924   map = projectDims(map, projectedDimensions, compressDimsFlag);
925   if (compressSymbolsFlag)
926     map = compressUnusedSymbols(map);
927   return map;
928 }
929 
930 llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
931   unsigned numDims = maps[0].getNumDims();
932   llvm::SmallBitVector numDimsBitVector(numDims, true);
933   for (AffineMap m : maps) {
934     for (unsigned i = 0; i < numDims; ++i) {
935       if (m.isFunctionOfDim(i))
936         numDimsBitVector.reset(i);
937     }
938   }
939   return numDimsBitVector;
940 }
941 
942 llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps) {
943   unsigned numSymbols = maps[0].getNumSymbols();
944   llvm::SmallBitVector numSymbolsBitVector(numSymbols, true);
945   for (AffineMap m : maps) {
946     for (unsigned i = 0; i < numSymbols; ++i) {
947       if (m.isFunctionOfSymbol(i))
948         numSymbolsBitVector.reset(i);
949     }
950   }
951   return numSymbolsBitVector;
952 }
953 
954 AffineMap
955 mlir::expandDimsToRank(AffineMap map, int64_t rank,
956                        const llvm::SmallBitVector &projectedDimensions) {
957   auto id = AffineMap::getMultiDimIdentityMap(rank, map.getContext());
958   AffineMap proj = id.dropResults(projectedDimensions);
959   return map.compose(proj);
960 }
961 
962 //===----------------------------------------------------------------------===//
963 // MutableAffineMap.
964 //===----------------------------------------------------------------------===//
965 
966 MutableAffineMap::MutableAffineMap(AffineMap map)
967     : results(map.getResults()), numDims(map.getNumDims()),
968       numSymbols(map.getNumSymbols()), context(map.getContext()) {}
969 
970 void MutableAffineMap::reset(AffineMap map) {
971   results.clear();
972   numDims = map.getNumDims();
973   numSymbols = map.getNumSymbols();
974   context = map.getContext();
975   llvm::append_range(results, map.getResults());
976 }
977 
978 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
979   return results[idx].isMultipleOf(factor);
980 }
981 
982 // Simplifies the result affine expressions of this map. The expressions
983 // have to be pure for the simplification implemented.
984 void MutableAffineMap::simplify() {
985   // Simplify each of the results if possible.
986   // TODO: functional-style map
987   for (unsigned i = 0, e = getNumResults(); i < e; i++) {
988     results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
989   }
990 }
991 
992 AffineMap MutableAffineMap::getAffineMap() const {
993   return AffineMap::get(numDims, numSymbols, results, context);
994 }
995