1 //===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert memref ops into emitc ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" 14 15 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" 16 #include "mlir/Dialect/EmitC/IR/EmitC.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 namespace mlir { 22 #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC 23 #include "mlir/Conversion/Passes.h.inc" 24 } // namespace mlir 25 26 using namespace mlir; 27 28 namespace { 29 struct ConvertMemRefToEmitCPass 30 : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { 31 void runOnOperation() override { 32 TypeConverter converter; 33 34 // Fallback for other types. 35 converter.addConversion([](Type type) -> std::optional<Type> { 36 if (emitc::isSupportedEmitCType(type)) 37 return type; 38 return {}; 39 }); 40 41 populateMemRefToEmitCTypeConversion(converter); 42 43 auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, 44 ValueRange inputs, 45 Location loc) -> Value { 46 if (inputs.size() != 1) 47 return Value(); 48 49 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) 50 .getResult(0); 51 }; 52 53 converter.addSourceMaterialization(materializeAsUnrealizedCast); 54 converter.addTargetMaterialization(materializeAsUnrealizedCast); 55 56 RewritePatternSet patterns(&getContext()); 57 populateMemRefToEmitCConversionPatterns(patterns, converter); 58 59 ConversionTarget target(getContext()); 60 target.addIllegalDialect<memref::MemRefDialect>(); 61 target.addLegalDialect<emitc::EmitCDialect>(); 62 63 if (failed(applyPartialConversion(getOperation(), target, 64 std::move(patterns)))) 65 return signalPassFailure(); 66 } 67 }; 68 } // namespace 69