1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 19 #include "mlir/IR/FunctionInterfaces.h" 20 21 using namespace mlir; 22 using namespace mlir::bufferization; 23 using namespace mlir::transform; 24 25 //===----------------------------------------------------------------------===// 26 // OneShotBufferizeOp 27 //===----------------------------------------------------------------------===// 28 29 LogicalResult transform::OneShotBufferizeOp::verify() { 30 if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy") 31 return emitOpError() << "unsupported memcpy op"; 32 return success(); 33 } 34 35 DiagnosedSilenceableFailure 36 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, 37 TransformResults &transformResults, 38 TransformState &state) { 39 OneShotBufferizationOptions options; 40 options.allowReturnAllocs = getAllowReturnAllocs(); 41 options.allowUnknownOps = getAllowUnknownOps(); 42 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 43 options.createDeallocs = getCreateDeallocs(); 44 options.testAnalysisOnly = getTestAnalysisOnly(); 45 options.printConflicts = getPrintConflicts(); 46 if (getFunctionBoundaryTypeConversion().has_value()) 47 options.setFunctionBoundaryTypeConversion( 48 *getFunctionBoundaryTypeConversion()); 49 if (getMemcpyOp() == "memref.copy") { 50 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { 51 b.create<memref::CopyOp>(loc, from, to); 52 return success(); 53 }; 54 } else if (getMemcpyOp() == "linalg.copy") { 55 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { 56 b.create<linalg::CopyOp>(loc, from, to); 57 return success(); 58 }; 59 } else { 60 llvm_unreachable("invalid copy op"); 61 } 62 63 auto payloadOps = state.getPayloadOps(getTarget()); 64 for (Operation *target : payloadOps) { 65 if (!isa<ModuleOp, FunctionOpInterface>(target)) 66 return emitSilenceableError() << "expected module or function target"; 67 auto moduleOp = dyn_cast<ModuleOp>(target); 68 if (options.bufferizeFunctionBoundaries) { 69 if (!moduleOp) 70 return emitSilenceableError() << "expected module target"; 71 if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 72 return emitSilenceableError() << "bufferization failed"; 73 } else { 74 if (failed(bufferization::runOneShotBufferize(target, options))) 75 return emitSilenceableError() << "bufferization failed"; 76 } 77 } 78 79 // This transform op is currently restricted to ModuleOps and function ops. 80 // Such ops are modified in-place. 81 transformResults.set(cast<OpResult>(getTransformed()), payloadOps); 82 return DiagnosedSilenceableFailure::success(); 83 } 84 85 //===----------------------------------------------------------------------===// 86 // EliminateEmptyTensorsOp 87 //===----------------------------------------------------------------------===// 88 89 void transform::EliminateEmptyTensorsOp::getEffects( 90 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 91 onlyReadsHandle(getTarget(), effects); 92 modifiesPayload(effects); 93 } 94 95 DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( 96 transform::TransformRewriter &rewriter, TransformResults &transformResults, 97 TransformState &state) { 98 OneShotBufferizationOptions options; 99 options.allowReturnAllocs = true; 100 101 for (Operation *target : state.getPayloadOps(getTarget())) { 102 OneShotAnalysisState state(target, options); 103 if (failed(analyzeOp(target, state))) 104 return mlir::emitSilenceableFailure(target->getLoc()) 105 << "failed to analyze op"; 106 if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( 107 rewriter, target, state))) 108 return mlir::emitSilenceableFailure(target->getLoc()) 109 << "failed to eliminate insert_slice anchored tensor.empty ops"; 110 } 111 return DiagnosedSilenceableFailure::success(); 112 } 113 114 //===----------------------------------------------------------------------===// 115 // EmptyTensorToAllocTensorOp 116 //===----------------------------------------------------------------------===// 117 118 DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( 119 transform::TransformRewriter &rewriter, tensor::EmptyOp target, 120 ApplyToEachResultList &results, transform::TransformState &state) { 121 rewriter.setInsertionPoint(target); 122 auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( 123 target, target.getType(), target.getDynamicSizes()); 124 results.push_back(alloc); 125 return DiagnosedSilenceableFailure::success(); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // Transform op registration 130 //===----------------------------------------------------------------------===// 131 132 namespace { 133 /// Registers new ops and declares PDL as dependent dialect since the additional 134 /// ops are using PDL types for operands and results. 135 class BufferizationTransformDialectExtension 136 : public transform::TransformDialectExtension< 137 BufferizationTransformDialectExtension> { 138 public: 139 using Base::Base; 140 141 void init() { 142 declareGeneratedDialect<bufferization::BufferizationDialect>(); 143 declareGeneratedDialect<memref::MemRefDialect>(); 144 145 registerTransformOps< 146 #define GET_OP_LIST 147 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 148 >(); 149 } 150 }; 151 } // namespace 152 153 #define GET_OP_CLASSES 154 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 155 156 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" 157 158 void mlir::bufferization::registerTransformDialectExtension( 159 DialectRegistry ®istry) { 160 registry.addExtensions<BufferizationTransformDialectExtension>(); 161 } 162