xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp (revision 4a653b4df5d84c4d2df8f6d4040ef46413ac3816)
1 //===- CodegenEnv.cpp -  Code generation environment class ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenEnv.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Linalg/Utils/Utils.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 
16 #include <optional>
17 
18 using namespace mlir;
19 using namespace mlir::sparse_tensor;
20 
21 //===----------------------------------------------------------------------===//
22 // Code generation environment helper functions
23 //===----------------------------------------------------------------------===//
24 
25 /// Returns true if tensor materializes uninitialized into the computation.
isMaterializing(Value val)26 static bool isMaterializing(Value val) {
27   return val.getDefiningOp<tensor::EmptyOp>() ||
28          val.getDefiningOp<bufferization::AllocTensorOp>();
29 }
30 
31 /// Sorts the dependent loops such that it is ordered in the same sequence in
32 /// which loops will be generated.
sortDependentLoops(std::vector<LoopCoeffPair> & target)33 static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
34   std::sort(target.begin(), target.end(),
35             [](const LoopCoeffPair &l, const LoopCoeffPair &r) {
36               assert(std::addressof(l) == std::addressof(r) || l != r);
37               return l.first < r.first;
38             });
39 }
40 //===----------------------------------------------------------------------===//
41 // Code generation environment constructor and general methods
42 //===----------------------------------------------------------------------===//
43 
CodegenEnv(linalg::GenericOp linop,SparsificationOptions opts,unsigned numTensors,unsigned numLoops,unsigned maxRank)44 CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
45                        unsigned numTensors, unsigned numLoops, unsigned maxRank)
46     : linalgOp(linop), sparseOptions(opts),
47       latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
48       sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
49       expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId),
50       redCustom(detail::kInvalidId), redValidLexInsert() {}
51 
initTensorExp()52 LogicalResult CodegenEnv::initTensorExp() {
53   // Builds the tensor expression for the Linalg operation in SSA form.
54   std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op());
55   if (!optExp || !isAdmissibleTensorExp(*optExp))
56     return failure();
57 
58   tensorExp = *optExp;
59   return success();
60 }
61 
startEmit(SparseEmitStrategy emitStrategy)62 void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) {
63   assert(insChain == nullptr && "must only start emitting once");
64   if (sparseOut) {
65     insChain = sparseOut->get();
66     latticeMerger.setHasSparseOut(true);
67   }
68 
69   // Sort the related loop array such that they are in the same order as they
70   // appears on the topoOrder.
71   // TODO: since we only handle affine addition for slice based codegen, and
72   // addition is assoicative, the order how we evaluate the expression does
73   // not matter. However, to support multiplication, the order of the loop
74   // index should match the evaluation order to the affine expression AST.
75 
76   // Initialize loop emitter.
77   SmallVector<Value> tensors; // input tensors passed to loop emitter
78   for (OpOperand &t : linalgOp->getOpOperands()) {
79     tensors.push_back(t.get());
80     const TensorId tid = makeTensorId(t.getOperandNumber());
81     const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
82     const auto enc = getSparseTensorEncoding(t.get().getType());
83     (void)enc;
84     assert(!enc || lvlRank == enc.getLvlRank());
85     for (Level lvl = 0; lvl < lvlRank; lvl++)
86       sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
87   }
88   loopEmitter.initialize(
89       tensors,
90       StringAttr::get(linalgOp.getContext(),
91                       linalg::GenericOp::getOperationName()),
92       /*hasOutput=*/true,
93       /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(),
94       // TODO: compute the map and pass it to loop emitter directly instead of
95       // passing in a callback.
96       /*dependentLvlGetter=*/
97       [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
98         return merger().getDependentLoops(t, lvl);
99       },
100       emitStrategy);
101 }
102 
genLoopBoundary(function_ref<std::optional<Operation * > (MutableArrayRef<Value> parameters)> callback)103 std::optional<Operation *> CodegenEnv::genLoopBoundary(
104     function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)>
105         callback) {
106   SmallVector<Value> params;
107   if (isReduc()) {
108     params.push_back(redVal);
109     if (isValidLexInsert())
110       params.push_back(redValidLexInsert);
111   } else {
112     assert(!isValidLexInsert());
113   }
114   if (isExpand())
115     params.push_back(expCount);
116   if (insChain != nullptr)
117     params.push_back(insChain);
118   auto r = callback(params); // may update parameters
119   unsigned i = 0;
120   if (isReduc()) {
121     updateReduc(params[i++]);
122     if (isValidLexInsert())
123       updateValidLexInsert(params[i++]);
124   }
125   if (isExpand())
126     updateExpandCount(params[i++]);
127   if (insChain != nullptr)
128     updateInsertionChain(params[i]);
129   return r;
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // Code generation environment verify functions.
134 //===----------------------------------------------------------------------===//
135 
isAdmissibleTensorExp(ExprId exp)136 bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
137   // We reject any expression that makes a reduction from `-outTensor`, as those
138   // expressions create a dependency between the current iteration (i) and the
139   // previous iteration (i-1). It would require iterating over the whole
140   // coordinate space, which prevent exploiting sparsity for faster code.
141   for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) {
142     if (it == utils::IteratorType::reduction) {
143       if (latticeMerger.hasNegateOnOut(exp))
144         return false;
145       break;
146     }
147   }
148 
149   OpOperand *lhs = linalgOp.getDpsInitOperand(0);
150   const TensorId tensor = makeTensorId(lhs->getOperandNumber());
151   // An non-annotated output tensor is assumed dense, and becomes a random
152   // access n-dim memref. Admissible since insertions cannot occur.
153   if (getSparseTensorType(lhs->get()).isAllDense())
154     return true;
155 
156   // A tensor expression with a sparse output tensor that changes its values
157   // but not its nonzero structure, an operation called "simply dynamic" in
158   // [Bik96,Ch9], is also admissible without special env.
159   if (latticeMerger.isSingleCondition(tensor, exp))
160     return true;
161 
162   // Accept "truly dynamic" if the output tensor materializes uninitialized
163   // into the computation and insertions occur in lexicographic index order.
164   sparseOut = lhs;
165 
166   // Find the outermost parallel nest to determine whether compress/expand is
167   // needed.
168   outerParNest = 0;
169   const auto iteratorTypes = linalgOp.getIteratorTypesArray();
170   for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
171     if (linalg::isReductionIterator(iteratorTypes[i]))
172       break; // terminate at first reduction
173     outerParNest++;
174   }
175 
176   // Inadmissible kernel should have already been rejected by the previous
177   // path during loop scheduling.
178   assert(static_cast<int64_t>(outerParNest) >=
179          linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
180   return isMaterializing(lhs->get());
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // Code generation environment topological sort methods
185 //===----------------------------------------------------------------------===//
186 
getLoopVar(LoopId i) const187 Value CodegenEnv::getLoopVar(LoopId i) const {
188   return loopEmitter.getLoopIV(i);
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // Code generation environment sparse tensor output and expansion methods
193 //===----------------------------------------------------------------------===//
194 
updateInsertionChain(Value chain)195 void CodegenEnv::updateInsertionChain(Value chain) {
196   assert(sparseOut != nullptr && insChain != nullptr);
197   insChain = chain;
198 }
199 
atExpandLevel(OpOperand * o,unsigned rank,LoopId n) const200 bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const {
201   return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) &&
202          outerParNest == n;
203 }
204 
startExpand(Value values,Value filled,Value added,Value count)205 void CodegenEnv::startExpand(Value values, Value filled, Value added,
206                              Value count) {
207   assert(sparseOut != nullptr && expValues == nullptr);
208   expValues = values;
209   expFilled = filled;
210   expAdded = added;
211   expCount = count;
212 }
213 
updateExpandCount(Value count)214 void CodegenEnv::updateExpandCount(Value count) {
215   assert(sparseOut != nullptr && expValues != nullptr);
216   expCount = count;
217 }
218 
endExpand()219 void CodegenEnv::endExpand() {
220   assert(sparseOut != nullptr && expValues != nullptr);
221   expValues = expFilled = expAdded = expCount = Value();
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Code generation environment reduction methods
226 //===----------------------------------------------------------------------===//
227 
startReduc(ExprId exp,Value val)228 void CodegenEnv::startReduc(ExprId exp, Value val) {
229   assert(!isReduc() && exp != detail::kInvalidId && val);
230   redExp = exp;
231   redVal = val;
232   latticeMerger.setExprValue(exp, val);
233 }
234 
updateReduc(Value val)235 void CodegenEnv::updateReduc(Value val) {
236   assert(isReduc() && val);
237   redVal = val;
238   latticeMerger.clearExprValue(redExp);
239   latticeMerger.setExprValue(redExp, val);
240 }
241 
endReduc()242 Value CodegenEnv::endReduc() {
243   assert(isReduc());
244   Value val = redVal;
245   redVal = val;
246   latticeMerger.clearExprValue(redExp);
247   redExp = detail::kInvalidId;
248   return val;
249 }
250 
startValidLexInsert(Value val)251 void CodegenEnv::startValidLexInsert(Value val) {
252   assert(!isValidLexInsert() && isReduc() && val);
253   redValidLexInsert = val;
254 }
255 
updateValidLexInsert(Value val)256 void CodegenEnv::updateValidLexInsert(Value val) {
257   assert(redValidLexInsert && isReduc() && val);
258   redValidLexInsert = val;
259 }
260 
endValidLexInsert()261 void CodegenEnv::endValidLexInsert() {
262   assert(isValidLexInsert() && !isReduc());
263   redValidLexInsert = Value();
264 }
265 
startCustomReduc(ExprId exp)266 void CodegenEnv::startCustomReduc(ExprId exp) {
267   assert(!isCustomReduc() && exp != detail::kInvalidId);
268   redCustom = exp;
269 }
270 
getCustomRedId() const271 Value CodegenEnv::getCustomRedId() const {
272   assert(isCustomReduc());
273   return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
274 }
275 
endCustomReduc()276 void CodegenEnv::endCustomReduc() {
277   assert(isCustomReduc());
278   redCustom = detail::kInvalidId;
279 }
280