xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (revision f607102a0d6be0e2aebc1bfaed2ed0a6ae020145)
1 //===- LoopEmitter.h --------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_
10 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_
11 
12 #include <vector>
13 
14 #include "SparseTensorIterator.h"
15 
16 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
20 #include "mlir/IR/PatternMatch.h"
21 
22 namespace mlir {
23 namespace sparse_tensor {
24 
25 // A compressed <tensor id, level> pair.
26 using TensorLevel = unsigned;
27 
28 //
29 // SparseTensorLoopEmiter class, manages sparse tensors and helps to
30 // generate loop structure to (co)-iterate sparse tensors.
31 //
32 // An example usage:
33 // To generate the following loops over T1<?x?> and T2<?x?>
34 //
35 // for i in TENSOR_1_0 {
36 //   for j : TENSOR_2_0 {
37 //     for k : TENSOR_1_1 {}
38 //     for k : TENSOR_2_1 {}
39 //   }
40 // }
41 //
42 // One can use
43 //
44 // LoopEmiter loopEmiter({T1, T1});
45 // loopEmiter.initializeLoopEmit();
46 // loopEmiter.enterLoopOverTensorAtLvl(T1, 0);
47 // loopEmiter.enterLoopOverTensorAtLvl(T2, 0);
48 // loopEmiter.enterLoopOverTensorAtLvl(T1, 1);
49 // loopEmiter.exitCurrentLoop();
50 // loopEmiter.enterLoopOverTensorAtLvl(T2, 1);
51 // loopEmiter.exitCurrentLoop(); // exit k
52 // loopEmiter.exitCurrentLoop(); // exit j
53 // loopEmiter.exitCurrentLoop(); // exit i
54 //
55 class LoopEmitter {
56 public:
57   /// Optional callback function to setup dense output tensors when
58   /// initializing the loop emitter (e.g., to fill a dense output with zeros).
59   using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
60                                            Value memref, Value tensor)>;
61 
62   /// Optional callback function to set the bound for the synthetic tensor,
63   /// which essentially is the dense loop bound.
64   using SynTensorBoundSetter =
65       function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>;
66 
67   // Map from [tid, lvl] to a list of dependent [LoopId, coeffecient] for
68   // subscript expressions on sparse tensors.
69   //
70   // E.g., for affine index (2 * d0 + d1), it depends on loop d0 and d1 (for
71   // affine expression reduction) and uses 2 and 1 for coefficients on d0, d1
72   // respectively. If the list is empty, it means that there is no affine
73   // expression on the input [tid, lvl].
74   //
75   // NOTE: LoopEmitter assumes that the loop id is consistent with the loop
76   // order, i.e., loop `d0` will be generated before loop `d1`.
77   using DependentLvlGetter =
78       function_ref<std::vector<std::pair<LoopId, unsigned>>(TensorId, Level)>;
79 
80   LoopEmitter() = default;
81 
82   /// Takes an array of input tensors, which the generated loops will
83   /// iterate over.  Each tensor is given a `TensorId` (numerically equal
84   /// to the position of that tensor `Value` in the array).  Setting
85   /// `isSparseOut` indicates that the sparse output tensor is empty,
86   /// so the loop emitter will generate loops over it according to the
87   /// level-sizes.
88   void
89   initialize(ValueRange tensors, StringAttr loopTag = nullptr,
90              bool hasOutput = false, bool isSparseOut = false,
91              unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
92              SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
93 
94   explicit LoopEmitter(
95       ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
96       bool isSparseOut = false, unsigned numLoops = 0,
97       DependentLvlGetter getter = nullptr,
98       SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
99 
100   /// Starts a loop emitting session by generating all the buffers needed
101   /// for iterating over the tensors.
102   void initializeLoopEmit(OpBuilder &builder, Location loc,
103                           OutputUpdater updater = nullptr,
104                           SynTensorBoundSetter synSetter = nullptr);
105 
106   /// Generates code to compute an affine expression whose variables are
107   /// `LoopId`s (i.e., `a.cast<AffineDimExpr>().getPosition()` is a valid
108   /// `LoopId`).
109   Value genAffine(OpBuilder &builder, Location loc, AffineExpr a);
110 
111   /// Enters a new loop sequence, the loops within the same sequence starts
112   /// from the break points of previous loop instead of starting over from 0.
113   /// e.g.,
114   /// {
115   ///   // loop sequence start.
116   ///   p0 = while(xxx)
117   ///     ...
118   ///     break p0
119   ///
120   ///   // Starts loop from p0
121   ///   for (i = p0; i < end; i++)
122   ///     ...
123   ///   // loop sequence end.
124   /// }
125   void enterNewLoopSeq(OpBuilder &builder, Location loc,
126                        ArrayRef<TensorLevel> tidLvls);
127 
128   /// Exits the current loop sequence, this will reset universal index to 0.
129   void exitCurrentLoopSeq(OpBuilder &builder, Location loc);
130 
131   /// Emits the address for a dense level based on the value evaluated by the
132   /// provided affine expression.
133   void locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
134                                 TensorLevel tidLvl, AffineExpr lvlExpr);
135 
136   // TODO: Get rid of `lvls` in the argument list? Track the level we
137   // are currently at internally. Then it would be enterNextLvlForTensor.
138   // Still need a way to specify the lvl for non-annotated tensors though,
139   // as those can be accessed out of order.
140   //
141   /// Emits a co-iteration loop over a set of tensors.
142   /// Emits loop over tensor_tid_lvl, it assumes that loops between
143   /// tensor_tid_[0, lvl - 1] have already been generated.
144   /// The function will also perform in-place update on the `reduc` vector to
145   /// return the reduction variable used inside the generated loop.
146   Operation *enterCoIterationOverTensorsAtLvls(
147       OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
148       unsigned numCases, MutableArrayRef<Value> reduc = {},
149       bool isParallel = false, bool needsUniv = false);
150 
151   Region *enterCurrentCoIterationCase(OpBuilder &builder, Location loc,
152                                       I64BitSet caseBit, unsigned caseIdx,
153                                       MutableArrayRef<Value> reduc);
154 
155   /// Generates code to exit the current loop (e.g., generates yields, forwards
156   /// loop induction variables, etc).
157   void exitCurrentLoop(RewriterBase &rewriter, Location loc,
158                        MutableArrayRef<Value> reduc = {});
159 
160   /// Get the range of values for all induction variables.
161   auto getLoopIVsRange() const {
162     return llvm::map_range(loopStack, [](const LoopInfo &li) { return li.iv; });
163   }
164 
165   /// Fills the out-parameter with the loop induction variables for all
166   /// loops in the current loop-stack.
167   SmallVector<Value> getLoopIVs() const {
168     return llvm::to_vector(getLoopIVsRange());
169   }
170 
171   /// Gets the current depth of the loop-stack.
172   LoopId getCurrentDepth() const { return llvm::range_size(getLoopIVsRange()); }
173 
174   /// Gets loop induction variable for the given loop
175   Value getLoopIV(LoopId n) const {
176     if (n >= getCurrentDepth())
177       return Value();
178     auto it = getLoopIVsRange().begin();
179     std::advance(it, n);
180     return *it;
181   }
182 
183   /// Gets the total number of manifest tensors (excluding the synthetic
184   /// tensor).
185   unsigned getNumManifestTensors() const { return tensors.size(); }
186 
187   /// Gets the total number of tensors that loopEmitter is operating on.
188   unsigned getNumTensors() const {
189     // Manifest tensors with one synthetic tensor at the end.
190     return getNumManifestTensors() + 1;
191   }
192 
193   /// Gets the TensorId for synthetic tensor.
194   TensorId getSynTensorId() const { return tensors.size(); }
195 
196   /// Gets the TensorId for output tensor.
197   TensorId getOutTensorId() const {
198     assert(hasOutput);
199     return getNumManifestTensors() - 1;
200   }
201 
202   /// Compresses a TensorId and Level into a TensorLevel.
203   TensorLevel makeTensorLevel(TensorId t, Level l) const {
204     return l * getNumTensors() + t;
205   }
206 
207   /// De-compresses a TensorLevel back to a pair of TensorId and Level.
208   std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const {
209     unsigned nt = getNumTensors();
210     return std::make_pair(tidLvl % nt, tidLvl / nt);
211   }
212 
213   /// Converts a range of TensorLevel to a range of std::pair<TensorId, Level>
214   template <class ContainerTy>
215   auto unpackTensorLevelRange(ContainerTy &&c) const {
216     using EltTy = decltype(*c.begin());
217     static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLevel>,
218                   "Must be unpacking a TensorLevel range");
219     return llvm::map_range(std::forward<ContainerTy>(c), [this](EltTy tl) {
220       return this->unpackTensorLevel(tl);
221     });
222   }
223 
224   ///
225   /// Getters.
226   ///
227   SmallVector<Value> getValPosits(TensorId tid) const {
228     // Returns the iterator if we are generating sparse (co)iterate-based loops.
229     if (emitStrategy == SparseEmitStrategy::kSparseIterator)
230       return {spIterVals[tid].back()};
231 
232     // Returns {[batch coords], last-level position}.
233     SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
234     Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
235     batchCrds.push_back(lastLvlPos);
236     return batchCrds;
237   };
238   Value getCoord(TensorId tid, Level lvl) const {
239     return getCurIterator(tid, lvl).getCrd();
240   };
241   const std::vector<Value> &getValBuffer() const { return valBuffer; };
242 
243   constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
244     return llvm::StringLiteral("Emitted from");
245   }
246 
247 private:
248   ///
249   /// Structure definitions that hold different kinds of loops information.
250   ///
251 
252   // LoopInfo stores information of a loop generated by LoopEmitter. E.g.,
253   // the set of tensors levels that the loop is iterating over.
254   struct LoopInfo final {
255     LoopInfo(ArrayRef<TensorLevel> tidLvls, Operation *loop, Block *userBlock,
256              Value iv, StringAttr loopTag)
257         : tidLvls(tidLvls), loop(loop), userCodeBlock(userBlock), iv(iv) {
258       // Attached a special tag to loop emitter generated loop.
259       if (loopTag)
260         loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag);
261     }
262     // The set of <tensor, lvl>, with *only* trivial index expressions, that are
263     // used as the condition for the generated loop. Extra information is
264     // required for levels with non-tivial index expressions, which is
265     // maintained by the sliceDrivenInfo array below.
266     const llvm::SmallVector<TensorLevel> tidLvls;
267     Operation *loop;            // the loop operation
268     Block *const userCodeBlock; // the block holding users' generated code.
269     Value iv;                   // the induction variable for the loop
270   };
271 
272   void categorizeIterators(ArrayRef<TensorLevel> tidLvls,
273                            SmallVectorImpl<SparseIterator *> &raIters,
274                            SmallVectorImpl<SparseIterator *> &spIters);
275   ///
276   /// LoopEmitter internal helper functions.
277   ///
278 
279   using LoopBodyBuilder = llvm::function_ref<void(OpBuilder &, Location, Value,
280                                                   MutableArrayRef<Value>)>;
281 
282   /// Whether the list of the sparse condition should be iterated by for loop.
283   bool shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters);
284 
285   /// Generates instructions to compute the coordinate of tensors[tid][lvl]
286   /// under the current loop context.  The final argument is the
287   /// collapsed-output level, whereas this function handles converting
288   /// that to the uncollapsed-input level
289   Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
290                      Level dstLvl);
291 
292   bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); }
293 
294   bool isOutputTensor(TensorId tid) const {
295     return hasOutput && tid == getOutTensorId();
296   }
297 
298   bool isSparseOutput(TensorId tid) const {
299     return isOutputTensor(tid) && isSparseOut;
300   }
301 
302   bool isValidLevel(TensorId tid, Level lvl) const {
303     return tid < lvls.size() && lvl < lvls[tid].size();
304   }
305 
306   /// Prepares loop for iterating over `tensor[lvl]`, under the assumption
307   /// that `tensor[0...lvl-1]` loops have already been set up.
308   void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
309                                   TensorId tid, Level lvl);
310 
311   /// Emits a for loop to iterate over a tensor level with the provided
312   /// lower bound `lo` and upper bound `hi`. Apart from iterating just
313   /// single tensor level, for loops can be used for slice-driven loop on
314   /// dense level too.
315   /// Returns a pair: the loop generated and the value for the induction
316   /// variable.
317   std::pair<Operation *, Value>
318   emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
319                              SparseIterator &iter, MutableArrayRef<Value> reduc,
320                              bool isParallel);
321 
322   /// Emits a while loop to co-iterate over a list of sparse condition, or
323   /// (complex) single sparse condition that can not be handled by for loop
324   /// (e.g., index reduction loop).
325   /// Returns a pair: the loop generated and the value for the induction
326   /// variable (which is the minimum coordinate of all the tensor that being
327   /// iterated).
328   std::pair<Operation *, Value>
329   emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc,
330                                  ArrayRef<SparseIterator *> iters,
331                                  MutableArrayRef<Value> reduc, bool needsUniv);
332 
333   /// Exits a for loop, returns the reduction results, e.g.,
334   /// For sequential for loops:
335   /// %ret = for () {
336   ///   ...
337   ///   %val = addi %args, %c
338   ///   yield %val
339   /// }
340   /// For parallel loops, the following generated code by users:
341   /// %ret = parallel () init(%args) {
342   ///   ...
343   ///   %val = op %args, %c
344   /// }
345   /// will be transformed into
346   /// %ret = parallel () init(%args) {
347   ///   ...
348   ///   scf.reduce(%c) bb0(%0, %1){
349   ///     %val = op %0, %1
350   ///     scf.reduce.return %val
351   ///   }
352   /// }
353   /// NOTE: only one instruction will be moved into reduce block,
354   /// transformation will fail if multiple instructions are used to compute
355   /// the reduction value. Return %ret to user, while %val is provided by
356   /// users (`reduc`).
357   void exitForLoop(RewriterBase &rewriter, Location loc,
358                    MutableArrayRef<Value> reduc);
359 
360   /// Exits a while loop, returns the reduction results.
361   void exitWhileLoop(OpBuilder &builder, Location loc,
362                      MutableArrayRef<Value> reduc);
363 
364   //
365   // Slice-driven loop related methods.
366   //
367 
368   void initSubSectIterator(OpBuilder &builder, Location loc);
369 
370   /// Get the reduced number of contraints on tensor[tid][lvl].
371   unsigned redDepOnLevel(TensorId tid, Level lvl) const {
372     return levelReducedDep[tid][lvl];
373   };
374 
375   SparseIterator &getCurIterator(TensorId tid, Level lvl) const {
376     if (dependentLvlMap[tid][lvl].empty())
377       return *iters[tid][lvl].back();
378 
379     assert(redDepOnLevel(tid, lvl) >= 1);
380     return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1];
381   }
382 
383   std::unique_ptr<SparseIterator>
384   makeLevelIterator(OpBuilder &builder, Location loc, TensorId tid, Level l);
385 
386   /// A optional string attribute that should be attached to the loop
387   /// generated by loop emitter, it might help following passes to identify
388   /// loops that operates on sparse tensors more easily.
389   StringAttr loopTag;
390   /// Whether the loop emitter needs to treat the last tensor as the output
391   /// tensor.
392   bool hasOutput;
393   bool isSparseOut;
394   SparseEmitStrategy emitStrategy;
395 
396   //
397   // Fields which have `numTensor` many entries.
398   //
399 
400   /// Input and (optional) output tensors.
401   std::vector<Value> tensors;
402   std::vector<Value> loopHighs;
403   std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
404   std::vector<std::vector<std::vector<std::unique_ptr<SparseIterator>>>> iters;
405   std::vector<Value> valBuffer; // to_value
406 
407   // Map from [tid, level] to a list of dependent [tidlevel, coefficient].
408   // See comments for `DependentLvlGetter`.
409   std::vector<std::vector<std::vector<std::pair<LoopId, unsigned>>>>
410       dependentLvlMap;
411 
412   // The (size, stride) for each conceptual slice used for index reduction
413   // loops.
414   std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta;
415 
416   // The number of reduced dependencies on a tensor level so far.
417   std::vector<std::vector<unsigned>> levelReducedDep;
418 
419   //
420   // Fields which have at most `numLoops` many entries.
421   //
422 
423   /// Loop Stack, stores the information of all the nested loops that are
424   /// alive.
425   std::vector<LoopInfo> loopStack;
426 
427   // Loop Sequence Stack, stores the universal index for the current loop
428   // sequence. and a list of tid level that the loop sequence traverse.
429   std::vector<std::pair<Value, std::vector<TensorLevel>>> loopSeqStack;
430 
431   //
432   // EXPERIMENTAL:
433   // Fields for generating sparse-iterator-based loop.
434   //
435 
436   std::vector<std::vector<Value>> spIterVals;
437 };
438 
439 //
440 // Utils functions to generate sparse loops.
441 //
442 
443 // Generate a while loop that co-iterates over a set of iterators.
444 std::pair<Operation *, Value> genCoIteration(OpBuilder &builder, Location loc,
445                                              ArrayRef<SparseIterator *> iters,
446                                              MutableArrayRef<Value> reduc,
447                                              Value uniIdx,
448                                              bool userReducFirst = false);
449 
450 } // namespace sparse_tensor
451 } // namespace mlir
452 
453 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_
454