xref: /llvm-project/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
10 
11 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Dialect/Transform/IR/TransformOps.h"
17 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Apply...ConversionPatternsOp
24 //===----------------------------------------------------------------------===//
25 
26 void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns(
27     TypeConverter &typeConverter, RewritePatternSet &patterns) {
28   populateFuncToLLVMConversionPatterns(
29       static_cast<LLVMTypeConverter &>(typeConverter), patterns);
30 }
31 
32 LogicalResult
33 transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
34     transform::TypeConverterBuilderOpInterface builder) {
35   if (builder.getTypeConverterType() != "LLVMTypeConverter")
36     return emitOpError("expected LLVMTypeConverter");
37   return success();
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // CastAndCallOp
42 //===----------------------------------------------------------------------===//
43 
44 DiagnosedSilenceableFailure
45 transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
46                                 transform::TransformResults &results,
47                                 transform::TransformState &state) {
48   SmallVector<Value> inputs;
49   if (getInputs())
50     llvm::append_range(inputs, state.getPayloadValues(getInputs()));
51 
52   SetVector<Value> outputs;
53   if (getOutputs()) {
54     for (auto output : state.getPayloadValues(getOutputs()))
55       outputs.insert(output);
56 
57     // Verify that the set of output values to be replaced is unique.
58     if (outputs.size() !=
59         llvm::range_size(state.getPayloadValues(getOutputs()))) {
60       return emitSilenceableFailure(getLoc())
61              << "cast and call output values must be unique";
62     }
63   }
64 
65   // Get the insertion point for the call.
66   auto insertionOps = state.getPayloadOps(getInsertionPoint());
67   if (!llvm::hasSingleElement(insertionOps)) {
68     return emitSilenceableFailure(getLoc())
69            << "Only one op can be specified as an insertion point";
70   }
71   bool insertAfter = getInsertAfter();
72   Operation *insertionPoint = *insertionOps.begin();
73 
74   // Check that all inputs dominate the insertion point, and the insertion
75   // point dominates all users of the outputs.
76   DominanceInfo dom(insertionPoint);
77   for (Value output : outputs) {
78     for (Operation *user : output.getUsers()) {
79       // If we are inserting after the insertion point operation, the
80       // insertion point operation must properly dominate the user. Otherwise
81       // basic dominance is enough.
82       bool doesDominate = insertAfter
83                               ? dom.properlyDominates(insertionPoint, user)
84                               : dom.dominates(insertionPoint, user);
85       if (!doesDominate) {
86         return emitDefiniteFailure()
87                << "User " << user << " is not dominated by insertion point "
88                << insertionPoint;
89       }
90     }
91   }
92 
93   for (Value input : inputs) {
94     // If we are inserting before the insertion point operation, the
95     // input must properly dominate the insertion point operation. Otherwise
96     // basic dominance is enough.
97     bool doesDominate = insertAfter
98                             ? dom.dominates(input, insertionPoint)
99                             : dom.properlyDominates(input, insertionPoint);
100     if (!doesDominate) {
101       return emitDefiniteFailure()
102              << "input " << input << " does not dominate insertion point "
103              << insertionPoint;
104     }
105   }
106 
107   // Get the function to call. This can either be specified by symbol or as a
108   // transform handle.
109   func::FuncOp targetFunction = nullptr;
110   if (getFunctionName()) {
111     targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112         insertionPoint, *getFunctionName());
113     if (!targetFunction) {
114       return emitDefiniteFailure()
115              << "unresolved symbol " << *getFunctionName();
116     }
117   } else if (getFunction()) {
118     auto payloadOps = state.getPayloadOps(getFunction());
119     if (!llvm::hasSingleElement(payloadOps)) {
120       return emitDefiniteFailure() << "requires a single function to call";
121     }
122     targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
123     if (!targetFunction) {
124       return emitDefiniteFailure() << "invalid non-function callee";
125     }
126   } else {
127     llvm_unreachable("Invalid CastAndCall op without a function to call");
128     return emitDefiniteFailure();
129   }
130 
131   // Verify that the function argument and result lengths match the inputs and
132   // outputs given to this op.
133   if (targetFunction.getNumArguments() != inputs.size()) {
134     return emitSilenceableFailure(targetFunction.getLoc())
135            << "mismatch between number of function arguments "
136            << targetFunction.getNumArguments() << " and number of inputs "
137            << inputs.size();
138   }
139   if (targetFunction.getNumResults() != outputs.size()) {
140     return emitSilenceableFailure(targetFunction.getLoc())
141            << "mismatch between number of function results "
142            << targetFunction->getNumResults() << " and number of outputs "
143            << outputs.size();
144   }
145 
146   // Gather all specified converters.
147   mlir::TypeConverter converter;
148   if (!getRegion().empty()) {
149     for (Operation &op : getRegion().front()) {
150       cast<transform::TypeConverterBuilderOpInterface>(&op)
151           .populateTypeMaterializations(converter);
152     }
153   }
154 
155   if (insertAfter)
156     rewriter.setInsertionPointAfter(insertionPoint);
157   else
158     rewriter.setInsertionPoint(insertionPoint);
159 
160   for (auto [input, type] :
161        llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
162     if (input.getType() != type) {
163       Value newInput = converter.materializeSourceConversion(
164           rewriter, input.getLoc(), type, input);
165       if (!newInput) {
166         return emitDefiniteFailure() << "Failed to materialize conversion of "
167                                      << input << " to type " << type;
168       }
169       input = newInput;
170     }
171   }
172 
173   auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
174                                               targetFunction, inputs);
175 
176   // Cast the call results back to the expected types. If any conversions fail
177   // this is a definite failure as the call has been constructed at this point.
178   for (auto [output, newOutput] :
179        llvm::zip_equal(outputs, callOp.getResults())) {
180     Value convertedOutput = newOutput;
181     if (output.getType() != newOutput.getType()) {
182       convertedOutput = converter.materializeTargetConversion(
183           rewriter, output.getLoc(), output.getType(), newOutput);
184       if (!convertedOutput) {
185         return emitDefiniteFailure()
186                << "Failed to materialize conversion of " << newOutput
187                << " to type " << output.getType();
188       }
189     }
190     rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
191   }
192   results.set(cast<OpResult>(getResult()), {callOp});
193   return DiagnosedSilenceableFailure::success();
194 }
195 
196 LogicalResult transform::CastAndCallOp::verify() {
197   if (!getRegion().empty()) {
198     for (Operation &op : getRegion().front()) {
199       if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
200         InFlightDiagnostic diag = emitOpError()
201                                   << "expected children ops to implement "
202                                      "TypeConverterBuilderOpInterface";
203         diag.attachNote(op.getLoc()) << "op without interface";
204         return diag;
205       }
206     }
207   }
208   if (!getFunction() && !getFunctionName()) {
209     return emitOpError() << "expected a function handle or name to call";
210   }
211   if (getFunction() && getFunctionName()) {
212     return emitOpError() << "function handle and name are mutually exclusive";
213   }
214   return success();
215 }
216 
217 void transform::CastAndCallOp::getEffects(
218     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
219   transform::onlyReadsHandle(getInsertionPointMutable(), effects);
220   if (getInputs())
221     transform::onlyReadsHandle(getInputsMutable(), effects);
222   if (getOutputs())
223     transform::onlyReadsHandle(getOutputsMutable(), effects);
224   if (getFunction())
225     transform::onlyReadsHandle(getFunctionMutable(), effects);
226   transform::producesHandle(getOperation()->getOpResults(), effects);
227   transform::modifiesPayload(effects);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Transform op registration
232 //===----------------------------------------------------------------------===//
233 
234 namespace {
235 class FuncTransformDialectExtension
236     : public transform::TransformDialectExtension<
237           FuncTransformDialectExtension> {
238 public:
239   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
240 
241   using Base::Base;
242 
243   void init() {
244     declareGeneratedDialect<LLVM::LLVMDialect>();
245 
246     registerTransformOps<
247 #define GET_OP_LIST
248 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
249         >();
250   }
251 };
252 } // namespace
253 
254 #define GET_OP_CLASSES
255 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc"
256 
257 void mlir::func::registerTransformDialectExtension(DialectRegistry &registry) {
258   registry.addExtensions<FuncTransformDialectExtension>();
259 }
260