1 //===- CodegenEnv.cpp - Code generation environment class ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "CodegenEnv.h"
10
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Linalg/Utils/Utils.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15
16 #include <optional>
17
18 using namespace mlir;
19 using namespace mlir::sparse_tensor;
20
21 //===----------------------------------------------------------------------===//
22 // Code generation environment helper functions
23 //===----------------------------------------------------------------------===//
24
25 /// Returns true if tensor materializes uninitialized into the computation.
isMaterializing(Value val)26 static bool isMaterializing(Value val) {
27 return val.getDefiningOp<tensor::EmptyOp>() ||
28 val.getDefiningOp<bufferization::AllocTensorOp>();
29 }
30
31 /// Sorts the dependent loops such that it is ordered in the same sequence in
32 /// which loops will be generated.
sortDependentLoops(std::vector<LoopCoeffPair> & target)33 static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
34 std::sort(target.begin(), target.end(),
35 [](const LoopCoeffPair &l, const LoopCoeffPair &r) {
36 assert(std::addressof(l) == std::addressof(r) || l != r);
37 return l.first < r.first;
38 });
39 }
40 //===----------------------------------------------------------------------===//
41 // Code generation environment constructor and general methods
42 //===----------------------------------------------------------------------===//
43
CodegenEnv(linalg::GenericOp linop,SparsificationOptions opts,unsigned numTensors,unsigned numLoops,unsigned maxRank)44 CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
45 unsigned numTensors, unsigned numLoops, unsigned maxRank)
46 : linalgOp(linop), sparseOptions(opts),
47 latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
48 sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
49 expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId),
50 redCustom(detail::kInvalidId), redValidLexInsert() {}
51
initTensorExp()52 LogicalResult CodegenEnv::initTensorExp() {
53 // Builds the tensor expression for the Linalg operation in SSA form.
54 std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op());
55 if (!optExp || !isAdmissibleTensorExp(*optExp))
56 return failure();
57
58 tensorExp = *optExp;
59 return success();
60 }
61
startEmit(SparseEmitStrategy emitStrategy)62 void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) {
63 assert(insChain == nullptr && "must only start emitting once");
64 if (sparseOut) {
65 insChain = sparseOut->get();
66 latticeMerger.setHasSparseOut(true);
67 }
68
69 // Sort the related loop array such that they are in the same order as they
70 // appears on the topoOrder.
71 // TODO: since we only handle affine addition for slice based codegen, and
72 // addition is assoicative, the order how we evaluate the expression does
73 // not matter. However, to support multiplication, the order of the loop
74 // index should match the evaluation order to the affine expression AST.
75
76 // Initialize loop emitter.
77 SmallVector<Value> tensors; // input tensors passed to loop emitter
78 for (OpOperand &t : linalgOp->getOpOperands()) {
79 tensors.push_back(t.get());
80 const TensorId tid = makeTensorId(t.getOperandNumber());
81 const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
82 const auto enc = getSparseTensorEncoding(t.get().getType());
83 (void)enc;
84 assert(!enc || lvlRank == enc.getLvlRank());
85 for (Level lvl = 0; lvl < lvlRank; lvl++)
86 sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
87 }
88 loopEmitter.initialize(
89 tensors,
90 StringAttr::get(linalgOp.getContext(),
91 linalg::GenericOp::getOperationName()),
92 /*hasOutput=*/true,
93 /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(),
94 // TODO: compute the map and pass it to loop emitter directly instead of
95 // passing in a callback.
96 /*dependentLvlGetter=*/
97 [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
98 return merger().getDependentLoops(t, lvl);
99 },
100 emitStrategy);
101 }
102
genLoopBoundary(function_ref<std::optional<Operation * > (MutableArrayRef<Value> parameters)> callback)103 std::optional<Operation *> CodegenEnv::genLoopBoundary(
104 function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)>
105 callback) {
106 SmallVector<Value> params;
107 if (isReduc()) {
108 params.push_back(redVal);
109 if (isValidLexInsert())
110 params.push_back(redValidLexInsert);
111 } else {
112 assert(!isValidLexInsert());
113 }
114 if (isExpand())
115 params.push_back(expCount);
116 if (insChain != nullptr)
117 params.push_back(insChain);
118 auto r = callback(params); // may update parameters
119 unsigned i = 0;
120 if (isReduc()) {
121 updateReduc(params[i++]);
122 if (isValidLexInsert())
123 updateValidLexInsert(params[i++]);
124 }
125 if (isExpand())
126 updateExpandCount(params[i++]);
127 if (insChain != nullptr)
128 updateInsertionChain(params[i]);
129 return r;
130 }
131
132 //===----------------------------------------------------------------------===//
133 // Code generation environment verify functions.
134 //===----------------------------------------------------------------------===//
135
isAdmissibleTensorExp(ExprId exp)136 bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
137 // We reject any expression that makes a reduction from `-outTensor`, as those
138 // expressions create a dependency between the current iteration (i) and the
139 // previous iteration (i-1). It would require iterating over the whole
140 // coordinate space, which prevent exploiting sparsity for faster code.
141 for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) {
142 if (it == utils::IteratorType::reduction) {
143 if (latticeMerger.hasNegateOnOut(exp))
144 return false;
145 break;
146 }
147 }
148
149 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
150 const TensorId tensor = makeTensorId(lhs->getOperandNumber());
151 // An non-annotated output tensor is assumed dense, and becomes a random
152 // access n-dim memref. Admissible since insertions cannot occur.
153 if (getSparseTensorType(lhs->get()).isAllDense())
154 return true;
155
156 // A tensor expression with a sparse output tensor that changes its values
157 // but not its nonzero structure, an operation called "simply dynamic" in
158 // [Bik96,Ch9], is also admissible without special env.
159 if (latticeMerger.isSingleCondition(tensor, exp))
160 return true;
161
162 // Accept "truly dynamic" if the output tensor materializes uninitialized
163 // into the computation and insertions occur in lexicographic index order.
164 sparseOut = lhs;
165
166 // Find the outermost parallel nest to determine whether compress/expand is
167 // needed.
168 outerParNest = 0;
169 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
170 for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
171 if (linalg::isReductionIterator(iteratorTypes[i]))
172 break; // terminate at first reduction
173 outerParNest++;
174 }
175
176 // Inadmissible kernel should have already been rejected by the previous
177 // path during loop scheduling.
178 assert(static_cast<int64_t>(outerParNest) >=
179 linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
180 return isMaterializing(lhs->get());
181 }
182
183 //===----------------------------------------------------------------------===//
184 // Code generation environment topological sort methods
185 //===----------------------------------------------------------------------===//
186
getLoopVar(LoopId i) const187 Value CodegenEnv::getLoopVar(LoopId i) const {
188 return loopEmitter.getLoopIV(i);
189 }
190
191 //===----------------------------------------------------------------------===//
192 // Code generation environment sparse tensor output and expansion methods
193 //===----------------------------------------------------------------------===//
194
updateInsertionChain(Value chain)195 void CodegenEnv::updateInsertionChain(Value chain) {
196 assert(sparseOut != nullptr && insChain != nullptr);
197 insChain = chain;
198 }
199
atExpandLevel(OpOperand * o,unsigned rank,LoopId n) const200 bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const {
201 return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) &&
202 outerParNest == n;
203 }
204
startExpand(Value values,Value filled,Value added,Value count)205 void CodegenEnv::startExpand(Value values, Value filled, Value added,
206 Value count) {
207 assert(sparseOut != nullptr && expValues == nullptr);
208 expValues = values;
209 expFilled = filled;
210 expAdded = added;
211 expCount = count;
212 }
213
updateExpandCount(Value count)214 void CodegenEnv::updateExpandCount(Value count) {
215 assert(sparseOut != nullptr && expValues != nullptr);
216 expCount = count;
217 }
218
endExpand()219 void CodegenEnv::endExpand() {
220 assert(sparseOut != nullptr && expValues != nullptr);
221 expValues = expFilled = expAdded = expCount = Value();
222 }
223
224 //===----------------------------------------------------------------------===//
225 // Code generation environment reduction methods
226 //===----------------------------------------------------------------------===//
227
startReduc(ExprId exp,Value val)228 void CodegenEnv::startReduc(ExprId exp, Value val) {
229 assert(!isReduc() && exp != detail::kInvalidId && val);
230 redExp = exp;
231 redVal = val;
232 latticeMerger.setExprValue(exp, val);
233 }
234
updateReduc(Value val)235 void CodegenEnv::updateReduc(Value val) {
236 assert(isReduc() && val);
237 redVal = val;
238 latticeMerger.clearExprValue(redExp);
239 latticeMerger.setExprValue(redExp, val);
240 }
241
endReduc()242 Value CodegenEnv::endReduc() {
243 assert(isReduc());
244 Value val = redVal;
245 redVal = val;
246 latticeMerger.clearExprValue(redExp);
247 redExp = detail::kInvalidId;
248 return val;
249 }
250
startValidLexInsert(Value val)251 void CodegenEnv::startValidLexInsert(Value val) {
252 assert(!isValidLexInsert() && isReduc() && val);
253 redValidLexInsert = val;
254 }
255
updateValidLexInsert(Value val)256 void CodegenEnv::updateValidLexInsert(Value val) {
257 assert(redValidLexInsert && isReduc() && val);
258 redValidLexInsert = val;
259 }
260
endValidLexInsert()261 void CodegenEnv::endValidLexInsert() {
262 assert(isValidLexInsert() && !isReduc());
263 redValidLexInsert = Value();
264 }
265
startCustomReduc(ExprId exp)266 void CodegenEnv::startCustomReduc(ExprId exp) {
267 assert(!isCustomReduc() && exp != detail::kInvalidId);
268 redCustom = exp;
269 }
270
getCustomRedId() const271 Value CodegenEnv::getCustomRedId() const {
272 assert(isCustomReduc());
273 return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
274 }
275
endCustomReduc()276 void CodegenEnv::endCustomReduc() {
277 assert(isCustomReduc());
278 redCustom = detail::kInvalidId;
279 }
280