xref: /llvm-project/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (revision 7267c85959aa2490e2950f7fb817a76af7e94043)
1 //===- FuncConversions.cpp - Function conversions -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Transforms/DialectConversion.h"
12 
13 using namespace mlir;
14 using namespace mlir::func;
15 
16 /// Flatten the given value ranges into a single vector of values.
17 static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
18   SmallVector<Value> result;
19   for (const auto &vals : values)
20     llvm::append_range(result, vals);
21   return result;
22 }
23 
24 namespace {
25 /// Converts the operand and result types of the CallOp, used together with the
26 /// FuncOpSignatureConversion.
27 struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
28   using OpConversionPattern<CallOp>::OpConversionPattern;
29 
30   /// Hook for derived classes to implement combined matching and rewriting.
31   LogicalResult
32   matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
33                   ConversionPatternRewriter &rewriter) const override {
34     // Convert the original function results. Keep track of how many result
35     // types an original result type is converted into.
36     SmallVector<size_t> numResultsReplacments;
37     SmallVector<Type, 1> convertedResults;
38     size_t numFlattenedResults = 0;
39     for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
40       if (failed(typeConverter->convertTypes(type, convertedResults)))
41         return failure();
42       numResultsReplacments.push_back(convertedResults.size() -
43                                       numFlattenedResults);
44       numFlattenedResults = convertedResults.size();
45     }
46 
47     // Substitute with the new result types from the corresponding FuncType
48     // conversion.
49     auto newCallOp = rewriter.create<CallOp>(
50         callOp.getLoc(), callOp.getCallee(), convertedResults,
51         flattenValues(adaptor.getOperands()));
52     SmallVector<ValueRange> replacements;
53     size_t offset = 0;
54     for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
55       replacements.push_back(
56           newCallOp->getResults().slice(offset, numResultsReplacments[i]));
57       offset += numResultsReplacments[i];
58     }
59     assert(offset == convertedResults.size() &&
60            "expected that all converted results are used");
61     rewriter.replaceOpWithMultiple(callOp, replacements);
62     return success();
63   }
64 };
65 } // namespace
66 
67 void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
68                                                const TypeConverter &converter) {
69   patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
70 }
71 
72 namespace {
73 /// Only needed to support partial conversion of functions where this pattern
74 /// ensures that the branch operation arguments matches up with the succesor
75 /// block arguments.
76 class BranchOpInterfaceTypeConversion
77     : public OpInterfaceConversionPattern<BranchOpInterface> {
78 public:
79   using OpInterfaceConversionPattern<
80       BranchOpInterface>::OpInterfaceConversionPattern;
81 
82   BranchOpInterfaceTypeConversion(
83       const TypeConverter &typeConverter, MLIRContext *ctx,
84       function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
85       : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
86         shouldConvertBranchOperand(shouldConvertBranchOperand) {}
87 
88   LogicalResult
89   matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
90                   ConversionPatternRewriter &rewriter) const final {
91     // For a branch operation, only some operands go to the target blocks, so
92     // only rewrite those.
93     SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
94     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
95          succIdx < succEnd; ++succIdx) {
96       OperandRange forwardedOperands =
97           op.getSuccessorOperands(succIdx).getForwardedOperands();
98       if (forwardedOperands.empty())
99         continue;
100 
101       for (int idx = forwardedOperands.getBeginOperandIndex(),
102                eidx = idx + forwardedOperands.size();
103            idx < eidx; ++idx) {
104         if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
105           newOperands[idx] = operands[idx];
106       }
107     }
108     rewriter.modifyOpInPlace(
109         op, [newOperands, op]() { op->setOperands(newOperands); });
110     return success();
111   }
112 
113 private:
114   function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
115 };
116 } // namespace
117 
118 namespace {
119 /// Only needed to support partial conversion of functions where this pattern
120 /// ensures that the branch operation arguments matches up with the succesor
121 /// block arguments.
122 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
123 public:
124   using OpConversionPattern<ReturnOp>::OpConversionPattern;
125 
126   LogicalResult
127   matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
128                   ConversionPatternRewriter &rewriter) const final {
129     rewriter.replaceOpWithNewOp<ReturnOp>(op,
130                                           flattenValues(adaptor.getOperands()));
131     return success();
132   }
133 };
134 } // namespace
135 
136 void mlir::populateBranchOpInterfaceTypeConversionPattern(
137     RewritePatternSet &patterns, const TypeConverter &typeConverter,
138     function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
139   patterns.add<BranchOpInterfaceTypeConversion>(
140       typeConverter, patterns.getContext(), shouldConvertBranchOperand);
141 }
142 
143 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
144     Operation *op, const TypeConverter &converter) {
145   // All successor operands of branch like operations must be rewritten.
146   if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
147     for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
148       auto successorOperands = branchOp.getSuccessorOperands(p);
149       if (!converter.isLegal(
150               successorOperands.getForwardedOperands().getTypes()))
151         return false;
152     }
153     return true;
154   }
155 
156   return false;
157 }
158 
159 void mlir::populateReturnOpTypeConversionPattern(
160     RewritePatternSet &patterns, const TypeConverter &typeConverter) {
161   patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
162 }
163 
164 bool mlir::isLegalForReturnOpTypeConversionPattern(
165     Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) {
166   // If this is a `return` and the user pass wants to convert/transform across
167   // function boundaries, then `converter` is invoked to check whether the
168   // `return` op is legal.
169   if (isa<ReturnOp>(op) && !returnOpAlwaysLegal)
170     return converter.isLegal(op);
171 
172   // ReturnLike operations have to be legalized with their parent. For
173   // return this is handled, for other ops they remain as is.
174   return op->hasTrait<OpTrait::ReturnLike>();
175 }
176 
177 bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
178   // If it is not a terminator, ignore it.
179   if (!op->mightHaveTrait<OpTrait::IsTerminator>())
180     return true;
181 
182   // If it is not the last operation in the block, also ignore it. We do
183   // this to handle unknown operations, as well.
184   Block *block = op->getBlock();
185   if (!block || &block->back() != op)
186     return true;
187 
188   // We don't want to handle terminators in nested regions, assume they are
189   // always legal.
190   if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
191     return true;
192 
193   return false;
194 }
195