1221856f5Swren romano //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2a2c9d4bbSAart Bik // 3a2c9d4bbSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a2c9d4bbSAart Bik // See https://llvm.org/LICENSE.txt for license information. 5a2c9d4bbSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a2c9d4bbSAart Bik // 7a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 8a2c9d4bbSAart Bik 9eda6f907SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h" 10abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 1228b6d412SAart Bik #include "mlir/Dialect/Complex/IR/Complex.h" 131f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 1519466ebcSAart Bik #include "mlir/Dialect/GPU/IR/GPUDialect.h" 16a2c9d4bbSAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17a2c9d4bbSAart Bik #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 184a6b31b8SAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h" 19a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20a2c9d4bbSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h" 22a2c9d4bbSAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23a2c9d4bbSAart Bik 2467d0d7acSMichele Scuttari namespace mlir { 2533b463adSAart Bik #define GEN_PASS_DEF_SPARSEASSEMBLER 267cfac1beSAart Bik #define GEN_PASS_DEF_SPARSEREINTERPRETMAP 27f81f0cb7Sbixia1 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE 2867d0d7acSMichele Scuttari #define GEN_PASS_DEF_SPARSIFICATIONPASS 29c42bbda4SPeiming Liu #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF 30f82bee13SPeiming Liu #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH 31f82bee13SPeiming Liu #define GEN_PASS_DEF_LOWERFOREACHTOSCF 3267d0d7acSMichele Scuttari #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS 3367d0d7acSMichele Scuttari #define GEN_PASS_DEF_SPARSETENSORCODEGEN 34062e515bSbixia1 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE 3599b3849dSAart Bik #define GEN_PASS_DEF_SPARSEVECTORIZATION 3619466ebcSAart Bik #define GEN_PASS_DEF_SPARSEGPUCODEGEN 3706374400SPeiming Liu #define GEN_PASS_DEF_STAGESPARSEOPERATIONS 38083ddffeSPeiming Liu #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM 3967d0d7acSMichele Scuttari #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 4067d0d7acSMichele Scuttari } // namespace mlir 4167d0d7acSMichele Scuttari 42a2c9d4bbSAart Bik using namespace mlir; 4396a23911SAart Bik using namespace mlir::sparse_tensor; 44a2c9d4bbSAart Bik 45a2c9d4bbSAart Bik namespace { 46a2c9d4bbSAart Bik 47a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 48a2c9d4bbSAart Bik // Passes implementation. 49a2c9d4bbSAart Bik //===----------------------------------------------------------------------===// 50a2c9d4bbSAart Bik 5133b463adSAart Bik struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> { 5233b463adSAart Bik SparseAssembler() = default; 5333b463adSAart Bik SparseAssembler(const SparseAssembler &pass) = default; 545122a2c2SAart Bik SparseAssembler(bool dO) { directOut = dO; } 5533b463adSAart Bik 5633b463adSAart Bik void runOnOperation() override { 5733b463adSAart Bik auto *ctx = &getContext(); 5833b463adSAart Bik RewritePatternSet patterns(ctx); 595122a2c2SAart Bik populateSparseAssembler(patterns, directOut); 6009dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 6133b463adSAart Bik } 6233b463adSAart Bik }; 6333b463adSAart Bik 647cfac1beSAart Bik struct SparseReinterpretMap 657cfac1beSAart Bik : public impl::SparseReinterpretMapBase<SparseReinterpretMap> { 667cfac1beSAart Bik SparseReinterpretMap() = default; 677cfac1beSAart Bik SparseReinterpretMap(const SparseReinterpretMap &pass) = default; 686a93da99SPeiming Liu SparseReinterpretMap(const SparseReinterpretMapOptions &options) { 696a93da99SPeiming Liu scope = options.scope; 706a93da99SPeiming Liu } 717cfac1beSAart Bik 727cfac1beSAart Bik void runOnOperation() override { 737cfac1beSAart Bik auto *ctx = &getContext(); 747cfac1beSAart Bik RewritePatternSet patterns(ctx); 756a93da99SPeiming Liu populateSparseReinterpretMap(patterns, scope); 7609dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 777cfac1beSAart Bik } 787cfac1beSAart Bik }; 797cfac1beSAart Bik 80f81f0cb7Sbixia1 struct PreSparsificationRewritePass 81f81f0cb7Sbixia1 : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> { 82f81f0cb7Sbixia1 PreSparsificationRewritePass() = default; 83f81f0cb7Sbixia1 PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) = 84f81f0cb7Sbixia1 default; 85779dcd2eSAart Bik 86779dcd2eSAart Bik void runOnOperation() override { 87779dcd2eSAart Bik auto *ctx = &getContext(); 88779dcd2eSAart Bik RewritePatternSet patterns(ctx); 89f81f0cb7Sbixia1 populatePreSparsificationRewriting(patterns); 9009dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 91779dcd2eSAart Bik } 92779dcd2eSAart Bik }; 93779dcd2eSAart Bik 9467d0d7acSMichele Scuttari struct SparsificationPass 9567d0d7acSMichele Scuttari : public impl::SparsificationPassBase<SparsificationPass> { 96a2c9d4bbSAart Bik SparsificationPass() = default; 97abb336d2SMehdi Amini SparsificationPass(const SparsificationPass &pass) = default; 98b85ed4e0Swren romano SparsificationPass(const SparsificationOptions &options) { 9930ceb783SNick Kreeger parallelization = options.parallelizationStrategy; 1004a653b4dSPeiming Liu sparseEmitStrategy = options.sparseEmitStrategy; 101ee42e236SAart Bik enableRuntimeLibrary = options.enableRuntimeLibrary; 102a2c9d4bbSAart Bik } 103a2c9d4bbSAart Bik 104a2c9d4bbSAart Bik void runOnOperation() override { 105a2c9d4bbSAart Bik auto *ctx = &getContext(); 106a2c9d4bbSAart Bik // Translate strategy flags to strategy options. 1074a653b4dSPeiming Liu SparsificationOptions options(parallelization, sparseEmitStrategy, 1084a653b4dSPeiming Liu enableRuntimeLibrary); 1095f32bcfbSAart Bik // Apply sparsification and cleanup rewriting. 11028ebb0b6SAart Bik RewritePatternSet patterns(ctx); 111a2c9d4bbSAart Bik populateSparsificationPatterns(patterns, options); 11230b550f1SAart Bik scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 11309dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 114a2c9d4bbSAart Bik } 115a2c9d4bbSAart Bik }; 116a2c9d4bbSAart Bik 11706374400SPeiming Liu struct StageSparseOperationsPass 11806374400SPeiming Liu : public impl::StageSparseOperationsBase<StageSparseOperationsPass> { 11906374400SPeiming Liu StageSparseOperationsPass() = default; 12006374400SPeiming Liu StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default; 12106374400SPeiming Liu void runOnOperation() override { 12206374400SPeiming Liu auto *ctx = &getContext(); 12306374400SPeiming Liu RewritePatternSet patterns(ctx); 12406374400SPeiming Liu populateStageSparseOperationsPatterns(patterns); 12509dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 12606374400SPeiming Liu } 12706374400SPeiming Liu }; 12806374400SPeiming Liu 129f82bee13SPeiming Liu struct LowerSparseOpsToForeachPass 130f82bee13SPeiming Liu : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> { 131f82bee13SPeiming Liu LowerSparseOpsToForeachPass() = default; 132f82bee13SPeiming Liu LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) = 133f81f0cb7Sbixia1 default; 134f82bee13SPeiming Liu LowerSparseOpsToForeachPass(bool enableRT, bool convert) { 135f81f0cb7Sbixia1 enableRuntimeLibrary = enableRT; 136f81f0cb7Sbixia1 enableConvert = convert; 137f81f0cb7Sbixia1 } 138f81f0cb7Sbixia1 139f81f0cb7Sbixia1 void runOnOperation() override { 140f81f0cb7Sbixia1 auto *ctx = &getContext(); 141f81f0cb7Sbixia1 RewritePatternSet patterns(ctx); 142f82bee13SPeiming Liu populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary, 143f82bee13SPeiming Liu enableConvert); 14409dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 145f82bee13SPeiming Liu } 146f82bee13SPeiming Liu }; 147f82bee13SPeiming Liu 148f82bee13SPeiming Liu struct LowerForeachToSCFPass 149f82bee13SPeiming Liu : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> { 150f82bee13SPeiming Liu LowerForeachToSCFPass() = default; 151f82bee13SPeiming Liu LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default; 152f82bee13SPeiming Liu 153f82bee13SPeiming Liu void runOnOperation() override { 154f82bee13SPeiming Liu auto *ctx = &getContext(); 155f82bee13SPeiming Liu RewritePatternSet patterns(ctx); 156f82bee13SPeiming Liu populateLowerForeachToSCFPatterns(patterns); 15709dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 158f81f0cb7Sbixia1 } 159f81f0cb7Sbixia1 }; 160f81f0cb7Sbixia1 161c42bbda4SPeiming Liu struct LowerSparseIterationToSCFPass 162c42bbda4SPeiming Liu : public impl::LowerSparseIterationToSCFBase< 163c42bbda4SPeiming Liu LowerSparseIterationToSCFPass> { 164c42bbda4SPeiming Liu LowerSparseIterationToSCFPass() = default; 165c42bbda4SPeiming Liu LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) = 166c42bbda4SPeiming Liu default; 167c42bbda4SPeiming Liu 168c42bbda4SPeiming Liu void runOnOperation() override { 169c42bbda4SPeiming Liu auto *ctx = &getContext(); 170c42bbda4SPeiming Liu RewritePatternSet patterns(ctx); 171c42bbda4SPeiming Liu SparseIterationTypeConverter converter; 172c42bbda4SPeiming Liu ConversionTarget target(*ctx); 173c42bbda4SPeiming Liu 174c42bbda4SPeiming Liu // The actual conversion. 175*2b5b3cf6SMatthias Springer target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect, 176*2b5b3cf6SMatthias Springer memref::MemRefDialect, scf::SCFDialect, 177*2b5b3cf6SMatthias Springer sparse_tensor::SparseTensorDialect>(); 178*2b5b3cf6SMatthias Springer target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp, 179*2b5b3cf6SMatthias Springer IterateOp>(); 180*2b5b3cf6SMatthias Springer target.addLegalOp<UnrealizedConversionCastOp>(); 181c42bbda4SPeiming Liu populateLowerSparseIterationToSCFPatterns(converter, patterns); 182c42bbda4SPeiming Liu 183*2b5b3cf6SMatthias Springer if (failed(applyPartialConversion(getOperation(), target, 184c42bbda4SPeiming Liu std::move(patterns)))) 185c42bbda4SPeiming Liu signalPassFailure(); 186c42bbda4SPeiming Liu } 187c42bbda4SPeiming Liu }; 188c42bbda4SPeiming Liu 189a2c9d4bbSAart Bik struct SparseTensorConversionPass 19067d0d7acSMichele Scuttari : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> { 191c7e24db4Swren romano SparseTensorConversionPass() = default; 192c7e24db4Swren romano SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; 193c7e24db4Swren romano 194a2c9d4bbSAart Bik void runOnOperation() override { 195a2c9d4bbSAart Bik auto *ctx = &getContext(); 19696a23911SAart Bik RewritePatternSet patterns(ctx); 19786b22d31SAart Bik SparseTensorTypeToPtrConverter converter; 198a2c9d4bbSAart Bik ConversionTarget target(*ctx); 1991b15160eSAart Bik // Everything in the sparse dialect must go! 2001b15160eSAart Bik target.addIllegalDialect<SparseTensorDialect>(); 201fde04aeeSAart Bik // All dynamic rules below accept new function, call, return, and various 202fde04aeeSAart Bik // tensor and bufferization operations as legal output of the rewriting 203fde04aeeSAart Bik // provided that all sparse tensor types have been fully rewritten. 20458ceae95SRiver Riddle target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 2054a3460a7SRiver Riddle return converter.isSignatureLegal(op.getFunctionType()); 2064a3460a7SRiver Riddle }); 20723aa5a74SRiver Riddle target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 20896a23911SAart Bik return converter.isSignatureLegal(op.getCalleeType()); 20996a23911SAart Bik }); 21023aa5a74SRiver Riddle target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 21123aa5a74SRiver Riddle return converter.isLegal(op.getOperandTypes()); 21223aa5a74SRiver Riddle }); 213236a9080SAart Bik target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 214236a9080SAart Bik return converter.isLegal(op.getOperandTypes()); 215236a9080SAart Bik }); 2161b15160eSAart Bik target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 217136d746eSJacques Pienaar return converter.isLegal(op.getSource().getType()) && 218136d746eSJacques Pienaar converter.isLegal(op.getDest().getType()); 2196d8e2f1eSAart Bik }); 2206d8e2f1eSAart Bik target.addDynamicallyLegalOp<tensor::ExpandShapeOp>( 2216d8e2f1eSAart Bik [&](tensor::ExpandShapeOp op) { 222136d746eSJacques Pienaar return converter.isLegal(op.getSrc().getType()) && 223136d746eSJacques Pienaar converter.isLegal(op.getResult().getType()); 2246d8e2f1eSAart Bik }); 2256d8e2f1eSAart Bik target.addDynamicallyLegalOp<tensor::CollapseShapeOp>( 2266d8e2f1eSAart Bik [&](tensor::CollapseShapeOp op) { 227136d746eSJacques Pienaar return converter.isLegal(op.getSrc().getType()) && 228136d746eSJacques Pienaar converter.isLegal(op.getResult().getType()); 2291b15160eSAart Bik }); 230fde04aeeSAart Bik target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( 231fde04aeeSAart Bik [&](bufferization::AllocTensorOp op) { 232fde04aeeSAart Bik return converter.isLegal(op.getType()); 233fde04aeeSAart Bik }); 23427a431f5SMatthias Springer target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>( 23527a431f5SMatthias Springer [&](bufferization::DeallocTensorOp op) { 23627a431f5SMatthias Springer return converter.isLegal(op.getTensor().getType()); 23727a431f5SMatthias Springer }); 238236a9080SAart Bik // The following operations and dialects may be introduced by the 239236a9080SAart Bik // rewriting rules, and are therefore marked as legal. 2402ddfacd9SAart Bik target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp, 241af8428c0SAart Bik linalg::YieldOp, tensor::ExtractOp, 242af8428c0SAart Bik tensor::FromElementsOp>(); 243faa00c13SAart Bik target.addLegalDialect< 244abc362a1SJakub Kuderski arith::ArithDialect, bufferization::BufferizationDialect, 245faa00c13SAart Bik LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>(); 246f248d0b2SPeiming Liu 247236a9080SAart Bik // Populate with rules and apply rewriting rules. 24858ceae95SRiver Riddle populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 2497ceffae1SRiver Riddle converter); 25096a23911SAart Bik populateCallOpTypeConversionPattern(patterns, converter); 251bf59cd32SPeiming Liu scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, 252bf59cd32SPeiming Liu target); 253f248d0b2SPeiming Liu populateSparseTensorConversionPatterns(converter, patterns); 254a2c9d4bbSAart Bik if (failed(applyPartialConversion(getOperation(), target, 25596a23911SAart Bik std::move(patterns)))) 256a2c9d4bbSAart Bik signalPassFailure(); 257a2c9d4bbSAart Bik } 258a2c9d4bbSAart Bik }; 259a2c9d4bbSAart Bik 26086b22d31SAart Bik struct SparseTensorCodegenPass 26167d0d7acSMichele Scuttari : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> { 26286b22d31SAart Bik SparseTensorCodegenPass() = default; 26386b22d31SAart Bik SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; 264c44d307cSPeiming Liu SparseTensorCodegenPass(bool createDeallocs, bool enableInit) { 265c44d307cSPeiming Liu createSparseDeallocs = createDeallocs; 2667276b643Sbixia1 enableBufferInitialization = enableInit; 2677276b643Sbixia1 } 26886b22d31SAart Bik 26986b22d31SAart Bik void runOnOperation() override { 27086b22d31SAart Bik auto *ctx = &getContext(); 27186b22d31SAart Bik RewritePatternSet patterns(ctx); 27286b22d31SAart Bik SparseTensorTypeToBufferConverter converter; 27386b22d31SAart Bik ConversionTarget target(*ctx); 2744d068619SAart Bik // Most ops in the sparse dialect must go! 27586b22d31SAart Bik target.addIllegalDialect<SparseTensorDialect>(); 2760083f833SPeiming Liu target.addLegalOp<SortOp>(); 277654bbbdeSbixia1 target.addLegalOp<PushBackOp>(); 278083ddffeSPeiming Liu // Storage specifier outlives sparse tensor pipeline. 279083ddffeSPeiming Liu target.addLegalOp<GetStorageSpecifierOp>(); 280083ddffeSPeiming Liu target.addLegalOp<SetStorageSpecifierOp>(); 281083ddffeSPeiming Liu target.addLegalOp<StorageSpecifierInitOp>(); 282098f46dcSPeiming Liu // Note that tensor::FromElementsOp might be yield after lowering unpack. 283098f46dcSPeiming Liu target.addLegalOp<tensor::FromElementsOp>(); 284083ddffeSPeiming Liu // All dynamic rules below accept new function, call, return, and 285083ddffeSPeiming Liu // various tensor and bufferization operations as legal output of the 286083ddffeSPeiming Liu // rewriting provided that all sparse tensor types have been fully 287083ddffeSPeiming Liu // rewritten. 28886b22d31SAart Bik target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 28986b22d31SAart Bik return converter.isSignatureLegal(op.getFunctionType()); 29086b22d31SAart Bik }); 29186b22d31SAart Bik target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 29286b22d31SAart Bik return converter.isSignatureLegal(op.getCalleeType()); 29386b22d31SAart Bik }); 29486b22d31SAart Bik target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 29586b22d31SAart Bik return converter.isLegal(op.getOperandTypes()); 29686b22d31SAart Bik }); 2970c7abd39SAart Bik target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( 2980c7abd39SAart Bik [&](bufferization::AllocTensorOp op) { 2990c7abd39SAart Bik return converter.isLegal(op.getType()); 3000c7abd39SAart Bik }); 3012ddfacd9SAart Bik target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>( 3022ddfacd9SAart Bik [&](bufferization::DeallocTensorOp op) { 3032ddfacd9SAart Bik return converter.isLegal(op.getTensor().getType()); 3042ddfacd9SAart Bik }); 3058a583bd5Sbixia1 // The following operations and dialects may be introduced by the 3068a583bd5Sbixia1 // codegen rules, and are therefore marked as legal. 307a4c47055SMatthias Springer target.addLegalOp<linalg::FillOp, linalg::YieldOp>(); 308ea4be70cSbixia1 target.addLegalDialect< 309ea4be70cSbixia1 arith::ArithDialect, bufferization::BufferizationDialect, 310ea4be70cSbixia1 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); 311928b5b06SPeiming Liu target.addLegalOp<UnrealizedConversionCastOp>(); 312ca01c996SPeiming Liu // Populate with rules and apply rewriting rules. 313ca01c996SPeiming Liu populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 314ca01c996SPeiming Liu converter); 315ca01c996SPeiming Liu scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, 316ca01c996SPeiming Liu target); 317c44d307cSPeiming Liu populateSparseTensorCodegenPatterns( 318c44d307cSPeiming Liu converter, patterns, createSparseDeallocs, enableBufferInitialization); 319ca01c996SPeiming Liu if (failed(applyPartialConversion(getOperation(), target, 320ca01c996SPeiming Liu std::move(patterns)))) 321ca01c996SPeiming Liu signalPassFailure(); 322ca01c996SPeiming Liu } 323ca01c996SPeiming Liu }; 324ca01c996SPeiming Liu 325062e515bSbixia1 struct SparseBufferRewritePass 326062e515bSbixia1 : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> { 327062e515bSbixia1 SparseBufferRewritePass() = default; 328062e515bSbixia1 SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default; 3295618d2beSbixia1 SparseBufferRewritePass(bool enableInit) { 3305618d2beSbixia1 enableBufferInitialization = enableInit; 3315618d2beSbixia1 } 332062e515bSbixia1 333062e515bSbixia1 void runOnOperation() override { 334062e515bSbixia1 auto *ctx = &getContext(); 335062e515bSbixia1 RewritePatternSet patterns(ctx); 3365618d2beSbixia1 populateSparseBufferRewriting(patterns, enableBufferInitialization); 33709dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 338062e515bSbixia1 } 339062e515bSbixia1 }; 340062e515bSbixia1 34199b3849dSAart Bik struct SparseVectorizationPass 34299b3849dSAart Bik : public impl::SparseVectorizationBase<SparseVectorizationPass> { 34399b3849dSAart Bik SparseVectorizationPass() = default; 34499b3849dSAart Bik SparseVectorizationPass(const SparseVectorizationPass &pass) = default; 34599b3849dSAart Bik SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) { 34699b3849dSAart Bik vectorLength = vl; 34799b3849dSAart Bik enableVLAVectorization = vla; 34899b3849dSAart Bik enableSIMDIndex32 = sidx32; 34999b3849dSAart Bik } 35099b3849dSAart Bik 35199b3849dSAart Bik void runOnOperation() override { 35214c0317fSAart Bik if (vectorLength == 0) 35314c0317fSAart Bik return signalPassFailure(); 35499b3849dSAart Bik auto *ctx = &getContext(); 35599b3849dSAart Bik RewritePatternSet patterns(ctx); 35699b3849dSAart Bik populateSparseVectorizationPatterns( 35799b3849dSAart Bik patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32); 35899b3849dSAart Bik vector::populateVectorToVectorCanonicalizationPatterns(patterns); 35909dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 36099b3849dSAart Bik } 36199b3849dSAart Bik }; 36299b3849dSAart Bik 36319466ebcSAart Bik struct SparseGPUCodegenPass 36419466ebcSAart Bik : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> { 36519466ebcSAart Bik SparseGPUCodegenPass() = default; 36619466ebcSAart Bik SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default; 3675f32bcfbSAart Bik SparseGPUCodegenPass(unsigned nT, bool enableRT) { 3685f32bcfbSAart Bik numThreads = nT; 3695f32bcfbSAart Bik enableRuntimeLibrary = enableRT; 3705f32bcfbSAart Bik } 37119466ebcSAart Bik 37219466ebcSAart Bik void runOnOperation() override { 37319466ebcSAart Bik auto *ctx = &getContext(); 37419466ebcSAart Bik RewritePatternSet patterns(ctx); 3755f32bcfbSAart Bik if (numThreads == 0) 3765f32bcfbSAart Bik populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); 3775f32bcfbSAart Bik else 37819466ebcSAart Bik populateSparseGPUCodegenPatterns(patterns, numThreads); 37909dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 38019466ebcSAart Bik } 38119466ebcSAart Bik }; 38219466ebcSAart Bik 383083ddffeSPeiming Liu struct StorageSpecifierToLLVMPass 384083ddffeSPeiming Liu : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> { 385083ddffeSPeiming Liu StorageSpecifierToLLVMPass() = default; 386083ddffeSPeiming Liu 387083ddffeSPeiming Liu void runOnOperation() override { 388083ddffeSPeiming Liu auto *ctx = &getContext(); 389083ddffeSPeiming Liu ConversionTarget target(*ctx); 390083ddffeSPeiming Liu RewritePatternSet patterns(ctx); 391083ddffeSPeiming Liu StorageSpecifierToLLVMTypeConverter converter; 392083ddffeSPeiming Liu 393083ddffeSPeiming Liu // All ops in the sparse dialect must go! 394083ddffeSPeiming Liu target.addIllegalDialect<SparseTensorDialect>(); 395083ddffeSPeiming Liu target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 396083ddffeSPeiming Liu return converter.isSignatureLegal(op.getFunctionType()); 397083ddffeSPeiming Liu }); 398083ddffeSPeiming Liu target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 399083ddffeSPeiming Liu return converter.isSignatureLegal(op.getCalleeType()); 400083ddffeSPeiming Liu }); 401083ddffeSPeiming Liu target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 402083ddffeSPeiming Liu return converter.isLegal(op.getOperandTypes()); 403083ddffeSPeiming Liu }); 404083ddffeSPeiming Liu target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>(); 405083ddffeSPeiming Liu 406083ddffeSPeiming Liu populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 407083ddffeSPeiming Liu converter); 408083ddffeSPeiming Liu populateCallOpTypeConversionPattern(patterns, converter); 409083ddffeSPeiming Liu populateBranchOpInterfaceTypeConversionPattern(patterns, converter); 410083ddffeSPeiming Liu populateReturnOpTypeConversionPattern(patterns, converter); 411083ddffeSPeiming Liu scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, 412083ddffeSPeiming Liu target); 413083ddffeSPeiming Liu populateStorageSpecifierToLLVMPatterns(converter, patterns); 414083ddffeSPeiming Liu if (failed(applyPartialConversion(getOperation(), target, 415083ddffeSPeiming Liu std::move(patterns)))) 416083ddffeSPeiming Liu signalPassFailure(); 417083ddffeSPeiming Liu } 418083ddffeSPeiming Liu }; 419083ddffeSPeiming Liu 420be0a7e9fSMehdi Amini } // namespace 421a2c9d4bbSAart Bik 42286b22d31SAart Bik //===----------------------------------------------------------------------===// 42386b22d31SAart Bik // Pass creation methods. 42486b22d31SAart Bik //===----------------------------------------------------------------------===// 42586b22d31SAart Bik 42633b463adSAart Bik std::unique_ptr<Pass> mlir::createSparseAssembler() { 42733b463adSAart Bik return std::make_unique<SparseAssembler>(); 42833b463adSAart Bik } 42933b463adSAart Bik 4307cfac1beSAart Bik std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() { 4317cfac1beSAart Bik return std::make_unique<SparseReinterpretMap>(); 4327cfac1beSAart Bik } 4337cfac1beSAart Bik 4346a93da99SPeiming Liu std::unique_ptr<Pass> 4356a93da99SPeiming Liu mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) { 4366a93da99SPeiming Liu SparseReinterpretMapOptions options; 4376a93da99SPeiming Liu options.scope = scope; 4386a93da99SPeiming Liu return std::make_unique<SparseReinterpretMap>(options); 4396a93da99SPeiming Liu } 4406a93da99SPeiming Liu 441f81f0cb7Sbixia1 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() { 442f81f0cb7Sbixia1 return std::make_unique<PreSparsificationRewritePass>(); 443779dcd2eSAart Bik } 444779dcd2eSAart Bik 445a2c9d4bbSAart Bik std::unique_ptr<Pass> mlir::createSparsificationPass() { 446a2c9d4bbSAart Bik return std::make_unique<SparsificationPass>(); 447a2c9d4bbSAart Bik } 448a2c9d4bbSAart Bik 449b85ed4e0Swren romano std::unique_ptr<Pass> 450b85ed4e0Swren romano mlir::createSparsificationPass(const SparsificationOptions &options) { 451b85ed4e0Swren romano return std::make_unique<SparsificationPass>(options); 452b85ed4e0Swren romano } 453b85ed4e0Swren romano 45406374400SPeiming Liu std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() { 45506374400SPeiming Liu return std::make_unique<StageSparseOperationsPass>(); 45606374400SPeiming Liu } 45706374400SPeiming Liu 458f82bee13SPeiming Liu std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() { 459f82bee13SPeiming Liu return std::make_unique<LowerSparseOpsToForeachPass>(); 460f81f0cb7Sbixia1 } 461f81f0cb7Sbixia1 462f81f0cb7Sbixia1 std::unique_ptr<Pass> 463f82bee13SPeiming Liu mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) { 464f82bee13SPeiming Liu return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert); 465f82bee13SPeiming Liu } 466f82bee13SPeiming Liu 467f82bee13SPeiming Liu std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() { 468f82bee13SPeiming Liu return std::make_unique<LowerForeachToSCFPass>(); 469f81f0cb7Sbixia1 } 470f81f0cb7Sbixia1 471c42bbda4SPeiming Liu std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() { 472c42bbda4SPeiming Liu return std::make_unique<LowerSparseIterationToSCFPass>(); 473c42bbda4SPeiming Liu } 474c42bbda4SPeiming Liu 475a2c9d4bbSAart Bik std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 476a2c9d4bbSAart Bik return std::make_unique<SparseTensorConversionPass>(); 477a2c9d4bbSAart Bik } 478c7e24db4Swren romano 479d1da6f23SAart Bik std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() { 480d1da6f23SAart Bik return std::make_unique<SparseTensorCodegenPass>(); 481d1da6f23SAart Bik } 482d1da6f23SAart Bik 4837276b643Sbixia1 std::unique_ptr<Pass> 484c44d307cSPeiming Liu mlir::createSparseTensorCodegenPass(bool createSparseDeallocs, 485c44d307cSPeiming Liu bool enableBufferInitialization) { 486c44d307cSPeiming Liu return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs, 487c44d307cSPeiming Liu enableBufferInitialization); 48886b22d31SAart Bik } 489062e515bSbixia1 490d1da6f23SAart Bik std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() { 491d1da6f23SAart Bik return std::make_unique<SparseBufferRewritePass>(); 492d1da6f23SAart Bik } 493d1da6f23SAart Bik 4945618d2beSbixia1 std::unique_ptr<Pass> 4955618d2beSbixia1 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { 4965618d2beSbixia1 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization); 497062e515bSbixia1 } 49899b3849dSAart Bik 49999b3849dSAart Bik std::unique_ptr<Pass> mlir::createSparseVectorizationPass() { 50099b3849dSAart Bik return std::make_unique<SparseVectorizationPass>(); 50199b3849dSAart Bik } 50299b3849dSAart Bik 50399b3849dSAart Bik std::unique_ptr<Pass> 50499b3849dSAart Bik mlir::createSparseVectorizationPass(unsigned vectorLength, 50599b3849dSAart Bik bool enableVLAVectorization, 50699b3849dSAart Bik bool enableSIMDIndex32) { 50799b3849dSAart Bik return std::make_unique<SparseVectorizationPass>( 50899b3849dSAart Bik vectorLength, enableVLAVectorization, enableSIMDIndex32); 50999b3849dSAart Bik } 510083ddffeSPeiming Liu 51119466ebcSAart Bik std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() { 51219466ebcSAart Bik return std::make_unique<SparseGPUCodegenPass>(); 51319466ebcSAart Bik } 51419466ebcSAart Bik 5155f32bcfbSAart Bik std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads, 5165f32bcfbSAart Bik bool enableRT) { 5175f32bcfbSAart Bik return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT); 51819466ebcSAart Bik } 51919466ebcSAart Bik 520083ddffeSPeiming Liu std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() { 521083ddffeSPeiming Liu return std::make_unique<StorageSpecifierToLLVMPass>(); 522083ddffeSPeiming Liu } 523