1 //===- CodegenEnv.cpp - Code generation environment class ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenEnv.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Linalg/Utils/Utils.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 16 #include <optional> 17 18 using namespace mlir; 19 using namespace mlir::sparse_tensor; 20 21 //===----------------------------------------------------------------------===// 22 // Code generation environment helper functions 23 //===----------------------------------------------------------------------===// 24 25 /// Returns true if tensor materializes uninitialized into the computation. 26 static bool isMaterializing(Value val) { 27 return val.getDefiningOp<tensor::EmptyOp>() || 28 val.getDefiningOp<bufferization::AllocTensorOp>(); 29 } 30 31 /// Sorts the dependent loops such that it is ordered in the same sequence in 32 /// which loops will be generated. 33 static void sortDependentLoops(std::vector<LoopCoeffPair> &target) { 34 std::sort(target.begin(), target.end(), 35 [](const LoopCoeffPair &l, const LoopCoeffPair &r) { 36 assert(std::addressof(l) == std::addressof(r) || l != r); 37 return l.first < r.first; 38 }); 39 } 40 //===----------------------------------------------------------------------===// 41 // Code generation environment constructor and general methods 42 //===----------------------------------------------------------------------===// 43 44 CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, 45 unsigned numTensors, unsigned numLoops, unsigned maxRank) 46 : linalgOp(linop), sparseOptions(opts), 47 latticeMerger(numTensors, numLoops, maxRank), loopEmitter(), 48 sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), 49 expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId), 50 redCustom(detail::kInvalidId), redValidLexInsert() {} 51 52 LogicalResult CodegenEnv::initTensorExp() { 53 // Builds the tensor expression for the Linalg operation in SSA form. 54 std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op()); 55 if (!optExp || !isAdmissibleTensorExp(*optExp)) 56 return failure(); 57 58 tensorExp = *optExp; 59 return success(); 60 } 61 62 void CodegenEnv::startEmit() { 63 assert(insChain == nullptr && "must only start emitting once"); 64 if (sparseOut) { 65 insChain = sparseOut->get(); 66 latticeMerger.setHasSparseOut(true); 67 } 68 69 // Sort the related loop array such that they are in the same order as they 70 // appears on the topoOrder. 71 // TODO: since we only handle affine addition for slice based codegen, and 72 // addition is assoicative, the order how we evaluate the expression does 73 // not matter. However, to support multiplication, the order of the loop 74 // index should match the evaluation order to the affine expression AST. 75 76 // Initialize loop emitter. 77 SmallVector<Value> tensors; // input tensors passed to loop emitter 78 for (OpOperand &t : linalgOp->getOpOperands()) { 79 tensors.push_back(t.get()); 80 const TensorId tid = makeTensorId(t.getOperandNumber()); 81 const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults(); 82 const auto enc = getSparseTensorEncoding(t.get().getType()); 83 (void)enc; 84 assert(!enc || lvlRank == enc.getLvlRank()); 85 for (Level lvl = 0; lvl < lvlRank; lvl++) 86 sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl)); 87 } 88 loopEmitter.initialize( 89 tensors, 90 StringAttr::get(linalgOp.getContext(), 91 linalg::GenericOp::getOperationName()), 92 /*hasOutput=*/true, 93 /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(), 94 // TODO: compute the map and pass it to loop emitter directly instead of 95 // passing in a callback. 96 /*dependentLvlGetter=*/ 97 [this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> { 98 return merger().getDependentLoops(t, lvl); 99 }); 100 } 101 102 std::optional<Operation *> CodegenEnv::genLoopBoundary( 103 function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)> 104 callback) { 105 SmallVector<Value> params; 106 if (isReduc()) { 107 params.push_back(redVal); 108 if (isValidLexInsert()) 109 params.push_back(redValidLexInsert); 110 } else { 111 assert(!isValidLexInsert()); 112 } 113 if (isExpand()) 114 params.push_back(expCount); 115 if (insChain != nullptr) 116 params.push_back(insChain); 117 auto r = callback(params); // may update parameters 118 unsigned i = 0; 119 if (isReduc()) { 120 updateReduc(params[i++]); 121 if (isValidLexInsert()) 122 updateValidLexInsert(params[i++]); 123 } 124 if (isExpand()) 125 updateExpandCount(params[i++]); 126 if (insChain != nullptr) 127 updateInsertionChain(params[i]); 128 return r; 129 } 130 131 //===----------------------------------------------------------------------===// 132 // Code generation environment verify functions. 133 //===----------------------------------------------------------------------===// 134 135 bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) { 136 // We reject any expression that makes a reduction from `-outTensor`, as those 137 // expressions create a dependency between the current iteration (i) and the 138 // previous iteration (i-1). It would require iterating over the whole 139 // coordinate space, which prevent exploiting sparsity for faster code. 140 for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) { 141 if (it == utils::IteratorType::reduction) { 142 if (latticeMerger.hasNegateOnOut(exp)) 143 return false; 144 break; 145 } 146 } 147 148 OpOperand *lhs = linalgOp.getDpsInitOperand(0); 149 const TensorId tensor = makeTensorId(lhs->getOperandNumber()); 150 // An non-annotated output tensor is assumed dense, and becomes a random 151 // access n-dim memref. Admissible since insertions cannot occur. 152 if (getSparseTensorType(lhs->get()).isAllDense()) 153 return true; 154 155 // A tensor expression with a sparse output tensor that changes its values 156 // but not its nonzero structure, an operation called "simply dynamic" in 157 // [Bik96,Ch9], is also admissible without special env. 158 if (latticeMerger.isSingleCondition(tensor, exp)) 159 return true; 160 161 // Accept "truly dynamic" if the output tensor materializes uninitialized 162 // into the computation and insertions occur in lexicographic index order. 163 sparseOut = lhs; 164 165 // Find the outermost parallel nest to determine whether compress/expand is 166 // needed. 167 outerParNest = 0; 168 const auto iteratorTypes = linalgOp.getIteratorTypesArray(); 169 for (unsigned i = 0, e = getLoopNum(); i < e; i++) { 170 if (linalg::isReductionIterator(iteratorTypes[i])) 171 break; // terminate at first reduction 172 outerParNest++; 173 } 174 175 // Inadmissible kernel should have already been rejected by the previous 176 // path during loop scheduling. 177 assert(static_cast<int64_t>(outerParNest) >= 178 linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1); 179 return isMaterializing(lhs->get()); 180 } 181 182 //===----------------------------------------------------------------------===// 183 // Code generation environment topological sort methods 184 //===----------------------------------------------------------------------===// 185 186 Value CodegenEnv::getLoopVar(LoopId i) const { 187 return loopEmitter.getLoopIV(i); 188 } 189 190 //===----------------------------------------------------------------------===// 191 // Code generation environment sparse tensor output and expansion methods 192 //===----------------------------------------------------------------------===// 193 194 void CodegenEnv::updateInsertionChain(Value chain) { 195 assert(sparseOut != nullptr && insChain != nullptr); 196 insChain = chain; 197 } 198 199 bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const { 200 return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) && 201 outerParNest == n; 202 } 203 204 void CodegenEnv::startExpand(Value values, Value filled, Value added, 205 Value count) { 206 assert(sparseOut != nullptr && expValues == nullptr); 207 expValues = values; 208 expFilled = filled; 209 expAdded = added; 210 expCount = count; 211 } 212 213 void CodegenEnv::updateExpandCount(Value count) { 214 assert(sparseOut != nullptr && expValues != nullptr); 215 expCount = count; 216 } 217 218 void CodegenEnv::endExpand() { 219 assert(sparseOut != nullptr && expValues != nullptr); 220 expValues = expFilled = expAdded = expCount = Value(); 221 } 222 223 //===----------------------------------------------------------------------===// 224 // Code generation environment reduction methods 225 //===----------------------------------------------------------------------===// 226 227 void CodegenEnv::startReduc(ExprId exp, Value val) { 228 assert(!isReduc() && exp != detail::kInvalidId && val); 229 redExp = exp; 230 redVal = val; 231 latticeMerger.setExprValue(exp, val); 232 } 233 234 void CodegenEnv::updateReduc(Value val) { 235 assert(isReduc() && val); 236 redVal = val; 237 latticeMerger.clearExprValue(redExp); 238 latticeMerger.setExprValue(redExp, val); 239 } 240 241 Value CodegenEnv::endReduc() { 242 assert(isReduc()); 243 Value val = redVal; 244 redVal = val; 245 latticeMerger.clearExprValue(redExp); 246 redExp = detail::kInvalidId; 247 return val; 248 } 249 250 void CodegenEnv::startValidLexInsert(Value val) { 251 assert(!isValidLexInsert() && isReduc() && val); 252 redValidLexInsert = val; 253 } 254 255 void CodegenEnv::updateValidLexInsert(Value val) { 256 assert(redValidLexInsert && isReduc() && val); 257 redValidLexInsert = val; 258 } 259 260 void CodegenEnv::endValidLexInsert() { 261 assert(isValidLexInsert() && !isReduc()); 262 redValidLexInsert = Value(); 263 } 264 265 void CodegenEnv::startCustomReduc(ExprId exp) { 266 assert(!isCustomReduc() && exp != detail::kInvalidId); 267 redCustom = exp; 268 } 269 270 Value CodegenEnv::getCustomRedId() const { 271 assert(isCustomReduc()); 272 return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity(); 273 } 274 275 void CodegenEnv::endCustomReduc() { 276 assert(isCustomReduc()); 277 redCustom = detail::kInvalidId; 278 } 279