1c1fef4e8SMatthias Springer //===- SparsificationAndBufferizationPass.cpp - Tensor to Memref Lowering -===// 2c1fef4e8SMatthias Springer // 3c1fef4e8SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c1fef4e8SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5c1fef4e8SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c1fef4e8SMatthias Springer // 7c1fef4e8SMatthias Springer //===----------------------------------------------------------------------===// 8c1fef4e8SMatthias Springer 9c1fef4e8SMatthias Springer #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 10c1fef4e8SMatthias Springer 118154494eSAart Bik #include "mlir/Dialect/Affine/IR/AffineOps.h" 12c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 13a0568eabSbixia1 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 16c7a9e5e5SPeiming Liu #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 17be630f07SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 18c1fef4e8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 19c1fef4e8SMatthias Springer #include "mlir/Dialect/Func/IR/FuncOps.h" 20ee42e236SAart Bik #include "mlir/Dialect/GPU/IR/GPUDialect.h" 21006340baSPeiming Liu #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 228154494eSAart Bik #include "mlir/Dialect/Linalg/IR/Linalg.h" 238154494eSAart Bik #include "mlir/Dialect/MemRef/IR/MemRef.h" 248154494eSAart Bik #include "mlir/Dialect/SCF/IR/SCF.h" 25c1fef4e8SMatthias Springer #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26c1fef4e8SMatthias Springer #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 279f3334e9SMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h" 28c1fef4e8SMatthias Springer #include "mlir/Pass/PassManager.h" 2916aa4e4bSAart Bik #include "mlir/Transforms/Passes.h" 30c1fef4e8SMatthias Springer 31c1fef4e8SMatthias Springer using namespace mlir; 32c1fef4e8SMatthias Springer 33c1fef4e8SMatthias Springer namespace mlir { 348154494eSAart Bik 358154494eSAart Bik #define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION 368154494eSAart Bik #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" 378154494eSAart Bik 38c1fef4e8SMatthias Springer namespace sparse_tensor { 39c1fef4e8SMatthias Springer 40c1fef4e8SMatthias Springer /// Return `true` if one of the given types is a sparse tensor type. 41c1fef4e8SMatthias Springer static bool containsSparseTensor(TypeRange types) { 42c1fef4e8SMatthias Springer for (Type t : types) 4348a73bc4SMatthias Springer if (isa<TensorType>(t) && getSparseTensorEncoding(t)) 44c1fef4e8SMatthias Springer return true; 45c1fef4e8SMatthias Springer return false; 46c1fef4e8SMatthias Springer } 47c1fef4e8SMatthias Springer 48c1fef4e8SMatthias Springer /// A pass that lowers tensor ops to memref ops, regardless of whether they are 49c1fef4e8SMatthias Springer /// dense or sparse. 50c1fef4e8SMatthias Springer /// 51c1fef4e8SMatthias Springer /// One-Shot Analysis is used to detect RaW conflicts and to insert buffer 52c1fef4e8SMatthias Springer /// copies of the tensor level (`insertTensorCopies`). Afterwards, the lowering 53c1fef4e8SMatthias Springer /// of tensor ops to memref ops follows a different code path depending on 54c1fef4e8SMatthias Springer /// whether the op is sparse or dense: 55c1fef4e8SMatthias Springer /// 56c1fef4e8SMatthias Springer /// * Sparse tensor ops are lowered through Sparsification and follow-up pass 57c1fef4e8SMatthias Springer /// that lowers sparse_tensor dialect ops. 58c1fef4e8SMatthias Springer /// * Dense tensor ops are lowered through BufferizableOpInterface 59c1fef4e8SMatthias Springer /// implementations. 60c1fef4e8SMatthias Springer class SparsificationAndBufferizationPass 618154494eSAart Bik : public impl::SparsificationAndBufferizationBase< 628154494eSAart Bik SparsificationAndBufferizationPass> { 63c1fef4e8SMatthias Springer public: 64438a7d4cSAart Bik // Private pass options only. 65c1fef4e8SMatthias Springer SparsificationAndBufferizationPass( 66c1fef4e8SMatthias Springer const bufferization::OneShotBufferizationOptions &bufferizationOptions, 67c1fef4e8SMatthias Springer const SparsificationOptions &sparsificationOptions, 68c44d307cSPeiming Liu bool createSparseDeallocs, bool enableRuntimeLibrary, 69438a7d4cSAart Bik bool enableBufferInitialization) 70c1fef4e8SMatthias Springer : bufferizationOptions(bufferizationOptions), 71c1fef4e8SMatthias Springer sparsificationOptions(sparsificationOptions), 72c44d307cSPeiming Liu createSparseDeallocs(createSparseDeallocs), 73c1fef4e8SMatthias Springer enableRuntimeLibrary(enableRuntimeLibrary), 74438a7d4cSAart Bik enableBufferInitialization(enableBufferInitialization) {} 75438a7d4cSAart Bik // Private pass options and visible pass options. 76438a7d4cSAart Bik SparsificationAndBufferizationPass( 77438a7d4cSAart Bik const bufferization::OneShotBufferizationOptions &bufferizationOptions, 78438a7d4cSAart Bik const SparsificationOptions &sparsificationOptions, 79438a7d4cSAart Bik bool createSparseDeallocs, bool enableRuntimeLibrary, 80438a7d4cSAart Bik bool enableBufferInitialization, unsigned vl, bool vla, bool index32, 81*8f0c014bSYinying Li bool gpu, SparseEmitStrategy emitStrategy, 82*8f0c014bSYinying Li SparseParallelizationStrategy parallelizationStrategy) 83438a7d4cSAart Bik : bufferizationOptions(bufferizationOptions), 84438a7d4cSAart Bik sparsificationOptions(sparsificationOptions), 85438a7d4cSAart Bik createSparseDeallocs(createSparseDeallocs), 86438a7d4cSAart Bik enableRuntimeLibrary(enableRuntimeLibrary), 87438a7d4cSAart Bik enableBufferInitialization(enableBufferInitialization) { 88438a7d4cSAart Bik // Set the visible pass options explicitly. 89438a7d4cSAart Bik vectorLength = vl; 90438a7d4cSAart Bik enableVLAVectorization = vla; 91438a7d4cSAart Bik enableSIMDIndex32 = index32; 92438a7d4cSAart Bik enableGPULibgen = gpu; 931ba2768cSPeiming Liu sparseEmitStrategy = emitStrategy; 94*8f0c014bSYinying Li parallelization = parallelizationStrategy; 955f32bcfbSAart Bik } 96c1fef4e8SMatthias Springer 97c1fef4e8SMatthias Springer /// Bufferize all dense ops. This assumes that no further analysis is needed 98c1fef4e8SMatthias Springer /// and that all required buffer copies were already inserted by 99c1fef4e8SMatthias Springer /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops. 100c1fef4e8SMatthias Springer LogicalResult runDenseBufferization() { 1019d34c052SMatthias Springer bufferization::OneShotBufferizationOptions updatedOptions = 1029d34c052SMatthias Springer bufferizationOptions; 1039d34c052SMatthias Springer // Skip all sparse ops. 1049d34c052SMatthias Springer updatedOptions.opFilter.denyOperation([&](Operation *op) { 105c1fef4e8SMatthias Springer if (containsSparseTensor(TypeRange(op->getResults())) || 106c1fef4e8SMatthias Springer containsSparseTensor(TypeRange(op->getOperands()))) 1079d34c052SMatthias Springer return true; 108c1fef4e8SMatthias Springer if (auto funcOp = dyn_cast<func::FuncOp>(op)) { 109c1fef4e8SMatthias Springer FunctionType funcType = funcOp.getFunctionType(); 110c1fef4e8SMatthias Springer if (containsSparseTensor(funcType.getInputs()) || 111c1fef4e8SMatthias Springer containsSparseTensor(funcType.getResults())) 112c1fef4e8SMatthias Springer return true; 1139d34c052SMatthias Springer } 1149d34c052SMatthias Springer return false; 115c1fef4e8SMatthias Springer }); 116c7a9e5e5SPeiming Liu 11748a73bc4SMatthias Springer if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()), 11848a73bc4SMatthias Springer updatedOptions))) 119c7a9e5e5SPeiming Liu return failure(); 120c7a9e5e5SPeiming Liu 121c7a9e5e5SPeiming Liu bufferization::removeBufferizationAttributesInModule(getOperation()); 122c7a9e5e5SPeiming Liu return success(); 123c1fef4e8SMatthias Springer } 124c1fef4e8SMatthias Springer 125c1fef4e8SMatthias Springer void runOnOperation() override { 1261ba2768cSPeiming Liu // Overrides the default emit strategy using user-provided value. 1271ba2768cSPeiming Liu this->sparsificationOptions.sparseEmitStrategy = sparseEmitStrategy; 1281ba2768cSPeiming Liu 129*8f0c014bSYinying Li // Overrides the default parallelization strategy using user-provided value. 130*8f0c014bSYinying Li this->sparsificationOptions.parallelizationStrategy = parallelization; 131*8f0c014bSYinying Li 132c1fef4e8SMatthias Springer // Run enabling transformations. 133b19c40c5SAart Bik { 134c1fef4e8SMatthias Springer OpPassManager pm("builtin.module"); 135c1fef4e8SMatthias Springer pm.addPass(createPreSparsificationRewritePass()); 136be630f07SMatthias Springer pm.addNestedPass<func::FuncOp>( 137be630f07SMatthias Springer bufferization::createEmptyTensorToAllocTensorPass()); 138c1fef4e8SMatthias Springer if (failed(runPipeline(pm, getOperation()))) 139c1fef4e8SMatthias Springer return signalPassFailure(); 140c1fef4e8SMatthias Springer } 141c1fef4e8SMatthias Springer 142c1fef4e8SMatthias Springer // Insert tensor copies. This step runs One-Shot Analysis (which analyzes 143c1fef4e8SMatthias Springer // SSA use-def chains of tensor IR) and decides where buffer copies are 144c1fef4e8SMatthias Springer // needed and where buffers can be written to in-place. These decisions are 145c1fef4e8SMatthias Springer // materialized in the IR in the form of `bufferization.alloc_tensor` ops. 146c1fef4e8SMatthias Springer // 147c1fef4e8SMatthias Springer // Note: All following steps in this pass must be careful not to modify the 148c1fef4e8SMatthias Springer // structure of the IR (i.e., tensor use-def chains), as that could 149c1fef4e8SMatthias Springer // invalidate the results of the analysis. From now on, only small and 150c1fef4e8SMatthias Springer // localized rewrites are allowed, such as replacing a tensor op with its 151c1fef4e8SMatthias Springer // memref equivalent. 152c1fef4e8SMatthias Springer if (failed(bufferization::insertTensorCopies(getOperation(), 153c1fef4e8SMatthias Springer bufferizationOptions))) 154c1fef4e8SMatthias Springer return signalPassFailure(); 155c1fef4e8SMatthias Springer 156b19c40c5SAart Bik // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of 157c1fef4e8SMatthias Springer // OneShotAnalysis are added to the IR via attributes. In that case, do not 158c1fef4e8SMatthias Springer // continue with the remaining pipeline. 159c1fef4e8SMatthias Springer if (bufferizationOptions.testAnalysisOnly) 160c1fef4e8SMatthias Springer return; 161c1fef4e8SMatthias Springer 162c1fef4e8SMatthias Springer // Bufferize all sparse ops. No further analysis is needed. All required 163c1fef4e8SMatthias Springer // buffer copies were already inserted by `insertTensorCopies` in the form 164c1fef4e8SMatthias Springer // of `bufferization.alloc_tensor` ops. 165c1fef4e8SMatthias Springer { 166c1fef4e8SMatthias Springer OpPassManager pm("builtin.module"); 1675f32bcfbSAart Bik if (enableGPULibgen) 1685f32bcfbSAart Bik pm.addPass(createSparseGPUCodegenPass(0, enableRuntimeLibrary)); 169c99951d4SPeiming Liu pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll)); 170c1fef4e8SMatthias Springer pm.addPass(createSparsificationPass(sparsificationOptions)); 171a02010b3SPeiming Liu if (sparsificationOptions.sparseEmitStrategy == 172a02010b3SPeiming Liu SparseEmitStrategy::kSparseIterator) { 173a02010b3SPeiming Liu pm.addNestedPass<func::FuncOp>(createSparseSpaceCollapsePass()); 174a02010b3SPeiming Liu pm.addNestedPass<func::FuncOp>(createLowerSparseIterationToSCFPass()); 175a02010b3SPeiming Liu } 176a02010b3SPeiming Liu 177dda3dc5eSPeiming Liu pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass()); 178f82bee13SPeiming Liu pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary, 179f82bee13SPeiming Liu /*enableConvert=*/true)); 180ef100c22SPeiming Liu pm.addPass( 181ef100c22SPeiming Liu createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric)); 182f82bee13SPeiming Liu pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass()); 18316aa4e4bSAart Bik pm.addPass(mlir::createLoopInvariantCodeMotionPass()); 184b6cad75eSPeiming Liu if (vectorLength > 0) { 18516aa4e4bSAart Bik pm.addPass(createSparseVectorizationPass( 18616aa4e4bSAart Bik vectorLength, enableVLAVectorization, enableSIMDIndex32)); 18716aa4e4bSAart Bik } 188c1fef4e8SMatthias Springer if (enableRuntimeLibrary) { 189f248d0b2SPeiming Liu pm.addPass(createSparseTensorConversionPass()); 190c1fef4e8SMatthias Springer } else { 191c44d307cSPeiming Liu pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs, 192c44d307cSPeiming Liu enableBufferInitialization)); 193c1fef4e8SMatthias Springer pm.addPass(createSparseBufferRewritePass(enableBufferInitialization)); 194c1fef4e8SMatthias Springer } 195c1fef4e8SMatthias Springer if (failed(runPipeline(pm, getOperation()))) 196c1fef4e8SMatthias Springer return signalPassFailure(); 197c1fef4e8SMatthias Springer } 198c1fef4e8SMatthias Springer 199c1fef4e8SMatthias Springer // Bufferize all dense ops. 200c1fef4e8SMatthias Springer if (failed(runDenseBufferization())) 201c1fef4e8SMatthias Springer signalPassFailure(); 202c1fef4e8SMatthias Springer } 203c1fef4e8SMatthias Springer 204c1fef4e8SMatthias Springer private: 205c1fef4e8SMatthias Springer bufferization::OneShotBufferizationOptions bufferizationOptions; 206c1fef4e8SMatthias Springer SparsificationOptions sparsificationOptions; 207c44d307cSPeiming Liu bool createSparseDeallocs; 208c1fef4e8SMatthias Springer bool enableRuntimeLibrary; 209c1fef4e8SMatthias Springer bool enableBufferInitialization; 210c1fef4e8SMatthias Springer }; 21116aa4e4bSAart Bik 212c1fef4e8SMatthias Springer } // namespace sparse_tensor 213c1fef4e8SMatthias Springer } // namespace mlir 214c1fef4e8SMatthias Springer 2158154494eSAart Bik mlir::bufferization::OneShotBufferizationOptions 2168154494eSAart Bik mlir::getBufferizationOptionsForSparsification(bool analysisOnly) { 2178154494eSAart Bik using namespace mlir::bufferization; 2188154494eSAart Bik OneShotBufferizationOptions options; 2198154494eSAart Bik options.bufferizeFunctionBoundaries = true; 2208154494eSAart Bik options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap); 2218154494eSAart Bik options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, 2228154494eSAart Bik const BufferizationOptions &options) { 2238154494eSAart Bik return getMemRefTypeWithStaticIdentityLayout( 2248154494eSAart Bik cast<TensorType>(value.getType()), memorySpace); 2258154494eSAart Bik }; 2268154494eSAart Bik if (analysisOnly) { 2278154494eSAart Bik options.testAnalysisOnly = true; 2288154494eSAart Bik options.printConflicts = true; 2298154494eSAart Bik } 230aec73eadSAart Bik // Since this mini-pipeline may be used in alternative pipelines (viz. 231aec73eadSAart Bik // different from the default "sparsifier" pipeline) where unknown ops 232aec73eadSAart Bik // are handled by alternative bufferization methods that are downstream 233aec73eadSAart Bik // of this mini-pipeline, we allow unknown ops by default (failure to 234aec73eadSAart Bik // bufferize is eventually apparent by failing to convert to LLVM IR). 235aec73eadSAart Bik options.allowUnknownOps = true; 2368154494eSAart Bik return options; 2378154494eSAart Bik } 2388154494eSAart Bik 2398154494eSAart Bik std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() { 2408154494eSAart Bik SparsificationOptions sparseOptions; 241438a7d4cSAart Bik return std::make_unique< 242438a7d4cSAart Bik mlir::sparse_tensor::SparsificationAndBufferizationPass>( 2438154494eSAart Bik getBufferizationOptionsForSparsification(/*analysisOnly=*/false), 244f248d0b2SPeiming Liu sparseOptions, 2458154494eSAart Bik /*createSparseDeallocs=*/false, 2468154494eSAart Bik /*enableRuntimeLibrary=*/false, 247438a7d4cSAart Bik /*enableBufferInitialization=*/false); 2488154494eSAart Bik } 2498154494eSAart Bik 2508154494eSAart Bik std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass( 251c1fef4e8SMatthias Springer const bufferization::OneShotBufferizationOptions &bufferizationOptions, 252c1fef4e8SMatthias Springer const SparsificationOptions &sparsificationOptions, 253c44d307cSPeiming Liu bool createSparseDeallocs, bool enableRuntimeLibrary, 254c44d307cSPeiming Liu bool enableBufferInitialization, unsigned vectorLength, 2551ba2768cSPeiming Liu bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen, 256*8f0c014bSYinying Li SparseEmitStrategy emitStrategy, 257*8f0c014bSYinying Li SparseParallelizationStrategy parallelizationStrategy) { 258c1fef4e8SMatthias Springer return std::make_unique< 259c1fef4e8SMatthias Springer mlir::sparse_tensor::SparsificationAndBufferizationPass>( 260f248d0b2SPeiming Liu bufferizationOptions, sparsificationOptions, createSparseDeallocs, 261f248d0b2SPeiming Liu enableRuntimeLibrary, enableBufferInitialization, vectorLength, 262*8f0c014bSYinying Li enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy, 263*8f0c014bSYinying Li parallelizationStrategy); 264c1fef4e8SMatthias Springer } 265