1 //===- CodegenEnv.h - Code generation environment class ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines the code generation environment class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_ 15 16 #include "CodegenUtils.h" 17 #include "LoopEmitter.h" 18 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 22 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 23 #include <optional> 24 25 namespace mlir { 26 namespace sparse_tensor { 27 28 /// The code generation environment class aggregates a number of data 29 /// structures that are needed during the code generation phase of 30 /// sparsification. This environment simplifies passing around such 31 /// data during sparsification (rather than passing around all the 32 /// individual compoments where needed). Furthermore, it provides 33 /// convience methods that keep implementation details transparent 34 /// to sparsification while asserting on internal consistency. 35 class CodegenEnv { 36 public: 37 /// Constructs a code generation environment which can be 38 /// passed around during sparsification for bookkeeping 39 /// together with some consistency asserts. 40 CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, 41 unsigned numTensors, unsigned numLoops, unsigned maxRank); 42 43 // 44 // General methods. 45 // 46 47 LogicalResult initTensorExp(); 48 ExprId getExprId() const { return tensorExp; } 49 50 linalg::GenericOp op() const { return linalgOp; } 51 const SparsificationOptions &options() const { return sparseOptions; } 52 bool generatingSparseIterator() const { 53 return sparseOptions.sparseEmitStrategy == 54 SparseEmitStrategy::kSparseIterator; 55 } 56 Merger &merger() { return latticeMerger; } 57 LoopEmitter &emitter() { return loopEmitter; } 58 59 void startEmit(SparseEmitStrategy emitStrategy); 60 61 /// Generates loop boundary statements (entering/exiting loops). The function 62 /// passes and updates the passed-in parameters. 63 std::optional<Operation *> 64 genLoopBoundary(function_ref< 65 std::optional<Operation *>(MutableArrayRef<Value> parameters)> 66 callback); 67 68 // 69 // Merger delegates. 70 // 71 72 constexpr TensorId makeTensorId(unsigned t) const { 73 return latticeMerger.makeTensorId(t); 74 } 75 constexpr LoopId makeLoopId(unsigned i) const { 76 return latticeMerger.makeLoopId(i); 77 } 78 constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const { 79 return latticeMerger.makeTensorLoopId(t, i); 80 } 81 const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); } 82 const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); } 83 ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); } 84 LevelType lt(TensorId t, LoopId i) const { 85 return latticeMerger.getLvlType(t, i); 86 } 87 LevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); } 88 89 unsigned getLoopNum() const { return latticeMerger.getNumLoops(); } 90 91 // 92 // LoopEmitter delegates. 93 // 94 95 TensorLevel makeTensorLevel(TensorId t, Level l) const { 96 // Make sure LoopEmitter, GenericOp, and Merger agree on the number of 97 // tensors. 98 assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() && 99 loopEmitter.getNumTensors() == latticeMerger.getNumTensors() && 100 loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() && 101 loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID()); 102 return loopEmitter.makeTensorLevel(t, l); 103 } 104 TensorLevel makeTensorLevel(std::pair<TensorId, Level> tlPair) const { 105 return makeTensorLevel(tlPair.first, tlPair.second); 106 } 107 std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const { 108 return loopEmitter.unpackTensorLevel(tl); 109 } 110 template <class ContainerTy> 111 auto unpackTensorLevelRange(ContainerTy &&c) const { 112 return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c)); 113 } 114 115 unsigned getCurrentDepth() const { return loopEmitter.getCurrentDepth(); } 116 117 // 118 // Code generation environment verify functions. 119 // 120 121 /// Whether the tensor expression is admissible for codegen. 122 /// It also sets the sparseOut if the output tensor is sparse. 123 bool isAdmissibleTensorExp(ExprId e); 124 125 /// Returns the induction-variable for the given loop. 126 Value getLoopVar(LoopId i) const; 127 128 // 129 // Sparse tensor output and expansion methods. 130 // 131 132 bool hasSparseOutput() const { return sparseOut != nullptr; } 133 bool isSparseOutput(OpOperand *o) const { return sparseOut == o; } 134 135 Value getInsertionChain() const { return insChain; } 136 void updateInsertionChain(Value chain); 137 138 bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const; 139 void startExpand(Value values, Value filled, Value added, Value count); 140 bool isExpand() const { return expValues != nullptr; } 141 void updateExpandCount(Value count); 142 Value getExpandValues() const { return expValues; } 143 Value getExpandFilled() const { return expFilled; } 144 Value getExpandAdded() const { return expAdded; } 145 Value getExpandCount() const { return expCount; } 146 void endExpand(); 147 148 // 149 // Reduction methods. 150 // 151 152 void startReduc(ExprId exp, Value val); 153 bool isReduc() const { return redExp != detail::kInvalidId; } 154 void updateReduc(Value val); 155 Value getReduc() const { return redVal; } 156 Value endReduc(); 157 158 void startValidLexInsert(Value val); 159 bool isValidLexInsert() const { return redValidLexInsert != nullptr; } 160 void updateValidLexInsert(Value val); 161 Value getValidLexInsert() const { return redValidLexInsert; } 162 void endValidLexInsert(); 163 164 void startCustomReduc(ExprId exp); 165 bool isCustomReduc() const { return redCustom != detail::kInvalidId; } 166 Value getCustomRedId() const; 167 void endCustomReduc(); 168 169 private: 170 // Linalg operation. 171 linalg::GenericOp linalgOp; 172 173 // Sparsification options. 174 SparsificationOptions sparseOptions; 175 176 // Merger helper class. 177 Merger latticeMerger; 178 179 // Loop emitter helper class. 180 LoopEmitter loopEmitter; 181 182 // Sparse tensor as output. Implemented either through direct injective 183 // insertion in lexicographic index order or through access pattern 184 // expansion in the innermost loop nest (`expValues` through `expCount`). 185 OpOperand *sparseOut; 186 // The count of outer non-filter loops, as defined by `isAdmissibleTopoOrder`. 187 LoopId outerParNest; 188 Value insChain; 189 Value expValues; 190 Value expFilled; 191 Value expAdded; 192 Value expCount; 193 194 // Bookkeeping for reductions (up-to-date value of the reduction, and indices 195 // into the merger's expression tree. When the indices of a tensor reduction 196 // expression are exhausted, all inner loops can use a scalarized reduction. 197 Value redVal; 198 ExprId redExp; 199 ExprId redCustom; 200 201 // Bookkeeping for lex insertion during reductions. Holds the runtime boolean 202 // value of whether any reduction occurred. This is only set during a 203 // reduction and cleared once the reduction is finished. 204 Value redValidLexInsert; 205 206 // The root tensor expression of the kernel. 207 ExprId tensorExp; 208 }; 209 210 } // namespace sparse_tensor 211 } // namespace mlir 212 213 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_ 214