xref: /llvm-project/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (revision 777142937a599d8a9cea5964b415d9cd13016d79)
10aa6d57eSMatthias Gehre //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
20aa6d57eSMatthias Gehre //
30aa6d57eSMatthias Gehre // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40aa6d57eSMatthias Gehre // See https://llvm.org/LICENSE.txt for license information.
50aa6d57eSMatthias Gehre // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60aa6d57eSMatthias Gehre //
70aa6d57eSMatthias Gehre //===----------------------------------------------------------------------===//
80aa6d57eSMatthias Gehre //
90aa6d57eSMatthias Gehre // This file implements patterns to convert memref ops into emitc ops.
100aa6d57eSMatthias Gehre //
110aa6d57eSMatthias Gehre //===----------------------------------------------------------------------===//
120aa6d57eSMatthias Gehre 
130aa6d57eSMatthias Gehre #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
140aa6d57eSMatthias Gehre 
150aa6d57eSMatthias Gehre #include "mlir/Dialect/EmitC/IR/EmitC.h"
160aa6d57eSMatthias Gehre #include "mlir/Dialect/MemRef/IR/MemRef.h"
170aa6d57eSMatthias Gehre #include "mlir/IR/Builders.h"
180aa6d57eSMatthias Gehre #include "mlir/IR/PatternMatch.h"
190aa6d57eSMatthias Gehre #include "mlir/Transforms/DialectConversion.h"
200aa6d57eSMatthias Gehre 
210aa6d57eSMatthias Gehre using namespace mlir;
220aa6d57eSMatthias Gehre 
230aa6d57eSMatthias Gehre namespace {
240aa6d57eSMatthias Gehre struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
250aa6d57eSMatthias Gehre   using OpConversionPattern::OpConversionPattern;
260aa6d57eSMatthias Gehre 
270aa6d57eSMatthias Gehre   LogicalResult
280aa6d57eSMatthias Gehre   matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
290aa6d57eSMatthias Gehre                   ConversionPatternRewriter &rewriter) const override {
300aa6d57eSMatthias Gehre 
310aa6d57eSMatthias Gehre     if (!op.getType().hasStaticShape()) {
320aa6d57eSMatthias Gehre       return rewriter.notifyMatchFailure(
330aa6d57eSMatthias Gehre           op.getLoc(), "cannot transform alloca with dynamic shape");
340aa6d57eSMatthias Gehre     }
350aa6d57eSMatthias Gehre 
360aa6d57eSMatthias Gehre     if (op.getAlignment().value_or(1) > 1) {
370aa6d57eSMatthias Gehre       // TODO: Allow alignment if it is not more than the natural alignment
380aa6d57eSMatthias Gehre       // of the C array.
390aa6d57eSMatthias Gehre       return rewriter.notifyMatchFailure(
400aa6d57eSMatthias Gehre           op.getLoc(), "cannot transform alloca with alignment requirement");
410aa6d57eSMatthias Gehre     }
420aa6d57eSMatthias Gehre 
430aa6d57eSMatthias Gehre     auto resultTy = getTypeConverter()->convertType(op.getType());
440aa6d57eSMatthias Gehre     if (!resultTy) {
450aa6d57eSMatthias Gehre       return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
460aa6d57eSMatthias Gehre     }
470aa6d57eSMatthias Gehre     auto noInit = emitc::OpaqueAttr::get(getContext(), "");
480aa6d57eSMatthias Gehre     rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
490aa6d57eSMatthias Gehre     return success();
500aa6d57eSMatthias Gehre   }
510aa6d57eSMatthias Gehre };
520aa6d57eSMatthias Gehre 
5365484656SMatthias Gehre struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
5465484656SMatthias Gehre   using OpConversionPattern::OpConversionPattern;
5565484656SMatthias Gehre 
5665484656SMatthias Gehre   LogicalResult
5765484656SMatthias Gehre   matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
5865484656SMatthias Gehre                   ConversionPatternRewriter &rewriter) const override {
5965484656SMatthias Gehre 
6065484656SMatthias Gehre     if (!op.getType().hasStaticShape()) {
6165484656SMatthias Gehre       return rewriter.notifyMatchFailure(
6265484656SMatthias Gehre           op.getLoc(), "cannot transform global with dynamic shape");
6365484656SMatthias Gehre     }
6465484656SMatthias Gehre 
6565484656SMatthias Gehre     if (op.getAlignment().value_or(1) > 1) {
6665484656SMatthias Gehre       // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
6765484656SMatthias Gehre       return rewriter.notifyMatchFailure(
6865484656SMatthias Gehre           op.getLoc(), "global variable with alignment requirement is "
6965484656SMatthias Gehre                        "currently not supported");
7065484656SMatthias Gehre     }
7165484656SMatthias Gehre     auto resultTy = getTypeConverter()->convertType(op.getType());
7265484656SMatthias Gehre     if (!resultTy) {
7365484656SMatthias Gehre       return rewriter.notifyMatchFailure(op.getLoc(),
7465484656SMatthias Gehre                                          "cannot convert result type");
7565484656SMatthias Gehre     }
7665484656SMatthias Gehre 
7765484656SMatthias Gehre     SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
7865484656SMatthias Gehre     if (visibility != SymbolTable::Visibility::Public &&
7965484656SMatthias Gehre         visibility != SymbolTable::Visibility::Private) {
8065484656SMatthias Gehre       return rewriter.notifyMatchFailure(
8165484656SMatthias Gehre           op.getLoc(),
8265484656SMatthias Gehre           "only public and private visibility is currently supported");
8365484656SMatthias Gehre     }
8465484656SMatthias Gehre     // We are explicit in specifing the linkage because the default linkage
8565484656SMatthias Gehre     // for constants is different in C and C++.
8665484656SMatthias Gehre     bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
8765484656SMatthias Gehre     bool externSpecifier = !staticSpecifier;
8865484656SMatthias Gehre 
8965484656SMatthias Gehre     Attribute initialValue = operands.getInitialValueAttr();
9065484656SMatthias Gehre     if (isa_and_present<UnitAttr>(initialValue))
9165484656SMatthias Gehre       initialValue = {};
9265484656SMatthias Gehre 
9365484656SMatthias Gehre     rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
9465484656SMatthias Gehre         op, operands.getSymName(), resultTy, initialValue, externSpecifier,
9565484656SMatthias Gehre         staticSpecifier, operands.getConstant());
9665484656SMatthias Gehre     return success();
9765484656SMatthias Gehre   }
9865484656SMatthias Gehre };
9965484656SMatthias Gehre 
10065484656SMatthias Gehre struct ConvertGetGlobal final
10165484656SMatthias Gehre     : public OpConversionPattern<memref::GetGlobalOp> {
10265484656SMatthias Gehre   using OpConversionPattern::OpConversionPattern;
10365484656SMatthias Gehre 
10465484656SMatthias Gehre   LogicalResult
10565484656SMatthias Gehre   matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
10665484656SMatthias Gehre                   ConversionPatternRewriter &rewriter) const override {
10765484656SMatthias Gehre 
10865484656SMatthias Gehre     auto resultTy = getTypeConverter()->convertType(op.getType());
10965484656SMatthias Gehre     if (!resultTy) {
11065484656SMatthias Gehre       return rewriter.notifyMatchFailure(op.getLoc(),
11165484656SMatthias Gehre                                          "cannot convert result type");
11265484656SMatthias Gehre     }
11365484656SMatthias Gehre     rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
11465484656SMatthias Gehre                                                     operands.getNameAttr());
11565484656SMatthias Gehre     return success();
11665484656SMatthias Gehre   }
11765484656SMatthias Gehre };
11865484656SMatthias Gehre 
1190aa6d57eSMatthias Gehre struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
1200aa6d57eSMatthias Gehre   using OpConversionPattern::OpConversionPattern;
1210aa6d57eSMatthias Gehre 
1220aa6d57eSMatthias Gehre   LogicalResult
1230aa6d57eSMatthias Gehre   matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
1240aa6d57eSMatthias Gehre                   ConversionPatternRewriter &rewriter) const override {
1250aa6d57eSMatthias Gehre 
1260aa6d57eSMatthias Gehre     auto resultTy = getTypeConverter()->convertType(op.getType());
1270aa6d57eSMatthias Gehre     if (!resultTy) {
1280aa6d57eSMatthias Gehre       return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
1290aa6d57eSMatthias Gehre     }
1300aa6d57eSMatthias Gehre 
1311f268092SSimon Camphausen     auto arrayValue =
1321f268092SSimon Camphausen         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
1331f268092SSimon Camphausen     if (!arrayValue) {
1341f268092SSimon Camphausen       return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
1351f268092SSimon Camphausen     }
1361f268092SSimon Camphausen 
1370aa6d57eSMatthias Gehre     auto subscript = rewriter.create<emitc::SubscriptOp>(
1381f268092SSimon Camphausen         op.getLoc(), arrayValue, operands.getIndices());
1390aa6d57eSMatthias Gehre 
140e47b5075SSimon Camphausen     rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
1410aa6d57eSMatthias Gehre     return success();
1420aa6d57eSMatthias Gehre   }
1430aa6d57eSMatthias Gehre };
1440aa6d57eSMatthias Gehre 
1450aa6d57eSMatthias Gehre struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
1460aa6d57eSMatthias Gehre   using OpConversionPattern::OpConversionPattern;
1470aa6d57eSMatthias Gehre 
1480aa6d57eSMatthias Gehre   LogicalResult
1490aa6d57eSMatthias Gehre   matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
1500aa6d57eSMatthias Gehre                   ConversionPatternRewriter &rewriter) const override {
1511f268092SSimon Camphausen     auto arrayValue =
1521f268092SSimon Camphausen         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
1531f268092SSimon Camphausen     if (!arrayValue) {
1541f268092SSimon Camphausen       return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
1551f268092SSimon Camphausen     }
1560aa6d57eSMatthias Gehre 
1570aa6d57eSMatthias Gehre     auto subscript = rewriter.create<emitc::SubscriptOp>(
1581f268092SSimon Camphausen         op.getLoc(), arrayValue, operands.getIndices());
1590aa6d57eSMatthias Gehre     rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
1600aa6d57eSMatthias Gehre                                                  operands.getValue());
1610aa6d57eSMatthias Gehre     return success();
1620aa6d57eSMatthias Gehre   }
1630aa6d57eSMatthias Gehre };
1640aa6d57eSMatthias Gehre } // namespace
1650aa6d57eSMatthias Gehre 
1660aa6d57eSMatthias Gehre void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
1670aa6d57eSMatthias Gehre   typeConverter.addConversion(
1680aa6d57eSMatthias Gehre       [&](MemRefType memRefType) -> std::optional<Type> {
1690aa6d57eSMatthias Gehre         if (!memRefType.hasStaticShape() ||
170*77714293SSimon Camphausen             !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
171*77714293SSimon Camphausen             llvm::any_of(memRefType.getShape(),
172*77714293SSimon Camphausen                          [](int64_t dim) { return dim == 0; })) {
1730aa6d57eSMatthias Gehre           return {};
1740aa6d57eSMatthias Gehre         }
1750aa6d57eSMatthias Gehre         Type convertedElementType =
1760aa6d57eSMatthias Gehre             typeConverter.convertType(memRefType.getElementType());
1770aa6d57eSMatthias Gehre         if (!convertedElementType)
1780aa6d57eSMatthias Gehre           return {};
1790aa6d57eSMatthias Gehre         return emitc::ArrayType::get(memRefType.getShape(),
1800aa6d57eSMatthias Gehre                                      convertedElementType);
1810aa6d57eSMatthias Gehre       });
1820aa6d57eSMatthias Gehre }
1830aa6d57eSMatthias Gehre 
184206fad0eSMatthias Springer void mlir::populateMemRefToEmitCConversionPatterns(
185206fad0eSMatthias Springer     RewritePatternSet &patterns, const TypeConverter &converter) {
18665484656SMatthias Gehre   patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
18765484656SMatthias Gehre                ConvertStore>(converter, patterns.getContext());
1880aa6d57eSMatthias Gehre }
189