1 //===- FuncTransformOps.cpp - Implementation of CF transform ops ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" 10 11 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16 #include "mlir/Dialect/Transform/IR/TransformOps.h" 17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 // Apply...ConversionPatternsOp 24 //===----------------------------------------------------------------------===// 25 26 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns( 27 TypeConverter &typeConverter, RewritePatternSet &patterns) { 28 populateFuncToLLVMConversionPatterns( 29 static_cast<LLVMTypeConverter &>(typeConverter), patterns); 30 } 31 32 LogicalResult 33 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter( 34 transform::TypeConverterBuilderOpInterface builder) { 35 if (builder.getTypeConverterType() != "LLVMTypeConverter") 36 return emitOpError("expected LLVMTypeConverter"); 37 return success(); 38 } 39 40 //===----------------------------------------------------------------------===// 41 // CastAndCallOp 42 //===----------------------------------------------------------------------===// 43 44 DiagnosedSilenceableFailure 45 transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, 46 transform::TransformResults &results, 47 transform::TransformState &state) { 48 SmallVector<Value> inputs; 49 if (getInputs()) 50 llvm::append_range(inputs, state.getPayloadValues(getInputs())); 51 52 SetVector<Value> outputs; 53 if (getOutputs()) { 54 for (auto output : state.getPayloadValues(getOutputs())) 55 outputs.insert(output); 56 57 // Verify that the set of output values to be replaced is unique. 58 if (outputs.size() != 59 llvm::range_size(state.getPayloadValues(getOutputs()))) { 60 return emitSilenceableFailure(getLoc()) 61 << "cast and call output values must be unique"; 62 } 63 } 64 65 // Get the insertion point for the call. 66 auto insertionOps = state.getPayloadOps(getInsertionPoint()); 67 if (!llvm::hasSingleElement(insertionOps)) { 68 return emitSilenceableFailure(getLoc()) 69 << "Only one op can be specified as an insertion point"; 70 } 71 bool insertAfter = getInsertAfter(); 72 Operation *insertionPoint = *insertionOps.begin(); 73 74 // Check that all inputs dominate the insertion point, and the insertion 75 // point dominates all users of the outputs. 76 DominanceInfo dom(insertionPoint); 77 for (Value output : outputs) { 78 for (Operation *user : output.getUsers()) { 79 // If we are inserting after the insertion point operation, the 80 // insertion point operation must properly dominate the user. Otherwise 81 // basic dominance is enough. 82 bool doesDominate = insertAfter 83 ? dom.properlyDominates(insertionPoint, user) 84 : dom.dominates(insertionPoint, user); 85 if (!doesDominate) { 86 return emitDefiniteFailure() 87 << "User " << user << " is not dominated by insertion point " 88 << insertionPoint; 89 } 90 } 91 } 92 93 for (Value input : inputs) { 94 // If we are inserting before the insertion point operation, the 95 // input must properly dominate the insertion point operation. Otherwise 96 // basic dominance is enough. 97 bool doesDominate = insertAfter 98 ? dom.dominates(input, insertionPoint) 99 : dom.properlyDominates(input, insertionPoint); 100 if (!doesDominate) { 101 return emitDefiniteFailure() 102 << "input " << input << " does not dominate insertion point " 103 << insertionPoint; 104 } 105 } 106 107 // Get the function to call. This can either be specified by symbol or as a 108 // transform handle. 109 func::FuncOp targetFunction = nullptr; 110 if (getFunctionName()) { 111 targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>( 112 insertionPoint, *getFunctionName()); 113 if (!targetFunction) { 114 return emitDefiniteFailure() 115 << "unresolved symbol " << *getFunctionName(); 116 } 117 } else if (getFunction()) { 118 auto payloadOps = state.getPayloadOps(getFunction()); 119 if (!llvm::hasSingleElement(payloadOps)) { 120 return emitDefiniteFailure() << "requires a single function to call"; 121 } 122 targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin()); 123 if (!targetFunction) { 124 return emitDefiniteFailure() << "invalid non-function callee"; 125 } 126 } else { 127 llvm_unreachable("Invalid CastAndCall op without a function to call"); 128 return emitDefiniteFailure(); 129 } 130 131 // Verify that the function argument and result lengths match the inputs and 132 // outputs given to this op. 133 if (targetFunction.getNumArguments() != inputs.size()) { 134 return emitSilenceableFailure(targetFunction.getLoc()) 135 << "mismatch between number of function arguments " 136 << targetFunction.getNumArguments() << " and number of inputs " 137 << inputs.size(); 138 } 139 if (targetFunction.getNumResults() != outputs.size()) { 140 return emitSilenceableFailure(targetFunction.getLoc()) 141 << "mismatch between number of function results " 142 << targetFunction->getNumResults() << " and number of outputs " 143 << outputs.size(); 144 } 145 146 // Gather all specified converters. 147 mlir::TypeConverter converter; 148 if (!getRegion().empty()) { 149 for (Operation &op : getRegion().front()) { 150 cast<transform::TypeConverterBuilderOpInterface>(&op) 151 .populateTypeMaterializations(converter); 152 } 153 } 154 155 if (insertAfter) 156 rewriter.setInsertionPointAfter(insertionPoint); 157 else 158 rewriter.setInsertionPoint(insertionPoint); 159 160 for (auto [input, type] : 161 llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) { 162 if (input.getType() != type) { 163 Value newInput = converter.materializeSourceConversion( 164 rewriter, input.getLoc(), type, input); 165 if (!newInput) { 166 return emitDefiniteFailure() << "Failed to materialize conversion of " 167 << input << " to type " << type; 168 } 169 input = newInput; 170 } 171 } 172 173 auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(), 174 targetFunction, inputs); 175 176 // Cast the call results back to the expected types. If any conversions fail 177 // this is a definite failure as the call has been constructed at this point. 178 for (auto [output, newOutput] : 179 llvm::zip_equal(outputs, callOp.getResults())) { 180 Value convertedOutput = newOutput; 181 if (output.getType() != newOutput.getType()) { 182 convertedOutput = converter.materializeTargetConversion( 183 rewriter, output.getLoc(), output.getType(), newOutput); 184 if (!convertedOutput) { 185 return emitDefiniteFailure() 186 << "Failed to materialize conversion of " << newOutput 187 << " to type " << output.getType(); 188 } 189 } 190 rewriter.replaceAllUsesExcept(output, convertedOutput, callOp); 191 } 192 results.set(cast<OpResult>(getResult()), {callOp}); 193 return DiagnosedSilenceableFailure::success(); 194 } 195 196 LogicalResult transform::CastAndCallOp::verify() { 197 if (!getRegion().empty()) { 198 for (Operation &op : getRegion().front()) { 199 if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) { 200 InFlightDiagnostic diag = emitOpError() 201 << "expected children ops to implement " 202 "TypeConverterBuilderOpInterface"; 203 diag.attachNote(op.getLoc()) << "op without interface"; 204 return diag; 205 } 206 } 207 } 208 if (!getFunction() && !getFunctionName()) { 209 return emitOpError() << "expected a function handle or name to call"; 210 } 211 if (getFunction() && getFunctionName()) { 212 return emitOpError() << "function handle and name are mutually exclusive"; 213 } 214 return success(); 215 } 216 217 void transform::CastAndCallOp::getEffects( 218 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 219 transform::onlyReadsHandle(getInsertionPointMutable(), effects); 220 if (getInputs()) 221 transform::onlyReadsHandle(getInputsMutable(), effects); 222 if (getOutputs()) 223 transform::onlyReadsHandle(getOutputsMutable(), effects); 224 if (getFunction()) 225 transform::onlyReadsHandle(getFunctionMutable(), effects); 226 transform::producesHandle(getOperation()->getOpResults(), effects); 227 transform::modifiesPayload(effects); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // Transform op registration 232 //===----------------------------------------------------------------------===// 233 234 namespace { 235 class FuncTransformDialectExtension 236 : public transform::TransformDialectExtension< 237 FuncTransformDialectExtension> { 238 public: 239 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension) 240 241 using Base::Base; 242 243 void init() { 244 declareGeneratedDialect<LLVM::LLVMDialect>(); 245 246 registerTransformOps< 247 #define GET_OP_LIST 248 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" 249 >(); 250 } 251 }; 252 } // namespace 253 254 #define GET_OP_CLASSES 255 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" 256 257 void mlir::func::registerTransformDialectExtension(DialectRegistry ®istry) { 258 registry.addExtensions<FuncTransformDialectExtension>(); 259 } 260