1 //===- Passes.h - Sparse tensor pass entry points ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines prototypes of all sparse tensor passes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_ 15 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "mlir/Transforms/OneToNTypeConversion.h" 20 21 //===----------------------------------------------------------------------===// 22 // Include the generated pass header (which needs some early definitions). 23 //===----------------------------------------------------------------------===// 24 25 namespace mlir { 26 27 namespace bufferization { 28 struct OneShotBufferizationOptions; 29 } // namespace bufferization 30 31 /// Defines a parallelization strategy. Any independent loop is a candidate 32 /// for parallelization. The loop is made parallel if (1) allowed by the 33 /// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse 34 /// outermost loop only), and (2) the generated code is an actual for-loop 35 /// (and not a co-iterating while-loop). 36 enum class SparseParallelizationStrategy { 37 kNone, 38 kDenseOuterLoop, 39 kAnyStorageOuterLoop, 40 kDenseAnyLoop, 41 kAnyStorageAnyLoop 42 }; 43 44 /// Defines a scope for reinterpret map pass. 45 enum class ReinterpretMapScope { 46 kAll, // reinterprets all applicable operations 47 kGenericOnly, // reinterprets only linalg.generic 48 kExceptGeneric, // reinterprets operation other than linalg.generic 49 }; 50 51 /// Defines a scope for reinterpret map pass. 52 enum class SparseEmitStrategy { 53 kFunctional, // generate fully inlined (and functional) sparse iteration 54 kSparseIterator, // generate (experimental) loop using sparse iterator. 55 kDebugInterface, // generate only place-holder for sparse iteration 56 }; 57 58 #define GEN_PASS_DECL 59 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 60 61 //===----------------------------------------------------------------------===// 62 // The SparseAssembler pass. 63 //===----------------------------------------------------------------------===// 64 65 void populateSparseAssembler(RewritePatternSet &patterns, bool directOut); 66 67 std::unique_ptr<Pass> createSparseAssembler(); 68 std::unique_ptr<Pass> createSparseAssembler(bool directOut); 69 70 //===----------------------------------------------------------------------===// 71 // The SparseReinterpretMap pass. 72 //===----------------------------------------------------------------------===// 73 74 void populateSparseReinterpretMap(RewritePatternSet &patterns, 75 ReinterpretMapScope scope); 76 77 std::unique_ptr<Pass> createSparseReinterpretMapPass(); 78 std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope); 79 80 //===----------------------------------------------------------------------===// 81 // The PreSparsificationRewriting pass. 82 //===----------------------------------------------------------------------===// 83 84 void populatePreSparsificationRewriting(RewritePatternSet &patterns); 85 86 std::unique_ptr<Pass> createPreSparsificationRewritePass(); 87 88 //===----------------------------------------------------------------------===// 89 // The Sparsification pass. 90 //===----------------------------------------------------------------------===// 91 92 /// Options for the Sparsification pass. 93 struct SparsificationOptions { 94 SparsificationOptions(SparseParallelizationStrategy p, SparseEmitStrategy d, 95 bool enableRT) 96 : parallelizationStrategy(p), sparseEmitStrategy(d), 97 enableRuntimeLibrary(enableRT) {} 98 99 SparsificationOptions(SparseParallelizationStrategy p, bool enableRT) 100 : SparsificationOptions(p, SparseEmitStrategy::kFunctional, enableRT) {} 101 102 SparsificationOptions() 103 : SparsificationOptions(SparseParallelizationStrategy::kNone, 104 SparseEmitStrategy::kFunctional, true) {} 105 106 SparseParallelizationStrategy parallelizationStrategy; 107 SparseEmitStrategy sparseEmitStrategy; 108 bool enableRuntimeLibrary; 109 }; 110 111 /// Sets up sparsification rewriting rules with the given options. 112 void populateSparsificationPatterns( 113 RewritePatternSet &patterns, 114 const SparsificationOptions &options = SparsificationOptions()); 115 116 std::unique_ptr<Pass> createSparsificationPass(); 117 std::unique_ptr<Pass> 118 createSparsificationPass(const SparsificationOptions &options); 119 120 //===----------------------------------------------------------------------===// 121 // The StageSparseOperations pass. 122 //===----------------------------------------------------------------------===// 123 124 /// Sets up StageSparseOperation rewriting rules. 125 void populateStageSparseOperationsPatterns(RewritePatternSet &patterns); 126 127 std::unique_ptr<Pass> createStageSparseOperationsPass(); 128 129 //===----------------------------------------------------------------------===// 130 // The LowerSparseOpsToForeach pass. 131 //===----------------------------------------------------------------------===// 132 133 void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, 134 bool enableRT, bool enableConvert); 135 136 std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(); 137 std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT, 138 bool enableConvert); 139 140 //===----------------------------------------------------------------------===// 141 // The LowerForeachToSCF pass. 142 //===----------------------------------------------------------------------===// 143 144 void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns); 145 146 std::unique_ptr<Pass> createLowerForeachToSCFPass(); 147 148 //===----------------------------------------------------------------------===// 149 // The LowerSparseIterationToSCF pass. 150 //===----------------------------------------------------------------------===// 151 152 /// Type converter for iter_space and iterator. 153 struct SparseIterationTypeConverter : public TypeConverter { 154 SparseIterationTypeConverter(); 155 }; 156 157 void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, 158 RewritePatternSet &patterns); 159 160 std::unique_ptr<Pass> createLowerSparseIterationToSCFPass(); 161 162 //===----------------------------------------------------------------------===// 163 // The SparseTensorConversion pass. 164 //===----------------------------------------------------------------------===// 165 166 /// Sparse tensor type converter into an opaque pointer. 167 class SparseTensorTypeToPtrConverter : public TypeConverter { 168 public: 169 SparseTensorTypeToPtrConverter(); 170 }; 171 172 /// Sets up sparse tensor conversion rules. 173 void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter, 174 RewritePatternSet &patterns); 175 176 std::unique_ptr<Pass> createSparseTensorConversionPass(); 177 178 //===----------------------------------------------------------------------===// 179 // The SparseTensorCodegen pass. 180 //===----------------------------------------------------------------------===// 181 182 /// Sparse tensor type converter into an actual buffer. 183 class SparseTensorTypeToBufferConverter : public TypeConverter { 184 public: 185 SparseTensorTypeToBufferConverter(); 186 }; 187 188 /// Sets up sparse tensor codegen rules. 189 void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter, 190 RewritePatternSet &patterns, 191 bool createSparseDeallocs, 192 bool enableBufferInitialization); 193 194 std::unique_ptr<Pass> createSparseTensorCodegenPass(); 195 std::unique_ptr<Pass> 196 createSparseTensorCodegenPass(bool createSparseDeallocs, 197 bool enableBufferInitialization); 198 199 //===----------------------------------------------------------------------===// 200 // The SparseBufferRewrite pass. 201 //===----------------------------------------------------------------------===// 202 203 void populateSparseBufferRewriting(RewritePatternSet &patterns, 204 bool enableBufferInitialization); 205 206 std::unique_ptr<Pass> createSparseBufferRewritePass(); 207 std::unique_ptr<Pass> 208 createSparseBufferRewritePass(bool enableBufferInitialization); 209 210 //===----------------------------------------------------------------------===// 211 // The SparseVectorization pass. 212 //===----------------------------------------------------------------------===// 213 214 void populateSparseVectorizationPatterns(RewritePatternSet &patterns, 215 unsigned vectorLength, 216 bool enableVLAVectorization, 217 bool enableSIMDIndex32); 218 219 std::unique_ptr<Pass> createSparseVectorizationPass(); 220 std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength, 221 bool enableVLAVectorization, 222 bool enableSIMDIndex32); 223 224 //===----------------------------------------------------------------------===// 225 // The SparseGPU pass. 226 //===----------------------------------------------------------------------===// 227 228 void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns, 229 unsigned numThreads); 230 231 void populateSparseGPULibgenPatterns(RewritePatternSet &patterns, 232 bool enableRT); 233 234 std::unique_ptr<Pass> createSparseGPUCodegenPass(); 235 std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads, 236 bool enableRT); 237 238 //===----------------------------------------------------------------------===// 239 // The SparseStorageSpecifierToLLVM pass. 240 //===----------------------------------------------------------------------===// 241 242 class StorageSpecifierToLLVMTypeConverter : public TypeConverter { 243 public: 244 StorageSpecifierToLLVMTypeConverter(); 245 }; 246 247 void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter, 248 RewritePatternSet &patterns); 249 std::unique_ptr<Pass> createStorageSpecifierToLLVMPass(); 250 251 //===----------------------------------------------------------------------===// 252 // The mini-pipeline for sparsification and bufferization. 253 //===----------------------------------------------------------------------===// 254 255 bufferization::OneShotBufferizationOptions 256 getBufferizationOptionsForSparsification(bool analysisOnly); 257 258 std::unique_ptr<Pass> createSparsificationAndBufferizationPass(); 259 260 std::unique_ptr<Pass> createSparsificationAndBufferizationPass( 261 const bufferization::OneShotBufferizationOptions &bufferizationOptions, 262 const SparsificationOptions &sparsificationOptions, 263 bool createSparseDeallocs, bool enableRuntimeLibrary, 264 bool enableBufferInitialization, unsigned vectorLength, 265 bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen, 266 SparseEmitStrategy emitStrategy, 267 SparseParallelizationStrategy parallelizationStrategy); 268 269 //===----------------------------------------------------------------------===// 270 // Sparse Iteration Transform Passes 271 //===----------------------------------------------------------------------===// 272 273 std::unique_ptr<Pass> createSparseSpaceCollapsePass(); 274 275 //===----------------------------------------------------------------------===// 276 // Registration. 277 //===----------------------------------------------------------------------===// 278 279 /// Generate the code for registering passes. 280 #define GEN_PASS_REGISTRATION 281 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 282 283 } // namespace mlir 284 285 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_ 286