xref: /llvm-project/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp (revision 9c4611f9c7a7055b18f0a30a4c9074b9917e4ab0)
1 //===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The patterns in this file are heavily inspired (and copied from)
10 // convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
11 // patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
12 // type conversions.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Transforms/OneToNTypeConversion.h"
20 
21 using namespace mlir;
22 using namespace mlir::func;
23 
24 namespace {
25 
26 class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
27 public:
28   using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
29 
30   LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter,
31                                 const OneToNTypeMapping &operandMapping,
32                                 const OneToNTypeMapping &resultMapping,
33                                 ValueRange convertedOperands) const override {
34     Location loc = op->getLoc();
35 
36     // Nothing to do if the op doesn't have any non-identity conversions for its
37     // operands or results.
38     if (!operandMapping.hasNonIdentityConversion() &&
39         !resultMapping.hasNonIdentityConversion())
40       return failure();
41 
42     // Create new CallOp.
43     auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
44                                          convertedOperands);
45     newOp->setAttrs(op->getAttrs());
46 
47     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
48     return success();
49   }
50 };
51 
52 class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
53 public:
54   using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
55 
56   LogicalResult
57   matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter,
58                   const OneToNTypeMapping & /*operandMapping*/,
59                   const OneToNTypeMapping & /*resultMapping*/,
60                   ValueRange /*convertedOperands*/) const override {
61     auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
62 
63     // Construct mapping for function arguments.
64     OneToNTypeMapping argumentMapping(op.getArgumentTypes());
65     if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
66                                                  argumentMapping)))
67       return failure();
68 
69     // Construct mapping for function results.
70     OneToNTypeMapping funcResultMapping(op.getResultTypes());
71     if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
72                                                  funcResultMapping)))
73       return failure();
74 
75     // Nothing to do if the op doesn't have any non-identity conversions for its
76     // operands or results.
77     if (!argumentMapping.hasNonIdentityConversion() &&
78         !funcResultMapping.hasNonIdentityConversion())
79       return failure();
80 
81     // Update the function signature in-place.
82     auto newType = FunctionType::get(rewriter.getContext(),
83                                      argumentMapping.getConvertedTypes(),
84                                      funcResultMapping.getConvertedTypes());
85     rewriter.updateRootInPlace(op, [&] { op.setType(newType); });
86 
87     // Update block signatures.
88     if (!op.isExternal()) {
89       Region *region = &op.getBody();
90       Block *block = &region->front();
91       rewriter.applySignatureConversion(block, argumentMapping);
92     }
93 
94     return success();
95   }
96 };
97 
98 class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
99 public:
100   using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
101 
102   LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter,
103                                 const OneToNTypeMapping &operandMapping,
104                                 const OneToNTypeMapping & /*resultMapping*/,
105                                 ValueRange convertedOperands) const override {
106     // Nothing to do if there is no non-identity conversion.
107     if (!operandMapping.hasNonIdentityConversion())
108       return failure();
109 
110     // Convert operands.
111     rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
112 
113     return success();
114   }
115 };
116 
117 } // namespace
118 
119 namespace mlir {
120 
121 void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
122                                         RewritePatternSet &patterns) {
123   patterns.add<
124       // clang-format off
125       ConvertTypesInFuncCallOp,
126       ConvertTypesInFuncFuncOp,
127       ConvertTypesInFuncReturnOp
128       // clang-format on
129       >(typeConverter, patterns.getContext());
130 }
131 
132 } // namespace mlir
133