xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp (revision eb6c4197d5263ed2e086925b2b2f032a19442d2b)
1b85ed4e0Swren romano //===- SparseTensorPipelines.cpp - Pipelines for sparse tensor code -------===//
2b85ed4e0Swren romano //
3b85ed4e0Swren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b85ed4e0Swren romano // See https://llvm.org/LICENSE.txt for license information.
5b85ed4e0Swren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b85ed4e0Swren romano //
7b85ed4e0Swren romano //===----------------------------------------------------------------------===//
8b85ed4e0Swren romano 
9b85ed4e0Swren romano #include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
10b85ed4e0Swren romano 
11b03a09e7SMatthias Springer #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12*eb6c4197SMatthias Springer #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1361c1ed86SAart Bik #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
14b85ed4e0Swren romano #include "mlir/Conversion/Passes.h"
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Transforms/Passes.h"
16c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
17c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
18b85ed4e0Swren romano #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1936550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
2061c1ed86SAart Bik #include "mlir/Dialect/GPU/IR/GPUDialect.h"
2161c1ed86SAart Bik #include "mlir/Dialect/GPU/Transforms/Passes.h"
2261c1ed86SAart Bik #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
23b85ed4e0Swren romano #include "mlir/Dialect/Linalg/Passes.h"
24786cbb09SQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Passes.h"
25b85ed4e0Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26b85ed4e0Swren romano #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27b85ed4e0Swren romano #include "mlir/Pass/PassManager.h"
28eb65327fSPeiming Liu #include "mlir/Transforms/Passes.h"
29b85ed4e0Swren romano 
30b85ed4e0Swren romano //===----------------------------------------------------------------------===//
31b85ed4e0Swren romano // Pipeline implementation.
32b85ed4e0Swren romano //===----------------------------------------------------------------------===//
33b85ed4e0Swren romano 
34dce7a7cfSTim Harvey void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
35dce7a7cfSTim Harvey                                           const SparsifierOptions &options) {
3619498561SAart Bik   // Rewrite named linalg ops into generic ops and apply fusion.
371e98d488SQuinn Dawkins   pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
3819498561SAart Bik   pm.addNestedPass<func::FuncOp>(createLinalgElementwiseOpFusionPass());
395f32bcfbSAart Bik 
405f32bcfbSAart Bik   // Sparsification and bufferization mini-pipeline.
41c1fef4e8SMatthias Springer   pm.addPass(createSparsificationAndBufferizationPass(
428154494eSAart Bik       getBufferizationOptionsForSparsification(
438154494eSAart Bik           options.testBufferizationAnalysisOnly),
44f248d0b2SPeiming Liu       options.sparsificationOptions(), options.createSparseDeallocs,
45f248d0b2SPeiming Liu       options.enableRuntimeLibrary, options.enableBufferInitialization,
46f248d0b2SPeiming Liu       options.vectorLength,
4716aa4e4bSAart Bik       /*enableVLAVectorization=*/options.armSVE,
485f32bcfbSAart Bik       /*enableSIMDIndex32=*/options.force32BitVectorIndices,
491ba2768cSPeiming Liu       options.enableGPULibgen,
508f0c014bSYinying Li       options.sparsificationOptions().sparseEmitStrategy,
518f0c014bSYinying Li       options.sparsificationOptions().parallelizationStrategy));
525f32bcfbSAart Bik 
535f32bcfbSAart Bik   // Bail-early for test setup.
54c66303c2SMatthias Springer   if (options.testBufferizationAnalysisOnly)
55c66303c2SMatthias Springer     return;
5648a73bc4SMatthias Springer 
575f32bcfbSAart Bik   // Storage specifier lowering and bufferization wrap-up.
5848a73bc4SMatthias Springer   pm.addPass(createStorageSpecifierToLLVMPass());
59e5cb0ee3SPeiming Liu   pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
6061c1ed86SAart Bik 
6161c1ed86SAart Bik   // GPU code generation.
6261c1ed86SAart Bik   const bool gpuCodegen = options.gpuTriple.hasValue();
6361c1ed86SAart Bik   if (gpuCodegen) {
6461c1ed86SAart Bik     pm.addPass(createSparseGPUCodegenPass());
6561c1ed86SAart Bik     pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
6661c1ed86SAart Bik     pm.addNestedPass<gpu::GPUModuleOp>(createConvertSCFToCFPass());
67119c489cSFabian Mora     pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToNVVMOps());
6861c1ed86SAart Bik   }
6961c1ed86SAart Bik 
704241e847SAart Bik   // Progressively lower to LLVM. Note that the convert-vector-to-llvm
714241e847SAart Bik   // pass is repeated on purpose.
72c66303c2SMatthias Springer   // TODO(springerm): Add sparse support to the BufferDeallocation pass and add
73c66303c2SMatthias Springer   // it to this pipeline.
7458ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
7558ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createConvertVectorToSCFPass());
768037deb7SMartin Erhart   pm.addNestedPass<func::FuncOp>(memref::createExpandReallocPass());
77039b969bSMichele Scuttari   pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
78786cbb09SQuentin Colombet   pm.addPass(memref::createExpandStridedMetadataPass());
79039b969bSMichele Scuttari   pm.addPass(createLowerAffinePass());
80cb9267f0SHugo Trachino   pm.addPass(
81cb9267f0SHugo Trachino       createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
82cb4ccd38SQuentin Colombet   pm.addPass(createFinalizeMemRefToLLVMConversionPass());
83736c1b66SAart Bik   pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
8461c1ed86SAart Bik   pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
8558ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
865b122a73SAart Bik   pm.addPass(createConvertMathToLibmPass());
87a9e354c8SAart Bik   pm.addPass(createConvertComplexToLibmPass());
88cb9267f0SHugo Trachino   pm.addPass(
89cb9267f0SHugo Trachino       createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
90a9e354c8SAart Bik   pm.addPass(createConvertComplexToLLVMPass());
91cb9267f0SHugo Trachino   pm.addPass(
92cb9267f0SHugo Trachino       createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
935a7b9194SRiver Riddle   pm.addPass(createConvertFuncToLLVMPass());
94b03a09e7SMatthias Springer   pm.addPass(createArithToLLVMConversionPass());
95*eb6c4197SMatthias Springer   pm.addPass(createConvertControlFlowToLLVMPass());
9661c1ed86SAart Bik 
9761c1ed86SAart Bik   // Finalize GPU code generation.
9861c1ed86SAart Bik   if (gpuCodegen) {
991828deb7SFabian Mora     GpuNVVMAttachTargetOptions nvvmTargetOptions;
1001828deb7SFabian Mora     nvvmTargetOptions.triple = options.gpuTriple;
1011828deb7SFabian Mora     nvvmTargetOptions.chip = options.gpuChip;
1021828deb7SFabian Mora     nvvmTargetOptions.features = options.gpuFeatures;
1031828deb7SFabian Mora     pm.addPass(createGpuNVVMAttachTarget(nvvmTargetOptions));
10461c1ed86SAart Bik     pm.addPass(createGpuToLLVMConversionPass());
1055093413aSFabian Mora     GpuModuleToBinaryPassOptions gpuModuleToBinaryPassOptions;
1065093413aSFabian Mora     gpuModuleToBinaryPassOptions.compilationTarget = options.gpuFormat;
1075093413aSFabian Mora     pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions));
10861c1ed86SAart Bik   }
10961c1ed86SAart Bik 
1104241e847SAart Bik   // Ensure all casts are realized.
111b85ed4e0Swren romano   pm.addPass(createReconcileUnrealizedCastsPass());
112b85ed4e0Swren romano }
113b85ed4e0Swren romano 
114b85ed4e0Swren romano //===----------------------------------------------------------------------===//
115b85ed4e0Swren romano // Pipeline registration.
116b85ed4e0Swren romano //===----------------------------------------------------------------------===//
117b85ed4e0Swren romano 
118b85ed4e0Swren romano void mlir::sparse_tensor::registerSparseTensorPipelines() {
119dce7a7cfSTim Harvey   PassPipelineRegistration<SparsifierOptions>(
120dce7a7cfSTim Harvey       "sparsifier",
121b85ed4e0Swren romano       "The standard pipeline for taking sparsity-agnostic IR using the"
122b85ed4e0Swren romano       " sparse-tensor type, and lowering it to LLVM IR with concrete"
123b85ed4e0Swren romano       " representations and algorithms for sparse tensors.",
124dce7a7cfSTim Harvey       buildSparsifier);
125b85ed4e0Swren romano }
126