xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp (revision a5757c5b65f1894de16f549212b1c37793312703)
1 //===- IterationGraphSorter.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IterationGraphSorter.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
15 #include "mlir/IR/AffineExprVisitor.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 
18 using namespace mlir;
19 using namespace mlir::sparse_tensor;
20 
21 namespace {
22 
23 /// A helper class that visits an affine expression and tries to find
24 /// an AffineDimExpr to which the corresponding iterator from a GenericOp
25 /// matches the desired iterator type. If there is no matched iterator
26 /// type, the method returns the first DimExpr in the expression.
27 class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
28 public:
AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)29   explicit AffineDimFinder(ArrayRef<utils::IteratorType> itTypes)
30       : iterTypes(itTypes) {}
31 
32   /// Overrides the visit method from AffineExprVisitor.
visitDimExpr(AffineDimExpr expr)33   void visitDimExpr(AffineDimExpr expr) {
34     if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()])
35       pickedDim = expr;
36   }
37 
38   /// Sets the desired iterator type that we want to pick.
setPickedIterType(utils::IteratorType iterType)39   void setPickedIterType(utils::IteratorType iterType) {
40     pickIterType = iterType;
41   }
42 
43   /// Gets the desired AffineDimExpr.
getDimExpr() const44   AffineDimExpr getDimExpr() const {
45     return llvm::cast<AffineDimExpr>(pickedDim);
46   }
47 
48   /// Walks the graph in post order to find dim expr.
walkPostOrder(AffineExpr expr)49   void walkPostOrder(AffineExpr expr) {
50     pickedDim = nullptr;
51     AffineExprVisitor<AffineDimFinder>::walkPostOrder(expr);
52   }
53 
54 private:
55   /// The picked AffineDimExpr after visit.
56   AffineExpr pickedDim;
57   /// The iterator type that we want.
58   utils::IteratorType pickIterType;
59   /// The mapping between levels and iterator types.
60   ArrayRef<utils::IteratorType> iterTypes;
61 };
62 
63 /// Flattens an affine expression into a list of AffineDimExprs.
64 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
65   // Overrides method from AffineExprVisitor.
visitDimExpr__anonf305e6970111::AffineDimCollector66   void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); }
67   SmallVector<AffineDimExpr> dims;
68 };
69 
70 } // namespace
71 
includesAny(SortMask mask1,SortMask mask2)72 inline static bool includesAny(SortMask mask1, SortMask mask2) {
73   return static_cast<unsigned>(mask1) & static_cast<unsigned>(mask2);
74 }
75 
includesDenseInput(SortMask mask)76 inline static bool includesDenseInput(SortMask mask) {
77   return includesAny(mask, SortMask::kIncludeDenseInput);
78 }
79 
includesDenseOutput(SortMask mask)80 inline static bool includesDenseOutput(SortMask mask) {
81   return includesAny(mask, SortMask::kIncludeDenseOutput);
82 }
83 
topoSort()84 AffineMap IterationGraphSorter::topoSort() {
85   // The sorted result will put the first Reduction iterator to the
86   // latest possible position.
87   std::vector<unsigned> redIt; // reduce iterator with 0 degree
88   std::vector<unsigned> parIt; // parallel iterator with 0 degree
89   const unsigned numLoops = getNumLoops();
90   for (unsigned i = 0; i < numLoops; i++) {
91     if (inDegree[i] == 0) {
92       if (iterTypes[i] == utils::IteratorType::reduction)
93         redIt.push_back(i);
94       else
95         parIt.push_back(i);
96     }
97   }
98 
99   SmallVector<unsigned> loopOrder;
100   while (!redIt.empty() || !parIt.empty()) {
101     // We always prefer a parallel loop over a reduction loop because putting
102     // a reduction loop early might make the loop sequence inadmissible.
103     auto &it = !parIt.empty() ? parIt : redIt;
104     auto src = it.back();
105     loopOrder.push_back(src);
106     it.pop_back();
107     // Update in-degree, and push 0-degree node into worklist.
108     for (unsigned dst = 0; dst < numLoops; dst++) {
109       if (itGraph[src][dst] && --inDegree[dst] == 0) {
110         if (iterTypes[dst] == utils::IteratorType::reduction)
111           redIt.push_back(dst);
112         else
113           parIt.push_back(dst);
114       }
115     }
116   }
117 
118   // Return the topological sort on success.
119   if (loopOrder.size() == numLoops)
120     return AffineMap::getPermutationMap(loopOrder, out.getContext());
121 
122   // Cycle detected.
123   return AffineMap();
124 }
125 
126 IterationGraphSorter
fromGenericOp(linalg::GenericOp genericOp)127 IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
128   // Must be a demapped sparse kernel.
129   assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
130          hasAnySparseOperandOrResult(genericOp) &&
131          genericOp.getNumDpsInits() == 1);
132 
133   SmallVector<AffineMap> loopMap = genericOp.getIndexingMapsArray();
134   SmallVector<Value> ins = genericOp.getDpsInputs();
135 
136   AffineMap outMap = loopMap.back();
137   loopMap.pop_back();
138 
139   Value out = genericOp.getDpsInitOperand(0)->get();
140   SmallVector<utils::IteratorType> iterTypes =
141       genericOp.getIteratorTypesArray();
142 
143   return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
144                               std::move(iterTypes));
145 }
146 
IterationGraphSorter(SmallVector<Value> && ins,SmallVector<AffineMap> && loop2InsLvl,Value out,AffineMap loop2OutLvl,SmallVector<utils::IteratorType> && iterTypes)147 IterationGraphSorter::IterationGraphSorter(
148     SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
149     AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
150     : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
151       loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
152   // One map per tensor.
153   assert(loop2InsLvl.size() == ins.size());
154   // All the affine maps have the same number of dimensions (loops).
155   assert(llvm::all_equal(llvm::map_range(
156       loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
157   // The number of results of the map should match the rank of the tensor.
158   assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
159     auto [m, v] = mvPair;
160     return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
161   }));
162 
163   itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
164   inDegree.resize(getNumLoops());
165 }
166 
sort(SortMask mask,Value ignored)167 AffineMap IterationGraphSorter::sort(SortMask mask, Value ignored) {
168   // Reset the adjacency matrix that represents the iteration graph.
169   for (auto &row : itGraph)
170     std::fill(row.begin(), row.end(), false);
171 
172   // Reset in-degree.
173   std::fill(inDegree.begin(), inDegree.end(), 0);
174 
175   // Add the constraints for the loop to level map.
176   for (auto [in, map] : llvm::zip(ins, loop2InsLvl)) {
177     // Get map and encoding.
178     const auto enc = getSparseTensorEncoding(in.getType());
179     // Skip dense inputs when not requested.
180     if ((!enc && !includesDenseInput(mask)) || in == ignored)
181       continue;
182     addConstraints(in, map);
183   }
184 
185   // Add the constraints for the output map.
186   const auto enc = getSparseTensorEncoding(out.getType());
187   if ((enc || includesDenseOutput(mask)) && out != ignored)
188     addConstraints(out, loop2OutLvl);
189 
190   // Return the topological sort (empty for cyclic).
191   return topoSort();
192 }
193 
addConstraints(Value t,AffineMap loop2LvlMap)194 void IterationGraphSorter::addConstraints(Value t, AffineMap loop2LvlMap) {
195   auto addIterOrdering = [this](unsigned f, unsigned t) {
196     if (!itGraph[f][t] && f != t) {
197       itGraph[f][t] = true;
198       inDegree[t]++;
199     }
200   };
201 
202   // Set up a reduction finder.
203   AffineDimFinder finder(iterTypes);
204   finder.setPickedIterType(utils::IteratorType::reduction);
205 
206   // To compute iteration graph for tensor[d0 + d1 + d3, d4 + d5 + d6],
207   // we require there exist d_x \in {d0, d1, d3} and d_y \in {d4, d5, d6},
208   // and d_x > d_y && {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
209   const Level lvlRank = loop2LvlMap.getNumResults();
210   for (Level lvl = 1; lvl < lvlRank; lvl++) {
211     const AffineExpr fa = loop2LvlMap.getResult(lvl - 1);
212     const AffineExpr ta = loop2LvlMap.getResult(lvl);
213 
214     if (llvm::isa<AffineDimExpr>(fa) || llvm::isa<AffineDimExpr>(ta)) {
215       // Special case when at least one loop2LvlExp is a simple AffineDimExpr
216       // (say, d0) and we require d0 > {d1, d2, ...} or {d1, d2, ...} > d0
217       AffineDimCollector fCollector;
218       fCollector.walkPostOrder(fa);
219       AffineDimCollector tCollector;
220       tCollector.walkPostOrder(ta);
221 
222       for (auto fd : fCollector.dims) {
223         for (auto td : tCollector.dims) {
224           const unsigned f = fd.getPosition();
225           const unsigned t = td.getPosition();
226           addIterOrdering(f, t);
227         }
228       }
229       continue;
230     }
231 
232     // When both loop2LvlExpr is compound, we pick an abitrary reduction loop
233     // from lhs and rhs and use them as d_x and d_y.
234     finder.walkPostOrder(fa);
235     const AffineDimExpr fexp = finder.getDimExpr();
236     const unsigned fldx = fexp.getPosition();
237 
238     finder.walkPostOrder(ta);
239     const AffineDimExpr texp = finder.getDimExpr();
240     const unsigned tldx = texp.getPosition();
241 
242     // d_x > d_y
243     addIterOrdering(fldx, tldx);
244 
245     AffineDimCollector fCollector;
246     fCollector.walkPostOrder(fa);
247     AffineDimCollector tCollector;
248     tCollector.walkPostOrder(ta);
249 
250     // Make sure dx and dy is the last.
251     for (auto fd : fCollector.dims) {
252       const unsigned f = fd.getPosition();
253       addIterOrdering(f, fldx);
254     }
255     for (auto td : tCollector.dims) {
256       const unsigned t = td.getPosition();
257       addIterOrdering(t, tldx);
258     }
259     // {d0, d1, d3} - d_x > {d4, d5, d6} - d_y
260     // This is to ensure that the affine expressions are reduced in sparse
261     // tensor level ordering.
262     for (auto fd : fCollector.dims) {
263       const unsigned f = fd.getPosition();
264       if (f == fldx) // skip d_x
265         continue;
266       for (auto td : tCollector.dims) {
267         const unsigned t = td.getPosition();
268         if (t == tldx) // skip d_y
269           continue;
270         addIterOrdering(f, t);
271       }
272     }
273   }
274 }
275