xref: /llvm-project/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1920c4612SNicolas Vasilache //===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
2920c4612SNicolas Vasilache //
3920c4612SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4920c4612SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
5920c4612SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6920c4612SNicolas Vasilache //
7920c4612SNicolas Vasilache //===----------------------------------------------------------------------===//
8920c4612SNicolas Vasilache 
9920c4612SNicolas Vasilache #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
10920c4612SNicolas Vasilache 
11920c4612SNicolas Vasilache #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12920c4612SNicolas Vasilache #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13920c4612SNicolas Vasilache #include "mlir/Dialect/Func/IR/FuncOps.h"
14920c4612SNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15920c4612SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16920c4612SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformOps.h"
175a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1842b16035SQuinn Dawkins #include "mlir/Transforms/DialectConversion.h"
19920c4612SNicolas Vasilache 
20920c4612SNicolas Vasilache using namespace mlir;
21920c4612SNicolas Vasilache 
22920c4612SNicolas Vasilache //===----------------------------------------------------------------------===//
23920c4612SNicolas Vasilache // Apply...ConversionPatternsOp
24920c4612SNicolas Vasilache //===----------------------------------------------------------------------===//
25920c4612SNicolas Vasilache 
26920c4612SNicolas Vasilache void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
27920c4612SNicolas Vasilache     TypeConverter &typeConverter, RewritePatternSet &patterns) {
28920c4612SNicolas Vasilache   populateFuncToLLVMConversionPatterns(
29920c4612SNicolas Vasilache       static_cast<LLVMTypeConverter &>(typeConverter), patterns);
30920c4612SNicolas Vasilache }
31920c4612SNicolas Vasilache 
32920c4612SNicolas Vasilache LogicalResult
33920c4612SNicolas Vasilache transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
34920c4612SNicolas Vasilache     transform::TypeConverterBuilderOpInterface builder) {
35920c4612SNicolas Vasilache   if (builder.getTypeConverterType() != "LLVMTypeConverter")
36920c4612SNicolas Vasilache     return emitOpError("expected LLVMTypeConverter");
37920c4612SNicolas Vasilache   return success();
38920c4612SNicolas Vasilache }
39920c4612SNicolas Vasilache 
40920c4612SNicolas Vasilache //===----------------------------------------------------------------------===//
4142b16035SQuinn Dawkins // CastAndCallOp
4242b16035SQuinn Dawkins //===----------------------------------------------------------------------===//
4342b16035SQuinn Dawkins 
4442b16035SQuinn Dawkins DiagnosedSilenceableFailure
4542b16035SQuinn Dawkins transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
4642b16035SQuinn Dawkins                                 transform::TransformResults &results,
4742b16035SQuinn Dawkins                                 transform::TransformState &state) {
4842b16035SQuinn Dawkins   SmallVector<Value> inputs;
4942b16035SQuinn Dawkins   if (getInputs())
5042b16035SQuinn Dawkins     llvm::append_range(inputs, state.getPayloadValues(getInputs()));
5142b16035SQuinn Dawkins 
5242b16035SQuinn Dawkins   SetVector<Value> outputs;
5342b16035SQuinn Dawkins   if (getOutputs()) {
5442b16035SQuinn Dawkins     for (auto output : state.getPayloadValues(getOutputs()))
5542b16035SQuinn Dawkins       outputs.insert(output);
5642b16035SQuinn Dawkins 
5742b16035SQuinn Dawkins     // Verify that the set of output values to be replaced is unique.
5842b16035SQuinn Dawkins     if (outputs.size() !=
5942b16035SQuinn Dawkins         llvm::range_size(state.getPayloadValues(getOutputs()))) {
6042b16035SQuinn Dawkins       return emitSilenceableFailure(getLoc())
6142b16035SQuinn Dawkins              << "cast and call output values must be unique";
6242b16035SQuinn Dawkins     }
6342b16035SQuinn Dawkins   }
6442b16035SQuinn Dawkins 
6542b16035SQuinn Dawkins   // Get the insertion point for the call.
6642b16035SQuinn Dawkins   auto insertionOps = state.getPayloadOps(getInsertionPoint());
6742b16035SQuinn Dawkins   if (!llvm::hasSingleElement(insertionOps)) {
6842b16035SQuinn Dawkins     return emitSilenceableFailure(getLoc())
6942b16035SQuinn Dawkins            << "Only one op can be specified as an insertion point";
7042b16035SQuinn Dawkins   }
7142b16035SQuinn Dawkins   bool insertAfter = getInsertAfter();
7242b16035SQuinn Dawkins   Operation *insertionPoint = *insertionOps.begin();
7342b16035SQuinn Dawkins 
7442b16035SQuinn Dawkins   // Check that all inputs dominate the insertion point, and the insertion
7542b16035SQuinn Dawkins   // point dominates all users of the outputs.
7642b16035SQuinn Dawkins   DominanceInfo dom(insertionPoint);
7742b16035SQuinn Dawkins   for (Value output : outputs) {
7842b16035SQuinn Dawkins     for (Operation *user : output.getUsers()) {
7942b16035SQuinn Dawkins       // If we are inserting after the insertion point operation, the
8042b16035SQuinn Dawkins       // insertion point operation must properly dominate the user. Otherwise
8142b16035SQuinn Dawkins       // basic dominance is enough.
8242b16035SQuinn Dawkins       bool doesDominate = insertAfter
8342b16035SQuinn Dawkins                               ? dom.properlyDominates(insertionPoint, user)
8442b16035SQuinn Dawkins                               : dom.dominates(insertionPoint, user);
8542b16035SQuinn Dawkins       if (!doesDominate) {
8642b16035SQuinn Dawkins         return emitDefiniteFailure()
8742b16035SQuinn Dawkins                << "User " << user << " is not dominated by insertion point "
8842b16035SQuinn Dawkins                << insertionPoint;
8942b16035SQuinn Dawkins       }
9042b16035SQuinn Dawkins     }
9142b16035SQuinn Dawkins   }
9242b16035SQuinn Dawkins 
9342b16035SQuinn Dawkins   for (Value input : inputs) {
9442b16035SQuinn Dawkins     // If we are inserting before the insertion point operation, the
9542b16035SQuinn Dawkins     // input must properly dominate the insertion point operation. Otherwise
9642b16035SQuinn Dawkins     // basic dominance is enough.
9742b16035SQuinn Dawkins     bool doesDominate = insertAfter
9842b16035SQuinn Dawkins                             ? dom.dominates(input, insertionPoint)
9942b16035SQuinn Dawkins                             : dom.properlyDominates(input, insertionPoint);
10042b16035SQuinn Dawkins     if (!doesDominate) {
10142b16035SQuinn Dawkins       return emitDefiniteFailure()
10242b16035SQuinn Dawkins              << "input " << input << " does not dominate insertion point "
10342b16035SQuinn Dawkins              << insertionPoint;
10442b16035SQuinn Dawkins     }
10542b16035SQuinn Dawkins   }
10642b16035SQuinn Dawkins 
10742b16035SQuinn Dawkins   // Get the function to call. This can either be specified by symbol or as a
10842b16035SQuinn Dawkins   // transform handle.
10942b16035SQuinn Dawkins   func::FuncOp targetFunction = nullptr;
11042b16035SQuinn Dawkins   if (getFunctionName()) {
11142b16035SQuinn Dawkins     targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
11242b16035SQuinn Dawkins         insertionPoint, *getFunctionName());
11342b16035SQuinn Dawkins     if (!targetFunction) {
11442b16035SQuinn Dawkins       return emitDefiniteFailure()
11542b16035SQuinn Dawkins              << "unresolved symbol " << *getFunctionName();
11642b16035SQuinn Dawkins     }
11742b16035SQuinn Dawkins   } else if (getFunction()) {
11842b16035SQuinn Dawkins     auto payloadOps = state.getPayloadOps(getFunction());
11942b16035SQuinn Dawkins     if (!llvm::hasSingleElement(payloadOps)) {
12042b16035SQuinn Dawkins       return emitDefiniteFailure() << "requires a single function to call";
12142b16035SQuinn Dawkins     }
12242b16035SQuinn Dawkins     targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
12342b16035SQuinn Dawkins     if (!targetFunction) {
12442b16035SQuinn Dawkins       return emitDefiniteFailure() << "invalid non-function callee";
12542b16035SQuinn Dawkins     }
12642b16035SQuinn Dawkins   } else {
12742b16035SQuinn Dawkins     llvm_unreachable("Invalid CastAndCall op without a function to call");
12842b16035SQuinn Dawkins     return emitDefiniteFailure();
12942b16035SQuinn Dawkins   }
13042b16035SQuinn Dawkins 
13142b16035SQuinn Dawkins   // Verify that the function argument and result lengths match the inputs and
13242b16035SQuinn Dawkins   // outputs given to this op.
13342b16035SQuinn Dawkins   if (targetFunction.getNumArguments() != inputs.size()) {
13442b16035SQuinn Dawkins     return emitSilenceableFailure(targetFunction.getLoc())
13542b16035SQuinn Dawkins            << "mismatch between number of function arguments "
13642b16035SQuinn Dawkins            << targetFunction.getNumArguments() << " and number of inputs "
13742b16035SQuinn Dawkins            << inputs.size();
13842b16035SQuinn Dawkins   }
13942b16035SQuinn Dawkins   if (targetFunction.getNumResults() != outputs.size()) {
14042b16035SQuinn Dawkins     return emitSilenceableFailure(targetFunction.getLoc())
14142b16035SQuinn Dawkins            << "mismatch between number of function results "
14242b16035SQuinn Dawkins            << targetFunction->getNumResults() << " and number of outputs "
14342b16035SQuinn Dawkins            << outputs.size();
14442b16035SQuinn Dawkins   }
14542b16035SQuinn Dawkins 
14642b16035SQuinn Dawkins   // Gather all specified converters.
14742b16035SQuinn Dawkins   mlir::TypeConverter converter;
14842b16035SQuinn Dawkins   if (!getRegion().empty()) {
14942b16035SQuinn Dawkins     for (Operation &op : getRegion().front()) {
15042b16035SQuinn Dawkins       cast<transform::TypeConverterBuilderOpInterface>(&op)
15142b16035SQuinn Dawkins           .populateTypeMaterializations(converter);
15242b16035SQuinn Dawkins     }
15342b16035SQuinn Dawkins   }
15442b16035SQuinn Dawkins 
15542b16035SQuinn Dawkins   if (insertAfter)
15642b16035SQuinn Dawkins     rewriter.setInsertionPointAfter(insertionPoint);
15742b16035SQuinn Dawkins   else
15842b16035SQuinn Dawkins     rewriter.setInsertionPoint(insertionPoint);
15942b16035SQuinn Dawkins 
16042b16035SQuinn Dawkins   for (auto [input, type] :
16142b16035SQuinn Dawkins        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
16242b16035SQuinn Dawkins     if (input.getType() != type) {
16342b16035SQuinn Dawkins       Value newInput = converter.materializeSourceConversion(
16442b16035SQuinn Dawkins           rewriter, input.getLoc(), type, input);
16542b16035SQuinn Dawkins       if (!newInput) {
16642b16035SQuinn Dawkins         return emitDefiniteFailure() << "Failed to materialize conversion of "
16742b16035SQuinn Dawkins                                      << input << " to type " << type;
16842b16035SQuinn Dawkins       }
16942b16035SQuinn Dawkins       input = newInput;
17042b16035SQuinn Dawkins     }
17142b16035SQuinn Dawkins   }
17242b16035SQuinn Dawkins 
17342b16035SQuinn Dawkins   auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
17442b16035SQuinn Dawkins                                               targetFunction, inputs);
17542b16035SQuinn Dawkins 
17642b16035SQuinn Dawkins   // Cast the call results back to the expected types. If any conversions fail
17742b16035SQuinn Dawkins   // this is a definite failure as the call has been constructed at this point.
17842b16035SQuinn Dawkins   for (auto [output, newOutput] :
17942b16035SQuinn Dawkins        llvm::zip_equal(outputs, callOp.getResults())) {
18042b16035SQuinn Dawkins     Value convertedOutput = newOutput;
18142b16035SQuinn Dawkins     if (output.getType() != newOutput.getType()) {
18242b16035SQuinn Dawkins       convertedOutput = converter.materializeTargetConversion(
18342b16035SQuinn Dawkins           rewriter, output.getLoc(), output.getType(), newOutput);
18442b16035SQuinn Dawkins       if (!convertedOutput) {
18542b16035SQuinn Dawkins         return emitDefiniteFailure()
18642b16035SQuinn Dawkins                << "Failed to materialize conversion of " << newOutput
18742b16035SQuinn Dawkins                << " to type " << output.getType();
18842b16035SQuinn Dawkins       }
18942b16035SQuinn Dawkins     }
19042b16035SQuinn Dawkins     rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
19142b16035SQuinn Dawkins   }
19242b16035SQuinn Dawkins   results.set(cast<OpResult>(getResult()), {callOp});
19342b16035SQuinn Dawkins   return DiagnosedSilenceableFailure::success();
19442b16035SQuinn Dawkins }
19542b16035SQuinn Dawkins 
19642b16035SQuinn Dawkins LogicalResult transform::CastAndCallOp::verify() {
19742b16035SQuinn Dawkins   if (!getRegion().empty()) {
19842b16035SQuinn Dawkins     for (Operation &op : getRegion().front()) {
19942b16035SQuinn Dawkins       if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
20042b16035SQuinn Dawkins         InFlightDiagnostic diag = emitOpError()
20142b16035SQuinn Dawkins                                   << "expected children ops to implement "
20242b16035SQuinn Dawkins                                      "TypeConverterBuilderOpInterface";
20342b16035SQuinn Dawkins         diag.attachNote(op.getLoc()) << "op without interface";
20442b16035SQuinn Dawkins         return diag;
20542b16035SQuinn Dawkins       }
20642b16035SQuinn Dawkins     }
20742b16035SQuinn Dawkins   }
20842b16035SQuinn Dawkins   if (!getFunction() && !getFunctionName()) {
20942b16035SQuinn Dawkins     return emitOpError() << "expected a function handle or name to call";
21042b16035SQuinn Dawkins   }
21142b16035SQuinn Dawkins   if (getFunction() && getFunctionName()) {
21242b16035SQuinn Dawkins     return emitOpError() << "function handle and name are mutually exclusive";
21342b16035SQuinn Dawkins   }
21442b16035SQuinn Dawkins   return success();
21542b16035SQuinn Dawkins }
21642b16035SQuinn Dawkins 
21742b16035SQuinn Dawkins void transform::CastAndCallOp::getEffects(
21842b16035SQuinn Dawkins     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2192c1ae801Sdonald chen   transform::onlyReadsHandle(getInsertionPointMutable(), effects);
22042b16035SQuinn Dawkins   if (getInputs())
2212c1ae801Sdonald chen     transform::onlyReadsHandle(getInputsMutable(), effects);
22242b16035SQuinn Dawkins   if (getOutputs())
2232c1ae801Sdonald chen     transform::onlyReadsHandle(getOutputsMutable(), effects);
22442b16035SQuinn Dawkins   if (getFunction())
2252c1ae801Sdonald chen     transform::onlyReadsHandle(getFunctionMutable(), effects);
2262c1ae801Sdonald chen   transform::producesHandle(getOperation()->getOpResults(), effects);
22742b16035SQuinn Dawkins   transform::modifiesPayload(effects);
22842b16035SQuinn Dawkins }
22942b16035SQuinn Dawkins 
23042b16035SQuinn Dawkins //===----------------------------------------------------------------------===//
231920c4612SNicolas Vasilache // Transform op registration
232920c4612SNicolas Vasilache //===----------------------------------------------------------------------===//
233920c4612SNicolas Vasilache 
234920c4612SNicolas Vasilache namespace {
235920c4612SNicolas Vasilache class FuncTransformDialectExtension
236920c4612SNicolas Vasilache     : public transform::TransformDialectExtension<
237920c4612SNicolas Vasilache           FuncTransformDialectExtension> {
238920c4612SNicolas Vasilache public:
239*84cc1865SNikhil Kalra   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
240*84cc1865SNikhil Kalra 
241920c4612SNicolas Vasilache   using Base::Base;
242920c4612SNicolas Vasilache 
243920c4612SNicolas Vasilache   void init() {
244920c4612SNicolas Vasilache     declareGeneratedDialect<LLVM::LLVMDialect>();
245920c4612SNicolas Vasilache 
246920c4612SNicolas Vasilache     registerTransformOps<
247920c4612SNicolas Vasilache #define GET_OP_LIST
248920c4612SNicolas Vasilache #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
249920c4612SNicolas Vasilache         >();
250920c4612SNicolas Vasilache   }
251920c4612SNicolas Vasilache };
252920c4612SNicolas Vasilache } // namespace
253920c4612SNicolas Vasilache 
254920c4612SNicolas Vasilache #define GET_OP_CLASSES
255920c4612SNicolas Vasilache #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
256920c4612SNicolas Vasilache 
257920c4612SNicolas Vasilache void mlir::func::registerTransformDialectExtension(DialectRegistry &registry) {
258920c4612SNicolas Vasilache   registry.addExtensions<FuncTransformDialectExtension>();
259920c4612SNicolas Vasilache }
260