xref: /llvm-project/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The patterns in this file are heavily inspired (and copied from)
10 // convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
11 // patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
12 // type conversions.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Transforms/OneToNTypeConversion.h"
20 
21 using namespace mlir;
22 using namespace mlir::func;
23 
24 namespace {
25 
26 class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
27 public:
28   using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
29 
30   LogicalResult
31   matchAndRewrite(CallOp op, OpAdaptor adaptor,
32                   OneToNPatternRewriter &rewriter) const override {
33     Location loc = op->getLoc();
34     const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
35 
36     // Nothing to do if the op doesn't have any non-identity conversions for its
37     // operands or results.
38     if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
39         !resultMapping.hasNonIdentityConversion())
40       return failure();
41 
42     // Create new CallOp.
43     auto newOp =
44         rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
45                                 adaptor.getFlatOperands(), op->getAttrs());
46 
47     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
48     return success();
49   }
50 };
51 
52 class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
53 public:
54   using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
55 
56   LogicalResult
57   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
58                   OneToNPatternRewriter &rewriter) const override {
59     // Nothing to do if there is no non-identity conversion.
60     if (!adaptor.getOperandMapping().hasNonIdentityConversion())
61       return failure();
62 
63     // Convert operands.
64     rewriter.modifyOpInPlace(
65         op, [&] { op->setOperands(adaptor.getFlatOperands()); });
66 
67     return success();
68   }
69 };
70 
71 } // namespace
72 
73 namespace mlir {
74 
75 void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
76                                         RewritePatternSet &patterns) {
77   patterns.add<
78       // clang-format off
79       ConvertTypesInFuncCallOp,
80       ConvertTypesInFuncReturnOp
81       // clang-format on
82       >(typeConverter, patterns.getContext());
83   populateOneToNFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
84       typeConverter, patterns);
85 }
86 
87 } // namespace mlir
88