xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (revision 8f0c014b12663129d8bfe0cc89f06e7a1d8b48c2)
1c1fef4e8SMatthias Springer //===- SparsificationAndBufferizationPass.cpp - Tensor to Memref Lowering -===//
2c1fef4e8SMatthias Springer //
3c1fef4e8SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c1fef4e8SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5c1fef4e8SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c1fef4e8SMatthias Springer //
7c1fef4e8SMatthias Springer //===----------------------------------------------------------------------===//
8c1fef4e8SMatthias Springer 
9c1fef4e8SMatthias Springer #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10c1fef4e8SMatthias Springer 
118154494eSAart Bik #include "mlir/Dialect/Affine/IR/AffineOps.h"
12c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13a0568eabSbixia1 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16c7a9e5e5SPeiming Liu #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
17be630f07SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
18c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
19c1fef4e8SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h"
20ee42e236SAart Bik #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21006340baSPeiming Liu #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
228154494eSAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h"
238154494eSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h"
248154494eSAart Bik #include "mlir/Dialect/SCF/IR/SCF.h"
25c1fef4e8SMatthias Springer #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26c1fef4e8SMatthias Springer #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
279f3334e9SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
28c1fef4e8SMatthias Springer #include "mlir/Pass/PassManager.h"
2916aa4e4bSAart Bik #include "mlir/Transforms/Passes.h"
30c1fef4e8SMatthias Springer 
31c1fef4e8SMatthias Springer using namespace mlir;
32c1fef4e8SMatthias Springer 
33c1fef4e8SMatthias Springer namespace mlir {
348154494eSAart Bik 
358154494eSAart Bik #define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION
368154494eSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
378154494eSAart Bik 
38c1fef4e8SMatthias Springer namespace sparse_tensor {
39c1fef4e8SMatthias Springer 
40c1fef4e8SMatthias Springer /// Return `true` if one of the given types is a sparse tensor type.
41c1fef4e8SMatthias Springer static bool containsSparseTensor(TypeRange types) {
42c1fef4e8SMatthias Springer   for (Type t : types)
4348a73bc4SMatthias Springer     if (isa<TensorType>(t) && getSparseTensorEncoding(t))
44c1fef4e8SMatthias Springer       return true;
45c1fef4e8SMatthias Springer   return false;
46c1fef4e8SMatthias Springer }
47c1fef4e8SMatthias Springer 
48c1fef4e8SMatthias Springer /// A pass that lowers tensor ops to memref ops, regardless of whether they are
49c1fef4e8SMatthias Springer /// dense or sparse.
50c1fef4e8SMatthias Springer ///
51c1fef4e8SMatthias Springer /// One-Shot Analysis is used to detect RaW conflicts and to insert buffer
52c1fef4e8SMatthias Springer /// copies of the tensor level (`insertTensorCopies`). Afterwards, the lowering
53c1fef4e8SMatthias Springer /// of tensor ops to memref ops follows a different code path depending on
54c1fef4e8SMatthias Springer /// whether the op is sparse or dense:
55c1fef4e8SMatthias Springer ///
56c1fef4e8SMatthias Springer /// * Sparse tensor ops are lowered through Sparsification and follow-up pass
57c1fef4e8SMatthias Springer ///   that lowers sparse_tensor dialect ops.
58c1fef4e8SMatthias Springer /// * Dense tensor ops are lowered through BufferizableOpInterface
59c1fef4e8SMatthias Springer ///   implementations.
60c1fef4e8SMatthias Springer class SparsificationAndBufferizationPass
618154494eSAart Bik     : public impl::SparsificationAndBufferizationBase<
628154494eSAart Bik           SparsificationAndBufferizationPass> {
63c1fef4e8SMatthias Springer public:
64438a7d4cSAart Bik   // Private pass options only.
65c1fef4e8SMatthias Springer   SparsificationAndBufferizationPass(
66c1fef4e8SMatthias Springer       const bufferization::OneShotBufferizationOptions &bufferizationOptions,
67c1fef4e8SMatthias Springer       const SparsificationOptions &sparsificationOptions,
68c44d307cSPeiming Liu       bool createSparseDeallocs, bool enableRuntimeLibrary,
69438a7d4cSAart Bik       bool enableBufferInitialization)
70c1fef4e8SMatthias Springer       : bufferizationOptions(bufferizationOptions),
71c1fef4e8SMatthias Springer         sparsificationOptions(sparsificationOptions),
72c44d307cSPeiming Liu         createSparseDeallocs(createSparseDeallocs),
73c1fef4e8SMatthias Springer         enableRuntimeLibrary(enableRuntimeLibrary),
74438a7d4cSAart Bik         enableBufferInitialization(enableBufferInitialization) {}
75438a7d4cSAart Bik   // Private pass options and visible pass options.
76438a7d4cSAart Bik   SparsificationAndBufferizationPass(
77438a7d4cSAart Bik       const bufferization::OneShotBufferizationOptions &bufferizationOptions,
78438a7d4cSAart Bik       const SparsificationOptions &sparsificationOptions,
79438a7d4cSAart Bik       bool createSparseDeallocs, bool enableRuntimeLibrary,
80438a7d4cSAart Bik       bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
81*8f0c014bSYinying Li       bool gpu, SparseEmitStrategy emitStrategy,
82*8f0c014bSYinying Li       SparseParallelizationStrategy parallelizationStrategy)
83438a7d4cSAart Bik       : bufferizationOptions(bufferizationOptions),
84438a7d4cSAart Bik         sparsificationOptions(sparsificationOptions),
85438a7d4cSAart Bik         createSparseDeallocs(createSparseDeallocs),
86438a7d4cSAart Bik         enableRuntimeLibrary(enableRuntimeLibrary),
87438a7d4cSAart Bik         enableBufferInitialization(enableBufferInitialization) {
88438a7d4cSAart Bik     // Set the visible pass options explicitly.
89438a7d4cSAart Bik     vectorLength = vl;
90438a7d4cSAart Bik     enableVLAVectorization = vla;
91438a7d4cSAart Bik     enableSIMDIndex32 = index32;
92438a7d4cSAart Bik     enableGPULibgen = gpu;
931ba2768cSPeiming Liu     sparseEmitStrategy = emitStrategy;
94*8f0c014bSYinying Li     parallelization = parallelizationStrategy;
955f32bcfbSAart Bik   }
96c1fef4e8SMatthias Springer 
97c1fef4e8SMatthias Springer   /// Bufferize all dense ops. This assumes that no further analysis is needed
98c1fef4e8SMatthias Springer   /// and that all required buffer copies were already inserted by
99c1fef4e8SMatthias Springer   /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
100c1fef4e8SMatthias Springer   LogicalResult runDenseBufferization() {
1019d34c052SMatthias Springer     bufferization::OneShotBufferizationOptions updatedOptions =
1029d34c052SMatthias Springer         bufferizationOptions;
1039d34c052SMatthias Springer     // Skip all sparse ops.
1049d34c052SMatthias Springer     updatedOptions.opFilter.denyOperation([&](Operation *op) {
105c1fef4e8SMatthias Springer       if (containsSparseTensor(TypeRange(op->getResults())) ||
106c1fef4e8SMatthias Springer           containsSparseTensor(TypeRange(op->getOperands())))
1079d34c052SMatthias Springer         return true;
108c1fef4e8SMatthias Springer       if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
109c1fef4e8SMatthias Springer         FunctionType funcType = funcOp.getFunctionType();
110c1fef4e8SMatthias Springer         if (containsSparseTensor(funcType.getInputs()) ||
111c1fef4e8SMatthias Springer             containsSparseTensor(funcType.getResults()))
112c1fef4e8SMatthias Springer           return true;
1139d34c052SMatthias Springer       }
1149d34c052SMatthias Springer       return false;
115c1fef4e8SMatthias Springer     });
116c7a9e5e5SPeiming Liu 
11748a73bc4SMatthias Springer     if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
11848a73bc4SMatthias Springer                                                 updatedOptions)))
119c7a9e5e5SPeiming Liu       return failure();
120c7a9e5e5SPeiming Liu 
121c7a9e5e5SPeiming Liu     bufferization::removeBufferizationAttributesInModule(getOperation());
122c7a9e5e5SPeiming Liu     return success();
123c1fef4e8SMatthias Springer   }
124c1fef4e8SMatthias Springer 
125c1fef4e8SMatthias Springer   void runOnOperation() override {
1261ba2768cSPeiming Liu     // Overrides the default emit strategy using user-provided value.
1271ba2768cSPeiming Liu     this->sparsificationOptions.sparseEmitStrategy = sparseEmitStrategy;
1281ba2768cSPeiming Liu 
129*8f0c014bSYinying Li     // Overrides the default parallelization strategy using user-provided value.
130*8f0c014bSYinying Li     this->sparsificationOptions.parallelizationStrategy = parallelization;
131*8f0c014bSYinying Li 
132c1fef4e8SMatthias Springer     // Run enabling transformations.
133b19c40c5SAart Bik     {
134c1fef4e8SMatthias Springer       OpPassManager pm("builtin.module");
135c1fef4e8SMatthias Springer       pm.addPass(createPreSparsificationRewritePass());
136be630f07SMatthias Springer       pm.addNestedPass<func::FuncOp>(
137be630f07SMatthias Springer           bufferization::createEmptyTensorToAllocTensorPass());
138c1fef4e8SMatthias Springer       if (failed(runPipeline(pm, getOperation())))
139c1fef4e8SMatthias Springer         return signalPassFailure();
140c1fef4e8SMatthias Springer     }
141c1fef4e8SMatthias Springer 
142c1fef4e8SMatthias Springer     // Insert tensor copies. This step runs One-Shot Analysis (which analyzes
143c1fef4e8SMatthias Springer     // SSA use-def chains of tensor IR) and decides where buffer copies are
144c1fef4e8SMatthias Springer     // needed and where buffers can be written to in-place. These decisions are
145c1fef4e8SMatthias Springer     // materialized in the IR in the form of `bufferization.alloc_tensor` ops.
146c1fef4e8SMatthias Springer     //
147c1fef4e8SMatthias Springer     // Note: All following steps in this pass must be careful not to modify the
148c1fef4e8SMatthias Springer     // structure of the IR (i.e., tensor use-def chains), as that could
149c1fef4e8SMatthias Springer     // invalidate the results of the analysis. From now on, only small and
150c1fef4e8SMatthias Springer     // localized rewrites are allowed, such as replacing a tensor op with its
151c1fef4e8SMatthias Springer     // memref equivalent.
152c1fef4e8SMatthias Springer     if (failed(bufferization::insertTensorCopies(getOperation(),
153c1fef4e8SMatthias Springer                                                  bufferizationOptions)))
154c1fef4e8SMatthias Springer       return signalPassFailure();
155c1fef4e8SMatthias Springer 
156b19c40c5SAart Bik     // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
157c1fef4e8SMatthias Springer     // OneShotAnalysis are added to the IR via attributes. In that case, do not
158c1fef4e8SMatthias Springer     // continue with the remaining pipeline.
159c1fef4e8SMatthias Springer     if (bufferizationOptions.testAnalysisOnly)
160c1fef4e8SMatthias Springer       return;
161c1fef4e8SMatthias Springer 
162c1fef4e8SMatthias Springer     // Bufferize all sparse ops. No further analysis is needed. All required
163c1fef4e8SMatthias Springer     // buffer copies were already inserted by `insertTensorCopies` in the form
164c1fef4e8SMatthias Springer     // of `bufferization.alloc_tensor` ops.
165c1fef4e8SMatthias Springer     {
166c1fef4e8SMatthias Springer       OpPassManager pm("builtin.module");
1675f32bcfbSAart Bik       if (enableGPULibgen)
1685f32bcfbSAart Bik         pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
169c99951d4SPeiming Liu       pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
170c1fef4e8SMatthias Springer       pm.addPass(createSparsificationPass(sparsificationOptions));
171a02010b3SPeiming Liu       if (sparsificationOptions.sparseEmitStrategy ==
172a02010b3SPeiming Liu           SparseEmitStrategy::kSparseIterator) {
173a02010b3SPeiming Liu         pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass());
174a02010b3SPeiming Liu         pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass());
175a02010b3SPeiming Liu       }
176a02010b3SPeiming Liu 
177dda3dc5eSPeiming Liu       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
178f82bee13SPeiming Liu       pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
179f82bee13SPeiming Liu                                                    /*enableConvert=*/true));
180ef100c22SPeiming Liu       pm.addPass(
181ef100c22SPeiming Liu           createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
182f82bee13SPeiming Liu       pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
18316aa4e4bSAart Bik       pm.addPass(mlir::createLoopInvariantCodeMotionPass());
184b6cad75eSPeiming Liu       if (vectorLength > 0) {
18516aa4e4bSAart Bik         pm.addPass(createSparseVectorizationPass(
18616aa4e4bSAart Bik             vectorLength, enableVLAVectorization, enableSIMDIndex32));
18716aa4e4bSAart Bik       }
188c1fef4e8SMatthias Springer       if (enableRuntimeLibrary) {
189f248d0b2SPeiming Liu         pm.addPass(createSparseTensorConversionPass());
190c1fef4e8SMatthias Springer       } else {
191c44d307cSPeiming Liu         pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs,
192c44d307cSPeiming Liu                                                  enableBufferInitialization));
193c1fef4e8SMatthias Springer         pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
194c1fef4e8SMatthias Springer       }
195c1fef4e8SMatthias Springer       if (failed(runPipeline(pm, getOperation())))
196c1fef4e8SMatthias Springer         return signalPassFailure();
197c1fef4e8SMatthias Springer     }
198c1fef4e8SMatthias Springer 
199c1fef4e8SMatthias Springer     // Bufferize all dense ops.
200c1fef4e8SMatthias Springer     if (failed(runDenseBufferization()))
201c1fef4e8SMatthias Springer       signalPassFailure();
202c1fef4e8SMatthias Springer   }
203c1fef4e8SMatthias Springer 
204c1fef4e8SMatthias Springer private:
205c1fef4e8SMatthias Springer   bufferization::OneShotBufferizationOptions bufferizationOptions;
206c1fef4e8SMatthias Springer   SparsificationOptions sparsificationOptions;
207c44d307cSPeiming Liu   bool createSparseDeallocs;
208c1fef4e8SMatthias Springer   bool enableRuntimeLibrary;
209c1fef4e8SMatthias Springer   bool enableBufferInitialization;
210c1fef4e8SMatthias Springer };
21116aa4e4bSAart Bik 
212c1fef4e8SMatthias Springer } // namespace sparse_tensor
213c1fef4e8SMatthias Springer } // namespace mlir
214c1fef4e8SMatthias Springer 
2158154494eSAart Bik mlir::bufferization::OneShotBufferizationOptions
2168154494eSAart Bik mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
2178154494eSAart Bik   using namespace mlir::bufferization;
2188154494eSAart Bik   OneShotBufferizationOptions options;
2198154494eSAart Bik   options.bufferizeFunctionBoundaries = true;
2208154494eSAart Bik   options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
2218154494eSAart Bik   options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
2228154494eSAart Bik                                       const BufferizationOptions &options) {
2238154494eSAart Bik     return getMemRefTypeWithStaticIdentityLayout(
2248154494eSAart Bik         cast<TensorType>(value.getType()), memorySpace);
2258154494eSAart Bik   };
2268154494eSAart Bik   if (analysisOnly) {
2278154494eSAart Bik     options.testAnalysisOnly = true;
2288154494eSAart Bik     options.printConflicts = true;
2298154494eSAart Bik   }
230aec73eadSAart Bik   // Since this mini-pipeline may be used in alternative pipelines (viz.
231aec73eadSAart Bik   // different from the default "sparsifier" pipeline) where unknown ops
232aec73eadSAart Bik   // are handled by alternative bufferization methods that are downstream
233aec73eadSAart Bik   // of this mini-pipeline, we allow unknown ops by default (failure to
234aec73eadSAart Bik   // bufferize is eventually apparent by failing to convert to LLVM IR).
235aec73eadSAart Bik   options.allowUnknownOps = true;
2368154494eSAart Bik   return options;
2378154494eSAart Bik }
2388154494eSAart Bik 
2398154494eSAart Bik std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() {
2408154494eSAart Bik   SparsificationOptions sparseOptions;
241438a7d4cSAart Bik   return std::make_unique<
242438a7d4cSAart Bik       mlir::sparse_tensor::SparsificationAndBufferizationPass>(
2438154494eSAart Bik       getBufferizationOptionsForSparsification(/*analysisOnly=*/false),
244f248d0b2SPeiming Liu       sparseOptions,
2458154494eSAart Bik       /*createSparseDeallocs=*/false,
2468154494eSAart Bik       /*enableRuntimeLibrary=*/false,
247438a7d4cSAart Bik       /*enableBufferInitialization=*/false);
2488154494eSAart Bik }
2498154494eSAart Bik 
2508154494eSAart Bik std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
251c1fef4e8SMatthias Springer     const bufferization::OneShotBufferizationOptions &bufferizationOptions,
252c1fef4e8SMatthias Springer     const SparsificationOptions &sparsificationOptions,
253c44d307cSPeiming Liu     bool createSparseDeallocs, bool enableRuntimeLibrary,
254c44d307cSPeiming Liu     bool enableBufferInitialization, unsigned vectorLength,
2551ba2768cSPeiming Liu     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
256*8f0c014bSYinying Li     SparseEmitStrategy emitStrategy,
257*8f0c014bSYinying Li     SparseParallelizationStrategy parallelizationStrategy) {
258c1fef4e8SMatthias Springer   return std::make_unique<
259c1fef4e8SMatthias Springer       mlir::sparse_tensor::SparsificationAndBufferizationPass>(
260f248d0b2SPeiming Liu       bufferizationOptions, sparsificationOptions, createSparseDeallocs,
261f248d0b2SPeiming Liu       enableRuntimeLibrary, enableBufferInitialization, vectorLength,
262*8f0c014bSYinying Li       enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy,
263*8f0c014bSYinying Li       parallelizationStrategy);
264c1fef4e8SMatthias Springer }
265