xref: /llvm-project/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (revision 8c4bc1e75de27adfbaead34b895b0efbaf17bd02)
1 //===- Passes.h - Sparse tensor pass entry points ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes of all sparse tensor passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
15 
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "mlir/Transforms/OneToNTypeConversion.h"
20 
21 //===----------------------------------------------------------------------===//
22 // Include the generated pass header (which needs some early definitions).
23 //===----------------------------------------------------------------------===//
24 
25 namespace mlir {
26 
27 namespace bufferization {
28 struct OneShotBufferizationOptions;
29 } // namespace bufferization
30 
31 /// Defines a parallelization strategy. Any independent loop is a candidate
32 /// for parallelization. The loop is made parallel if (1) allowed by the
33 /// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
34 /// outermost loop only), and (2) the generated code is an actual for-loop
35 /// (and not a co-iterating while-loop).
36 enum class SparseParallelizationStrategy {
37   kNone,
38   kDenseOuterLoop,
39   kAnyStorageOuterLoop,
40   kDenseAnyLoop,
41   kAnyStorageAnyLoop
42 };
43 
44 /// Defines a scope for reinterpret map pass.
45 enum class ReinterpretMapScope {
46   kAll,           // reinterprets all applicable operations
47   kGenericOnly,   // reinterprets only linalg.generic
48   kExceptGeneric, // reinterprets operation other than linalg.generic
49 };
50 
51 /// Defines a scope for reinterpret map pass.
52 enum class SparseEmitStrategy {
53   kFunctional,     // generate fully inlined (and functional) sparse iteration
54   kSparseIterator, // generate (experimental) loop using sparse iterator.
55   kDebugInterface, // generate only place-holder for sparse iteration
56 };
57 
58 #define GEN_PASS_DECL
59 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
60 
61 //===----------------------------------------------------------------------===//
62 // The SparseAssembler pass.
63 //===----------------------------------------------------------------------===//
64 
65 void populateSparseAssembler(RewritePatternSet &patterns, bool directOut);
66 
67 std::unique_ptr<Pass> createSparseAssembler();
68 std::unique_ptr<Pass> createSparseAssembler(bool directOut);
69 
70 //===----------------------------------------------------------------------===//
71 // The SparseReinterpretMap pass.
72 //===----------------------------------------------------------------------===//
73 
74 void populateSparseReinterpretMap(RewritePatternSet &patterns,
75                                   ReinterpretMapScope scope);
76 
77 std::unique_ptr<Pass> createSparseReinterpretMapPass();
78 std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
79 
80 //===----------------------------------------------------------------------===//
81 // The PreSparsificationRewriting pass.
82 //===----------------------------------------------------------------------===//
83 
84 void populatePreSparsificationRewriting(RewritePatternSet &patterns);
85 
86 std::unique_ptr<Pass> createPreSparsificationRewritePass();
87 
88 //===----------------------------------------------------------------------===//
89 // The Sparsification pass.
90 //===----------------------------------------------------------------------===//
91 
92 /// Options for the Sparsification pass.
93 struct SparsificationOptions {
94   SparsificationOptions(SparseParallelizationStrategy p, SparseEmitStrategy d,
95                         bool enableRT)
96       : parallelizationStrategy(p), sparseEmitStrategy(d),
97         enableRuntimeLibrary(enableRT) {}
98 
99   SparsificationOptions(SparseParallelizationStrategy p, bool enableRT)
100       : SparsificationOptions(p, SparseEmitStrategy::kFunctional, enableRT) {}
101 
102   SparsificationOptions()
103       : SparsificationOptions(SparseParallelizationStrategy::kNone,
104                               SparseEmitStrategy::kFunctional, true) {}
105 
106   SparseParallelizationStrategy parallelizationStrategy;
107   SparseEmitStrategy sparseEmitStrategy;
108   bool enableRuntimeLibrary;
109 };
110 
111 /// Sets up sparsification rewriting rules with the given options.
112 void populateSparsificationPatterns(
113     RewritePatternSet &patterns,
114     const SparsificationOptions &options = SparsificationOptions());
115 
116 std::unique_ptr<Pass> createSparsificationPass();
117 std::unique_ptr<Pass>
118 createSparsificationPass(const SparsificationOptions &options);
119 
120 //===----------------------------------------------------------------------===//
121 // The StageSparseOperations pass.
122 //===----------------------------------------------------------------------===//
123 
124 /// Sets up StageSparseOperation rewriting rules.
125 void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
126 
127 std::unique_ptr<Pass> createStageSparseOperationsPass();
128 
129 //===----------------------------------------------------------------------===//
130 // The LowerSparseOpsToForeach pass.
131 //===----------------------------------------------------------------------===//
132 
133 void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
134                                              bool enableRT, bool enableConvert);
135 
136 std::unique_ptr<Pass> createLowerSparseOpsToForeachPass();
137 std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT,
138                                                         bool enableConvert);
139 
140 //===----------------------------------------------------------------------===//
141 // The LowerForeachToSCF pass.
142 //===----------------------------------------------------------------------===//
143 
144 void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
145 
146 std::unique_ptr<Pass> createLowerForeachToSCFPass();
147 
148 //===----------------------------------------------------------------------===//
149 // The LowerSparseIterationToSCF pass.
150 //===----------------------------------------------------------------------===//
151 
152 /// Type converter for iter_space and iterator.
153 struct SparseIterationTypeConverter : public TypeConverter {
154   SparseIterationTypeConverter();
155 };
156 
157 void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter,
158                                                RewritePatternSet &patterns);
159 
160 std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
161 
162 //===----------------------------------------------------------------------===//
163 // The SparseTensorConversion pass.
164 //===----------------------------------------------------------------------===//
165 
166 /// Sparse tensor type converter into an opaque pointer.
167 class SparseTensorTypeToPtrConverter : public TypeConverter {
168 public:
169   SparseTensorTypeToPtrConverter();
170 };
171 
172 /// Sets up sparse tensor conversion rules.
173 void populateSparseTensorConversionPatterns(const TypeConverter &typeConverter,
174                                             RewritePatternSet &patterns);
175 
176 std::unique_ptr<Pass> createSparseTensorConversionPass();
177 
178 //===----------------------------------------------------------------------===//
179 // The SparseTensorCodegen pass.
180 //===----------------------------------------------------------------------===//
181 
182 /// Sparse tensor type converter into an actual buffer.
183 class SparseTensorTypeToBufferConverter : public TypeConverter {
184 public:
185   SparseTensorTypeToBufferConverter();
186 };
187 
188 /// Sets up sparse tensor codegen rules.
189 void populateSparseTensorCodegenPatterns(const TypeConverter &typeConverter,
190                                          RewritePatternSet &patterns,
191                                          bool createSparseDeallocs,
192                                          bool enableBufferInitialization);
193 
194 std::unique_ptr<Pass> createSparseTensorCodegenPass();
195 std::unique_ptr<Pass>
196 createSparseTensorCodegenPass(bool createSparseDeallocs,
197                               bool enableBufferInitialization);
198 
199 //===----------------------------------------------------------------------===//
200 // The SparseBufferRewrite pass.
201 //===----------------------------------------------------------------------===//
202 
203 void populateSparseBufferRewriting(RewritePatternSet &patterns,
204                                    bool enableBufferInitialization);
205 
206 std::unique_ptr<Pass> createSparseBufferRewritePass();
207 std::unique_ptr<Pass>
208 createSparseBufferRewritePass(bool enableBufferInitialization);
209 
210 //===----------------------------------------------------------------------===//
211 // The SparseVectorization pass.
212 //===----------------------------------------------------------------------===//
213 
214 void populateSparseVectorizationPatterns(RewritePatternSet &patterns,
215                                          unsigned vectorLength,
216                                          bool enableVLAVectorization,
217                                          bool enableSIMDIndex32);
218 
219 std::unique_ptr<Pass> createSparseVectorizationPass();
220 std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
221                                                     bool enableVLAVectorization,
222                                                     bool enableSIMDIndex32);
223 
224 //===----------------------------------------------------------------------===//
225 // The SparseGPU pass.
226 //===----------------------------------------------------------------------===//
227 
228 void populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
229                                       unsigned numThreads);
230 
231 void populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
232                                      bool enableRT);
233 
234 std::unique_ptr<Pass> createSparseGPUCodegenPass();
235 std::unique_ptr<Pass> createSparseGPUCodegenPass(unsigned numThreads,
236                                                  bool enableRT);
237 
238 //===----------------------------------------------------------------------===//
239 // The SparseStorageSpecifierToLLVM pass.
240 //===----------------------------------------------------------------------===//
241 
242 class StorageSpecifierToLLVMTypeConverter : public TypeConverter {
243 public:
244   StorageSpecifierToLLVMTypeConverter();
245 };
246 
247 void populateStorageSpecifierToLLVMPatterns(const TypeConverter &converter,
248                                             RewritePatternSet &patterns);
249 std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
250 
251 //===----------------------------------------------------------------------===//
252 // The mini-pipeline for sparsification and bufferization.
253 //===----------------------------------------------------------------------===//
254 
255 bufferization::OneShotBufferizationOptions
256 getBufferizationOptionsForSparsification(bool analysisOnly);
257 
258 std::unique_ptr<Pass> createSparsificationAndBufferizationPass();
259 
260 std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
261     const bufferization::OneShotBufferizationOptions &bufferizationOptions,
262     const SparsificationOptions &sparsificationOptions,
263     bool createSparseDeallocs, bool enableRuntimeLibrary,
264     bool enableBufferInitialization, unsigned vectorLength,
265     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
266     SparseEmitStrategy emitStrategy,
267     SparseParallelizationStrategy parallelizationStrategy);
268 
269 //===----------------------------------------------------------------------===//
270 // Sparse Iteration Transform Passes
271 //===----------------------------------------------------------------------===//
272 
273 std::unique_ptr<Pass> createSparseSpaceCollapsePass();
274 
275 //===----------------------------------------------------------------------===//
276 // Registration.
277 //===----------------------------------------------------------------------===//
278 
279 /// Generate the code for registering passes.
280 #define GEN_PASS_REGISTRATION
281 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
282 
283 } // namespace mlir
284 
285 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
286