xref: /llvm-project/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (revision 7267c85959aa2490e2950f7fb817a76af7e94043)
123aa5a74SRiver Riddle //===- FuncConversions.cpp - Function conversions -------------------------===//
223aa5a74SRiver Riddle //
323aa5a74SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
423aa5a74SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
523aa5a74SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
623aa5a74SRiver Riddle //
723aa5a74SRiver Riddle //===----------------------------------------------------------------------===//
823aa5a74SRiver Riddle 
923aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1023aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1123aa5a74SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
1223aa5a74SRiver Riddle 
1323aa5a74SRiver Riddle using namespace mlir;
1423aa5a74SRiver Riddle using namespace mlir::func;
1523aa5a74SRiver Riddle 
169df63b26SMatthias Springer /// Flatten the given value ranges into a single vector of values.
179df63b26SMatthias Springer static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
189df63b26SMatthias Springer   SmallVector<Value> result;
199df63b26SMatthias Springer   for (const auto &vals : values)
209df63b26SMatthias Springer     llvm::append_range(result, vals);
219df63b26SMatthias Springer   return result;
229df63b26SMatthias Springer }
239df63b26SMatthias Springer 
2423aa5a74SRiver Riddle namespace {
2523aa5a74SRiver Riddle /// Converts the operand and result types of the CallOp, used together with the
2623aa5a74SRiver Riddle /// FuncOpSignatureConversion.
2723aa5a74SRiver Riddle struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
2823aa5a74SRiver Riddle   using OpConversionPattern<CallOp>::OpConversionPattern;
2923aa5a74SRiver Riddle 
3023aa5a74SRiver Riddle   /// Hook for derived classes to implement combined matching and rewriting.
3123aa5a74SRiver Riddle   LogicalResult
329df63b26SMatthias Springer   matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
3323aa5a74SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
3408e6566dSMatthias Springer     // Convert the original function results. Keep track of how many result
3508e6566dSMatthias Springer     // types an original result type is converted into.
3608e6566dSMatthias Springer     SmallVector<size_t> numResultsReplacments;
3723aa5a74SRiver Riddle     SmallVector<Type, 1> convertedResults;
3808e6566dSMatthias Springer     size_t numFlattenedResults = 0;
3908e6566dSMatthias Springer     for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
4008e6566dSMatthias Springer       if (failed(typeConverter->convertTypes(type, convertedResults)))
4123aa5a74SRiver Riddle         return failure();
4208e6566dSMatthias Springer       numResultsReplacments.push_back(convertedResults.size() -
4308e6566dSMatthias Springer                                       numFlattenedResults);
4408e6566dSMatthias Springer       numFlattenedResults = convertedResults.size();
4508e6566dSMatthias Springer     }
4688a3dc0eSAlex Zinenko 
4723aa5a74SRiver Riddle     // Substitute with the new result types from the corresponding FuncType
4823aa5a74SRiver Riddle     // conversion.
499df63b26SMatthias Springer     auto newCallOp = rewriter.create<CallOp>(
509df63b26SMatthias Springer         callOp.getLoc(), callOp.getCallee(), convertedResults,
519df63b26SMatthias Springer         flattenValues(adaptor.getOperands()));
5208e6566dSMatthias Springer     SmallVector<ValueRange> replacements;
5308e6566dSMatthias Springer     size_t offset = 0;
5408e6566dSMatthias Springer     for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
5508e6566dSMatthias Springer       replacements.push_back(
5608e6566dSMatthias Springer           newCallOp->getResults().slice(offset, numResultsReplacments[i]));
5708e6566dSMatthias Springer       offset += numResultsReplacments[i];
5808e6566dSMatthias Springer     }
5908e6566dSMatthias Springer     assert(offset == convertedResults.size() &&
6008e6566dSMatthias Springer            "expected that all converted results are used");
6108e6566dSMatthias Springer     rewriter.replaceOpWithMultiple(callOp, replacements);
6223aa5a74SRiver Riddle     return success();
6323aa5a74SRiver Riddle   }
6423aa5a74SRiver Riddle };
6523aa5a74SRiver Riddle } // namespace
6623aa5a74SRiver Riddle 
6723aa5a74SRiver Riddle void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
68206fad0eSMatthias Springer                                                const TypeConverter &converter) {
6923aa5a74SRiver Riddle   patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
7023aa5a74SRiver Riddle }
7123aa5a74SRiver Riddle 
7223aa5a74SRiver Riddle namespace {
7323aa5a74SRiver Riddle /// Only needed to support partial conversion of functions where this pattern
7423aa5a74SRiver Riddle /// ensures that the branch operation arguments matches up with the succesor
7523aa5a74SRiver Riddle /// block arguments.
7623aa5a74SRiver Riddle class BranchOpInterfaceTypeConversion
7723aa5a74SRiver Riddle     : public OpInterfaceConversionPattern<BranchOpInterface> {
7823aa5a74SRiver Riddle public:
7923aa5a74SRiver Riddle   using OpInterfaceConversionPattern<
8023aa5a74SRiver Riddle       BranchOpInterface>::OpInterfaceConversionPattern;
8123aa5a74SRiver Riddle 
8223aa5a74SRiver Riddle   BranchOpInterfaceTypeConversion(
83206fad0eSMatthias Springer       const TypeConverter &typeConverter, MLIRContext *ctx,
8423aa5a74SRiver Riddle       function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
8523aa5a74SRiver Riddle       : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
8623aa5a74SRiver Riddle         shouldConvertBranchOperand(shouldConvertBranchOperand) {}
8723aa5a74SRiver Riddle 
8823aa5a74SRiver Riddle   LogicalResult
8923aa5a74SRiver Riddle   matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
9023aa5a74SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
9123aa5a74SRiver Riddle     // For a branch operation, only some operands go to the target blocks, so
9223aa5a74SRiver Riddle     // only rewrite those.
9323aa5a74SRiver Riddle     SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
9423aa5a74SRiver Riddle     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
9523aa5a74SRiver Riddle          succIdx < succEnd; ++succIdx) {
960c789db5SMarkus Böck       OperandRange forwardedOperands =
970c789db5SMarkus Böck           op.getSuccessorOperands(succIdx).getForwardedOperands();
980c789db5SMarkus Böck       if (forwardedOperands.empty())
9923aa5a74SRiver Riddle         continue;
10023aa5a74SRiver Riddle 
1010c789db5SMarkus Böck       for (int idx = forwardedOperands.getBeginOperandIndex(),
1020c789db5SMarkus Böck                eidx = idx + forwardedOperands.size();
10323aa5a74SRiver Riddle            idx < eidx; ++idx) {
10423aa5a74SRiver Riddle         if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
10523aa5a74SRiver Riddle           newOperands[idx] = operands[idx];
10623aa5a74SRiver Riddle       }
10723aa5a74SRiver Riddle     }
1085fcf907bSMatthias Springer     rewriter.modifyOpInPlace(
10923aa5a74SRiver Riddle         op, [newOperands, op]() { op->setOperands(newOperands); });
11023aa5a74SRiver Riddle     return success();
11123aa5a74SRiver Riddle   }
11223aa5a74SRiver Riddle 
11323aa5a74SRiver Riddle private:
11423aa5a74SRiver Riddle   function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
11523aa5a74SRiver Riddle };
11623aa5a74SRiver Riddle } // namespace
11723aa5a74SRiver Riddle 
11823aa5a74SRiver Riddle namespace {
11923aa5a74SRiver Riddle /// Only needed to support partial conversion of functions where this pattern
12023aa5a74SRiver Riddle /// ensures that the branch operation arguments matches up with the succesor
12123aa5a74SRiver Riddle /// block arguments.
12223aa5a74SRiver Riddle class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
12323aa5a74SRiver Riddle public:
12423aa5a74SRiver Riddle   using OpConversionPattern<ReturnOp>::OpConversionPattern;
12523aa5a74SRiver Riddle 
12623aa5a74SRiver Riddle   LogicalResult
127*7267c859SMatthias Springer   matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
12823aa5a74SRiver Riddle                   ConversionPatternRewriter &rewriter) const final {
129*7267c859SMatthias Springer     rewriter.replaceOpWithNewOp<ReturnOp>(op,
130*7267c859SMatthias Springer                                           flattenValues(adaptor.getOperands()));
13123aa5a74SRiver Riddle     return success();
13223aa5a74SRiver Riddle   }
13323aa5a74SRiver Riddle };
13423aa5a74SRiver Riddle } // namespace
13523aa5a74SRiver Riddle 
13623aa5a74SRiver Riddle void mlir::populateBranchOpInterfaceTypeConversionPattern(
137206fad0eSMatthias Springer     RewritePatternSet &patterns, const TypeConverter &typeConverter,
13823aa5a74SRiver Riddle     function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
13923aa5a74SRiver Riddle   patterns.add<BranchOpInterfaceTypeConversion>(
14023aa5a74SRiver Riddle       typeConverter, patterns.getContext(), shouldConvertBranchOperand);
14123aa5a74SRiver Riddle }
14223aa5a74SRiver Riddle 
14323aa5a74SRiver Riddle bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
144206fad0eSMatthias Springer     Operation *op, const TypeConverter &converter) {
14523aa5a74SRiver Riddle   // All successor operands of branch like operations must be rewritten.
14623aa5a74SRiver Riddle   if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
14723aa5a74SRiver Riddle     for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
14823aa5a74SRiver Riddle       auto successorOperands = branchOp.getSuccessorOperands(p);
1490c789db5SMarkus Böck       if (!converter.isLegal(
1500c789db5SMarkus Böck               successorOperands.getForwardedOperands().getTypes()))
15123aa5a74SRiver Riddle         return false;
15223aa5a74SRiver Riddle     }
15323aa5a74SRiver Riddle     return true;
15423aa5a74SRiver Riddle   }
15523aa5a74SRiver Riddle 
15623aa5a74SRiver Riddle   return false;
15723aa5a74SRiver Riddle }
15823aa5a74SRiver Riddle 
159206fad0eSMatthias Springer void mlir::populateReturnOpTypeConversionPattern(
160206fad0eSMatthias Springer     RewritePatternSet &patterns, const TypeConverter &typeConverter) {
16123aa5a74SRiver Riddle   patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
16223aa5a74SRiver Riddle }
16323aa5a74SRiver Riddle 
164206fad0eSMatthias Springer bool mlir::isLegalForReturnOpTypeConversionPattern(
165206fad0eSMatthias Springer     Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal) {
16623aa5a74SRiver Riddle   // If this is a `return` and the user pass wants to convert/transform across
1677557530fSFangrui Song   // function boundaries, then `converter` is invoked to check whether the
16823aa5a74SRiver Riddle   // `return` op is legal.
169e4c39501SRahul Joshi   if (isa<ReturnOp>(op) && !returnOpAlwaysLegal)
17023aa5a74SRiver Riddle     return converter.isLegal(op);
17123aa5a74SRiver Riddle 
17223aa5a74SRiver Riddle   // ReturnLike operations have to be legalized with their parent. For
17323aa5a74SRiver Riddle   // return this is handled, for other ops they remain as is.
17423aa5a74SRiver Riddle   return op->hasTrait<OpTrait::ReturnLike>();
17523aa5a74SRiver Riddle }
17623aa5a74SRiver Riddle 
17723aa5a74SRiver Riddle bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
17823aa5a74SRiver Riddle   // If it is not a terminator, ignore it.
17923aa5a74SRiver Riddle   if (!op->mightHaveTrait<OpTrait::IsTerminator>())
18023aa5a74SRiver Riddle     return true;
18123aa5a74SRiver Riddle 
18223aa5a74SRiver Riddle   // If it is not the last operation in the block, also ignore it. We do
18323aa5a74SRiver Riddle   // this to handle unknown operations, as well.
18423aa5a74SRiver Riddle   Block *block = op->getBlock();
18523aa5a74SRiver Riddle   if (!block || &block->back() != op)
18623aa5a74SRiver Riddle     return true;
18723aa5a74SRiver Riddle 
18823aa5a74SRiver Riddle   // We don't want to handle terminators in nested regions, assume they are
18923aa5a74SRiver Riddle   // always legal.
19023aa5a74SRiver Riddle   if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
19123aa5a74SRiver Riddle     return true;
19223aa5a74SRiver Riddle 
19323aa5a74SRiver Riddle   return false;
19423aa5a74SRiver Riddle }
195