1 //===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arith/IR/Arith.h" 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Complex/IR/Complex.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 20 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 24 namespace mlir { 25 #define GEN_PASS_DEF_SPARSEASSEMBLER 26 #define GEN_PASS_DEF_SPARSEREINTERPRETMAP 27 #define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE 28 #define GEN_PASS_DEF_SPARSIFICATIONPASS 29 #define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF 30 #define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH 31 #define GEN_PASS_DEF_LOWERFOREACHTOSCF 32 #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS 33 #define GEN_PASS_DEF_SPARSETENSORCODEGEN 34 #define GEN_PASS_DEF_SPARSEBUFFERREWRITE 35 #define GEN_PASS_DEF_SPARSEVECTORIZATION 36 #define GEN_PASS_DEF_SPARSEGPUCODEGEN 37 #define GEN_PASS_DEF_STAGESPARSEOPERATIONS 38 #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM 39 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 40 } // namespace mlir 41 42 using namespace mlir; 43 using namespace mlir::sparse_tensor; 44 45 namespace { 46 47 //===----------------------------------------------------------------------===// 48 // Passes implementation. 49 //===----------------------------------------------------------------------===// 50 51 struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> { 52 SparseAssembler() = default; 53 SparseAssembler(const SparseAssembler &pass) = default; 54 SparseAssembler(bool dO) { directOut = dO; } 55 56 void runOnOperation() override { 57 auto *ctx = &getContext(); 58 RewritePatternSet patterns(ctx); 59 populateSparseAssembler(patterns, directOut); 60 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 61 } 62 }; 63 64 struct SparseReinterpretMap 65 : public impl::SparseReinterpretMapBase<SparseReinterpretMap> { 66 SparseReinterpretMap() = default; 67 SparseReinterpretMap(const SparseReinterpretMap &pass) = default; 68 SparseReinterpretMap(const SparseReinterpretMapOptions &options) { 69 scope = options.scope; 70 } 71 72 void runOnOperation() override { 73 auto *ctx = &getContext(); 74 RewritePatternSet patterns(ctx); 75 populateSparseReinterpretMap(patterns, scope); 76 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 77 } 78 }; 79 80 struct PreSparsificationRewritePass 81 : public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> { 82 PreSparsificationRewritePass() = default; 83 PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) = 84 default; 85 86 void runOnOperation() override { 87 auto *ctx = &getContext(); 88 RewritePatternSet patterns(ctx); 89 populatePreSparsificationRewriting(patterns); 90 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 91 } 92 }; 93 94 struct SparsificationPass 95 : public impl::SparsificationPassBase<SparsificationPass> { 96 SparsificationPass() = default; 97 SparsificationPass(const SparsificationPass &pass) = default; 98 SparsificationPass(const SparsificationOptions &options) { 99 parallelization = options.parallelizationStrategy; 100 sparseEmitStrategy = options.sparseEmitStrategy; 101 enableRuntimeLibrary = options.enableRuntimeLibrary; 102 } 103 104 void runOnOperation() override { 105 auto *ctx = &getContext(); 106 // Translate strategy flags to strategy options. 107 SparsificationOptions options(parallelization, sparseEmitStrategy, 108 enableRuntimeLibrary); 109 // Apply sparsification and cleanup rewriting. 110 RewritePatternSet patterns(ctx); 111 populateSparsificationPatterns(patterns, options); 112 scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 113 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 114 } 115 }; 116 117 struct StageSparseOperationsPass 118 : public impl::StageSparseOperationsBase<StageSparseOperationsPass> { 119 StageSparseOperationsPass() = default; 120 StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default; 121 void runOnOperation() override { 122 auto *ctx = &getContext(); 123 RewritePatternSet patterns(ctx); 124 populateStageSparseOperationsPatterns(patterns); 125 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 126 } 127 }; 128 129 struct LowerSparseOpsToForeachPass 130 : public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> { 131 LowerSparseOpsToForeachPass() = default; 132 LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) = 133 default; 134 LowerSparseOpsToForeachPass(bool enableRT, bool convert) { 135 enableRuntimeLibrary = enableRT; 136 enableConvert = convert; 137 } 138 139 void runOnOperation() override { 140 auto *ctx = &getContext(); 141 RewritePatternSet patterns(ctx); 142 populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary, 143 enableConvert); 144 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 145 } 146 }; 147 148 struct LowerForeachToSCFPass 149 : public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> { 150 LowerForeachToSCFPass() = default; 151 LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default; 152 153 void runOnOperation() override { 154 auto *ctx = &getContext(); 155 RewritePatternSet patterns(ctx); 156 populateLowerForeachToSCFPatterns(patterns); 157 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 158 } 159 }; 160 161 struct LowerSparseIterationToSCFPass 162 : public impl::LowerSparseIterationToSCFBase< 163 LowerSparseIterationToSCFPass> { 164 LowerSparseIterationToSCFPass() = default; 165 LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) = 166 default; 167 168 void runOnOperation() override { 169 auto *ctx = &getContext(); 170 RewritePatternSet patterns(ctx); 171 SparseIterationTypeConverter converter; 172 ConversionTarget target(*ctx); 173 174 // The actual conversion. 175 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect, 176 memref::MemRefDialect, scf::SCFDialect, 177 sparse_tensor::SparseTensorDialect>(); 178 target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp, 179 IterateOp>(); 180 target.addLegalOp<UnrealizedConversionCastOp>(); 181 populateLowerSparseIterationToSCFPatterns(converter, patterns); 182 183 if (failed(applyPartialConversion(getOperation(), target, 184 std::move(patterns)))) 185 signalPassFailure(); 186 } 187 }; 188 189 struct SparseTensorConversionPass 190 : public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> { 191 SparseTensorConversionPass() = default; 192 SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; 193 194 void runOnOperation() override { 195 auto *ctx = &getContext(); 196 RewritePatternSet patterns(ctx); 197 SparseTensorTypeToPtrConverter converter; 198 ConversionTarget target(*ctx); 199 // Everything in the sparse dialect must go! 200 target.addIllegalDialect<SparseTensorDialect>(); 201 // All dynamic rules below accept new function, call, return, and various 202 // tensor and bufferization operations as legal output of the rewriting 203 // provided that all sparse tensor types have been fully rewritten. 204 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 205 return converter.isSignatureLegal(op.getFunctionType()); 206 }); 207 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 208 return converter.isSignatureLegal(op.getCalleeType()); 209 }); 210 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 211 return converter.isLegal(op.getOperandTypes()); 212 }); 213 target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) { 214 return converter.isLegal(op.getOperandTypes()); 215 }); 216 target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) { 217 return converter.isLegal(op.getSource().getType()) && 218 converter.isLegal(op.getDest().getType()); 219 }); 220 target.addDynamicallyLegalOp<tensor::ExpandShapeOp>( 221 [&](tensor::ExpandShapeOp op) { 222 return converter.isLegal(op.getSrc().getType()) && 223 converter.isLegal(op.getResult().getType()); 224 }); 225 target.addDynamicallyLegalOp<tensor::CollapseShapeOp>( 226 [&](tensor::CollapseShapeOp op) { 227 return converter.isLegal(op.getSrc().getType()) && 228 converter.isLegal(op.getResult().getType()); 229 }); 230 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( 231 [&](bufferization::AllocTensorOp op) { 232 return converter.isLegal(op.getType()); 233 }); 234 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>( 235 [&](bufferization::DeallocTensorOp op) { 236 return converter.isLegal(op.getTensor().getType()); 237 }); 238 // The following operations and dialects may be introduced by the 239 // rewriting rules, and are therefore marked as legal. 240 target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp, 241 linalg::YieldOp, tensor::ExtractOp, 242 tensor::FromElementsOp>(); 243 target.addLegalDialect< 244 arith::ArithDialect, bufferization::BufferizationDialect, 245 LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>(); 246 247 // Populate with rules and apply rewriting rules. 248 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 249 converter); 250 populateCallOpTypeConversionPattern(patterns, converter); 251 scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, 252 target); 253 populateSparseTensorConversionPatterns(converter, patterns); 254 if (failed(applyPartialConversion(getOperation(), target, 255 std::move(patterns)))) 256 signalPassFailure(); 257 } 258 }; 259 260 struct SparseTensorCodegenPass 261 : public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> { 262 SparseTensorCodegenPass() = default; 263 SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default; 264 SparseTensorCodegenPass(bool createDeallocs, bool enableInit) { 265 createSparseDeallocs = createDeallocs; 266 enableBufferInitialization = enableInit; 267 } 268 269 void runOnOperation() override { 270 auto *ctx = &getContext(); 271 RewritePatternSet patterns(ctx); 272 SparseTensorTypeToBufferConverter converter; 273 ConversionTarget target(*ctx); 274 // Most ops in the sparse dialect must go! 275 target.addIllegalDialect<SparseTensorDialect>(); 276 target.addLegalOp<SortOp>(); 277 target.addLegalOp<PushBackOp>(); 278 // Storage specifier outlives sparse tensor pipeline. 279 target.addLegalOp<GetStorageSpecifierOp>(); 280 target.addLegalOp<SetStorageSpecifierOp>(); 281 target.addLegalOp<StorageSpecifierInitOp>(); 282 // Note that tensor::FromElementsOp might be yield after lowering unpack. 283 target.addLegalOp<tensor::FromElementsOp>(); 284 // All dynamic rules below accept new function, call, return, and 285 // various tensor and bufferization operations as legal output of the 286 // rewriting provided that all sparse tensor types have been fully 287 // rewritten. 288 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 289 return converter.isSignatureLegal(op.getFunctionType()); 290 }); 291 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 292 return converter.isSignatureLegal(op.getCalleeType()); 293 }); 294 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 295 return converter.isLegal(op.getOperandTypes()); 296 }); 297 target.addDynamicallyLegalOp<bufferization::AllocTensorOp>( 298 [&](bufferization::AllocTensorOp op) { 299 return converter.isLegal(op.getType()); 300 }); 301 target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>( 302 [&](bufferization::DeallocTensorOp op) { 303 return converter.isLegal(op.getTensor().getType()); 304 }); 305 // The following operations and dialects may be introduced by the 306 // codegen rules, and are therefore marked as legal. 307 target.addLegalOp<linalg::FillOp, linalg::YieldOp>(); 308 target.addLegalDialect< 309 arith::ArithDialect, bufferization::BufferizationDialect, 310 complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); 311 target.addLegalOp<UnrealizedConversionCastOp>(); 312 // Populate with rules and apply rewriting rules. 313 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 314 converter); 315 scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, 316 target); 317 populateSparseTensorCodegenPatterns( 318 converter, patterns, createSparseDeallocs, enableBufferInitialization); 319 if (failed(applyPartialConversion(getOperation(), target, 320 std::move(patterns)))) 321 signalPassFailure(); 322 } 323 }; 324 325 struct SparseBufferRewritePass 326 : public impl::SparseBufferRewriteBase<SparseBufferRewritePass> { 327 SparseBufferRewritePass() = default; 328 SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default; 329 SparseBufferRewritePass(bool enableInit) { 330 enableBufferInitialization = enableInit; 331 } 332 333 void runOnOperation() override { 334 auto *ctx = &getContext(); 335 RewritePatternSet patterns(ctx); 336 populateSparseBufferRewriting(patterns, enableBufferInitialization); 337 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 338 } 339 }; 340 341 struct SparseVectorizationPass 342 : public impl::SparseVectorizationBase<SparseVectorizationPass> { 343 SparseVectorizationPass() = default; 344 SparseVectorizationPass(const SparseVectorizationPass &pass) = default; 345 SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) { 346 vectorLength = vl; 347 enableVLAVectorization = vla; 348 enableSIMDIndex32 = sidx32; 349 } 350 351 void runOnOperation() override { 352 if (vectorLength == 0) 353 return signalPassFailure(); 354 auto *ctx = &getContext(); 355 RewritePatternSet patterns(ctx); 356 populateSparseVectorizationPatterns( 357 patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32); 358 vector::populateVectorToVectorCanonicalizationPatterns(patterns); 359 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 360 } 361 }; 362 363 struct SparseGPUCodegenPass 364 : public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> { 365 SparseGPUCodegenPass() = default; 366 SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default; 367 SparseGPUCodegenPass(unsigned nT, bool enableRT) { 368 numThreads = nT; 369 enableRuntimeLibrary = enableRT; 370 } 371 372 void runOnOperation() override { 373 auto *ctx = &getContext(); 374 RewritePatternSet patterns(ctx); 375 if (numThreads == 0) 376 populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); 377 else 378 populateSparseGPUCodegenPatterns(patterns, numThreads); 379 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 380 } 381 }; 382 383 struct StorageSpecifierToLLVMPass 384 : public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> { 385 StorageSpecifierToLLVMPass() = default; 386 387 void runOnOperation() override { 388 auto *ctx = &getContext(); 389 ConversionTarget target(*ctx); 390 RewritePatternSet patterns(ctx); 391 StorageSpecifierToLLVMTypeConverter converter; 392 393 // All ops in the sparse dialect must go! 394 target.addIllegalDialect<SparseTensorDialect>(); 395 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 396 return converter.isSignatureLegal(op.getFunctionType()); 397 }); 398 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 399 return converter.isSignatureLegal(op.getCalleeType()); 400 }); 401 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 402 return converter.isLegal(op.getOperandTypes()); 403 }); 404 target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>(); 405 406 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 407 converter); 408 populateCallOpTypeConversionPattern(patterns, converter); 409 populateBranchOpInterfaceTypeConversionPattern(patterns, converter); 410 populateReturnOpTypeConversionPattern(patterns, converter); 411 scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, 412 target); 413 populateStorageSpecifierToLLVMPatterns(converter, patterns); 414 if (failed(applyPartialConversion(getOperation(), target, 415 std::move(patterns)))) 416 signalPassFailure(); 417 } 418 }; 419 420 } // namespace 421 422 //===----------------------------------------------------------------------===// 423 // Pass creation methods. 424 //===----------------------------------------------------------------------===// 425 426 std::unique_ptr<Pass> mlir::createSparseAssembler() { 427 return std::make_unique<SparseAssembler>(); 428 } 429 430 std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() { 431 return std::make_unique<SparseReinterpretMap>(); 432 } 433 434 std::unique_ptr<Pass> 435 mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) { 436 SparseReinterpretMapOptions options; 437 options.scope = scope; 438 return std::make_unique<SparseReinterpretMap>(options); 439 } 440 441 std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() { 442 return std::make_unique<PreSparsificationRewritePass>(); 443 } 444 445 std::unique_ptr<Pass> mlir::createSparsificationPass() { 446 return std::make_unique<SparsificationPass>(); 447 } 448 449 std::unique_ptr<Pass> 450 mlir::createSparsificationPass(const SparsificationOptions &options) { 451 return std::make_unique<SparsificationPass>(options); 452 } 453 454 std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() { 455 return std::make_unique<StageSparseOperationsPass>(); 456 } 457 458 std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() { 459 return std::make_unique<LowerSparseOpsToForeachPass>(); 460 } 461 462 std::unique_ptr<Pass> 463 mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) { 464 return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert); 465 } 466 467 std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() { 468 return std::make_unique<LowerForeachToSCFPass>(); 469 } 470 471 std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() { 472 return std::make_unique<LowerSparseIterationToSCFPass>(); 473 } 474 475 std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() { 476 return std::make_unique<SparseTensorConversionPass>(); 477 } 478 479 std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() { 480 return std::make_unique<SparseTensorCodegenPass>(); 481 } 482 483 std::unique_ptr<Pass> 484 mlir::createSparseTensorCodegenPass(bool createSparseDeallocs, 485 bool enableBufferInitialization) { 486 return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs, 487 enableBufferInitialization); 488 } 489 490 std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() { 491 return std::make_unique<SparseBufferRewritePass>(); 492 } 493 494 std::unique_ptr<Pass> 495 mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { 496 return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization); 497 } 498 499 std::unique_ptr<Pass> mlir::createSparseVectorizationPass() { 500 return std::make_unique<SparseVectorizationPass>(); 501 } 502 503 std::unique_ptr<Pass> 504 mlir::createSparseVectorizationPass(unsigned vectorLength, 505 bool enableVLAVectorization, 506 bool enableSIMDIndex32) { 507 return std::make_unique<SparseVectorizationPass>( 508 vectorLength, enableVLAVectorization, enableSIMDIndex32); 509 } 510 511 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() { 512 return std::make_unique<SparseGPUCodegenPass>(); 513 } 514 515 std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads, 516 bool enableRT) { 517 return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT); 518 } 519 520 std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() { 521 return std::make_unique<StorageSpecifierToLLVMPass>(); 522 } 523