xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (revision 2b5b3cf60d9e9e0c597bad1be1207b167ef15c9f)
1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Complex/IR/Complex.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_SPARSEASSEMBLER
26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
28 #define GEN_PASS_DEF_SPARSIFICATIONPASS
29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF
32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
35 #define GEN_PASS_DEF_SPARSEVECTORIZATION
36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN
37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
40 } // namespace mlir
41 
42 using namespace mlir;
43 using namespace mlir::sparse_tensor;
44 
45 namespace {
46 
47 //===----------------------------------------------------------------------===//
48 // Passes implementation.
49 //===----------------------------------------------------------------------===//
50 
51 struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
52   SparseAssembler() = default;
53   SparseAssembler(const SparseAssembler &pass) = default;
54   SparseAssembler(bool dO) { directOut = dO; }
55 
56   void runOnOperation() override {
57     auto *ctx = &getContext();
58     RewritePatternSet patterns(ctx);
59     populateSparseAssembler(patterns, directOut);
60     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
61   }
62 };
63 
64 struct SparseReinterpretMap
65     : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
66   SparseReinterpretMap() = default;
67   SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
68   SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
69     scope = options.scope;
70   }
71 
72   void runOnOperation() override {
73     auto *ctx = &getContext();
74     RewritePatternSet patterns(ctx);
75     populateSparseReinterpretMap(patterns, scope);
76     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
77   }
78 };
79 
80 struct PreSparsificationRewritePass
81     : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
82   PreSparsificationRewritePass() = default;
83   PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
84       default;
85 
86   void runOnOperation() override {
87     auto *ctx = &getContext();
88     RewritePatternSet patterns(ctx);
89     populatePreSparsificationRewriting(patterns);
90     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
91   }
92 };
93 
94 struct SparsificationPass
95     : public impl::SparsificationPassBase<SparsificationPass> {
96   SparsificationPass() = default;
97   SparsificationPass(const SparsificationPass &pass) = default;
98   SparsificationPass(const SparsificationOptions &options) {
99     parallelization = options.parallelizationStrategy;
100     sparseEmitStrategy = options.sparseEmitStrategy;
101     enableRuntimeLibrary = options.enableRuntimeLibrary;
102   }
103 
104   void runOnOperation() override {
105     auto *ctx = &getContext();
106     // Translate strategy flags to strategy options.
107     SparsificationOptions options(parallelization, sparseEmitStrategy,
108                                   enableRuntimeLibrary);
109     // Apply sparsification and cleanup rewriting.
110     RewritePatternSet patterns(ctx);
111     populateSparsificationPatterns(patterns, options);
112     scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
113     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
114   }
115 };
116 
117 struct StageSparseOperationsPass
118     : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
119   StageSparseOperationsPass() = default;
120   StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
121   void runOnOperation() override {
122     auto *ctx = &getContext();
123     RewritePatternSet patterns(ctx);
124     populateStageSparseOperationsPatterns(patterns);
125     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
126   }
127 };
128 
129 struct LowerSparseOpsToForeachPass
130     : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
131   LowerSparseOpsToForeachPass() = default;
132   LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
133       default;
134   LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
135     enableRuntimeLibrary = enableRT;
136     enableConvert = convert;
137   }
138 
139   void runOnOperation() override {
140     auto *ctx = &getContext();
141     RewritePatternSet patterns(ctx);
142     populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
143                                             enableConvert);
144     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
145   }
146 };
147 
148 struct LowerForeachToSCFPass
149     : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
150   LowerForeachToSCFPass() = default;
151   LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
152 
153   void runOnOperation() override {
154     auto *ctx = &getContext();
155     RewritePatternSet patterns(ctx);
156     populateLowerForeachToSCFPatterns(patterns);
157     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
158   }
159 };
160 
161 struct LowerSparseIterationToSCFPass
162     : public impl::LowerSparseIterationToSCFBase<
163           LowerSparseIterationToSCFPass> {
164   LowerSparseIterationToSCFPass() = default;
165   LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
166       default;
167 
168   void runOnOperation() override {
169     auto *ctx = &getContext();
170     RewritePatternSet patterns(ctx);
171     SparseIterationTypeConverter converter;
172     ConversionTarget target(*ctx);
173 
174     // The actual conversion.
175     target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
176                            memref::MemRefDialect, scf::SCFDialect,
177                            sparse_tensor::SparseTensorDialect>();
178     target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
179                         IterateOp>();
180     target.addLegalOp<UnrealizedConversionCastOp>();
181     populateLowerSparseIterationToSCFPatterns(converter, patterns);
182 
183     if (failed(applyPartialConversion(getOperation(), target,
184                                       std::move(patterns))))
185       signalPassFailure();
186   }
187 };
188 
189 struct SparseTensorConversionPass
190     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
191   SparseTensorConversionPass() = default;
192   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
193 
194   void runOnOperation() override {
195     auto *ctx = &getContext();
196     RewritePatternSet patterns(ctx);
197     SparseTensorTypeToPtrConverter converter;
198     ConversionTarget target(*ctx);
199     // Everything in the sparse dialect must go!
200     target.addIllegalDialect<SparseTensorDialect>();
201     // All dynamic rules below accept new function, call, return, and various
202     // tensor and bufferization operations as legal output of the rewriting
203     // provided that all sparse tensor types have been fully rewritten.
204     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
205       return converter.isSignatureLegal(op.getFunctionType());
206     });
207     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
208       return converter.isSignatureLegal(op.getCalleeType());
209     });
210     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
211       return converter.isLegal(op.getOperandTypes());
212     });
213     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
214       return converter.isLegal(op.getOperandTypes());
215     });
216     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
217       return converter.isLegal(op.getSource().getType()) &&
218              converter.isLegal(op.getDest().getType());
219     });
220     target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
221         [&](tensor::ExpandShapeOp op) {
222           return converter.isLegal(op.getSrc().getType()) &&
223                  converter.isLegal(op.getResult().getType());
224         });
225     target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
226         [&](tensor::CollapseShapeOp op) {
227           return converter.isLegal(op.getSrc().getType()) &&
228                  converter.isLegal(op.getResult().getType());
229         });
230     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
231         [&](bufferization::AllocTensorOp op) {
232           return converter.isLegal(op.getType());
233         });
234     target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
235         [&](bufferization::DeallocTensorOp op) {
236           return converter.isLegal(op.getTensor().getType());
237         });
238     // The following operations and dialects may be introduced by the
239     // rewriting rules, and are therefore marked as legal.
240     target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
241                       linalg::YieldOp, tensor::ExtractOp,
242                       tensor::FromElementsOp>();
243     target.addLegalDialect<
244         arith::ArithDialect, bufferization::BufferizationDialect,
245         LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
246 
247     // Populate with rules and apply rewriting rules.
248     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
249                                                                    converter);
250     populateCallOpTypeConversionPattern(patterns, converter);
251     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
252                                                          target);
253     populateSparseTensorConversionPatterns(converter, patterns);
254     if (failed(applyPartialConversion(getOperation(), target,
255                                       std::move(patterns))))
256       signalPassFailure();
257   }
258 };
259 
260 struct SparseTensorCodegenPass
261     : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
262   SparseTensorCodegenPass() = default;
263   SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
264   SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
265     createSparseDeallocs = createDeallocs;
266     enableBufferInitialization = enableInit;
267   }
268 
269   void runOnOperation() override {
270     auto *ctx = &getContext();
271     RewritePatternSet patterns(ctx);
272     SparseTensorTypeToBufferConverter converter;
273     ConversionTarget target(*ctx);
274     // Most ops in the sparse dialect must go!
275     target.addIllegalDialect<SparseTensorDialect>();
276     target.addLegalOp<SortOp>();
277     target.addLegalOp<PushBackOp>();
278     // Storage specifier outlives sparse tensor pipeline.
279     target.addLegalOp<GetStorageSpecifierOp>();
280     target.addLegalOp<SetStorageSpecifierOp>();
281     target.addLegalOp<StorageSpecifierInitOp>();
282     // Note that tensor::FromElementsOp might be yield after lowering unpack.
283     target.addLegalOp<tensor::FromElementsOp>();
284     // All dynamic rules below accept new function, call, return, and
285     // various tensor and bufferization operations as legal output of the
286     // rewriting provided that all sparse tensor types have been fully
287     // rewritten.
288     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
289       return converter.isSignatureLegal(op.getFunctionType());
290     });
291     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
292       return converter.isSignatureLegal(op.getCalleeType());
293     });
294     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
295       return converter.isLegal(op.getOperandTypes());
296     });
297     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
298         [&](bufferization::AllocTensorOp op) {
299           return converter.isLegal(op.getType());
300         });
301     target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
302         [&](bufferization::DeallocTensorOp op) {
303           return converter.isLegal(op.getTensor().getType());
304         });
305     // The following operations and dialects may be introduced by the
306     // codegen rules, and are therefore marked as legal.
307     target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
308     target.addLegalDialect<
309         arith::ArithDialect, bufferization::BufferizationDialect,
310         complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
311     target.addLegalOp<UnrealizedConversionCastOp>();
312     // Populate with rules and apply rewriting rules.
313     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
314                                                                    converter);
315     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
316                                                          target);
317     populateSparseTensorCodegenPatterns(
318         converter, patterns, createSparseDeallocs, enableBufferInitialization);
319     if (failed(applyPartialConversion(getOperation(), target,
320                                       std::move(patterns))))
321       signalPassFailure();
322   }
323 };
324 
325 struct SparseBufferRewritePass
326     : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
327   SparseBufferRewritePass() = default;
328   SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
329   SparseBufferRewritePass(bool enableInit) {
330     enableBufferInitialization = enableInit;
331   }
332 
333   void runOnOperation() override {
334     auto *ctx = &getContext();
335     RewritePatternSet patterns(ctx);
336     populateSparseBufferRewriting(patterns, enableBufferInitialization);
337     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
338   }
339 };
340 
341 struct SparseVectorizationPass
342     : public impl::SparseVectorizationBase<SparseVectorizationPass> {
343   SparseVectorizationPass() = default;
344   SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
345   SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
346     vectorLength = vl;
347     enableVLAVectorization = vla;
348     enableSIMDIndex32 = sidx32;
349   }
350 
351   void runOnOperation() override {
352     if (vectorLength == 0)
353       return signalPassFailure();
354     auto *ctx = &getContext();
355     RewritePatternSet patterns(ctx);
356     populateSparseVectorizationPatterns(
357         patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
358     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
359     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
360   }
361 };
362 
363 struct SparseGPUCodegenPass
364     : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
365   SparseGPUCodegenPass() = default;
366   SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
367   SparseGPUCodegenPass(unsigned nT, bool enableRT) {
368     numThreads = nT;
369     enableRuntimeLibrary = enableRT;
370   }
371 
372   void runOnOperation() override {
373     auto *ctx = &getContext();
374     RewritePatternSet patterns(ctx);
375     if (numThreads == 0)
376       populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
377     else
378       populateSparseGPUCodegenPatterns(patterns, numThreads);
379     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
380   }
381 };
382 
383 struct StorageSpecifierToLLVMPass
384     : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
385   StorageSpecifierToLLVMPass() = default;
386 
387   void runOnOperation() override {
388     auto *ctx = &getContext();
389     ConversionTarget target(*ctx);
390     RewritePatternSet patterns(ctx);
391     StorageSpecifierToLLVMTypeConverter converter;
392 
393     // All ops in the sparse dialect must go!
394     target.addIllegalDialect<SparseTensorDialect>();
395     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
396       return converter.isSignatureLegal(op.getFunctionType());
397     });
398     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
399       return converter.isSignatureLegal(op.getCalleeType());
400     });
401     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
402       return converter.isLegal(op.getOperandTypes());
403     });
404     target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
405 
406     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
407                                                                    converter);
408     populateCallOpTypeConversionPattern(patterns, converter);
409     populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
410     populateReturnOpTypeConversionPattern(patterns, converter);
411     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
412                                                          target);
413     populateStorageSpecifierToLLVMPatterns(converter, patterns);
414     if (failed(applyPartialConversion(getOperation(), target,
415                                       std::move(patterns))))
416       signalPassFailure();
417   }
418 };
419 
420 } // namespace
421 
422 //===----------------------------------------------------------------------===//
423 // Pass creation methods.
424 //===----------------------------------------------------------------------===//
425 
426 std::unique_ptr<Pass> mlir::createSparseAssembler() {
427   return std::make_unique<SparseAssembler>();
428 }
429 
430 std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
431   return std::make_unique<SparseReinterpretMap>();
432 }
433 
434 std::unique_ptr<Pass>
435 mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
436   SparseReinterpretMapOptions options;
437   options.scope = scope;
438   return std::make_unique<SparseReinterpretMap>(options);
439 }
440 
441 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
442   return std::make_unique<PreSparsificationRewritePass>();
443 }
444 
445 std::unique_ptr<Pass> mlir::createSparsificationPass() {
446   return std::make_unique<SparsificationPass>();
447 }
448 
449 std::unique_ptr<Pass>
450 mlir::createSparsificationPass(const SparsificationOptions &options) {
451   return std::make_unique<SparsificationPass>(options);
452 }
453 
454 std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
455   return std::make_unique<StageSparseOperationsPass>();
456 }
457 
458 std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
459   return std::make_unique<LowerSparseOpsToForeachPass>();
460 }
461 
462 std::unique_ptr<Pass>
463 mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
464   return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
465 }
466 
467 std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
468   return std::make_unique<LowerForeachToSCFPass>();
469 }
470 
471 std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
472   return std::make_unique<LowerSparseIterationToSCFPass>();
473 }
474 
475 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
476   return std::make_unique<SparseTensorConversionPass>();
477 }
478 
479 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
480   return std::make_unique<SparseTensorCodegenPass>();
481 }
482 
483 std::unique_ptr<Pass>
484 mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
485                                     bool enableBufferInitialization) {
486   return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
487                                                    enableBufferInitialization);
488 }
489 
490 std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
491   return std::make_unique<SparseBufferRewritePass>();
492 }
493 
494 std::unique_ptr<Pass>
495 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
496   return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
497 }
498 
499 std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
500   return std::make_unique<SparseVectorizationPass>();
501 }
502 
503 std::unique_ptr<Pass>
504 mlir::createSparseVectorizationPass(unsigned vectorLength,
505                                     bool enableVLAVectorization,
506                                     bool enableSIMDIndex32) {
507   return std::make_unique<SparseVectorizationPass>(
508       vectorLength, enableVLAVectorization, enableSIMDIndex32);
509 }
510 
511 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
512   return std::make_unique<SparseGPUCodegenPass>();
513 }
514 
515 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
516                                                        bool enableRT) {
517   return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
518 }
519 
520 std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
521   return std::make_unique<StorageSpecifierToLLVMPass>();
522 }
523