xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (revision 2b5b3cf60d9e9e0c597bad1be1207b167ef15c9f)
1221856f5Swren romano //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
2a2c9d4bbSAart Bik //
3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information.
5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a2c9d4bbSAart Bik //
7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
8a2c9d4bbSAart Bik 
9eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
10abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1228b6d412SAart Bik #include "mlir/Dialect/Complex/IR/Complex.h"
131f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1519466ebcSAart Bik #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16a2c9d4bbSAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
184a6b31b8SAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h"
19a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
20a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
21eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h"
22a2c9d4bbSAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23a2c9d4bbSAart Bik 
2467d0d7acSMichele Scuttari namespace mlir {
2533b463adSAart Bik #define GEN_PASS_DEF_SPARSEASSEMBLER
267cfac1beSAart Bik #define GEN_PASS_DEF_SPARSEREINTERPRETMAP
27f81f0cb7Sbixia1 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
2867d0d7acSMichele Scuttari #define GEN_PASS_DEF_SPARSIFICATIONPASS
29c42bbda4SPeiming Liu #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
30f82bee13SPeiming Liu #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
31f82bee13SPeiming Liu #define GEN_PASS_DEF_LOWERFOREACHTOSCF
3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
3367d0d7acSMichele Scuttari #define GEN_PASS_DEF_SPARSETENSORCODEGEN
34062e515bSbixia1 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE
3599b3849dSAart Bik #define GEN_PASS_DEF_SPARSEVECTORIZATION
3619466ebcSAart Bik #define GEN_PASS_DEF_SPARSEGPUCODEGEN
3706374400SPeiming Liu #define GEN_PASS_DEF_STAGESPARSEOPERATIONS
38083ddffeSPeiming Liu #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
3967d0d7acSMichele Scuttari #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
4067d0d7acSMichele Scuttari } // namespace mlir
4167d0d7acSMichele Scuttari 
42a2c9d4bbSAart Bik using namespace mlir;
4396a23911SAart Bik using namespace mlir::sparse_tensor;
44a2c9d4bbSAart Bik 
45a2c9d4bbSAart Bik namespace {
46a2c9d4bbSAart Bik 
47a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
48a2c9d4bbSAart Bik // Passes implementation.
49a2c9d4bbSAart Bik //===----------------------------------------------------------------------===//
50a2c9d4bbSAart Bik 
5133b463adSAart Bik struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
5233b463adSAart Bik   SparseAssembler() = default;
5333b463adSAart Bik   SparseAssembler(const SparseAssembler &pass) = default;
545122a2c2SAart Bik   SparseAssembler(bool dO) { directOut = dO; }
5533b463adSAart Bik 
5633b463adSAart Bik   void runOnOperation() override {
5733b463adSAart Bik     auto *ctx = &getContext();
5833b463adSAart Bik     RewritePatternSet patterns(ctx);
595122a2c2SAart Bik     populateSparseAssembler(patterns, directOut);
6009dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
6133b463adSAart Bik   }
6233b463adSAart Bik };
6333b463adSAart Bik 
647cfac1beSAart Bik struct SparseReinterpretMap
657cfac1beSAart Bik     : public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
667cfac1beSAart Bik   SparseReinterpretMap() = default;
677cfac1beSAart Bik   SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
686a93da99SPeiming Liu   SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
696a93da99SPeiming Liu     scope = options.scope;
706a93da99SPeiming Liu   }
717cfac1beSAart Bik 
727cfac1beSAart Bik   void runOnOperation() override {
737cfac1beSAart Bik     auto *ctx = &getContext();
747cfac1beSAart Bik     RewritePatternSet patterns(ctx);
756a93da99SPeiming Liu     populateSparseReinterpretMap(patterns, scope);
7609dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
777cfac1beSAart Bik   }
787cfac1beSAart Bik };
797cfac1beSAart Bik 
80f81f0cb7Sbixia1 struct PreSparsificationRewritePass
81f81f0cb7Sbixia1     : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
82f81f0cb7Sbixia1   PreSparsificationRewritePass() = default;
83f81f0cb7Sbixia1   PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
84f81f0cb7Sbixia1       default;
85779dcd2eSAart Bik 
86779dcd2eSAart Bik   void runOnOperation() override {
87779dcd2eSAart Bik     auto *ctx = &getContext();
88779dcd2eSAart Bik     RewritePatternSet patterns(ctx);
89f81f0cb7Sbixia1     populatePreSparsificationRewriting(patterns);
9009dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
91779dcd2eSAart Bik   }
92779dcd2eSAart Bik };
93779dcd2eSAart Bik 
9467d0d7acSMichele Scuttari struct SparsificationPass
9567d0d7acSMichele Scuttari     : public impl::SparsificationPassBase<SparsificationPass> {
96a2c9d4bbSAart Bik   SparsificationPass() = default;
97abb336d2SMehdi Amini   SparsificationPass(const SparsificationPass &pass) = default;
98b85ed4e0Swren romano   SparsificationPass(const SparsificationOptions &options) {
9930ceb783SNick Kreeger     parallelization = options.parallelizationStrategy;
1004a653b4dSPeiming Liu     sparseEmitStrategy = options.sparseEmitStrategy;
101ee42e236SAart Bik     enableRuntimeLibrary = options.enableRuntimeLibrary;
102a2c9d4bbSAart Bik   }
103a2c9d4bbSAart Bik 
104a2c9d4bbSAart Bik   void runOnOperation() override {
105a2c9d4bbSAart Bik     auto *ctx = &getContext();
106a2c9d4bbSAart Bik     // Translate strategy flags to strategy options.
1074a653b4dSPeiming Liu     SparsificationOptions options(parallelization, sparseEmitStrategy,
1084a653b4dSPeiming Liu                                   enableRuntimeLibrary);
1095f32bcfbSAart Bik     // Apply sparsification and cleanup rewriting.
11028ebb0b6SAart Bik     RewritePatternSet patterns(ctx);
111a2c9d4bbSAart Bik     populateSparsificationPatterns(patterns, options);
11230b550f1SAart Bik     scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
11309dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
114a2c9d4bbSAart Bik   }
115a2c9d4bbSAart Bik };
116a2c9d4bbSAart Bik 
11706374400SPeiming Liu struct StageSparseOperationsPass
11806374400SPeiming Liu     : public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
11906374400SPeiming Liu   StageSparseOperationsPass() = default;
12006374400SPeiming Liu   StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
12106374400SPeiming Liu   void runOnOperation() override {
12206374400SPeiming Liu     auto *ctx = &getContext();
12306374400SPeiming Liu     RewritePatternSet patterns(ctx);
12406374400SPeiming Liu     populateStageSparseOperationsPatterns(patterns);
12509dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
12606374400SPeiming Liu   }
12706374400SPeiming Liu };
12806374400SPeiming Liu 
129f82bee13SPeiming Liu struct LowerSparseOpsToForeachPass
130f82bee13SPeiming Liu     : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
131f82bee13SPeiming Liu   LowerSparseOpsToForeachPass() = default;
132f82bee13SPeiming Liu   LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
133f81f0cb7Sbixia1       default;
134f82bee13SPeiming Liu   LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
135f81f0cb7Sbixia1     enableRuntimeLibrary = enableRT;
136f81f0cb7Sbixia1     enableConvert = convert;
137f81f0cb7Sbixia1   }
138f81f0cb7Sbixia1 
139f81f0cb7Sbixia1   void runOnOperation() override {
140f81f0cb7Sbixia1     auto *ctx = &getContext();
141f81f0cb7Sbixia1     RewritePatternSet patterns(ctx);
142f82bee13SPeiming Liu     populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
143f82bee13SPeiming Liu                                             enableConvert);
14409dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
145f82bee13SPeiming Liu   }
146f82bee13SPeiming Liu };
147f82bee13SPeiming Liu 
148f82bee13SPeiming Liu struct LowerForeachToSCFPass
149f82bee13SPeiming Liu     : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
150f82bee13SPeiming Liu   LowerForeachToSCFPass() = default;
151f82bee13SPeiming Liu   LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
152f82bee13SPeiming Liu 
153f82bee13SPeiming Liu   void runOnOperation() override {
154f82bee13SPeiming Liu     auto *ctx = &getContext();
155f82bee13SPeiming Liu     RewritePatternSet patterns(ctx);
156f82bee13SPeiming Liu     populateLowerForeachToSCFPatterns(patterns);
15709dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
158f81f0cb7Sbixia1   }
159f81f0cb7Sbixia1 };
160f81f0cb7Sbixia1 
161c42bbda4SPeiming Liu struct LowerSparseIterationToSCFPass
162c42bbda4SPeiming Liu     : public impl::LowerSparseIterationToSCFBase<
163c42bbda4SPeiming Liu           LowerSparseIterationToSCFPass> {
164c42bbda4SPeiming Liu   LowerSparseIterationToSCFPass() = default;
165c42bbda4SPeiming Liu   LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
166c42bbda4SPeiming Liu       default;
167c42bbda4SPeiming Liu 
168c42bbda4SPeiming Liu   void runOnOperation() override {
169c42bbda4SPeiming Liu     auto *ctx = &getContext();
170c42bbda4SPeiming Liu     RewritePatternSet patterns(ctx);
171c42bbda4SPeiming Liu     SparseIterationTypeConverter converter;
172c42bbda4SPeiming Liu     ConversionTarget target(*ctx);
173c42bbda4SPeiming Liu 
174c42bbda4SPeiming Liu     // The actual conversion.
175*2b5b3cf6SMatthias Springer     target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
176*2b5b3cf6SMatthias Springer                            memref::MemRefDialect, scf::SCFDialect,
177*2b5b3cf6SMatthias Springer                            sparse_tensor::SparseTensorDialect>();
178*2b5b3cf6SMatthias Springer     target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
179*2b5b3cf6SMatthias Springer                         IterateOp>();
180*2b5b3cf6SMatthias Springer     target.addLegalOp<UnrealizedConversionCastOp>();
181c42bbda4SPeiming Liu     populateLowerSparseIterationToSCFPatterns(converter, patterns);
182c42bbda4SPeiming Liu 
183*2b5b3cf6SMatthias Springer     if (failed(applyPartialConversion(getOperation(), target,
184c42bbda4SPeiming Liu                                       std::move(patterns))))
185c42bbda4SPeiming Liu       signalPassFailure();
186c42bbda4SPeiming Liu   }
187c42bbda4SPeiming Liu };
188c42bbda4SPeiming Liu 
189a2c9d4bbSAart Bik struct SparseTensorConversionPass
19067d0d7acSMichele Scuttari     : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
191c7e24db4Swren romano   SparseTensorConversionPass() = default;
192c7e24db4Swren romano   SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
193c7e24db4Swren romano 
194a2c9d4bbSAart Bik   void runOnOperation() override {
195a2c9d4bbSAart Bik     auto *ctx = &getContext();
19696a23911SAart Bik     RewritePatternSet patterns(ctx);
19786b22d31SAart Bik     SparseTensorTypeToPtrConverter converter;
198a2c9d4bbSAart Bik     ConversionTarget target(*ctx);
1991b15160eSAart Bik     // Everything in the sparse dialect must go!
2001b15160eSAart Bik     target.addIllegalDialect<SparseTensorDialect>();
201fde04aeeSAart Bik     // All dynamic rules below accept new function, call, return, and various
202fde04aeeSAart Bik     // tensor and bufferization operations as legal output of the rewriting
203fde04aeeSAart Bik     // provided that all sparse tensor types have been fully rewritten.
20458ceae95SRiver Riddle     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
2054a3460a7SRiver Riddle       return converter.isSignatureLegal(op.getFunctionType());
2064a3460a7SRiver Riddle     });
20723aa5a74SRiver Riddle     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
20896a23911SAart Bik       return converter.isSignatureLegal(op.getCalleeType());
20996a23911SAart Bik     });
21023aa5a74SRiver Riddle     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
21123aa5a74SRiver Riddle       return converter.isLegal(op.getOperandTypes());
21223aa5a74SRiver Riddle     });
213236a9080SAart Bik     target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
214236a9080SAart Bik       return converter.isLegal(op.getOperandTypes());
215236a9080SAart Bik     });
2161b15160eSAart Bik     target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
217136d746eSJacques Pienaar       return converter.isLegal(op.getSource().getType()) &&
218136d746eSJacques Pienaar              converter.isLegal(op.getDest().getType());
2196d8e2f1eSAart Bik     });
2206d8e2f1eSAart Bik     target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
2216d8e2f1eSAart Bik         [&](tensor::ExpandShapeOp op) {
222136d746eSJacques Pienaar           return converter.isLegal(op.getSrc().getType()) &&
223136d746eSJacques Pienaar                  converter.isLegal(op.getResult().getType());
2246d8e2f1eSAart Bik         });
2256d8e2f1eSAart Bik     target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
2266d8e2f1eSAart Bik         [&](tensor::CollapseShapeOp op) {
227136d746eSJacques Pienaar           return converter.isLegal(op.getSrc().getType()) &&
228136d746eSJacques Pienaar                  converter.isLegal(op.getResult().getType());
2291b15160eSAart Bik         });
230fde04aeeSAart Bik     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
231fde04aeeSAart Bik         [&](bufferization::AllocTensorOp op) {
232fde04aeeSAart Bik           return converter.isLegal(op.getType());
233fde04aeeSAart Bik         });
23427a431f5SMatthias Springer     target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
23527a431f5SMatthias Springer         [&](bufferization::DeallocTensorOp op) {
23627a431f5SMatthias Springer           return converter.isLegal(op.getTensor().getType());
23727a431f5SMatthias Springer         });
238236a9080SAart Bik     // The following operations and dialects may be introduced by the
239236a9080SAart Bik     // rewriting rules, and are therefore marked as legal.
2402ddfacd9SAart Bik     target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
241af8428c0SAart Bik                       linalg::YieldOp, tensor::ExtractOp,
242af8428c0SAart Bik                       tensor::FromElementsOp>();
243faa00c13SAart Bik     target.addLegalDialect<
244abc362a1SJakub Kuderski         arith::ArithDialect, bufferization::BufferizationDialect,
245faa00c13SAart Bik         LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
246f248d0b2SPeiming Liu 
247236a9080SAart Bik     // Populate with rules and apply rewriting rules.
24858ceae95SRiver Riddle     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
2497ceffae1SRiver Riddle                                                                    converter);
25096a23911SAart Bik     populateCallOpTypeConversionPattern(patterns, converter);
251bf59cd32SPeiming Liu     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
252bf59cd32SPeiming Liu                                                          target);
253f248d0b2SPeiming Liu     populateSparseTensorConversionPatterns(converter, patterns);
254a2c9d4bbSAart Bik     if (failed(applyPartialConversion(getOperation(), target,
25596a23911SAart Bik                                       std::move(patterns))))
256a2c9d4bbSAart Bik       signalPassFailure();
257a2c9d4bbSAart Bik   }
258a2c9d4bbSAart Bik };
259a2c9d4bbSAart Bik 
26086b22d31SAart Bik struct SparseTensorCodegenPass
26167d0d7acSMichele Scuttari     : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
26286b22d31SAart Bik   SparseTensorCodegenPass() = default;
26386b22d31SAart Bik   SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
264c44d307cSPeiming Liu   SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
265c44d307cSPeiming Liu     createSparseDeallocs = createDeallocs;
2667276b643Sbixia1     enableBufferInitialization = enableInit;
2677276b643Sbixia1   }
26886b22d31SAart Bik 
26986b22d31SAart Bik   void runOnOperation() override {
27086b22d31SAart Bik     auto *ctx = &getContext();
27186b22d31SAart Bik     RewritePatternSet patterns(ctx);
27286b22d31SAart Bik     SparseTensorTypeToBufferConverter converter;
27386b22d31SAart Bik     ConversionTarget target(*ctx);
2744d068619SAart Bik     // Most ops in the sparse dialect must go!
27586b22d31SAart Bik     target.addIllegalDialect<SparseTensorDialect>();
2760083f833SPeiming Liu     target.addLegalOp<SortOp>();
277654bbbdeSbixia1     target.addLegalOp<PushBackOp>();
278083ddffeSPeiming Liu     // Storage specifier outlives sparse tensor pipeline.
279083ddffeSPeiming Liu     target.addLegalOp<GetStorageSpecifierOp>();
280083ddffeSPeiming Liu     target.addLegalOp<SetStorageSpecifierOp>();
281083ddffeSPeiming Liu     target.addLegalOp<StorageSpecifierInitOp>();
282098f46dcSPeiming Liu     // Note that tensor::FromElementsOp might be yield after lowering unpack.
283098f46dcSPeiming Liu     target.addLegalOp<tensor::FromElementsOp>();
284083ddffeSPeiming Liu     // All dynamic rules below accept new function, call, return, and
285083ddffeSPeiming Liu     // various tensor and bufferization operations as legal output of the
286083ddffeSPeiming Liu     // rewriting provided that all sparse tensor types have been fully
287083ddffeSPeiming Liu     // rewritten.
28886b22d31SAart Bik     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
28986b22d31SAart Bik       return converter.isSignatureLegal(op.getFunctionType());
29086b22d31SAart Bik     });
29186b22d31SAart Bik     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
29286b22d31SAart Bik       return converter.isSignatureLegal(op.getCalleeType());
29386b22d31SAart Bik     });
29486b22d31SAart Bik     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
29586b22d31SAart Bik       return converter.isLegal(op.getOperandTypes());
29686b22d31SAart Bik     });
2970c7abd39SAart Bik     target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
2980c7abd39SAart Bik         [&](bufferization::AllocTensorOp op) {
2990c7abd39SAart Bik           return converter.isLegal(op.getType());
3000c7abd39SAart Bik         });
3012ddfacd9SAart Bik     target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
3022ddfacd9SAart Bik         [&](bufferization::DeallocTensorOp op) {
3032ddfacd9SAart Bik           return converter.isLegal(op.getTensor().getType());
3042ddfacd9SAart Bik         });
3058a583bd5Sbixia1     // The following operations and dialects may be introduced by the
3068a583bd5Sbixia1     // codegen rules, and are therefore marked as legal.
307a4c47055SMatthias Springer     target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
308ea4be70cSbixia1     target.addLegalDialect<
309ea4be70cSbixia1         arith::ArithDialect, bufferization::BufferizationDialect,
310ea4be70cSbixia1         complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
311928b5b06SPeiming Liu     target.addLegalOp<UnrealizedConversionCastOp>();
312ca01c996SPeiming Liu     // Populate with rules and apply rewriting rules.
313ca01c996SPeiming Liu     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
314ca01c996SPeiming Liu                                                                    converter);
315ca01c996SPeiming Liu     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
316ca01c996SPeiming Liu                                                          target);
317c44d307cSPeiming Liu     populateSparseTensorCodegenPatterns(
318c44d307cSPeiming Liu         converter, patterns, createSparseDeallocs, enableBufferInitialization);
319ca01c996SPeiming Liu     if (failed(applyPartialConversion(getOperation(), target,
320ca01c996SPeiming Liu                                       std::move(patterns))))
321ca01c996SPeiming Liu       signalPassFailure();
322ca01c996SPeiming Liu   }
323ca01c996SPeiming Liu };
324ca01c996SPeiming Liu 
325062e515bSbixia1 struct SparseBufferRewritePass
326062e515bSbixia1     : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
327062e515bSbixia1   SparseBufferRewritePass() = default;
328062e515bSbixia1   SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
3295618d2beSbixia1   SparseBufferRewritePass(bool enableInit) {
3305618d2beSbixia1     enableBufferInitialization = enableInit;
3315618d2beSbixia1   }
332062e515bSbixia1 
333062e515bSbixia1   void runOnOperation() override {
334062e515bSbixia1     auto *ctx = &getContext();
335062e515bSbixia1     RewritePatternSet patterns(ctx);
3365618d2beSbixia1     populateSparseBufferRewriting(patterns, enableBufferInitialization);
33709dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
338062e515bSbixia1   }
339062e515bSbixia1 };
340062e515bSbixia1 
34199b3849dSAart Bik struct SparseVectorizationPass
34299b3849dSAart Bik     : public impl::SparseVectorizationBase<SparseVectorizationPass> {
34399b3849dSAart Bik   SparseVectorizationPass() = default;
34499b3849dSAart Bik   SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
34599b3849dSAart Bik   SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
34699b3849dSAart Bik     vectorLength = vl;
34799b3849dSAart Bik     enableVLAVectorization = vla;
34899b3849dSAart Bik     enableSIMDIndex32 = sidx32;
34999b3849dSAart Bik   }
35099b3849dSAart Bik 
35199b3849dSAart Bik   void runOnOperation() override {
35214c0317fSAart Bik     if (vectorLength == 0)
35314c0317fSAart Bik       return signalPassFailure();
35499b3849dSAart Bik     auto *ctx = &getContext();
35599b3849dSAart Bik     RewritePatternSet patterns(ctx);
35699b3849dSAart Bik     populateSparseVectorizationPatterns(
35799b3849dSAart Bik         patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
35899b3849dSAart Bik     vector::populateVectorToVectorCanonicalizationPatterns(patterns);
35909dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
36099b3849dSAart Bik   }
36199b3849dSAart Bik };
36299b3849dSAart Bik 
36319466ebcSAart Bik struct SparseGPUCodegenPass
36419466ebcSAart Bik     : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
36519466ebcSAart Bik   SparseGPUCodegenPass() = default;
36619466ebcSAart Bik   SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
3675f32bcfbSAart Bik   SparseGPUCodegenPass(unsigned nT, bool enableRT) {
3685f32bcfbSAart Bik     numThreads = nT;
3695f32bcfbSAart Bik     enableRuntimeLibrary = enableRT;
3705f32bcfbSAart Bik   }
37119466ebcSAart Bik 
37219466ebcSAart Bik   void runOnOperation() override {
37319466ebcSAart Bik     auto *ctx = &getContext();
37419466ebcSAart Bik     RewritePatternSet patterns(ctx);
3755f32bcfbSAart Bik     if (numThreads == 0)
3765f32bcfbSAart Bik       populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
3775f32bcfbSAart Bik     else
37819466ebcSAart Bik       populateSparseGPUCodegenPatterns(patterns, numThreads);
37909dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
38019466ebcSAart Bik   }
38119466ebcSAart Bik };
38219466ebcSAart Bik 
383083ddffeSPeiming Liu struct StorageSpecifierToLLVMPass
384083ddffeSPeiming Liu     : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
385083ddffeSPeiming Liu   StorageSpecifierToLLVMPass() = default;
386083ddffeSPeiming Liu 
387083ddffeSPeiming Liu   void runOnOperation() override {
388083ddffeSPeiming Liu     auto *ctx = &getContext();
389083ddffeSPeiming Liu     ConversionTarget target(*ctx);
390083ddffeSPeiming Liu     RewritePatternSet patterns(ctx);
391083ddffeSPeiming Liu     StorageSpecifierToLLVMTypeConverter converter;
392083ddffeSPeiming Liu 
393083ddffeSPeiming Liu     // All ops in the sparse dialect must go!
394083ddffeSPeiming Liu     target.addIllegalDialect<SparseTensorDialect>();
395083ddffeSPeiming Liu     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
396083ddffeSPeiming Liu       return converter.isSignatureLegal(op.getFunctionType());
397083ddffeSPeiming Liu     });
398083ddffeSPeiming Liu     target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
399083ddffeSPeiming Liu       return converter.isSignatureLegal(op.getCalleeType());
400083ddffeSPeiming Liu     });
401083ddffeSPeiming Liu     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
402083ddffeSPeiming Liu       return converter.isLegal(op.getOperandTypes());
403083ddffeSPeiming Liu     });
404083ddffeSPeiming Liu     target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
405083ddffeSPeiming Liu 
406083ddffeSPeiming Liu     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
407083ddffeSPeiming Liu                                                                    converter);
408083ddffeSPeiming Liu     populateCallOpTypeConversionPattern(patterns, converter);
409083ddffeSPeiming Liu     populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
410083ddffeSPeiming Liu     populateReturnOpTypeConversionPattern(patterns, converter);
411083ddffeSPeiming Liu     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
412083ddffeSPeiming Liu                                                          target);
413083ddffeSPeiming Liu     populateStorageSpecifierToLLVMPatterns(converter, patterns);
414083ddffeSPeiming Liu     if (failed(applyPartialConversion(getOperation(), target,
415083ddffeSPeiming Liu                                       std::move(patterns))))
416083ddffeSPeiming Liu       signalPassFailure();
417083ddffeSPeiming Liu   }
418083ddffeSPeiming Liu };
419083ddffeSPeiming Liu 
420be0a7e9fSMehdi Amini } // namespace
421a2c9d4bbSAart Bik 
42286b22d31SAart Bik //===----------------------------------------------------------------------===//
42386b22d31SAart Bik // Pass creation methods.
42486b22d31SAart Bik //===----------------------------------------------------------------------===//
42586b22d31SAart Bik 
42633b463adSAart Bik std::unique_ptr<Pass> mlir::createSparseAssembler() {
42733b463adSAart Bik   return std::make_unique<SparseAssembler>();
42833b463adSAart Bik }
42933b463adSAart Bik 
4307cfac1beSAart Bik std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
4317cfac1beSAart Bik   return std::make_unique<SparseReinterpretMap>();
4327cfac1beSAart Bik }
4337cfac1beSAart Bik 
4346a93da99SPeiming Liu std::unique_ptr<Pass>
4356a93da99SPeiming Liu mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
4366a93da99SPeiming Liu   SparseReinterpretMapOptions options;
4376a93da99SPeiming Liu   options.scope = scope;
4386a93da99SPeiming Liu   return std::make_unique<SparseReinterpretMap>(options);
4396a93da99SPeiming Liu }
4406a93da99SPeiming Liu 
441f81f0cb7Sbixia1 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
442f81f0cb7Sbixia1   return std::make_unique<PreSparsificationRewritePass>();
443779dcd2eSAart Bik }
444779dcd2eSAart Bik 
445a2c9d4bbSAart Bik std::unique_ptr<Pass> mlir::createSparsificationPass() {
446a2c9d4bbSAart Bik   return std::make_unique<SparsificationPass>();
447a2c9d4bbSAart Bik }
448a2c9d4bbSAart Bik 
449b85ed4e0Swren romano std::unique_ptr<Pass>
450b85ed4e0Swren romano mlir::createSparsificationPass(const SparsificationOptions &options) {
451b85ed4e0Swren romano   return std::make_unique<SparsificationPass>(options);
452b85ed4e0Swren romano }
453b85ed4e0Swren romano 
45406374400SPeiming Liu std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
45506374400SPeiming Liu   return std::make_unique<StageSparseOperationsPass>();
45606374400SPeiming Liu }
45706374400SPeiming Liu 
458f82bee13SPeiming Liu std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
459f82bee13SPeiming Liu   return std::make_unique<LowerSparseOpsToForeachPass>();
460f81f0cb7Sbixia1 }
461f81f0cb7Sbixia1 
462f81f0cb7Sbixia1 std::unique_ptr<Pass>
463f82bee13SPeiming Liu mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
464f82bee13SPeiming Liu   return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
465f82bee13SPeiming Liu }
466f82bee13SPeiming Liu 
467f82bee13SPeiming Liu std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
468f82bee13SPeiming Liu   return std::make_unique<LowerForeachToSCFPass>();
469f81f0cb7Sbixia1 }
470f81f0cb7Sbixia1 
471c42bbda4SPeiming Liu std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
472c42bbda4SPeiming Liu   return std::make_unique<LowerSparseIterationToSCFPass>();
473c42bbda4SPeiming Liu }
474c42bbda4SPeiming Liu 
475a2c9d4bbSAart Bik std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
476a2c9d4bbSAart Bik   return std::make_unique<SparseTensorConversionPass>();
477a2c9d4bbSAart Bik }
478c7e24db4Swren romano 
479d1da6f23SAart Bik std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
480d1da6f23SAart Bik   return std::make_unique<SparseTensorCodegenPass>();
481d1da6f23SAart Bik }
482d1da6f23SAart Bik 
4837276b643Sbixia1 std::unique_ptr<Pass>
484c44d307cSPeiming Liu mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
485c44d307cSPeiming Liu                                     bool enableBufferInitialization) {
486c44d307cSPeiming Liu   return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
487c44d307cSPeiming Liu                                                    enableBufferInitialization);
48886b22d31SAart Bik }
489062e515bSbixia1 
490d1da6f23SAart Bik std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
491d1da6f23SAart Bik   return std::make_unique<SparseBufferRewritePass>();
492d1da6f23SAart Bik }
493d1da6f23SAart Bik 
4945618d2beSbixia1 std::unique_ptr<Pass>
4955618d2beSbixia1 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
4965618d2beSbixia1   return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
497062e515bSbixia1 }
49899b3849dSAart Bik 
49999b3849dSAart Bik std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
50099b3849dSAart Bik   return std::make_unique<SparseVectorizationPass>();
50199b3849dSAart Bik }
50299b3849dSAart Bik 
50399b3849dSAart Bik std::unique_ptr<Pass>
50499b3849dSAart Bik mlir::createSparseVectorizationPass(unsigned vectorLength,
50599b3849dSAart Bik                                     bool enableVLAVectorization,
50699b3849dSAart Bik                                     bool enableSIMDIndex32) {
50799b3849dSAart Bik   return std::make_unique<SparseVectorizationPass>(
50899b3849dSAart Bik       vectorLength, enableVLAVectorization, enableSIMDIndex32);
50999b3849dSAart Bik }
510083ddffeSPeiming Liu 
51119466ebcSAart Bik std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
51219466ebcSAart Bik   return std::make_unique<SparseGPUCodegenPass>();
51319466ebcSAart Bik }
51419466ebcSAart Bik 
5155f32bcfbSAart Bik std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
5165f32bcfbSAart Bik                                                        bool enableRT) {
5175f32bcfbSAart Bik   return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
51819466ebcSAart Bik }
51919466ebcSAart Bik 
520083ddffeSPeiming Liu std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
521083ddffeSPeiming Liu   return std::make_unique<StorageSpecifierToLLVMPass>();
522083ddffeSPeiming Liu }
523