1920c4612SNicolas Vasilache //===- FuncTransformOps.cpp - Implementation of CF transform ops ---===// 2920c4612SNicolas Vasilache // 3920c4612SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4920c4612SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 5920c4612SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6920c4612SNicolas Vasilache // 7920c4612SNicolas Vasilache //===----------------------------------------------------------------------===// 8920c4612SNicolas Vasilache 9920c4612SNicolas Vasilache #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" 10920c4612SNicolas Vasilache 11920c4612SNicolas Vasilache #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 12920c4612SNicolas Vasilache #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 13920c4612SNicolas Vasilache #include "mlir/Dialect/Func/IR/FuncOps.h" 14920c4612SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15920c4612SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h" 16920c4612SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformOps.h" 175a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 1842b16035SQuinn Dawkins #include "mlir/Transforms/DialectConversion.h" 19920c4612SNicolas Vasilache 20920c4612SNicolas Vasilache using namespace mlir; 21920c4612SNicolas Vasilache 22920c4612SNicolas Vasilache //===----------------------------------------------------------------------===// 23920c4612SNicolas Vasilache // Apply...ConversionPatternsOp 24920c4612SNicolas Vasilache //===----------------------------------------------------------------------===// 25920c4612SNicolas Vasilache 26920c4612SNicolas Vasilache void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns( 27920c4612SNicolas Vasilache TypeConverter &typeConverter, RewritePatternSet &patterns) { 28920c4612SNicolas Vasilache populateFuncToLLVMConversionPatterns( 29920c4612SNicolas Vasilache static_cast<LLVMTypeConverter &>(typeConverter), patterns); 30920c4612SNicolas Vasilache } 31920c4612SNicolas Vasilache 32920c4612SNicolas Vasilache LogicalResult 33920c4612SNicolas Vasilache transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter( 34920c4612SNicolas Vasilache transform::TypeConverterBuilderOpInterface builder) { 35920c4612SNicolas Vasilache if (builder.getTypeConverterType() != "LLVMTypeConverter") 36920c4612SNicolas Vasilache return emitOpError("expected LLVMTypeConverter"); 37920c4612SNicolas Vasilache return success(); 38920c4612SNicolas Vasilache } 39920c4612SNicolas Vasilache 40920c4612SNicolas Vasilache //===----------------------------------------------------------------------===// 4142b16035SQuinn Dawkins // CastAndCallOp 4242b16035SQuinn Dawkins //===----------------------------------------------------------------------===// 4342b16035SQuinn Dawkins 4442b16035SQuinn Dawkins DiagnosedSilenceableFailure 4542b16035SQuinn Dawkins transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter, 4642b16035SQuinn Dawkins transform::TransformResults &results, 4742b16035SQuinn Dawkins transform::TransformState &state) { 4842b16035SQuinn Dawkins SmallVector<Value> inputs; 4942b16035SQuinn Dawkins if (getInputs()) 5042b16035SQuinn Dawkins llvm::append_range(inputs, state.getPayloadValues(getInputs())); 5142b16035SQuinn Dawkins 5242b16035SQuinn Dawkins SetVector<Value> outputs; 5342b16035SQuinn Dawkins if (getOutputs()) { 5442b16035SQuinn Dawkins for (auto output : state.getPayloadValues(getOutputs())) 5542b16035SQuinn Dawkins outputs.insert(output); 5642b16035SQuinn Dawkins 5742b16035SQuinn Dawkins // Verify that the set of output values to be replaced is unique. 5842b16035SQuinn Dawkins if (outputs.size() != 5942b16035SQuinn Dawkins llvm::range_size(state.getPayloadValues(getOutputs()))) { 6042b16035SQuinn Dawkins return emitSilenceableFailure(getLoc()) 6142b16035SQuinn Dawkins << "cast and call output values must be unique"; 6242b16035SQuinn Dawkins } 6342b16035SQuinn Dawkins } 6442b16035SQuinn Dawkins 6542b16035SQuinn Dawkins // Get the insertion point for the call. 6642b16035SQuinn Dawkins auto insertionOps = state.getPayloadOps(getInsertionPoint()); 6742b16035SQuinn Dawkins if (!llvm::hasSingleElement(insertionOps)) { 6842b16035SQuinn Dawkins return emitSilenceableFailure(getLoc()) 6942b16035SQuinn Dawkins << "Only one op can be specified as an insertion point"; 7042b16035SQuinn Dawkins } 7142b16035SQuinn Dawkins bool insertAfter = getInsertAfter(); 7242b16035SQuinn Dawkins Operation *insertionPoint = *insertionOps.begin(); 7342b16035SQuinn Dawkins 7442b16035SQuinn Dawkins // Check that all inputs dominate the insertion point, and the insertion 7542b16035SQuinn Dawkins // point dominates all users of the outputs. 7642b16035SQuinn Dawkins DominanceInfo dom(insertionPoint); 7742b16035SQuinn Dawkins for (Value output : outputs) { 7842b16035SQuinn Dawkins for (Operation *user : output.getUsers()) { 7942b16035SQuinn Dawkins // If we are inserting after the insertion point operation, the 8042b16035SQuinn Dawkins // insertion point operation must properly dominate the user. Otherwise 8142b16035SQuinn Dawkins // basic dominance is enough. 8242b16035SQuinn Dawkins bool doesDominate = insertAfter 8342b16035SQuinn Dawkins ? dom.properlyDominates(insertionPoint, user) 8442b16035SQuinn Dawkins : dom.dominates(insertionPoint, user); 8542b16035SQuinn Dawkins if (!doesDominate) { 8642b16035SQuinn Dawkins return emitDefiniteFailure() 8742b16035SQuinn Dawkins << "User " << user << " is not dominated by insertion point " 8842b16035SQuinn Dawkins << insertionPoint; 8942b16035SQuinn Dawkins } 9042b16035SQuinn Dawkins } 9142b16035SQuinn Dawkins } 9242b16035SQuinn Dawkins 9342b16035SQuinn Dawkins for (Value input : inputs) { 9442b16035SQuinn Dawkins // If we are inserting before the insertion point operation, the 9542b16035SQuinn Dawkins // input must properly dominate the insertion point operation. Otherwise 9642b16035SQuinn Dawkins // basic dominance is enough. 9742b16035SQuinn Dawkins bool doesDominate = insertAfter 9842b16035SQuinn Dawkins ? dom.dominates(input, insertionPoint) 9942b16035SQuinn Dawkins : dom.properlyDominates(input, insertionPoint); 10042b16035SQuinn Dawkins if (!doesDominate) { 10142b16035SQuinn Dawkins return emitDefiniteFailure() 10242b16035SQuinn Dawkins << "input " << input << " does not dominate insertion point " 10342b16035SQuinn Dawkins << insertionPoint; 10442b16035SQuinn Dawkins } 10542b16035SQuinn Dawkins } 10642b16035SQuinn Dawkins 10742b16035SQuinn Dawkins // Get the function to call. This can either be specified by symbol or as a 10842b16035SQuinn Dawkins // transform handle. 10942b16035SQuinn Dawkins func::FuncOp targetFunction = nullptr; 11042b16035SQuinn Dawkins if (getFunctionName()) { 11142b16035SQuinn Dawkins targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>( 11242b16035SQuinn Dawkins insertionPoint, *getFunctionName()); 11342b16035SQuinn Dawkins if (!targetFunction) { 11442b16035SQuinn Dawkins return emitDefiniteFailure() 11542b16035SQuinn Dawkins << "unresolved symbol " << *getFunctionName(); 11642b16035SQuinn Dawkins } 11742b16035SQuinn Dawkins } else if (getFunction()) { 11842b16035SQuinn Dawkins auto payloadOps = state.getPayloadOps(getFunction()); 11942b16035SQuinn Dawkins if (!llvm::hasSingleElement(payloadOps)) { 12042b16035SQuinn Dawkins return emitDefiniteFailure() << "requires a single function to call"; 12142b16035SQuinn Dawkins } 12242b16035SQuinn Dawkins targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin()); 12342b16035SQuinn Dawkins if (!targetFunction) { 12442b16035SQuinn Dawkins return emitDefiniteFailure() << "invalid non-function callee"; 12542b16035SQuinn Dawkins } 12642b16035SQuinn Dawkins } else { 12742b16035SQuinn Dawkins llvm_unreachable("Invalid CastAndCall op without a function to call"); 12842b16035SQuinn Dawkins return emitDefiniteFailure(); 12942b16035SQuinn Dawkins } 13042b16035SQuinn Dawkins 13142b16035SQuinn Dawkins // Verify that the function argument and result lengths match the inputs and 13242b16035SQuinn Dawkins // outputs given to this op. 13342b16035SQuinn Dawkins if (targetFunction.getNumArguments() != inputs.size()) { 13442b16035SQuinn Dawkins return emitSilenceableFailure(targetFunction.getLoc()) 13542b16035SQuinn Dawkins << "mismatch between number of function arguments " 13642b16035SQuinn Dawkins << targetFunction.getNumArguments() << " and number of inputs " 13742b16035SQuinn Dawkins << inputs.size(); 13842b16035SQuinn Dawkins } 13942b16035SQuinn Dawkins if (targetFunction.getNumResults() != outputs.size()) { 14042b16035SQuinn Dawkins return emitSilenceableFailure(targetFunction.getLoc()) 14142b16035SQuinn Dawkins << "mismatch between number of function results " 14242b16035SQuinn Dawkins << targetFunction->getNumResults() << " and number of outputs " 14342b16035SQuinn Dawkins << outputs.size(); 14442b16035SQuinn Dawkins } 14542b16035SQuinn Dawkins 14642b16035SQuinn Dawkins // Gather all specified converters. 14742b16035SQuinn Dawkins mlir::TypeConverter converter; 14842b16035SQuinn Dawkins if (!getRegion().empty()) { 14942b16035SQuinn Dawkins for (Operation &op : getRegion().front()) { 15042b16035SQuinn Dawkins cast<transform::TypeConverterBuilderOpInterface>(&op) 15142b16035SQuinn Dawkins .populateTypeMaterializations(converter); 15242b16035SQuinn Dawkins } 15342b16035SQuinn Dawkins } 15442b16035SQuinn Dawkins 15542b16035SQuinn Dawkins if (insertAfter) 15642b16035SQuinn Dawkins rewriter.setInsertionPointAfter(insertionPoint); 15742b16035SQuinn Dawkins else 15842b16035SQuinn Dawkins rewriter.setInsertionPoint(insertionPoint); 15942b16035SQuinn Dawkins 16042b16035SQuinn Dawkins for (auto [input, type] : 16142b16035SQuinn Dawkins llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) { 16242b16035SQuinn Dawkins if (input.getType() != type) { 16342b16035SQuinn Dawkins Value newInput = converter.materializeSourceConversion( 16442b16035SQuinn Dawkins rewriter, input.getLoc(), type, input); 16542b16035SQuinn Dawkins if (!newInput) { 16642b16035SQuinn Dawkins return emitDefiniteFailure() << "Failed to materialize conversion of " 16742b16035SQuinn Dawkins << input << " to type " << type; 16842b16035SQuinn Dawkins } 16942b16035SQuinn Dawkins input = newInput; 17042b16035SQuinn Dawkins } 17142b16035SQuinn Dawkins } 17242b16035SQuinn Dawkins 17342b16035SQuinn Dawkins auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(), 17442b16035SQuinn Dawkins targetFunction, inputs); 17542b16035SQuinn Dawkins 17642b16035SQuinn Dawkins // Cast the call results back to the expected types. If any conversions fail 17742b16035SQuinn Dawkins // this is a definite failure as the call has been constructed at this point. 17842b16035SQuinn Dawkins for (auto [output, newOutput] : 17942b16035SQuinn Dawkins llvm::zip_equal(outputs, callOp.getResults())) { 18042b16035SQuinn Dawkins Value convertedOutput = newOutput; 18142b16035SQuinn Dawkins if (output.getType() != newOutput.getType()) { 18242b16035SQuinn Dawkins convertedOutput = converter.materializeTargetConversion( 18342b16035SQuinn Dawkins rewriter, output.getLoc(), output.getType(), newOutput); 18442b16035SQuinn Dawkins if (!convertedOutput) { 18542b16035SQuinn Dawkins return emitDefiniteFailure() 18642b16035SQuinn Dawkins << "Failed to materialize conversion of " << newOutput 18742b16035SQuinn Dawkins << " to type " << output.getType(); 18842b16035SQuinn Dawkins } 18942b16035SQuinn Dawkins } 19042b16035SQuinn Dawkins rewriter.replaceAllUsesExcept(output, convertedOutput, callOp); 19142b16035SQuinn Dawkins } 19242b16035SQuinn Dawkins results.set(cast<OpResult>(getResult()), {callOp}); 19342b16035SQuinn Dawkins return DiagnosedSilenceableFailure::success(); 19442b16035SQuinn Dawkins } 19542b16035SQuinn Dawkins 19642b16035SQuinn Dawkins LogicalResult transform::CastAndCallOp::verify() { 19742b16035SQuinn Dawkins if (!getRegion().empty()) { 19842b16035SQuinn Dawkins for (Operation &op : getRegion().front()) { 19942b16035SQuinn Dawkins if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) { 20042b16035SQuinn Dawkins InFlightDiagnostic diag = emitOpError() 20142b16035SQuinn Dawkins << "expected children ops to implement " 20242b16035SQuinn Dawkins "TypeConverterBuilderOpInterface"; 20342b16035SQuinn Dawkins diag.attachNote(op.getLoc()) << "op without interface"; 20442b16035SQuinn Dawkins return diag; 20542b16035SQuinn Dawkins } 20642b16035SQuinn Dawkins } 20742b16035SQuinn Dawkins } 20842b16035SQuinn Dawkins if (!getFunction() && !getFunctionName()) { 20942b16035SQuinn Dawkins return emitOpError() << "expected a function handle or name to call"; 21042b16035SQuinn Dawkins } 21142b16035SQuinn Dawkins if (getFunction() && getFunctionName()) { 21242b16035SQuinn Dawkins return emitOpError() << "function handle and name are mutually exclusive"; 21342b16035SQuinn Dawkins } 21442b16035SQuinn Dawkins return success(); 21542b16035SQuinn Dawkins } 21642b16035SQuinn Dawkins 21742b16035SQuinn Dawkins void transform::CastAndCallOp::getEffects( 21842b16035SQuinn Dawkins SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 2192c1ae801Sdonald chen transform::onlyReadsHandle(getInsertionPointMutable(), effects); 22042b16035SQuinn Dawkins if (getInputs()) 2212c1ae801Sdonald chen transform::onlyReadsHandle(getInputsMutable(), effects); 22242b16035SQuinn Dawkins if (getOutputs()) 2232c1ae801Sdonald chen transform::onlyReadsHandle(getOutputsMutable(), effects); 22442b16035SQuinn Dawkins if (getFunction()) 2252c1ae801Sdonald chen transform::onlyReadsHandle(getFunctionMutable(), effects); 2262c1ae801Sdonald chen transform::producesHandle(getOperation()->getOpResults(), effects); 22742b16035SQuinn Dawkins transform::modifiesPayload(effects); 22842b16035SQuinn Dawkins } 22942b16035SQuinn Dawkins 23042b16035SQuinn Dawkins //===----------------------------------------------------------------------===// 231920c4612SNicolas Vasilache // Transform op registration 232920c4612SNicolas Vasilache //===----------------------------------------------------------------------===// 233920c4612SNicolas Vasilache 234920c4612SNicolas Vasilache namespace { 235920c4612SNicolas Vasilache class FuncTransformDialectExtension 236920c4612SNicolas Vasilache : public transform::TransformDialectExtension< 237920c4612SNicolas Vasilache FuncTransformDialectExtension> { 238920c4612SNicolas Vasilache public: 239*84cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension) 240*84cc1865SNikhil Kalra 241920c4612SNicolas Vasilache using Base::Base; 242920c4612SNicolas Vasilache 243920c4612SNicolas Vasilache void init() { 244920c4612SNicolas Vasilache declareGeneratedDialect<LLVM::LLVMDialect>(); 245920c4612SNicolas Vasilache 246920c4612SNicolas Vasilache registerTransformOps< 247920c4612SNicolas Vasilache #define GET_OP_LIST 248920c4612SNicolas Vasilache #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" 249920c4612SNicolas Vasilache >(); 250920c4612SNicolas Vasilache } 251920c4612SNicolas Vasilache }; 252920c4612SNicolas Vasilache } // namespace 253920c4612SNicolas Vasilache 254920c4612SNicolas Vasilache #define GET_OP_CLASSES 255920c4612SNicolas Vasilache #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" 256920c4612SNicolas Vasilache 257920c4612SNicolas Vasilache void mlir::func::registerTransformDialectExtension(DialectRegistry ®istry) { 258920c4612SNicolas Vasilache registry.addExtensions<FuncTransformDialectExtension>(); 259920c4612SNicolas Vasilache } 260