xref: /llvm-project/mlir/lib/Tools/mlir-translate/Translation.cpp (revision c4cc755c72f189437c789c5f1564193af2d0b7d6)
1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Tools/mlir-translate/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser/Parser.h"
19 #include "llvm/Support/SourceMgr.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Translation Registry
25 //===----------------------------------------------------------------------===//
26 
27 struct TranslationBundle {
28   TranslateFunction translateFunction;
29   StringRef translateDescription;
30 };
31 
32 /// Get the mutable static map between registered file-to-file MLIR translations
33 /// and TranslateFunctions with its description that perform those translations.
34 static llvm::StringMap<TranslationBundle> &getTranslationRegistry() {
35   static llvm::StringMap<TranslationBundle> translationBundle;
36   return translationBundle;
37 }
38 
39 /// Register the given translation.
40 static void registerTranslation(StringRef name, StringRef description,
41                                 const TranslateFunction &function) {
42   auto &translationRegistry = getTranslationRegistry();
43   if (translationRegistry.find(name) != translationRegistry.end())
44     llvm::report_fatal_error(
45         "Attempting to overwrite an existing <file-to-file> function");
46   assert(function &&
47          "Attempting to register an empty translate <file-to-file> function");
48   translationRegistry[name].translateFunction = function;
49   translationRegistry[name].translateDescription = description;
50 }
51 
52 TranslateRegistration::TranslateRegistration(
53     StringRef name, StringRef description, const TranslateFunction &function) {
54   registerTranslation(name, description, function);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // Translation to MLIR
59 //===----------------------------------------------------------------------===//
60 
61 // Puts `function` into the to-MLIR translation registry unless there is already
62 // a function registered for the same name.
63 static void registerTranslateToMLIRFunction(
64     StringRef name, StringRef description,
65     const TranslateSourceMgrToMLIRFunction &function) {
66   auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
67                               MLIRContext *context) {
68     OwningOpRef<ModuleOp> module = function(sourceMgr, context);
69     if (!module || failed(verify(*module)))
70       return failure();
71     module->print(output);
72     return success();
73   };
74   registerTranslation(name, description, wrappedFn);
75 }
76 
77 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
78     StringRef name, StringRef description,
79     const TranslateSourceMgrToMLIRFunction &function) {
80   registerTranslateToMLIRFunction(name, description, function);
81 }
82 /// Wraps `function` with a lambda that extracts a StringRef from a source
83 /// manager and registers the wrapper lambda as a to-MLIR conversion.
84 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
85     StringRef name, StringRef description,
86     const TranslateStringRefToMLIRFunction &function) {
87   registerTranslateToMLIRFunction(
88       name, description,
89       [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
90         const llvm::MemoryBuffer *buffer =
91             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
92         return function(buffer->getBuffer(), ctx);
93       });
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // Translation from MLIR
98 //===----------------------------------------------------------------------===//
99 
100 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
101     StringRef name, StringRef description,
102     const TranslateFromMLIRFunction &function,
103     const std::function<void(DialectRegistry &)> &dialectRegistration) {
104   registerTranslation(name, description,
105                       [function, dialectRegistration](
106                           llvm::SourceMgr &sourceMgr, raw_ostream &output,
107                           MLIRContext *context) {
108                         DialectRegistry registry;
109                         dialectRegistration(registry);
110                         context->appendDialectRegistry(registry);
111                         auto module =
112                             parseSourceFile<ModuleOp>(sourceMgr, context);
113                         if (!module || failed(verify(*module)))
114                           return failure();
115                         return function(module.get(), output);
116                       });
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // Translation Parser
121 //===----------------------------------------------------------------------===//
122 
123 TranslationParser::TranslationParser(llvm::cl::Option &opt)
124     : llvm::cl::parser<const TranslateFunction *>(opt) {
125   for (const auto &kv : getTranslationRegistry()) {
126     addLiteralOption(kv.first(), &kv.second.translateFunction,
127                      kv.second.translateDescription);
128   }
129 }
130 
131 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
132                                         size_t globalWidth) const {
133   TranslationParser *tp = const_cast<TranslationParser *>(this);
134   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
135                        [](const TranslationParser::OptionInfo *lhs,
136                           const TranslationParser::OptionInfo *rhs) {
137                          return lhs->Name.compare(rhs->Name);
138                        });
139   llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
140 }
141