1 //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert memref ops into emitc ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" 14 15 #include "mlir/Dialect/EmitC/IR/EmitC.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 using namespace mlir; 22 23 namespace { 24 struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> { 25 using OpConversionPattern::OpConversionPattern; 26 27 LogicalResult 28 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands, 29 ConversionPatternRewriter &rewriter) const override { 30 31 if (!op.getType().hasStaticShape()) { 32 return rewriter.notifyMatchFailure( 33 op.getLoc(), "cannot transform alloca with dynamic shape"); 34 } 35 36 if (op.getAlignment().value_or(1) > 1) { 37 // TODO: Allow alignment if it is not more than the natural alignment 38 // of the C array. 39 return rewriter.notifyMatchFailure( 40 op.getLoc(), "cannot transform alloca with alignment requirement"); 41 } 42 43 auto resultTy = getTypeConverter()->convertType(op.getType()); 44 if (!resultTy) { 45 return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); 46 } 47 auto noInit = emitc::OpaqueAttr::get(getContext(), ""); 48 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit); 49 return success(); 50 } 51 }; 52 53 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { 54 using OpConversionPattern::OpConversionPattern; 55 56 LogicalResult 57 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, 58 ConversionPatternRewriter &rewriter) const override { 59 60 if (!op.getType().hasStaticShape()) { 61 return rewriter.notifyMatchFailure( 62 op.getLoc(), "cannot transform global with dynamic shape"); 63 } 64 65 if (op.getAlignment().value_or(1) > 1) { 66 // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier. 67 return rewriter.notifyMatchFailure( 68 op.getLoc(), "global variable with alignment requirement is " 69 "currently not supported"); 70 } 71 auto resultTy = getTypeConverter()->convertType(op.getType()); 72 if (!resultTy) { 73 return rewriter.notifyMatchFailure(op.getLoc(), 74 "cannot convert result type"); 75 } 76 77 SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op); 78 if (visibility != SymbolTable::Visibility::Public && 79 visibility != SymbolTable::Visibility::Private) { 80 return rewriter.notifyMatchFailure( 81 op.getLoc(), 82 "only public and private visibility is currently supported"); 83 } 84 // We are explicit in specifing the linkage because the default linkage 85 // for constants is different in C and C++. 86 bool staticSpecifier = visibility == SymbolTable::Visibility::Private; 87 bool externSpecifier = !staticSpecifier; 88 89 Attribute initialValue = operands.getInitialValueAttr(); 90 if (isa_and_present<UnitAttr>(initialValue)) 91 initialValue = {}; 92 93 rewriter.replaceOpWithNewOp<emitc::GlobalOp>( 94 op, operands.getSymName(), resultTy, initialValue, externSpecifier, 95 staticSpecifier, operands.getConstant()); 96 return success(); 97 } 98 }; 99 100 struct ConvertGetGlobal final 101 : public OpConversionPattern<memref::GetGlobalOp> { 102 using OpConversionPattern::OpConversionPattern; 103 104 LogicalResult 105 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, 106 ConversionPatternRewriter &rewriter) const override { 107 108 auto resultTy = getTypeConverter()->convertType(op.getType()); 109 if (!resultTy) { 110 return rewriter.notifyMatchFailure(op.getLoc(), 111 "cannot convert result type"); 112 } 113 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy, 114 operands.getNameAttr()); 115 return success(); 116 } 117 }; 118 119 struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { 120 using OpConversionPattern::OpConversionPattern; 121 122 LogicalResult 123 matchAndRewrite(memref::LoadOp op, OpAdaptor operands, 124 ConversionPatternRewriter &rewriter) const override { 125 126 auto resultTy = getTypeConverter()->convertType(op.getType()); 127 if (!resultTy) { 128 return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); 129 } 130 131 auto arrayValue = 132 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref()); 133 if (!arrayValue) { 134 return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); 135 } 136 137 auto subscript = rewriter.create<emitc::SubscriptOp>( 138 op.getLoc(), arrayValue, operands.getIndices()); 139 140 rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript); 141 return success(); 142 } 143 }; 144 145 struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { 146 using OpConversionPattern::OpConversionPattern; 147 148 LogicalResult 149 matchAndRewrite(memref::StoreOp op, OpAdaptor operands, 150 ConversionPatternRewriter &rewriter) const override { 151 auto arrayValue = 152 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref()); 153 if (!arrayValue) { 154 return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); 155 } 156 157 auto subscript = rewriter.create<emitc::SubscriptOp>( 158 op.getLoc(), arrayValue, operands.getIndices()); 159 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, 160 operands.getValue()); 161 return success(); 162 } 163 }; 164 } // namespace 165 166 void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { 167 typeConverter.addConversion( 168 [&](MemRefType memRefType) -> std::optional<Type> { 169 if (!memRefType.hasStaticShape() || 170 !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || 171 llvm::any_of(memRefType.getShape(), 172 [](int64_t dim) { return dim == 0; })) { 173 return {}; 174 } 175 Type convertedElementType = 176 typeConverter.convertType(memRefType.getElementType()); 177 if (!convertedElementType) 178 return {}; 179 return emitc::ArrayType::get(memRefType.getShape(), 180 convertedElementType); 181 }); 182 } 183 184 void mlir::populateMemRefToEmitCConversionPatterns( 185 RewritePatternSet &patterns, const TypeConverter &converter) { 186 patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, 187 ConvertStore>(converter, patterns.getContext()); 188 } 189