1 //===- LoopEmitter.h --------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_ 10 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_ 11 12 #include <vector> 13 14 #include "SparseTensorIterator.h" 15 16 #include "mlir/Dialect/SparseTensor/IR/Enums.h" 17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 19 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 20 #include "mlir/IR/PatternMatch.h" 21 22 namespace mlir { 23 namespace sparse_tensor { 24 25 // A compressed <tensor id, level> pair. 26 using TensorLevel = unsigned; 27 28 // 29 // SparseTensorLoopEmiter class, manages sparse tensors and helps to 30 // generate loop structure to (co)-iterate sparse tensors. 31 // 32 // An example usage: 33 // To generate the following loops over T1<?x?> and T2<?x?> 34 // 35 // for i in TENSOR_1_0 { 36 // for j : TENSOR_2_0 { 37 // for k : TENSOR_1_1 {} 38 // for k : TENSOR_2_1 {} 39 // } 40 // } 41 // 42 // One can use 43 // 44 // LoopEmiter loopEmiter({T1, T1}); 45 // loopEmiter.initializeLoopEmit(); 46 // loopEmiter.enterLoopOverTensorAtLvl(T1, 0); 47 // loopEmiter.enterLoopOverTensorAtLvl(T2, 0); 48 // loopEmiter.enterLoopOverTensorAtLvl(T1, 1); 49 // loopEmiter.exitCurrentLoop(); 50 // loopEmiter.enterLoopOverTensorAtLvl(T2, 1); 51 // loopEmiter.exitCurrentLoop(); // exit k 52 // loopEmiter.exitCurrentLoop(); // exit j 53 // loopEmiter.exitCurrentLoop(); // exit i 54 // 55 class LoopEmitter { 56 public: 57 /// Optional callback function to setup dense output tensors when 58 /// initializing the loop emitter (e.g., to fill a dense output with zeros). 59 using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc, 60 Value memref, Value tensor)>; 61 62 /// Optional callback function to set the bound for the synthetic tensor, 63 /// which essentially is the dense loop bound. 64 using SynTensorBoundSetter = 65 function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>; 66 67 // Map from [tid, lvl] to a list of dependent [LoopId, coeffecient] for 68 // subscript expressions on sparse tensors. 69 // 70 // E.g., for affine index (2 * d0 + d1), it depends on loop d0 and d1 (for 71 // affine expression reduction) and uses 2 and 1 for coefficients on d0, d1 72 // respectively. If the list is empty, it means that there is no affine 73 // expression on the input [tid, lvl]. 74 // 75 // NOTE: LoopEmitter assumes that the loop id is consistent with the loop 76 // order, i.e., loop `d0` will be generated before loop `d1`. 77 using DependentLvlGetter = 78 function_ref<std::vector<std::pair<LoopId, unsigned>>(TensorId, Level)>; 79 80 LoopEmitter() = default; 81 82 /// Takes an array of input tensors, which the generated loops will 83 /// iterate over. Each tensor is given a `TensorId` (numerically equal 84 /// to the position of that tensor `Value` in the array). Setting 85 /// `isSparseOut` indicates that the sparse output tensor is empty, 86 /// so the loop emitter will generate loops over it according to the 87 /// level-sizes. 88 void 89 initialize(ValueRange tensors, StringAttr loopTag = nullptr, 90 bool hasOutput = false, bool isSparseOut = false, 91 unsigned numLoops = 0, DependentLvlGetter getter = nullptr, 92 SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional); 93 94 explicit LoopEmitter( 95 ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, 96 bool isSparseOut = false, unsigned numLoops = 0, 97 DependentLvlGetter getter = nullptr, 98 SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional); 99 100 /// Starts a loop emitting session by generating all the buffers needed 101 /// for iterating over the tensors. 102 void initializeLoopEmit(OpBuilder &builder, Location loc, 103 OutputUpdater updater = nullptr, 104 SynTensorBoundSetter synSetter = nullptr); 105 106 /// Generates code to compute an affine expression whose variables are 107 /// `LoopId`s (i.e., `a.cast<AffineDimExpr>().getPosition()` is a valid 108 /// `LoopId`). 109 Value genAffine(OpBuilder &builder, Location loc, AffineExpr a); 110 111 /// Enters a new loop sequence, the loops within the same sequence starts 112 /// from the break points of previous loop instead of starting over from 0. 113 /// e.g., 114 /// { 115 /// // loop sequence start. 116 /// p0 = while(xxx) 117 /// ... 118 /// break p0 119 /// 120 /// // Starts loop from p0 121 /// for (i = p0; i < end; i++) 122 /// ... 123 /// // loop sequence end. 124 /// } 125 void enterNewLoopSeq(OpBuilder &builder, Location loc, 126 ArrayRef<TensorLevel> tidLvls); 127 128 /// Exits the current loop sequence, this will reset universal index to 0. 129 void exitCurrentLoopSeq(OpBuilder &builder, Location loc); 130 131 /// Emits the address for a dense level based on the value evaluated by the 132 /// provided affine expression. 133 void locateLvlAtAffineAddress(OpBuilder &builder, Location loc, 134 TensorLevel tidLvl, AffineExpr lvlExpr); 135 136 // TODO: Get rid of `lvls` in the argument list? Track the level we 137 // are currently at internally. Then it would be enterNextLvlForTensor. 138 // Still need a way to specify the lvl for non-annotated tensors though, 139 // as those can be accessed out of order. 140 // 141 /// Emits a co-iteration loop over a set of tensors. 142 /// Emits loop over tensor_tid_lvl, it assumes that loops between 143 /// tensor_tid_[0, lvl - 1] have already been generated. 144 /// The function will also perform in-place update on the `reduc` vector to 145 /// return the reduction variable used inside the generated loop. 146 Operation *enterCoIterationOverTensorsAtLvls( 147 OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls, 148 unsigned numCases, MutableArrayRef<Value> reduc = {}, 149 bool isParallel = false, bool needsUniv = false); 150 151 Region *enterCurrentCoIterationCase(OpBuilder &builder, Location loc, 152 I64BitSet caseBit, unsigned caseIdx, 153 MutableArrayRef<Value> reduc); 154 155 /// Generates code to exit the current loop (e.g., generates yields, forwards 156 /// loop induction variables, etc). 157 void exitCurrentLoop(RewriterBase &rewriter, Location loc, 158 MutableArrayRef<Value> reduc = {}); 159 160 /// Get the range of values for all induction variables. 161 auto getLoopIVsRange() const { 162 return llvm::map_range(loopStack, [](const LoopInfo &li) { return li.iv; }); 163 } 164 165 /// Fills the out-parameter with the loop induction variables for all 166 /// loops in the current loop-stack. 167 SmallVector<Value> getLoopIVs() const { 168 return llvm::to_vector(getLoopIVsRange()); 169 } 170 171 /// Gets the current depth of the loop-stack. 172 LoopId getCurrentDepth() const { return llvm::range_size(getLoopIVsRange()); } 173 174 /// Gets loop induction variable for the given loop 175 Value getLoopIV(LoopId n) const { 176 if (n >= getCurrentDepth()) 177 return Value(); 178 auto it = getLoopIVsRange().begin(); 179 std::advance(it, n); 180 return *it; 181 } 182 183 /// Gets the total number of manifest tensors (excluding the synthetic 184 /// tensor). 185 unsigned getNumManifestTensors() const { return tensors.size(); } 186 187 /// Gets the total number of tensors that loopEmitter is operating on. 188 unsigned getNumTensors() const { 189 // Manifest tensors with one synthetic tensor at the end. 190 return getNumManifestTensors() + 1; 191 } 192 193 /// Gets the TensorId for synthetic tensor. 194 TensorId getSynTensorId() const { return tensors.size(); } 195 196 /// Gets the TensorId for output tensor. 197 TensorId getOutTensorId() const { 198 assert(hasOutput); 199 return getNumManifestTensors() - 1; 200 } 201 202 /// Compresses a TensorId and Level into a TensorLevel. 203 TensorLevel makeTensorLevel(TensorId t, Level l) const { 204 return l * getNumTensors() + t; 205 } 206 207 /// De-compresses a TensorLevel back to a pair of TensorId and Level. 208 std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const { 209 unsigned nt = getNumTensors(); 210 return std::make_pair(tidLvl % nt, tidLvl / nt); 211 } 212 213 /// Converts a range of TensorLevel to a range of std::pair<TensorId, Level> 214 template <class ContainerTy> 215 auto unpackTensorLevelRange(ContainerTy &&c) const { 216 using EltTy = decltype(*c.begin()); 217 static_assert(std::is_same_v<llvm::remove_cvref_t<EltTy>, TensorLevel>, 218 "Must be unpacking a TensorLevel range"); 219 return llvm::map_range(std::forward<ContainerTy>(c), [this](EltTy tl) { 220 return this->unpackTensorLevel(tl); 221 }); 222 } 223 224 /// 225 /// Getters. 226 /// 227 SmallVector<Value> getValPosits(TensorId tid) const { 228 // Returns the iterator if we are generating sparse (co)iterate-based loops. 229 if (emitStrategy == SparseEmitStrategy::kSparseIterator) 230 return {spIterVals[tid].back()}; 231 232 // Returns {[batch coords], last-level position}. 233 SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds(); 234 Value lastLvlPos = iters[tid].back().back()->getCurPosition().front(); 235 batchCrds.push_back(lastLvlPos); 236 return batchCrds; 237 }; 238 Value getCoord(TensorId tid, Level lvl) const { 239 return getCurIterator(tid, lvl).getCrd(); 240 }; 241 const std::vector<Value> &getValBuffer() const { return valBuffer; }; 242 243 constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() { 244 return llvm::StringLiteral("Emitted from"); 245 } 246 247 private: 248 /// 249 /// Structure definitions that hold different kinds of loops information. 250 /// 251 252 // LoopInfo stores information of a loop generated by LoopEmitter. E.g., 253 // the set of tensors levels that the loop is iterating over. 254 struct LoopInfo final { 255 LoopInfo(ArrayRef<TensorLevel> tidLvls, Operation *loop, Block *userBlock, 256 Value iv, StringAttr loopTag) 257 : tidLvls(tidLvls), loop(loop), userCodeBlock(userBlock), iv(iv) { 258 // Attached a special tag to loop emitter generated loop. 259 if (loopTag) 260 loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag); 261 } 262 // The set of <tensor, lvl>, with *only* trivial index expressions, that are 263 // used as the condition for the generated loop. Extra information is 264 // required for levels with non-tivial index expressions, which is 265 // maintained by the sliceDrivenInfo array below. 266 const llvm::SmallVector<TensorLevel> tidLvls; 267 Operation *loop; // the loop operation 268 Block *const userCodeBlock; // the block holding users' generated code. 269 Value iv; // the induction variable for the loop 270 }; 271 272 void categorizeIterators(ArrayRef<TensorLevel> tidLvls, 273 SmallVectorImpl<SparseIterator *> &raIters, 274 SmallVectorImpl<SparseIterator *> &spIters); 275 /// 276 /// LoopEmitter internal helper functions. 277 /// 278 279 using LoopBodyBuilder = llvm::function_ref<void(OpBuilder &, Location, Value, 280 MutableArrayRef<Value>)>; 281 282 /// Whether the list of the sparse condition should be iterated by for loop. 283 bool shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters); 284 285 /// Generates instructions to compute the coordinate of tensors[tid][lvl] 286 /// under the current loop context. The final argument is the 287 /// collapsed-output level, whereas this function handles converting 288 /// that to the uncollapsed-input level 289 Value genSparseCrd(OpBuilder &builder, Location loc, TensorId tid, 290 Level dstLvl); 291 292 bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); } 293 294 bool isOutputTensor(TensorId tid) const { 295 return hasOutput && tid == getOutTensorId(); 296 } 297 298 bool isSparseOutput(TensorId tid) const { 299 return isOutputTensor(tid) && isSparseOut; 300 } 301 302 bool isValidLevel(TensorId tid, Level lvl) const { 303 return tid < lvls.size() && lvl < lvls[tid].size(); 304 } 305 306 /// Prepares loop for iterating over `tensor[lvl]`, under the assumption 307 /// that `tensor[0...lvl-1]` loops have already been set up. 308 void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc, 309 TensorId tid, Level lvl); 310 311 /// Emits a for loop to iterate over a tensor level with the provided 312 /// lower bound `lo` and upper bound `hi`. Apart from iterating just 313 /// single tensor level, for loops can be used for slice-driven loop on 314 /// dense level too. 315 /// Returns a pair: the loop generated and the value for the induction 316 /// variable. 317 std::pair<Operation *, Value> 318 emitForLoopOverTensorAtLvl(OpBuilder &builder, Location loc, 319 SparseIterator &iter, MutableArrayRef<Value> reduc, 320 bool isParallel); 321 322 /// Emits a while loop to co-iterate over a list of sparse condition, or 323 /// (complex) single sparse condition that can not be handled by for loop 324 /// (e.g., index reduction loop). 325 /// Returns a pair: the loop generated and the value for the induction 326 /// variable (which is the minimum coordinate of all the tensor that being 327 /// iterated). 328 std::pair<Operation *, Value> 329 emitWhileLoopOverTensorsAtLvls(OpBuilder &builder, Location loc, 330 ArrayRef<SparseIterator *> iters, 331 MutableArrayRef<Value> reduc, bool needsUniv); 332 333 /// Exits a for loop, returns the reduction results, e.g., 334 /// For sequential for loops: 335 /// %ret = for () { 336 /// ... 337 /// %val = addi %args, %c 338 /// yield %val 339 /// } 340 /// For parallel loops, the following generated code by users: 341 /// %ret = parallel () init(%args) { 342 /// ... 343 /// %val = op %args, %c 344 /// } 345 /// will be transformed into 346 /// %ret = parallel () init(%args) { 347 /// ... 348 /// scf.reduce(%c) bb0(%0, %1){ 349 /// %val = op %0, %1 350 /// scf.reduce.return %val 351 /// } 352 /// } 353 /// NOTE: only one instruction will be moved into reduce block, 354 /// transformation will fail if multiple instructions are used to compute 355 /// the reduction value. Return %ret to user, while %val is provided by 356 /// users (`reduc`). 357 void exitForLoop(RewriterBase &rewriter, Location loc, 358 MutableArrayRef<Value> reduc); 359 360 /// Exits a while loop, returns the reduction results. 361 void exitWhileLoop(OpBuilder &builder, Location loc, 362 MutableArrayRef<Value> reduc); 363 364 // 365 // Slice-driven loop related methods. 366 // 367 368 void initSubSectIterator(OpBuilder &builder, Location loc); 369 370 /// Get the reduced number of contraints on tensor[tid][lvl]. 371 unsigned redDepOnLevel(TensorId tid, Level lvl) const { 372 return levelReducedDep[tid][lvl]; 373 }; 374 375 SparseIterator &getCurIterator(TensorId tid, Level lvl) const { 376 if (dependentLvlMap[tid][lvl].empty()) 377 return *iters[tid][lvl].back(); 378 379 assert(redDepOnLevel(tid, lvl) >= 1); 380 return *iters[tid][lvl][redDepOnLevel(tid, lvl) - 1]; 381 } 382 383 std::unique_ptr<SparseIterator> 384 makeLevelIterator(OpBuilder &builder, Location loc, TensorId tid, Level l); 385 386 /// A optional string attribute that should be attached to the loop 387 /// generated by loop emitter, it might help following passes to identify 388 /// loops that operates on sparse tensors more easily. 389 StringAttr loopTag; 390 /// Whether the loop emitter needs to treat the last tensor as the output 391 /// tensor. 392 bool hasOutput; 393 bool isSparseOut; 394 SparseEmitStrategy emitStrategy; 395 396 // 397 // Fields which have `numTensor` many entries. 398 // 399 400 /// Input and (optional) output tensors. 401 std::vector<Value> tensors; 402 std::vector<Value> loopHighs; 403 std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls; 404 std::vector<std::vector<std::vector<std::unique_ptr<SparseIterator>>>> iters; 405 std::vector<Value> valBuffer; // to_value 406 407 // Map from [tid, level] to a list of dependent [tidlevel, coefficient]. 408 // See comments for `DependentLvlGetter`. 409 std::vector<std::vector<std::vector<std::pair<LoopId, unsigned>>>> 410 dependentLvlMap; 411 412 // The (size, stride) for each conceptual slice used for index reduction 413 // loops. 414 std::vector<std::vector<std::vector<std::pair<Value, unsigned>>>> sliceMeta; 415 416 // The number of reduced dependencies on a tensor level so far. 417 std::vector<std::vector<unsigned>> levelReducedDep; 418 419 // 420 // Fields which have at most `numLoops` many entries. 421 // 422 423 /// Loop Stack, stores the information of all the nested loops that are 424 /// alive. 425 std::vector<LoopInfo> loopStack; 426 427 // Loop Sequence Stack, stores the universal index for the current loop 428 // sequence. and a list of tid level that the loop sequence traverse. 429 std::vector<std::pair<Value, std::vector<TensorLevel>>> loopSeqStack; 430 431 // 432 // EXPERIMENTAL: 433 // Fields for generating sparse-iterator-based loop. 434 // 435 436 std::vector<std::vector<Value>> spIterVals; 437 }; 438 439 // 440 // Utils functions to generate sparse loops. 441 // 442 443 // Generate a while loop that co-iterates over a set of iterators. 444 std::pair<Operation *, Value> genCoIteration(OpBuilder &builder, Location loc, 445 ArrayRef<SparseIterator *> iters, 446 MutableArrayRef<Value> reduc, 447 Value uniIdx, 448 bool userReducFirst = false); 449 450 } // namespace sparse_tensor 451 } // namespace mlir 452 453 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_LOOPEMITTER_H_ 454