1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/IR/FunctionInterfaces.h" 19 20 using namespace mlir; 21 using namespace mlir::bufferization; 22 using namespace mlir::transform; 23 24 //===----------------------------------------------------------------------===// 25 // OneShotBufferizeOp 26 //===----------------------------------------------------------------------===// 27 28 DiagnosedSilenceableFailure 29 transform::OneShotBufferizeOp::apply(TransformResults &transformResults, 30 TransformState &state) { 31 OneShotBufferizationOptions options; 32 options.allowReturnAllocs = getAllowReturnAllocs(); 33 options.allowUnknownOps = getAllowUnknownOps(); 34 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); 35 options.createDeallocs = getCreateDeallocs(); 36 options.testAnalysisOnly = getTestAnalysisOnly(); 37 options.printConflicts = getPrintConflicts(); 38 if (getFunctionBoundaryTypeConversion().has_value()) 39 options.setFunctionBoundaryTypeConversion( 40 *getFunctionBoundaryTypeConversion()); 41 42 auto payloadOps = state.getPayloadOps(getTarget()); 43 for (Operation *target : payloadOps) { 44 if (!isa<ModuleOp, FunctionOpInterface>(target)) 45 return emitSilenceableError() << "expected module or function target"; 46 auto moduleOp = dyn_cast<ModuleOp>(target); 47 if (options.bufferizeFunctionBoundaries) { 48 if (!moduleOp) 49 return emitSilenceableError() << "expected module target"; 50 if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) 51 return emitSilenceableError() << "bufferization failed"; 52 } else { 53 if (failed(bufferization::runOneShotBufferize(target, options))) 54 return emitSilenceableError() << "bufferization failed"; 55 } 56 } 57 58 // This transform op is currently restricted to ModuleOps and function ops. 59 // Such ops are modified in-place. 60 transformResults.set(cast<OpResult>(getTransformed()), payloadOps); 61 return DiagnosedSilenceableFailure::success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // EliminateEmptyTensorsOp 66 //===----------------------------------------------------------------------===// 67 68 void transform::EliminateEmptyTensorsOp::getEffects( 69 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 70 onlyReadsHandle(getTarget(), effects); 71 modifiesPayload(effects); 72 } 73 74 DiagnosedSilenceableFailure 75 transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, 76 TransformState &state) { 77 IRRewriter rewriter(getContext()); 78 OneShotBufferizationOptions options; 79 options.allowReturnAllocs = true; 80 81 for (Operation *target : state.getPayloadOps(getTarget())) { 82 OneShotAnalysisState state(target, options); 83 if (failed(analyzeOp(target, state))) 84 return mlir::emitSilenceableFailure(target->getLoc()) 85 << "failed to analyze op"; 86 if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( 87 rewriter, target, state))) 88 return mlir::emitSilenceableFailure(target->getLoc()) 89 << "failed to eliminate insert_slice anchored tensor.empty ops"; 90 } 91 return DiagnosedSilenceableFailure::success(); 92 } 93 94 //===----------------------------------------------------------------------===// 95 // EmptyTensorToAllocTensorOp 96 //===----------------------------------------------------------------------===// 97 98 DiagnosedSilenceableFailure 99 EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target, 100 ApplyToEachResultList &results, 101 transform::TransformState &state) { 102 IRRewriter rewriter(target->getContext()); 103 rewriter.setInsertionPoint(target); 104 auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( 105 target, target.getType(), target.getDynamicSizes()); 106 results.push_back(alloc); 107 return DiagnosedSilenceableFailure::success(); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // Transform op registration 112 //===----------------------------------------------------------------------===// 113 114 namespace { 115 /// Registers new ops and declares PDL as dependent dialect since the additional 116 /// ops are using PDL types for operands and results. 117 class BufferizationTransformDialectExtension 118 : public transform::TransformDialectExtension< 119 BufferizationTransformDialectExtension> { 120 public: 121 using Base::Base; 122 123 void init() { 124 declareGeneratedDialect<bufferization::BufferizationDialect>(); 125 declareGeneratedDialect<memref::MemRefDialect>(); 126 127 registerTransformOps< 128 #define GET_OP_LIST 129 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 130 >(); 131 } 132 }; 133 } // namespace 134 135 #define GET_OP_CLASSES 136 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" 137 138 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" 139 140 void mlir::bufferization::registerTransformDialectExtension( 141 DialectRegistry ®istry) { 142 registry.addExtensions<BufferizationTransformDialectExtension>(); 143 } 144