xref: /llvm-project/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp (revision 346bd917929fa89dfe00d999effcde7ee3d8d4a7)
1 //===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert the Func dialect to the EmitC
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
15 
16 #include "mlir/Dialect/EmitC/IR/EmitC.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Conversion Patterns
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 class CallOpConversion final : public OpConversionPattern<func::CallOp> {
28 public:
29   using OpConversionPattern<func::CallOp>::OpConversionPattern;
30 
31   LogicalResult
matchAndRewrite(func::CallOp callOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const32   matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
33                   ConversionPatternRewriter &rewriter) const override {
34     // Multiple results func cannot be converted to `emitc.func`.
35     if (callOp.getNumResults() > 1)
36       return rewriter.notifyMatchFailure(
37           callOp, "only functions with zero or one result can be converted");
38 
39     rewriter.replaceOpWithNewOp<emitc::CallOp>(callOp, callOp.getResultTypes(),
40                                                adaptor.getOperands(),
41                                                callOp->getAttrs());
42 
43     return success();
44   }
45 };
46 
47 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
48 public:
49   using OpConversionPattern<func::FuncOp>::OpConversionPattern;
50 
51   LogicalResult
matchAndRewrite(func::FuncOp funcOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const52   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
53                   ConversionPatternRewriter &rewriter) const override {
54 
55     if (funcOp.getFunctionType().getNumResults() > 1)
56       return rewriter.notifyMatchFailure(
57           funcOp, "only functions with zero or one result can be converted");
58 
59     // Create the converted `emitc.func` op.
60     emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
61         funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
62 
63     // Copy over all attributes other than the function name and type.
64     for (const auto &namedAttr : funcOp->getAttrs()) {
65       if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
66           namedAttr.getName() != SymbolTable::getSymbolAttrName())
67         newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
68     }
69 
70     // Add `extern` to specifiers if `func.func` is declaration only.
71     if (funcOp.isDeclaration()) {
72       ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"});
73       newFuncOp.setSpecifiersAttr(specifiers);
74     }
75 
76     // Add `static` to specifiers if `func.func` is private but not a
77     // declaration.
78     if (funcOp.isPrivate() && !funcOp.isDeclaration()) {
79       ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"});
80       newFuncOp.setSpecifiersAttr(specifiers);
81     }
82 
83     if (!funcOp.isDeclaration())
84       rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
85                                   newFuncOp.end());
86     rewriter.eraseOp(funcOp);
87 
88     return success();
89   }
90 };
91 
92 class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
93 public:
94   using OpConversionPattern<func::ReturnOp>::OpConversionPattern;
95 
96   LogicalResult
matchAndRewrite(func::ReturnOp returnOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const97   matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const override {
99     if (returnOp.getNumOperands() > 1)
100       return rewriter.notifyMatchFailure(
101           returnOp, "only zero or one operand is supported");
102 
103     rewriter.replaceOpWithNewOp<emitc::ReturnOp>(
104         returnOp,
105         returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr);
106     return success();
107   }
108 };
109 } // namespace
110 
111 //===----------------------------------------------------------------------===//
112 // Pattern population
113 //===----------------------------------------------------------------------===//
114 
populateFuncToEmitCPatterns(RewritePatternSet & patterns)115 void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) {
116   MLIRContext *ctx = patterns.getContext();
117 
118   patterns.add<CallOpConversion, FuncOpConversion, ReturnOpConversion>(ctx);
119 }
120