1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 20 #include "mlir/Interfaces/FunctionInterfaces.h" 21 22 using namespace mlir; 23 using namespace mlir::bufferization; 24 using namespace mlir::transform; 25 26 //===----------------------------------------------------------------------===// 27 // BufferLoopHoistingOp 28 //===----------------------------------------------------------------------===// 29 30 DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne( 31 TransformRewriter &rewriter, Operation *target, 32 ApplyToEachResultList &results, TransformState &state) { 33 bufferization::hoistBuffersFromLoops(target); 34 return DiagnosedSilenceableFailure::success(); 35 } 36 37 void transform::BufferLoopHoistingOp::getEffects( 38 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 39 onlyReadsHandle(getTargetMutable(), effects); 40 modifiesPayload(effects); 41 } 42 43 //===----------------------------------------------------------------------===// 44 // OneShotBufferizeOp 45 //===----------------------------------------------------------------------===// 46 47 LogicalResult transform::OneShotBufferizeOp::verify() { 48 if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy") 49 return emitOpError() << "unsupported memcpy op"; 50 if (getPrintConflicts() && !getTestAnalysisOnly()) 51 return emitOpError() << "'print_conflicts' requires 'test_analysis_only'"; 52 if (getDumpAliasSets() && !getTestAnalysisOnly()) 53 return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'"; 54 return success(); 55 } 56 57 DiagnosedSilenceableFailure 58 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, 59 TransformResults &transformResults, 60 TransformState &state) { 61 OneShotBufferizationOptions options; 62 options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops(); 63 options.allowUnknownOps = getAllowUnknownOps(); 64 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 65 options.dumpAliasSets = getDumpAliasSets(); 66 options.testAnalysisOnly = getTestAnalysisOnly(); 67 options.printConflicts = getPrintConflicts(); 68 if (getFunctionBoundaryTypeConversion().has_value()) 69 options.setFunctionBoundaryTypeConversion( 70 *getFunctionBoundaryTypeConversion()); 71 if (getMemcpyOp() == "memref.copy") { 72 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { 73 b.create<memref::CopyOp>(loc, from, to); 74 return success(); 75 }; 76 } else if (getMemcpyOp() == "linalg.copy") { 77 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { 78 b.create<linalg::CopyOp>(loc, from, to); 79 return success(); 80 }; 81 } else { 82 llvm_unreachable("invalid copy op"); 83 } 84 85 auto payloadOps = state.getPayloadOps(getTarget()); 86 for (Operation *target : payloadOps) { 87 if (!isa<ModuleOp, FunctionOpInterface>(target)) 88 return emitSilenceableError() << "expected module or function target"; 89 auto moduleOp = dyn_cast<ModuleOp>(target); 90 if (options.bufferizeFunctionBoundaries) { 91 if (!moduleOp) 92 return emitSilenceableError() << "expected module target"; 93 if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 94 return emitSilenceableError() << "bufferization failed"; 95 } else { 96 if (failed(bufferization::runOneShotBufferize(target, options))) 97 return emitSilenceableError() << "bufferization failed"; 98 } 99 } 100 101 // This transform op is currently restricted to ModuleOps and function ops. 102 // Such ops are modified in-place. 103 transformResults.set(cast<OpResult>(getTransformed()), payloadOps); 104 return DiagnosedSilenceableFailure::success(); 105 } 106 107 //===----------------------------------------------------------------------===// 108 // EliminateEmptyTensorsOp 109 //===----------------------------------------------------------------------===// 110 111 void transform::EliminateEmptyTensorsOp::getEffects( 112 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 113 onlyReadsHandle(getTargetMutable(), effects); 114 modifiesPayload(effects); 115 } 116 117 DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( 118 transform::TransformRewriter &rewriter, TransformResults &transformResults, 119 TransformState &state) { 120 for (Operation *target : state.getPayloadOps(getTarget())) { 121 if (failed(bufferization::eliminateEmptyTensors(rewriter, target))) 122 return mlir::emitSilenceableFailure(target->getLoc()) 123 << "empty tensor elimination failed"; 124 } 125 return DiagnosedSilenceableFailure::success(); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // EmptyTensorToAllocTensorOp 130 //===----------------------------------------------------------------------===// 131 132 DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( 133 transform::TransformRewriter &rewriter, tensor::EmptyOp target, 134 ApplyToEachResultList &results, transform::TransformState &state) { 135 rewriter.setInsertionPoint(target); 136 auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( 137 target, target.getType(), target.getDynamicSizes()); 138 results.push_back(alloc); 139 return DiagnosedSilenceableFailure::success(); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // Transform op registration 144 //===----------------------------------------------------------------------===// 145 146 namespace { 147 /// Registers new ops and declares PDL as dependent dialect since the additional 148 /// ops are using PDL types for operands and results. 149 class BufferizationTransformDialectExtension 150 : public transform::TransformDialectExtension< 151 BufferizationTransformDialectExtension> { 152 public: 153 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 154 BufferizationTransformDialectExtension) 155 156 using Base::Base; 157 158 void init() { 159 declareGeneratedDialect<bufferization::BufferizationDialect>(); 160 declareGeneratedDialect<memref::MemRefDialect>(); 161 162 registerTransformOps< 163 #define GET_OP_LIST 164 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 165 >(); 166 } 167 }; 168 } // namespace 169 170 #define GET_OP_CLASSES 171 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 172 173 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" 174 175 void mlir::bufferization::registerTransformDialectExtension( 176 DialectRegistry ®istry) { 177 registry.addExtensions<BufferizationTransformDialectExtension>(); 178 } 179