1 //===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===// 2 // 3 // Licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The patterns in this file are heavily inspired (and copied from) 10 // convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the 11 // patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N 12 // type conversions. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" 17 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Transforms/OneToNTypeConversion.h" 20 21 using namespace mlir; 22 using namespace mlir::func; 23 24 namespace { 25 26 class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> { 27 public: 28 using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern; 29 30 LogicalResult 31 matchAndRewrite(CallOp op, OpAdaptor adaptor, 32 OneToNPatternRewriter &rewriter) const override { 33 Location loc = op->getLoc(); 34 const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); 35 36 // Nothing to do if the op doesn't have any non-identity conversions for its 37 // operands or results. 38 if (!adaptor.getOperandMapping().hasNonIdentityConversion() && 39 !resultMapping.hasNonIdentityConversion()) 40 return failure(); 41 42 // Create new CallOp. 43 auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(), 44 adaptor.getFlatOperands()); 45 newOp->setAttrs(op->getAttrs()); 46 47 rewriter.replaceOp(op, newOp->getResults(), resultMapping); 48 return success(); 49 } 50 }; 51 52 class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> { 53 public: 54 using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern; 55 56 LogicalResult 57 matchAndRewrite(FuncOp op, OpAdaptor adaptor, 58 OneToNPatternRewriter &rewriter) const override { 59 auto *typeConverter = getTypeConverter<OneToNTypeConverter>(); 60 61 // Construct mapping for function arguments. 62 OneToNTypeMapping argumentMapping(op.getArgumentTypes()); 63 if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(), 64 argumentMapping))) 65 return failure(); 66 67 // Construct mapping for function results. 68 OneToNTypeMapping funcResultMapping(op.getResultTypes()); 69 if (failed(typeConverter->computeTypeMapping(op.getResultTypes(), 70 funcResultMapping))) 71 return failure(); 72 73 // Nothing to do if the op doesn't have any non-identity conversions for its 74 // operands or results. 75 if (!argumentMapping.hasNonIdentityConversion() && 76 !funcResultMapping.hasNonIdentityConversion()) 77 return failure(); 78 79 // Update the function signature in-place. 80 auto newType = FunctionType::get(rewriter.getContext(), 81 argumentMapping.getConvertedTypes(), 82 funcResultMapping.getConvertedTypes()); 83 rewriter.updateRootInPlace(op, [&] { op.setType(newType); }); 84 85 // Update block signatures. 86 if (!op.isExternal()) { 87 Region *region = &op.getBody(); 88 Block *block = ®ion->front(); 89 rewriter.applySignatureConversion(block, argumentMapping); 90 } 91 92 return success(); 93 } 94 }; 95 96 class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> { 97 public: 98 using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern; 99 100 LogicalResult 101 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 102 OneToNPatternRewriter &rewriter) const override { 103 // Nothing to do if there is no non-identity conversion. 104 if (!adaptor.getOperandMapping().hasNonIdentityConversion()) 105 return failure(); 106 107 // Convert operands. 108 rewriter.updateRootInPlace( 109 op, [&] { op->setOperands(adaptor.getFlatOperands()); }); 110 111 return success(); 112 } 113 }; 114 115 } // namespace 116 117 namespace mlir { 118 119 void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, 120 RewritePatternSet &patterns) { 121 patterns.add< 122 // clang-format off 123 ConvertTypesInFuncCallOp, 124 ConvertTypesInFuncFuncOp, 125 ConvertTypesInFuncReturnOp 126 // clang-format on 127 >(typeConverter, patterns.getContext()); 128 } 129 130 } // namespace mlir 131