xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h (revision c44202574ff9a8c0632aba30c2765b134557435f)
1 //===- CodegenEnv.h - Code generation environment class ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines the code generation environment class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
15 
16 #include "CodegenUtils.h"
17 #include "LoopEmitter.h"
18 
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
22 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
23 #include <optional>
24 
25 namespace mlir {
26 namespace sparse_tensor {
27 
28 /// The code generation environment class aggregates a number of data
29 /// structures that are needed during the code generation phase of
30 /// sparsification. This environment simplifies passing around such
31 /// data during sparsification (rather than passing around all the
32 /// individual compoments where needed). Furthermore, it provides
33 /// convience methods that keep implementation details transparent
34 /// to sparsification while asserting on internal consistency.
35 class CodegenEnv {
36 public:
37   /// Constructs a code generation environment which can be
38   /// passed around during sparsification for bookkeeping
39   /// together with some consistency asserts.
40   CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
41              unsigned numTensors, unsigned numLoops, unsigned maxRank);
42 
43   //
44   // General methods.
45   //
46 
47   LogicalResult initTensorExp();
48   ExprId getExprId() const { return tensorExp; }
49 
50   linalg::GenericOp op() const { return linalgOp; }
51   const SparsificationOptions &options() const { return sparseOptions; }
52   bool generatingSparseIterator() const {
53     return sparseOptions.sparseEmitStrategy ==
54            SparseEmitStrategy::kSparseIterator;
55   }
56   Merger &merger() { return latticeMerger; }
57   LoopEmitter &emitter() { return loopEmitter; }
58 
59   void startEmit(SparseEmitStrategy emitStrategy);
60 
61   /// Generates loop boundary statements (entering/exiting loops). The function
62   /// passes and updates the passed-in parameters.
63   std::optional<Operation *>
64   genLoopBoundary(function_ref<
65                   std::optional<Operation *>(MutableArrayRef<Value> parameters)>
66                       callback);
67 
68   //
69   // Merger delegates.
70   //
71 
72   constexpr TensorId makeTensorId(unsigned t) const {
73     return latticeMerger.makeTensorId(t);
74   }
75   constexpr LoopId makeLoopId(unsigned i) const {
76     return latticeMerger.makeLoopId(i);
77   }
78   constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
79     return latticeMerger.makeTensorLoopId(t, i);
80   }
81   const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); }
82   const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
83   ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
84   LevelType lt(TensorId t, LoopId i) const {
85     return latticeMerger.getLvlType(t, i);
86   }
87   LevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
88 
89   unsigned getLoopNum() const { return latticeMerger.getNumLoops(); }
90 
91   //
92   // LoopEmitter delegates.
93   //
94 
95   TensorLevel makeTensorLevel(TensorId t, Level l) const {
96     // Make sure LoopEmitter, GenericOp, and Merger agree on the number of
97     // tensors.
98     assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() &&
99            loopEmitter.getNumTensors() == latticeMerger.getNumTensors() &&
100            loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() &&
101            loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
102     return loopEmitter.makeTensorLevel(t, l);
103   }
104   TensorLevel makeTensorLevel(std::pair<TensorId, Level> tlPair) const {
105     return makeTensorLevel(tlPair.first, tlPair.second);
106   }
107   std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
108     return loopEmitter.unpackTensorLevel(tl);
109   }
110   template <class ContainerTy>
111   auto unpackTensorLevelRange(ContainerTy &&c) const {
112     return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
113   }
114 
115   unsigned getCurrentDepth() const { return loopEmitter.getCurrentDepth(); }
116 
117   //
118   // Code generation environment verify functions.
119   //
120 
121   /// Whether the tensor expression is admissible for codegen.
122   /// It also sets the sparseOut if the output tensor is sparse.
123   bool isAdmissibleTensorExp(ExprId e);
124 
125   /// Returns the induction-variable for the given loop.
126   Value getLoopVar(LoopId i) const;
127 
128   //
129   // Sparse tensor output and expansion methods.
130   //
131 
132   bool hasSparseOutput() const { return sparseOut != nullptr; }
133   bool isSparseOutput(OpOperand *o) const { return sparseOut == o; }
134 
135   Value getInsertionChain() const { return insChain; }
136   void updateInsertionChain(Value chain);
137 
138   bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const;
139   void startExpand(Value values, Value filled, Value added, Value count);
140   bool isExpand() const { return expValues != nullptr; }
141   void updateExpandCount(Value count);
142   Value getExpandValues() const { return expValues; }
143   Value getExpandFilled() const { return expFilled; }
144   Value getExpandAdded() const { return expAdded; }
145   Value getExpandCount() const { return expCount; }
146   void endExpand();
147 
148   //
149   // Reduction methods.
150   //
151 
152   void startReduc(ExprId exp, Value val);
153   bool isReduc() const { return redExp != detail::kInvalidId; }
154   void updateReduc(Value val);
155   Value getReduc() const { return redVal; }
156   Value endReduc();
157 
158   void startValidLexInsert(Value val);
159   bool isValidLexInsert() const { return redValidLexInsert != nullptr; }
160   void updateValidLexInsert(Value val);
161   Value getValidLexInsert() const { return redValidLexInsert; }
162   void endValidLexInsert();
163 
164   void startCustomReduc(ExprId exp);
165   bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
166   Value getCustomRedId() const;
167   void endCustomReduc();
168 
169 private:
170   // Linalg operation.
171   linalg::GenericOp linalgOp;
172 
173   // Sparsification options.
174   SparsificationOptions sparseOptions;
175 
176   // Merger helper class.
177   Merger latticeMerger;
178 
179   // Loop emitter helper class.
180   LoopEmitter loopEmitter;
181 
182   // Sparse tensor as output. Implemented either through direct injective
183   // insertion in lexicographic index order or through access pattern
184   // expansion in the innermost loop nest (`expValues` through `expCount`).
185   OpOperand *sparseOut;
186   // The count of outer non-filter loops, as defined by `isAdmissibleTopoOrder`.
187   LoopId outerParNest;
188   Value insChain;
189   Value expValues;
190   Value expFilled;
191   Value expAdded;
192   Value expCount;
193 
194   // Bookkeeping for reductions (up-to-date value of the reduction, and indices
195   // into the merger's expression tree. When the indices of a tensor reduction
196   // expression are exhausted, all inner loops can use a scalarized reduction.
197   Value redVal;
198   ExprId redExp;
199   ExprId redCustom;
200 
201   // Bookkeeping for lex insertion during reductions. Holds the runtime boolean
202   // value of whether any reduction occurred. This is only set during a
203   // reduction and cleared once the reduction is finished.
204   Value redValidLexInsert;
205 
206   // The root tensor expression of the kernel.
207   ExprId tensorExp;
208 };
209 
210 } // namespace sparse_tensor
211 } // namespace mlir
212 
213 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
214