xref: /llvm-project/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (revision 8f0c014b12663129d8bfe0cc89f06e7a1d8b48c2)
1 //===- SparsificationAndBufferizationPass.cpp - Tensor to Memref Lowering -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
16 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
17 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/Linalg/IR/Linalg.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27 #include "mlir/Dialect/Vector/IR/VectorOps.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Transforms/Passes.h"
30 
31 using namespace mlir;
32 
33 namespace mlir {
34 
35 #define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION
36 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
37 
38 namespace sparse_tensor {
39 
40 /// Return `true` if one of the given types is a sparse tensor type.
41 static bool containsSparseTensor(TypeRange types) {
42   for (Type t : types)
43     if (isa<TensorType>(t) && getSparseTensorEncoding(t))
44       return true;
45   return false;
46 }
47 
48 /// A pass that lowers tensor ops to memref ops, regardless of whether they are
49 /// dense or sparse.
50 ///
51 /// One-Shot Analysis is used to detect RaW conflicts and to insert buffer
52 /// copies of the tensor level (`insertTensorCopies`). Afterwards, the lowering
53 /// of tensor ops to memref ops follows a different code path depending on
54 /// whether the op is sparse or dense:
55 ///
56 /// * Sparse tensor ops are lowered through Sparsification and follow-up pass
57 ///   that lowers sparse_tensor dialect ops.
58 /// * Dense tensor ops are lowered through BufferizableOpInterface
59 ///   implementations.
60 class SparsificationAndBufferizationPass
61     : public impl::SparsificationAndBufferizationBase<
62           SparsificationAndBufferizationPass> {
63 public:
64   // Private pass options only.
65   SparsificationAndBufferizationPass(
66       const bufferization::OneShotBufferizationOptions &bufferizationOptions,
67       const SparsificationOptions &sparsificationOptions,
68       bool createSparseDeallocs, bool enableRuntimeLibrary,
69       bool enableBufferInitialization)
70       : bufferizationOptions(bufferizationOptions),
71         sparsificationOptions(sparsificationOptions),
72         createSparseDeallocs(createSparseDeallocs),
73         enableRuntimeLibrary(enableRuntimeLibrary),
74         enableBufferInitialization(enableBufferInitialization) {}
75   // Private pass options and visible pass options.
76   SparsificationAndBufferizationPass(
77       const bufferization::OneShotBufferizationOptions &bufferizationOptions,
78       const SparsificationOptions &sparsificationOptions,
79       bool createSparseDeallocs, bool enableRuntimeLibrary,
80       bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
81       bool gpu, SparseEmitStrategy emitStrategy,
82       SparseParallelizationStrategy parallelizationStrategy)
83       : bufferizationOptions(bufferizationOptions),
84         sparsificationOptions(sparsificationOptions),
85         createSparseDeallocs(createSparseDeallocs),
86         enableRuntimeLibrary(enableRuntimeLibrary),
87         enableBufferInitialization(enableBufferInitialization) {
88     // Set the visible pass options explicitly.
89     vectorLength = vl;
90     enableVLAVectorization = vla;
91     enableSIMDIndex32 = index32;
92     enableGPULibgen = gpu;
93     sparseEmitStrategy = emitStrategy;
94     parallelization = parallelizationStrategy;
95   }
96 
97   /// Bufferize all dense ops. This assumes that no further analysis is needed
98   /// and that all required buffer copies were already inserted by
99   /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops.
100   LogicalResult runDenseBufferization() {
101     bufferization::OneShotBufferizationOptions updatedOptions =
102         bufferizationOptions;
103     // Skip all sparse ops.
104     updatedOptions.opFilter.denyOperation([&](Operation *op) {
105       if (containsSparseTensor(TypeRange(op->getResults())) ||
106           containsSparseTensor(TypeRange(op->getOperands())))
107         return true;
108       if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
109         FunctionType funcType = funcOp.getFunctionType();
110         if (containsSparseTensor(funcType.getInputs()) ||
111             containsSparseTensor(funcType.getResults()))
112           return true;
113       }
114       return false;
115     });
116 
117     if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
118                                                 updatedOptions)))
119       return failure();
120 
121     bufferization::removeBufferizationAttributesInModule(getOperation());
122     return success();
123   }
124 
125   void runOnOperation() override {
126     // Overrides the default emit strategy using user-provided value.
127     this->sparsificationOptions.sparseEmitStrategy = sparseEmitStrategy;
128 
129     // Overrides the default parallelization strategy using user-provided value.
130     this->sparsificationOptions.parallelizationStrategy = parallelization;
131 
132     // Run enabling transformations.
133     {
134       OpPassManager pm("builtin.module");
135       pm.addPass(createPreSparsificationRewritePass());
136       pm.addNestedPass<func::FuncOp>(
137           bufferization::createEmptyTensorToAllocTensorPass());
138       if (failed(runPipeline(pm, getOperation())))
139         return signalPassFailure();
140     }
141 
142     // Insert tensor copies. This step runs One-Shot Analysis (which analyzes
143     // SSA use-def chains of tensor IR) and decides where buffer copies are
144     // needed and where buffers can be written to in-place. These decisions are
145     // materialized in the IR in the form of `bufferization.alloc_tensor` ops.
146     //
147     // Note: All following steps in this pass must be careful not to modify the
148     // structure of the IR (i.e., tensor use-def chains), as that could
149     // invalidate the results of the analysis. From now on, only small and
150     // localized rewrites are allowed, such as replacing a tensor op with its
151     // memref equivalent.
152     if (failed(bufferization::insertTensorCopies(getOperation(),
153                                                  bufferizationOptions)))
154       return signalPassFailure();
155 
156     // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
157     // OneShotAnalysis are added to the IR via attributes. In that case, do not
158     // continue with the remaining pipeline.
159     if (bufferizationOptions.testAnalysisOnly)
160       return;
161 
162     // Bufferize all sparse ops. No further analysis is needed. All required
163     // buffer copies were already inserted by `insertTensorCopies` in the form
164     // of `bufferization.alloc_tensor` ops.
165     {
166       OpPassManager pm("builtin.module");
167       if (enableGPULibgen)
168         pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary));
169       pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
170       pm.addPass(createSparsificationPass(sparsificationOptions));
171       if (sparsificationOptions.sparseEmitStrategy ==
172           SparseEmitStrategy::kSparseIterator) {
173         pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass());
174         pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass());
175       }
176 
177       pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
178       pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
179                                                    /*enableConvert=*/true));
180       pm.addPass(
181           createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
182       pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
183       pm.addPass(mlir::createLoopInvariantCodeMotionPass());
184       if (vectorLength > 0) {
185         pm.addPass(createSparseVectorizationPass(
186             vectorLength, enableVLAVectorization, enableSIMDIndex32));
187       }
188       if (enableRuntimeLibrary) {
189         pm.addPass(createSparseTensorConversionPass());
190       } else {
191         pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs,
192                                                  enableBufferInitialization));
193         pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
194       }
195       if (failed(runPipeline(pm, getOperation())))
196         return signalPassFailure();
197     }
198 
199     // Bufferize all dense ops.
200     if (failed(runDenseBufferization()))
201       signalPassFailure();
202   }
203 
204 private:
205   bufferization::OneShotBufferizationOptions bufferizationOptions;
206   SparsificationOptions sparsificationOptions;
207   bool createSparseDeallocs;
208   bool enableRuntimeLibrary;
209   bool enableBufferInitialization;
210 };
211 
212 } // namespace sparse_tensor
213 } // namespace mlir
214 
215 mlir::bufferization::OneShotBufferizationOptions
216 mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
217   using namespace mlir::bufferization;
218   OneShotBufferizationOptions options;
219   options.bufferizeFunctionBoundaries = true;
220   options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
221   options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
222                                       const BufferizationOptions &options) {
223     return getMemRefTypeWithStaticIdentityLayout(
224         cast<TensorType>(value.getType()), memorySpace);
225   };
226   if (analysisOnly) {
227     options.testAnalysisOnly = true;
228     options.printConflicts = true;
229   }
230   // Since this mini-pipeline may be used in alternative pipelines (viz.
231   // different from the default "sparsifier" pipeline) where unknown ops
232   // are handled by alternative bufferization methods that are downstream
233   // of this mini-pipeline, we allow unknown ops by default (failure to
234   // bufferize is eventually apparent by failing to convert to LLVM IR).
235   options.allowUnknownOps = true;
236   return options;
237 }
238 
239 std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() {
240   SparsificationOptions sparseOptions;
241   return std::make_unique<
242       mlir::sparse_tensor::SparsificationAndBufferizationPass>(
243       getBufferizationOptionsForSparsification(/*analysisOnly=*/false),
244       sparseOptions,
245       /*createSparseDeallocs=*/false,
246       /*enableRuntimeLibrary=*/false,
247       /*enableBufferInitialization=*/false);
248 }
249 
250 std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
251     const bufferization::OneShotBufferizationOptions &bufferizationOptions,
252     const SparsificationOptions &sparsificationOptions,
253     bool createSparseDeallocs, bool enableRuntimeLibrary,
254     bool enableBufferInitialization, unsigned vectorLength,
255     bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
256     SparseEmitStrategy emitStrategy,
257     SparseParallelizationStrategy parallelizationStrategy) {
258   return std::make_unique<
259       mlir::sparse_tensor::SparsificationAndBufferizationPass>(
260       bufferizationOptions, sparsificationOptions, createSparseDeallocs,
261       enableRuntimeLibrary, enableBufferInitialization, vectorLength,
262       enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy,
263       parallelizationStrategy);
264 }
265