xref: /llvm-project/mlir/lib/Tools/mlir-translate/Translation.cpp (revision ed90f8026e45b9f922002ad03c6f37c4881be1d8)
1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Tools/mlir-translate/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser/Parser.h"
19 #include "mlir/Tools/ParseUtilties.h"
20 #include "llvm/Support/SourceMgr.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Translation Registry
26 //===----------------------------------------------------------------------===//
27 
28 struct TranslationBundle {
29   TranslateFunction translateFunction;
30   StringRef translateDescription;
31 };
32 
33 /// Get the mutable static map between registered file-to-file MLIR translations
34 /// and TranslateFunctions with its description that perform those translations.
35 static llvm::StringMap<TranslationBundle> &getTranslationRegistry() {
36   static llvm::StringMap<TranslationBundle> translationBundle;
37   return translationBundle;
38 }
39 
40 /// Register the given translation.
41 static void registerTranslation(StringRef name, StringRef description,
42                                 const TranslateFunction &function) {
43   auto &translationRegistry = getTranslationRegistry();
44   if (translationRegistry.find(name) != translationRegistry.end())
45     llvm::report_fatal_error(
46         "Attempting to overwrite an existing <file-to-file> function");
47   assert(function &&
48          "Attempting to register an empty translate <file-to-file> function");
49   translationRegistry[name].translateFunction = function;
50   translationRegistry[name].translateDescription = description;
51 }
52 
53 TranslateRegistration::TranslateRegistration(
54     StringRef name, StringRef description, const TranslateFunction &function) {
55   registerTranslation(name, description, function);
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Translation to MLIR
60 //===----------------------------------------------------------------------===//
61 
62 // Puts `function` into the to-MLIR translation registry unless there is already
63 // a function registered for the same name.
64 static void registerTranslateToMLIRFunction(
65     StringRef name, StringRef description,
66     const TranslateSourceMgrToMLIRFunction &function) {
67   auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
68                               MLIRContext *context) {
69     OwningOpRef<Operation *> op = function(sourceMgr, context);
70     if (!op || failed(verify(*op)))
71       return failure();
72     op.get()->print(output);
73     return success();
74   };
75   registerTranslation(name, description, wrappedFn);
76 }
77 
78 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
79     StringRef name, StringRef description,
80     const TranslateSourceMgrToMLIRFunction &function) {
81   registerTranslateToMLIRFunction(name, description, function);
82 }
83 /// Wraps `function` with a lambda that extracts a StringRef from a source
84 /// manager and registers the wrapper lambda as a to-MLIR conversion.
85 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
86     StringRef name, StringRef description,
87     const TranslateStringRefToMLIRFunction &function) {
88   registerTranslateToMLIRFunction(
89       name, description,
90       [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
91         const llvm::MemoryBuffer *buffer =
92             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
93         return function(buffer->getBuffer(), ctx);
94       });
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // Translation from MLIR
99 //===----------------------------------------------------------------------===//
100 
101 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
102     StringRef name, StringRef description,
103     const TranslateFromMLIRFunction &function,
104     const std::function<void(DialectRegistry &)> &dialectRegistration) {
105 
106   static llvm::cl::opt<bool> noImplicitModule{
107       "no-implicit-module",
108       llvm::cl::desc("Disable the parsing of an implicit top-level module op"),
109       llvm::cl::init(false)};
110 
111   registerTranslation(name, description,
112                       [function, dialectRegistration](
113                           llvm::SourceMgr &sourceMgr, raw_ostream &output,
114                           MLIRContext *context) {
115                         DialectRegistry registry;
116                         dialectRegistration(registry);
117                         context->appendDialectRegistry(registry);
118                         OwningOpRef<Operation *> op = parseSourceFileForTool(
119                             sourceMgr, context, !noImplicitModule);
120                         if (!op || failed(verify(*op)))
121                           return failure();
122                         return function(op.get(), output);
123                       });
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // Translation Parser
128 //===----------------------------------------------------------------------===//
129 
130 TranslationParser::TranslationParser(llvm::cl::Option &opt)
131     : llvm::cl::parser<const TranslateFunction *>(opt) {
132   for (const auto &kv : getTranslationRegistry()) {
133     addLiteralOption(kv.first(), &kv.second.translateFunction,
134                      kv.second.translateDescription);
135   }
136 }
137 
138 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
139                                         size_t globalWidth) const {
140   TranslationParser *tp = const_cast<TranslationParser *>(this);
141   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
142                        [](const TranslationParser::OptionInfo *lhs,
143                           const TranslationParser::OptionInfo *rhs) {
144                          return lhs->Name.compare(rhs->Name);
145                        });
146   llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
147 }
148