1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 20 #include "mlir/IR/FunctionInterfaces.h" 21 22 using namespace mlir; 23 using namespace mlir::bufferization; 24 using namespace mlir::transform; 25 26 //===----------------------------------------------------------------------===// 27 // OneShotBufferizeOp 28 //===----------------------------------------------------------------------===// 29 30 DiagnosedSilenceableFailure 31 transform::OneShotBufferizeOp::apply(TransformResults &transformResults, 32 TransformState &state) { 33 OneShotBufferizationOptions options; 34 options.allowReturnAllocs = getAllowReturnAllocs(); 35 options.allowUnknownOps = getAllowUnknownOps(); 36 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 37 options.createDeallocs = getCreateDeallocs(); 38 options.testAnalysisOnly = getTestAnalysisOnly(); 39 options.printConflicts = getPrintConflicts(); 40 if (getFunctionBoundaryTypeConversion().has_value()) 41 options.functionBoundaryTypeConversion = 42 *getFunctionBoundaryTypeConversion(); 43 44 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 45 for (Operation *target : payloadOps) { 46 if (!isa<ModuleOp, FunctionOpInterface>(target)) 47 return emitSilenceableError() << "expected module or function target"; 48 auto moduleOp = dyn_cast<ModuleOp>(target); 49 if (options.bufferizeFunctionBoundaries) { 50 if (!moduleOp) 51 return emitSilenceableError() << "expected module target"; 52 if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 53 return emitSilenceableError() << "bufferization failed"; 54 } else { 55 if (failed(bufferization::runOneShotBufferize(target, options))) 56 return emitSilenceableError() << "bufferization failed"; 57 } 58 } 59 60 // This transform op is currently restricted to ModuleOps and function ops. 61 // Such ops are modified in-place. 62 transformResults.set(getTransformed().cast<OpResult>(), payloadOps); 63 return DiagnosedSilenceableFailure::success(); 64 } 65 66 //===----------------------------------------------------------------------===// 67 // EliminateEmptyTensorsOp 68 //===----------------------------------------------------------------------===// 69 70 void transform::EliminateEmptyTensorsOp::getEffects( 71 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 72 onlyReadsHandle(getTarget(), effects); 73 modifiesPayload(effects); 74 } 75 76 DiagnosedSilenceableFailure 77 transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, 78 TransformState &state) { 79 IRRewriter rewriter(getContext()); 80 OneShotBufferizationOptions options; 81 options.allowReturnAllocs = true; 82 83 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 84 for (Operation *target : payloadOps) { 85 OneShotAnalysisState state(target, options); 86 if (failed(analyzeOp(target, state))) 87 return mlir::emitSilenceableFailure(target->getLoc()) 88 << "failed to analyze op"; 89 if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( 90 rewriter, target, state))) 91 return mlir::emitSilenceableFailure(target->getLoc()) 92 << "failed to eliminate insert_slice anchored tensor.empty ops"; 93 } 94 return DiagnosedSilenceableFailure::success(); 95 } 96 97 //===----------------------------------------------------------------------===// 98 // EmptyTensorToAllocTensorOp 99 //===----------------------------------------------------------------------===// 100 101 DiagnosedSilenceableFailure 102 EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target, 103 ApplyToEachResultList &results, 104 transform::TransformState &state) { 105 IRRewriter rewriter(target->getContext()); 106 rewriter.setInsertionPoint(target); 107 auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( 108 target, target.getType(), target.getDynamicSizes()); 109 results.push_back(alloc); 110 return DiagnosedSilenceableFailure::success(); 111 } 112 113 //===----------------------------------------------------------------------===// 114 // Transform op registration 115 //===----------------------------------------------------------------------===// 116 117 namespace { 118 /// Registers new ops and declares PDL as dependent dialect since the additional 119 /// ops are using PDL types for operands and results. 120 class BufferizationTransformDialectExtension 121 : public transform::TransformDialectExtension< 122 BufferizationTransformDialectExtension> { 123 public: 124 using Base::Base; 125 126 void init() { 127 declareDependentDialect<pdl::PDLDialect>(); 128 129 declareGeneratedDialect<bufferization::BufferizationDialect>(); 130 declareGeneratedDialect<memref::MemRefDialect>(); 131 132 registerTransformOps< 133 #define GET_OP_LIST 134 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 135 >(); 136 } 137 }; 138 } // namespace 139 140 #define GET_OP_CLASSES 141 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 142 143 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" 144 145 void mlir::bufferization::registerTransformDialectExtension( 146 DialectRegistry ®istry) { 147 registry.addExtensions<BufferizationTransformDialectExtension>(); 148 } 149