xref: /llvm-project/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
10ceb7a12SIngo Müller //===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
20ceb7a12SIngo Müller //
30ceb7a12SIngo Müller // Licensed under the Apache License v2.0 with LLVM Exceptions.
40ceb7a12SIngo Müller // See https://llvm.org/LICENSE.txt for license information.
50ceb7a12SIngo Müller // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60ceb7a12SIngo Müller //
70ceb7a12SIngo Müller //===----------------------------------------------------------------------===//
80ceb7a12SIngo Müller //
90ceb7a12SIngo Müller // The patterns in this file are heavily inspired (and copied from)
100ceb7a12SIngo Müller // convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
110ceb7a12SIngo Müller // patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
120ceb7a12SIngo Müller // type conversions.
130ceb7a12SIngo Müller //
140ceb7a12SIngo Müller //===----------------------------------------------------------------------===//
150ceb7a12SIngo Müller 
160ceb7a12SIngo Müller #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
170ceb7a12SIngo Müller 
180ceb7a12SIngo Müller #include "mlir/Dialect/Func/IR/FuncOps.h"
190ceb7a12SIngo Müller #include "mlir/Transforms/OneToNTypeConversion.h"
200ceb7a12SIngo Müller 
210ceb7a12SIngo Müller using namespace mlir;
220ceb7a12SIngo Müller using namespace mlir::func;
230ceb7a12SIngo Müller 
240ceb7a12SIngo Müller namespace {
250ceb7a12SIngo Müller 
260ceb7a12SIngo Müller class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
270ceb7a12SIngo Müller public:
280ceb7a12SIngo Müller   using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
290ceb7a12SIngo Müller 
304e563616SIngo Müller   LogicalResult
314e563616SIngo Müller   matchAndRewrite(CallOp op, OpAdaptor adaptor,
324e563616SIngo Müller                   OneToNPatternRewriter &rewriter) const override {
330ceb7a12SIngo Müller     Location loc = op->getLoc();
344e563616SIngo Müller     const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
350ceb7a12SIngo Müller 
360ceb7a12SIngo Müller     // Nothing to do if the op doesn't have any non-identity conversions for its
370ceb7a12SIngo Müller     // operands or results.
384e563616SIngo Müller     if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
390ceb7a12SIngo Müller         !resultMapping.hasNonIdentityConversion())
400ceb7a12SIngo Müller       return failure();
410ceb7a12SIngo Müller 
420ceb7a12SIngo Müller     // Create new CallOp.
43e67080dfSJacques Pienaar     auto newOp =
44e67080dfSJacques Pienaar         rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
45e67080dfSJacques Pienaar                                 adaptor.getFlatOperands(), op->getAttrs());
460ceb7a12SIngo Müller 
470ceb7a12SIngo Müller     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
480ceb7a12SIngo Müller     return success();
490ceb7a12SIngo Müller   }
500ceb7a12SIngo Müller };
510ceb7a12SIngo Müller 
520ceb7a12SIngo Müller class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
530ceb7a12SIngo Müller public:
540ceb7a12SIngo Müller   using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
550ceb7a12SIngo Müller 
564e563616SIngo Müller   LogicalResult
574e563616SIngo Müller   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
584e563616SIngo Müller                   OneToNPatternRewriter &rewriter) const override {
590ceb7a12SIngo Müller     // Nothing to do if there is no non-identity conversion.
604e563616SIngo Müller     if (!adaptor.getOperandMapping().hasNonIdentityConversion())
610ceb7a12SIngo Müller       return failure();
620ceb7a12SIngo Müller 
630ceb7a12SIngo Müller     // Convert operands.
645fcf907bSMatthias Springer     rewriter.modifyOpInPlace(
654e563616SIngo Müller         op, [&] { op->setOperands(adaptor.getFlatOperands()); });
660ceb7a12SIngo Müller 
670ceb7a12SIngo Müller     return success();
680ceb7a12SIngo Müller   }
690ceb7a12SIngo Müller };
700ceb7a12SIngo Müller 
710ceb7a12SIngo Müller } // namespace
720ceb7a12SIngo Müller 
730ceb7a12SIngo Müller namespace mlir {
740ceb7a12SIngo Müller 
75*206fad0eSMatthias Springer void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
760ceb7a12SIngo Müller                                         RewritePatternSet &patterns) {
770ceb7a12SIngo Müller   patterns.add<
780ceb7a12SIngo Müller       // clang-format off
790ceb7a12SIngo Müller       ConvertTypesInFuncCallOp,
800ceb7a12SIngo Müller       ConvertTypesInFuncReturnOp
810ceb7a12SIngo Müller       // clang-format on
820ceb7a12SIngo Müller       >(typeConverter, patterns.getContext());
83a2590e0cSMatthias Springer   populateOneToNFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
84a2590e0cSMatthias Springer       typeConverter, patterns);
850ceb7a12SIngo Müller }
860ceb7a12SIngo Müller 
870ceb7a12SIngo Müller } // namespace mlir
88