xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp (revision 4a653b4df5d84c4d2df8f6d4040ef46413ac3816)
1365777ecSAart Bik //===- CodegenEnv.cpp -  Code generation environment class ----------------===//
2365777ecSAart Bik //
3365777ecSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4365777ecSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5365777ecSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6365777ecSAart Bik //
7365777ecSAart Bik //===----------------------------------------------------------------------===//
8365777ecSAart Bik 
9365777ecSAart Bik #include "CodegenEnv.h"
10365777ecSAart Bik 
11365777ecSAart Bik #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12365777ecSAart Bik #include "mlir/Dialect/Linalg/Utils/Utils.h"
13365777ecSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14365777ecSAart Bik #include "mlir/Dialect/Tensor/IR/Tensor.h"
15365777ecSAart Bik 
16365777ecSAart Bik #include <optional>
17365777ecSAart Bik 
18365777ecSAart Bik using namespace mlir;
19365777ecSAart Bik using namespace mlir::sparse_tensor;
20365777ecSAart Bik 
21365777ecSAart Bik //===----------------------------------------------------------------------===//
22365777ecSAart Bik // Code generation environment helper functions
23365777ecSAart Bik //===----------------------------------------------------------------------===//
24365777ecSAart Bik 
25365777ecSAart Bik /// Returns true if tensor materializes uninitialized into the computation.
isMaterializing(Value val)26365777ecSAart Bik static bool isMaterializing(Value val) {
27365777ecSAart Bik   return val.getDefiningOp<tensor::EmptyOp>() ||
28365777ecSAart Bik          val.getDefiningOp<bufferization::AllocTensorOp>();
29365777ecSAart Bik }
30365777ecSAart Bik 
31365777ecSAart Bik /// Sorts the dependent loops such that it is ordered in the same sequence in
32365777ecSAart Bik /// which loops will be generated.
sortDependentLoops(std::vector<LoopCoeffPair> & target)33365777ecSAart Bik static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
34365777ecSAart Bik   std::sort(target.begin(), target.end(),
35365777ecSAart Bik             [](const LoopCoeffPair &l, const LoopCoeffPair &r) {
36365777ecSAart Bik               assert(std::addressof(l) == std::addressof(r) || l != r);
37365777ecSAart Bik               return l.first < r.first;
38365777ecSAart Bik             });
39365777ecSAart Bik }
40365777ecSAart Bik //===----------------------------------------------------------------------===//
41365777ecSAart Bik // Code generation environment constructor and general methods
42365777ecSAart Bik //===----------------------------------------------------------------------===//
43365777ecSAart Bik 
CodegenEnv(linalg::GenericOp linop,SparsificationOptions opts,unsigned numTensors,unsigned numLoops,unsigned maxRank)44365777ecSAart Bik CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
45365777ecSAart Bik                        unsigned numTensors, unsigned numLoops, unsigned maxRank)
46365777ecSAart Bik     : linalgOp(linop), sparseOptions(opts),
47365777ecSAart Bik       latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
48365777ecSAart Bik       sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
49365777ecSAart Bik       expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId),
50365777ecSAart Bik       redCustom(detail::kInvalidId), redValidLexInsert() {}
51365777ecSAart Bik 
initTensorExp()52365777ecSAart Bik LogicalResult CodegenEnv::initTensorExp() {
53365777ecSAart Bik   // Builds the tensor expression for the Linalg operation in SSA form.
54365777ecSAart Bik   std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op());
55365777ecSAart Bik   if (!optExp || !isAdmissibleTensorExp(*optExp))
56365777ecSAart Bik     return failure();
57365777ecSAart Bik 
58365777ecSAart Bik   tensorExp = *optExp;
59365777ecSAart Bik   return success();
60365777ecSAart Bik }
61365777ecSAart Bik 
startEmit(SparseEmitStrategy emitStrategy)62*4a653b4dSPeiming Liu void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) {
63365777ecSAart Bik   assert(insChain == nullptr && "must only start emitting once");
64365777ecSAart Bik   if (sparseOut) {
65365777ecSAart Bik     insChain = sparseOut->get();
66365777ecSAart Bik     latticeMerger.setHasSparseOut(true);
67365777ecSAart Bik   }
68365777ecSAart Bik 
69365777ecSAart Bik   // Sort the related loop array such that they are in the same order as they
70365777ecSAart Bik   // appears on the topoOrder.
71365777ecSAart Bik   // TODO: since we only handle affine addition for slice based codegen, and
72365777ecSAart Bik   // addition is assoicative, the order how we evaluate the expression does
73365777ecSAart Bik   // not matter. However, to support multiplication, the order of the loop
74365777ecSAart Bik   // index should match the evaluation order to the affine expression AST.
75365777ecSAart Bik 
76365777ecSAart Bik   // Initialize loop emitter.
77365777ecSAart Bik   SmallVector<Value> tensors; // input tensors passed to loop emitter
78365777ecSAart Bik   for (OpOperand &t : linalgOp->getOpOperands()) {
79365777ecSAart Bik     tensors.push_back(t.get());
80365777ecSAart Bik     const TensorId tid = makeTensorId(t.getOperandNumber());
81365777ecSAart Bik     const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
82365777ecSAart Bik     const auto enc = getSparseTensorEncoding(t.get().getType());
83365777ecSAart Bik     (void)enc;
84365777ecSAart Bik     assert(!enc || lvlRank == enc.getLvlRank());
85365777ecSAart Bik     for (Level lvl = 0; lvl < lvlRank; lvl++)
86365777ecSAart Bik       sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
87365777ecSAart Bik   }
88365777ecSAart Bik   loopEmitter.initialize(
89365777ecSAart Bik       tensors,
90365777ecSAart Bik       StringAttr::get(linalgOp.getContext(),
91365777ecSAart Bik                       linalg::GenericOp::getOperationName()),
92365777ecSAart Bik       /*hasOutput=*/true,
93365777ecSAart Bik       /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(),
94365777ecSAart Bik       // TODO: compute the map and pass it to loop emitter directly instead of
95365777ecSAart Bik       // passing in a callback.
96365777ecSAart Bik       /*dependentLvlGetter=*/
97cf4dd911SPeiming Liu       [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
98cf4dd911SPeiming Liu         return merger().getDependentLoops(t, lvl);
99*4a653b4dSPeiming Liu       },
100*4a653b4dSPeiming Liu       emitStrategy);
101365777ecSAart Bik }
102365777ecSAart Bik 
genLoopBoundary(function_ref<std::optional<Operation * > (MutableArrayRef<Value> parameters)> callback)103365777ecSAart Bik std::optional<Operation *> CodegenEnv::genLoopBoundary(
104365777ecSAart Bik     function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)>
105365777ecSAart Bik         callback) {
106365777ecSAart Bik   SmallVector<Value> params;
107365777ecSAart Bik   if (isReduc()) {
108365777ecSAart Bik     params.push_back(redVal);
109365777ecSAart Bik     if (isValidLexInsert())
110365777ecSAart Bik       params.push_back(redValidLexInsert);
111365777ecSAart Bik   } else {
112365777ecSAart Bik     assert(!isValidLexInsert());
113365777ecSAart Bik   }
114365777ecSAart Bik   if (isExpand())
115365777ecSAart Bik     params.push_back(expCount);
116365777ecSAart Bik   if (insChain != nullptr)
117365777ecSAart Bik     params.push_back(insChain);
118365777ecSAart Bik   auto r = callback(params); // may update parameters
119365777ecSAart Bik   unsigned i = 0;
120365777ecSAart Bik   if (isReduc()) {
121365777ecSAart Bik     updateReduc(params[i++]);
122365777ecSAart Bik     if (isValidLexInsert())
123365777ecSAart Bik       updateValidLexInsert(params[i++]);
124365777ecSAart Bik   }
125365777ecSAart Bik   if (isExpand())
126365777ecSAart Bik     updateExpandCount(params[i++]);
127365777ecSAart Bik   if (insChain != nullptr)
128365777ecSAart Bik     updateInsertionChain(params[i]);
129365777ecSAart Bik   return r;
130365777ecSAart Bik }
131365777ecSAart Bik 
132365777ecSAart Bik //===----------------------------------------------------------------------===//
133365777ecSAart Bik // Code generation environment verify functions.
134365777ecSAart Bik //===----------------------------------------------------------------------===//
135365777ecSAart Bik 
isAdmissibleTensorExp(ExprId exp)136365777ecSAart Bik bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
137365777ecSAart Bik   // We reject any expression that makes a reduction from `-outTensor`, as those
138365777ecSAart Bik   // expressions create a dependency between the current iteration (i) and the
139365777ecSAart Bik   // previous iteration (i-1). It would require iterating over the whole
140365777ecSAart Bik   // coordinate space, which prevent exploiting sparsity for faster code.
141365777ecSAart Bik   for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) {
142365777ecSAart Bik     if (it == utils::IteratorType::reduction) {
143365777ecSAart Bik       if (latticeMerger.hasNegateOnOut(exp))
144365777ecSAart Bik         return false;
145365777ecSAart Bik       break;
146365777ecSAart Bik     }
147365777ecSAart Bik   }
148365777ecSAart Bik 
149365777ecSAart Bik   OpOperand *lhs = linalgOp.getDpsInitOperand(0);
150365777ecSAart Bik   const TensorId tensor = makeTensorId(lhs->getOperandNumber());
151365777ecSAart Bik   // An non-annotated output tensor is assumed dense, and becomes a random
152365777ecSAart Bik   // access n-dim memref. Admissible since insertions cannot occur.
153365777ecSAart Bik   if (getSparseTensorType(lhs->get()).isAllDense())
154365777ecSAart Bik     return true;
155365777ecSAart Bik 
156365777ecSAart Bik   // A tensor expression with a sparse output tensor that changes its values
157365777ecSAart Bik   // but not its nonzero structure, an operation called "simply dynamic" in
158365777ecSAart Bik   // [Bik96,Ch9], is also admissible without special env.
159365777ecSAart Bik   if (latticeMerger.isSingleCondition(tensor, exp))
160365777ecSAart Bik     return true;
161365777ecSAart Bik 
162365777ecSAart Bik   // Accept "truly dynamic" if the output tensor materializes uninitialized
163365777ecSAart Bik   // into the computation and insertions occur in lexicographic index order.
164365777ecSAart Bik   sparseOut = lhs;
165365777ecSAart Bik 
166365777ecSAart Bik   // Find the outermost parallel nest to determine whether compress/expand is
167365777ecSAart Bik   // needed.
168365777ecSAart Bik   outerParNest = 0;
169365777ecSAart Bik   const auto iteratorTypes = linalgOp.getIteratorTypesArray();
170365777ecSAart Bik   for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
171365777ecSAart Bik     if (linalg::isReductionIterator(iteratorTypes[i]))
172365777ecSAart Bik       break; // terminate at first reduction
173365777ecSAart Bik     outerParNest++;
174365777ecSAart Bik   }
175365777ecSAart Bik 
176365777ecSAart Bik   // Inadmissible kernel should have already been rejected by the previous
177365777ecSAart Bik   // path during loop scheduling.
178365777ecSAart Bik   assert(static_cast<int64_t>(outerParNest) >=
179365777ecSAart Bik          linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
180365777ecSAart Bik   return isMaterializing(lhs->get());
181365777ecSAart Bik }
182365777ecSAart Bik 
183365777ecSAart Bik //===----------------------------------------------------------------------===//
184365777ecSAart Bik // Code generation environment topological sort methods
185365777ecSAart Bik //===----------------------------------------------------------------------===//
186365777ecSAart Bik 
getLoopVar(LoopId i) const187365777ecSAart Bik Value CodegenEnv::getLoopVar(LoopId i) const {
188365777ecSAart Bik   return loopEmitter.getLoopIV(i);
189365777ecSAart Bik }
190365777ecSAart Bik 
191365777ecSAart Bik //===----------------------------------------------------------------------===//
192365777ecSAart Bik // Code generation environment sparse tensor output and expansion methods
193365777ecSAart Bik //===----------------------------------------------------------------------===//
194365777ecSAart Bik 
updateInsertionChain(Value chain)195365777ecSAart Bik void CodegenEnv::updateInsertionChain(Value chain) {
196365777ecSAart Bik   assert(sparseOut != nullptr && insChain != nullptr);
197365777ecSAart Bik   insChain = chain;
198365777ecSAart Bik }
199365777ecSAart Bik 
atExpandLevel(OpOperand * o,unsigned rank,LoopId n) const200365777ecSAart Bik bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const {
201365777ecSAart Bik   return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) &&
202365777ecSAart Bik          outerParNest == n;
203365777ecSAart Bik }
204365777ecSAart Bik 
startExpand(Value values,Value filled,Value added,Value count)205365777ecSAart Bik void CodegenEnv::startExpand(Value values, Value filled, Value added,
206365777ecSAart Bik                              Value count) {
207365777ecSAart Bik   assert(sparseOut != nullptr && expValues == nullptr);
208365777ecSAart Bik   expValues = values;
209365777ecSAart Bik   expFilled = filled;
210365777ecSAart Bik   expAdded = added;
211365777ecSAart Bik   expCount = count;
212365777ecSAart Bik }
213365777ecSAart Bik 
updateExpandCount(Value count)214365777ecSAart Bik void CodegenEnv::updateExpandCount(Value count) {
215365777ecSAart Bik   assert(sparseOut != nullptr && expValues != nullptr);
216365777ecSAart Bik   expCount = count;
217365777ecSAart Bik }
218365777ecSAart Bik 
endExpand()219365777ecSAart Bik void CodegenEnv::endExpand() {
220365777ecSAart Bik   assert(sparseOut != nullptr && expValues != nullptr);
221365777ecSAart Bik   expValues = expFilled = expAdded = expCount = Value();
222365777ecSAart Bik }
223365777ecSAart Bik 
224365777ecSAart Bik //===----------------------------------------------------------------------===//
225365777ecSAart Bik // Code generation environment reduction methods
226365777ecSAart Bik //===----------------------------------------------------------------------===//
227365777ecSAart Bik 
startReduc(ExprId exp,Value val)228365777ecSAart Bik void CodegenEnv::startReduc(ExprId exp, Value val) {
229365777ecSAart Bik   assert(!isReduc() && exp != detail::kInvalidId && val);
230365777ecSAart Bik   redExp = exp;
231365777ecSAart Bik   redVal = val;
232365777ecSAart Bik   latticeMerger.setExprValue(exp, val);
233365777ecSAart Bik }
234365777ecSAart Bik 
updateReduc(Value val)235365777ecSAart Bik void CodegenEnv::updateReduc(Value val) {
236365777ecSAart Bik   assert(isReduc() && val);
237365777ecSAart Bik   redVal = val;
238365777ecSAart Bik   latticeMerger.clearExprValue(redExp);
239365777ecSAart Bik   latticeMerger.setExprValue(redExp, val);
240365777ecSAart Bik }
241365777ecSAart Bik 
endReduc()242365777ecSAart Bik Value CodegenEnv::endReduc() {
243365777ecSAart Bik   assert(isReduc());
244365777ecSAart Bik   Value val = redVal;
245365777ecSAart Bik   redVal = val;
246365777ecSAart Bik   latticeMerger.clearExprValue(redExp);
247365777ecSAart Bik   redExp = detail::kInvalidId;
248365777ecSAart Bik   return val;
249365777ecSAart Bik }
250365777ecSAart Bik 
startValidLexInsert(Value val)251365777ecSAart Bik void CodegenEnv::startValidLexInsert(Value val) {
252365777ecSAart Bik   assert(!isValidLexInsert() && isReduc() && val);
253365777ecSAart Bik   redValidLexInsert = val;
254365777ecSAart Bik }
255365777ecSAart Bik 
updateValidLexInsert(Value val)256365777ecSAart Bik void CodegenEnv::updateValidLexInsert(Value val) {
257365777ecSAart Bik   assert(redValidLexInsert && isReduc() && val);
258365777ecSAart Bik   redValidLexInsert = val;
259365777ecSAart Bik }
260365777ecSAart Bik 
endValidLexInsert()261365777ecSAart Bik void CodegenEnv::endValidLexInsert() {
262365777ecSAart Bik   assert(isValidLexInsert() && !isReduc());
263365777ecSAart Bik   redValidLexInsert = Value();
264365777ecSAart Bik }
265365777ecSAart Bik 
startCustomReduc(ExprId exp)266365777ecSAart Bik void CodegenEnv::startCustomReduc(ExprId exp) {
267365777ecSAart Bik   assert(!isCustomReduc() && exp != detail::kInvalidId);
268365777ecSAart Bik   redCustom = exp;
269365777ecSAart Bik }
270365777ecSAart Bik 
getCustomRedId() const271365777ecSAart Bik Value CodegenEnv::getCustomRedId() const {
272365777ecSAart Bik   assert(isCustomReduc());
273365777ecSAart Bik   return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
274365777ecSAart Bik }
275365777ecSAart Bik 
endCustomReduc()276365777ecSAart Bik void CodegenEnv::endCustomReduc() {
277365777ecSAart Bik   assert(isCustomReduc());
278365777ecSAart Bik   redCustom = detail::kInvalidId;
279365777ecSAart Bik }
280