xref: /llvm-project/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (revision 777142937a599d8a9cea5964b415d9cd13016d79)
1 //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert memref ops into emitc ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
14 
15 #include "mlir/Dialect/EmitC/IR/EmitC.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
25   using OpConversionPattern::OpConversionPattern;
26 
27   LogicalResult
28   matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
29                   ConversionPatternRewriter &rewriter) const override {
30 
31     if (!op.getType().hasStaticShape()) {
32       return rewriter.notifyMatchFailure(
33           op.getLoc(), "cannot transform alloca with dynamic shape");
34     }
35 
36     if (op.getAlignment().value_or(1) > 1) {
37       // TODO: Allow alignment if it is not more than the natural alignment
38       // of the C array.
39       return rewriter.notifyMatchFailure(
40           op.getLoc(), "cannot transform alloca with alignment requirement");
41     }
42 
43     auto resultTy = getTypeConverter()->convertType(op.getType());
44     if (!resultTy) {
45       return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
46     }
47     auto noInit = emitc::OpaqueAttr::get(getContext(), "");
48     rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
49     return success();
50   }
51 };
52 
53 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
54   using OpConversionPattern::OpConversionPattern;
55 
56   LogicalResult
57   matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
58                   ConversionPatternRewriter &rewriter) const override {
59 
60     if (!op.getType().hasStaticShape()) {
61       return rewriter.notifyMatchFailure(
62           op.getLoc(), "cannot transform global with dynamic shape");
63     }
64 
65     if (op.getAlignment().value_or(1) > 1) {
66       // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
67       return rewriter.notifyMatchFailure(
68           op.getLoc(), "global variable with alignment requirement is "
69                        "currently not supported");
70     }
71     auto resultTy = getTypeConverter()->convertType(op.getType());
72     if (!resultTy) {
73       return rewriter.notifyMatchFailure(op.getLoc(),
74                                          "cannot convert result type");
75     }
76 
77     SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
78     if (visibility != SymbolTable::Visibility::Public &&
79         visibility != SymbolTable::Visibility::Private) {
80       return rewriter.notifyMatchFailure(
81           op.getLoc(),
82           "only public and private visibility is currently supported");
83     }
84     // We are explicit in specifing the linkage because the default linkage
85     // for constants is different in C and C++.
86     bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
87     bool externSpecifier = !staticSpecifier;
88 
89     Attribute initialValue = operands.getInitialValueAttr();
90     if (isa_and_present<UnitAttr>(initialValue))
91       initialValue = {};
92 
93     rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
94         op, operands.getSymName(), resultTy, initialValue, externSpecifier,
95         staticSpecifier, operands.getConstant());
96     return success();
97   }
98 };
99 
100 struct ConvertGetGlobal final
101     : public OpConversionPattern<memref::GetGlobalOp> {
102   using OpConversionPattern::OpConversionPattern;
103 
104   LogicalResult
105   matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
106                   ConversionPatternRewriter &rewriter) const override {
107 
108     auto resultTy = getTypeConverter()->convertType(op.getType());
109     if (!resultTy) {
110       return rewriter.notifyMatchFailure(op.getLoc(),
111                                          "cannot convert result type");
112     }
113     rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
114                                                     operands.getNameAttr());
115     return success();
116   }
117 };
118 
119 struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
120   using OpConversionPattern::OpConversionPattern;
121 
122   LogicalResult
123   matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
124                   ConversionPatternRewriter &rewriter) const override {
125 
126     auto resultTy = getTypeConverter()->convertType(op.getType());
127     if (!resultTy) {
128       return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
129     }
130 
131     auto arrayValue =
132         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
133     if (!arrayValue) {
134       return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
135     }
136 
137     auto subscript = rewriter.create<emitc::SubscriptOp>(
138         op.getLoc(), arrayValue, operands.getIndices());
139 
140     rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
141     return success();
142   }
143 };
144 
145 struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
146   using OpConversionPattern::OpConversionPattern;
147 
148   LogicalResult
149   matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
150                   ConversionPatternRewriter &rewriter) const override {
151     auto arrayValue =
152         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
153     if (!arrayValue) {
154       return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
155     }
156 
157     auto subscript = rewriter.create<emitc::SubscriptOp>(
158         op.getLoc(), arrayValue, operands.getIndices());
159     rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
160                                                  operands.getValue());
161     return success();
162   }
163 };
164 } // namespace
165 
166 void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
167   typeConverter.addConversion(
168       [&](MemRefType memRefType) -> std::optional<Type> {
169         if (!memRefType.hasStaticShape() ||
170             !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
171             llvm::any_of(memRefType.getShape(),
172                          [](int64_t dim) { return dim == 0; })) {
173           return {};
174         }
175         Type convertedElementType =
176             typeConverter.convertType(memRefType.getElementType());
177         if (!convertedElementType)
178           return {};
179         return emitc::ArrayType::get(memRefType.getShape(),
180                                      convertedElementType);
181       });
182 }
183 
184 void mlir::populateMemRefToEmitCConversionPatterns(
185     RewritePatternSet &patterns, const TypeConverter &converter) {
186   patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
187                ConvertStore>(converter, patterns.getContext());
188 }
189