xref: /llvm-project/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp (revision e95e94adc6bb748de015ac3053e7f0786b65f351)
1 //===- TestFromLLVMIRTranslation.cpp - Import Test dialect from LLVM IR ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between LLVM IR and the MLIR Test dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "TestDialect.h"
14 #include "TestOps.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
21 #include "mlir/Target/LLVMIR/Import.h"
22 #include "mlir/Target/LLVMIR/ModuleImport.h"
23 #include "mlir/Tools/mlir-translate/Translation.h"
24 
25 #include "llvm/IR/Instructions.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/Verifier.h"
28 #include "llvm/IRReader/IRReader.h"
29 #include "llvm/Support/SourceMgr.h"
30 
31 using namespace mlir;
32 using namespace test;
33 
getSupportedInstructionsImpl()34 static ArrayRef<unsigned> getSupportedInstructionsImpl() {
35   static unsigned instructions[] = {llvm::Instruction::Load};
36   return instructions;
37 }
38 
convertLoad(OpBuilder & builder,llvm::Instruction * inst,ArrayRef<llvm::Value * > llvmOperands,LLVM::ModuleImport & moduleImport)39 static LogicalResult convertLoad(OpBuilder &builder, llvm::Instruction *inst,
40                                  ArrayRef<llvm::Value *> llvmOperands,
41                                  LLVM::ModuleImport &moduleImport) {
42   FailureOr<Value> addr = moduleImport.convertValue(llvmOperands[0]);
43   if (failed(addr))
44     return failure();
45   // Create the LoadOp
46   Value loadOp = builder.create<LLVM::LoadOp>(
47       moduleImport.translateLoc(inst->getDebugLoc()),
48       moduleImport.convertType(inst->getType()), *addr);
49   moduleImport.mapValue(inst) = builder.create<SameOperandElementTypeOp>(
50       loadOp.getLoc(), loadOp.getType(), loadOp, loadOp);
51   return success();
52 }
53 
54 namespace {
55 class TestDialectLLVMImportDialectInterface
56     : public LLVMImportDialectInterface {
57 public:
58   using LLVMImportDialectInterface::LLVMImportDialectInterface;
59 
60   LogicalResult
convertInstruction(OpBuilder & builder,llvm::Instruction * inst,ArrayRef<llvm::Value * > llvmOperands,LLVM::ModuleImport & moduleImport) const61   convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
62                      ArrayRef<llvm::Value *> llvmOperands,
63                      LLVM::ModuleImport &moduleImport) const override {
64     switch (inst->getOpcode()) {
65     case llvm::Instruction::Load:
66       return convertLoad(builder, inst, llvmOperands, moduleImport);
67     default:
68       break;
69     }
70     return failure();
71   }
72 
getSupportedInstructions() const73   ArrayRef<unsigned> getSupportedInstructions() const override {
74     return getSupportedInstructionsImpl();
75   }
76 };
77 } // namespace
78 
79 namespace mlir {
registerTestFromLLVMIR()80 void registerTestFromLLVMIR() {
81   TranslateToMLIRRegistration registration(
82       "test-import-llvmir", "test dialect from LLVM IR",
83       [](llvm::SourceMgr &sourceMgr,
84          MLIRContext *context) -> OwningOpRef<Operation *> {
85         llvm::SMDiagnostic err;
86         llvm::LLVMContext llvmContext;
87         std::unique_ptr<llvm::Module> llvmModule =
88             llvm::parseIR(*sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()),
89                           err, llvmContext);
90         if (!llvmModule) {
91           std::string errStr;
92           llvm::raw_string_ostream errStream(errStr);
93           err.print(/*ProgName=*/"", errStream);
94           emitError(UnknownLoc::get(context)) << errStream.str();
95           return {};
96         }
97         if (llvm::verifyModule(*llvmModule, &llvm::errs()))
98           return nullptr;
99 
100         return translateLLVMIRToModule(std::move(llvmModule), context, false);
101       },
102       [](DialectRegistry &registry) {
103         registry.insert<DLTIDialect>();
104         registry.insert<test::TestDialect>();
105         registerLLVMDialectImport(registry);
106         registry.addExtension(
107             +[](MLIRContext *ctx, test::TestDialect *dialect) {
108               dialect->addInterfaces<TestDialectLLVMImportDialectInterface>();
109             });
110       });
111 }
112 } // namespace mlir
113