1 //===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===// 2 // 3 // Licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The patterns in this file are heavily inspired (and copied from) 10 // convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the 11 // patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N 12 // type conversions. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" 17 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Transforms/OneToNTypeConversion.h" 20 21 using namespace mlir; 22 using namespace mlir::func; 23 24 namespace { 25 26 class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> { 27 public: 28 using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern; 29 30 LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter, 31 const OneToNTypeMapping &operandMapping, 32 const OneToNTypeMapping &resultMapping, 33 ValueRange convertedOperands) const override { 34 Location loc = op->getLoc(); 35 36 // Nothing to do if the op doesn't have any non-identity conversions for its 37 // operands or results. 38 if (!operandMapping.hasNonIdentityConversion() && 39 !resultMapping.hasNonIdentityConversion()) 40 return failure(); 41 42 // Create new CallOp. 43 auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(), 44 convertedOperands); 45 newOp->setAttrs(op->getAttrs()); 46 47 rewriter.replaceOp(op, newOp->getResults(), resultMapping); 48 return success(); 49 } 50 }; 51 52 class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> { 53 public: 54 using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern; 55 56 LogicalResult 57 matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter, 58 const OneToNTypeMapping & /*operandMapping*/, 59 const OneToNTypeMapping & /*resultMapping*/, 60 ValueRange /*convertedOperands*/) const override { 61 auto *typeConverter = getTypeConverter<OneToNTypeConverter>(); 62 63 // Construct mapping for function arguments. 64 OneToNTypeMapping argumentMapping(op.getArgumentTypes()); 65 if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(), 66 argumentMapping))) 67 return failure(); 68 69 // Construct mapping for function results. 70 OneToNTypeMapping funcResultMapping(op.getResultTypes()); 71 if (failed(typeConverter->computeTypeMapping(op.getResultTypes(), 72 funcResultMapping))) 73 return failure(); 74 75 // Nothing to do if the op doesn't have any non-identity conversions for its 76 // operands or results. 77 if (!argumentMapping.hasNonIdentityConversion() && 78 !funcResultMapping.hasNonIdentityConversion()) 79 return failure(); 80 81 // Update the function signature in-place. 82 auto newType = FunctionType::get(rewriter.getContext(), 83 argumentMapping.getConvertedTypes(), 84 funcResultMapping.getConvertedTypes()); 85 rewriter.updateRootInPlace(op, [&] { op.setType(newType); }); 86 87 // Update block signatures. 88 if (!op.isExternal()) { 89 Region *region = &op.getBody(); 90 Block *block = ®ion->front(); 91 rewriter.applySignatureConversion(block, argumentMapping); 92 } 93 94 return success(); 95 } 96 }; 97 98 class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> { 99 public: 100 using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern; 101 102 LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter, 103 const OneToNTypeMapping &operandMapping, 104 const OneToNTypeMapping & /*resultMapping*/, 105 ValueRange convertedOperands) const override { 106 // Nothing to do if there is no non-identity conversion. 107 if (!operandMapping.hasNonIdentityConversion()) 108 return failure(); 109 110 // Convert operands. 111 rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); 112 113 return success(); 114 } 115 }; 116 117 } // namespace 118 119 namespace mlir { 120 121 void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, 122 RewritePatternSet &patterns) { 123 patterns.add< 124 // clang-format off 125 ConvertTypesInFuncCallOp, 126 ConvertTypesInFuncFuncOp, 127 ConvertTypesInFuncReturnOp 128 // clang-format on 129 >(typeConverter, patterns.getContext()); 130 } 131 132 } // namespace mlir 133