xref: /llvm-project/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp (revision 4e563616a5fffa1204286a9aa03604a68a7db835)
1 //===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The patterns in this file are heavily inspired (and copied from)
10 // convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
11 // patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
12 // type conversions.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Transforms/OneToNTypeConversion.h"
20 
21 using namespace mlir;
22 using namespace mlir::func;
23 
24 namespace {
25 
26 class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
27 public:
28   using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
29 
30   LogicalResult
31   matchAndRewrite(CallOp op, OpAdaptor adaptor,
32                   OneToNPatternRewriter &rewriter) const override {
33     Location loc = op->getLoc();
34     const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
35 
36     // Nothing to do if the op doesn't have any non-identity conversions for its
37     // operands or results.
38     if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
39         !resultMapping.hasNonIdentityConversion())
40       return failure();
41 
42     // Create new CallOp.
43     auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
44                                          adaptor.getFlatOperands());
45     newOp->setAttrs(op->getAttrs());
46 
47     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
48     return success();
49   }
50 };
51 
52 class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
53 public:
54   using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
55 
56   LogicalResult
57   matchAndRewrite(FuncOp op, OpAdaptor adaptor,
58                   OneToNPatternRewriter &rewriter) const override {
59     auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
60 
61     // Construct mapping for function arguments.
62     OneToNTypeMapping argumentMapping(op.getArgumentTypes());
63     if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
64                                                  argumentMapping)))
65       return failure();
66 
67     // Construct mapping for function results.
68     OneToNTypeMapping funcResultMapping(op.getResultTypes());
69     if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
70                                                  funcResultMapping)))
71       return failure();
72 
73     // Nothing to do if the op doesn't have any non-identity conversions for its
74     // operands or results.
75     if (!argumentMapping.hasNonIdentityConversion() &&
76         !funcResultMapping.hasNonIdentityConversion())
77       return failure();
78 
79     // Update the function signature in-place.
80     auto newType = FunctionType::get(rewriter.getContext(),
81                                      argumentMapping.getConvertedTypes(),
82                                      funcResultMapping.getConvertedTypes());
83     rewriter.updateRootInPlace(op, [&] { op.setType(newType); });
84 
85     // Update block signatures.
86     if (!op.isExternal()) {
87       Region *region = &op.getBody();
88       Block *block = &region->front();
89       rewriter.applySignatureConversion(block, argumentMapping);
90     }
91 
92     return success();
93   }
94 };
95 
96 class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
97 public:
98   using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
99 
100   LogicalResult
101   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
102                   OneToNPatternRewriter &rewriter) const override {
103     // Nothing to do if there is no non-identity conversion.
104     if (!adaptor.getOperandMapping().hasNonIdentityConversion())
105       return failure();
106 
107     // Convert operands.
108     rewriter.updateRootInPlace(
109         op, [&] { op->setOperands(adaptor.getFlatOperands()); });
110 
111     return success();
112   }
113 };
114 
115 } // namespace
116 
117 namespace mlir {
118 
119 void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
120                                         RewritePatternSet &patterns) {
121   patterns.add<
122       // clang-format off
123       ConvertTypesInFuncCallOp,
124       ConvertTypesInFuncFuncOp,
125       ConvertTypesInFuncReturnOp
126       // clang-format on
127       >(typeConverter, patterns.getContext());
128 }
129 
130 } // namespace mlir
131