1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 19 #include "mlir/IR/FunctionInterfaces.h" 20 21 using namespace mlir; 22 using namespace mlir::bufferization; 23 using namespace mlir::transform; 24 25 //===----------------------------------------------------------------------===// 26 // BufferLoopHoistingOp 27 //===----------------------------------------------------------------------===// 28 29 DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne( 30 TransformRewriter &rewriter, Operation *target, 31 ApplyToEachResultList &results, TransformState &state) { 32 bufferization::hoistBuffersFromLoops(target); 33 return DiagnosedSilenceableFailure::success(); 34 } 35 36 void transform::BufferLoopHoistingOp::getEffects( 37 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 38 onlyReadsHandle(getTarget(), effects); 39 modifiesPayload(effects); 40 } 41 42 //===----------------------------------------------------------------------===// 43 // OneShotBufferizeOp 44 //===----------------------------------------------------------------------===// 45 46 LogicalResult transform::OneShotBufferizeOp::verify() { 47 if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy") 48 return emitOpError() << "unsupported memcpy op"; 49 return success(); 50 } 51 52 DiagnosedSilenceableFailure 53 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, 54 TransformResults &transformResults, 55 TransformState &state) { 56 OneShotBufferizationOptions options; 57 options.allowReturnAllocs = getAllowReturnAllocs(); 58 options.allowUnknownOps = getAllowUnknownOps(); 59 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 60 options.createDeallocs = getCreateDeallocs(); 61 options.testAnalysisOnly = getTestAnalysisOnly(); 62 options.printConflicts = getPrintConflicts(); 63 if (getFunctionBoundaryTypeConversion().has_value()) 64 options.setFunctionBoundaryTypeConversion( 65 *getFunctionBoundaryTypeConversion()); 66 if (getMemcpyOp() == "memref.copy") { 67 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { 68 b.create<memref::CopyOp>(loc, from, to); 69 return success(); 70 }; 71 } else if (getMemcpyOp() == "linalg.copy") { 72 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { 73 b.create<linalg::CopyOp>(loc, from, to); 74 return success(); 75 }; 76 } else { 77 llvm_unreachable("invalid copy op"); 78 } 79 80 auto payloadOps = state.getPayloadOps(getTarget()); 81 for (Operation *target : payloadOps) { 82 if (!isa<ModuleOp, FunctionOpInterface>(target)) 83 return emitSilenceableError() << "expected module or function target"; 84 auto moduleOp = dyn_cast<ModuleOp>(target); 85 if (options.bufferizeFunctionBoundaries) { 86 if (!moduleOp) 87 return emitSilenceableError() << "expected module target"; 88 if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 89 return emitSilenceableError() << "bufferization failed"; 90 } else { 91 if (failed(bufferization::runOneShotBufferize(target, options))) 92 return emitSilenceableError() << "bufferization failed"; 93 } 94 } 95 96 // This transform op is currently restricted to ModuleOps and function ops. 97 // Such ops are modified in-place. 98 transformResults.set(cast<OpResult>(getTransformed()), payloadOps); 99 return DiagnosedSilenceableFailure::success(); 100 } 101 102 //===----------------------------------------------------------------------===// 103 // EliminateEmptyTensorsOp 104 //===----------------------------------------------------------------------===// 105 106 void transform::EliminateEmptyTensorsOp::getEffects( 107 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 108 onlyReadsHandle(getTarget(), effects); 109 modifiesPayload(effects); 110 } 111 112 DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( 113 transform::TransformRewriter &rewriter, TransformResults &transformResults, 114 TransformState &state) { 115 OneShotBufferizationOptions options; 116 options.allowReturnAllocs = true; 117 118 for (Operation *target : state.getPayloadOps(getTarget())) { 119 OneShotAnalysisState state(target, options); 120 if (failed(analyzeOp(target, state))) 121 return mlir::emitSilenceableFailure(target->getLoc()) 122 << "failed to analyze op"; 123 if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( 124 rewriter, target, state))) 125 return mlir::emitSilenceableFailure(target->getLoc()) 126 << "failed to eliminate insert_slice anchored tensor.empty ops"; 127 } 128 return DiagnosedSilenceableFailure::success(); 129 } 130 131 //===----------------------------------------------------------------------===// 132 // EmptyTensorToAllocTensorOp 133 //===----------------------------------------------------------------------===// 134 135 DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( 136 transform::TransformRewriter &rewriter, tensor::EmptyOp target, 137 ApplyToEachResultList &results, transform::TransformState &state) { 138 rewriter.setInsertionPoint(target); 139 auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( 140 target, target.getType(), target.getDynamicSizes()); 141 results.push_back(alloc); 142 return DiagnosedSilenceableFailure::success(); 143 } 144 145 //===----------------------------------------------------------------------===// 146 // Transform op registration 147 //===----------------------------------------------------------------------===// 148 149 namespace { 150 /// Registers new ops and declares PDL as dependent dialect since the additional 151 /// ops are using PDL types for operands and results. 152 class BufferizationTransformDialectExtension 153 : public transform::TransformDialectExtension< 154 BufferizationTransformDialectExtension> { 155 public: 156 using Base::Base; 157 158 void init() { 159 declareGeneratedDialect<bufferization::BufferizationDialect>(); 160 declareGeneratedDialect<memref::MemRefDialect>(); 161 162 registerTransformOps< 163 #define GET_OP_LIST 164 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 165 >(); 166 } 167 }; 168 } // namespace 169 170 #define GET_OP_CLASSES 171 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 172 173 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" 174 175 void mlir::bufferization::registerTransformDialectExtension( 176 DialectRegistry ®istry) { 177 registry.addExtensions<BufferizationTransformDialectExtension>(); 178 } 179