xref: /llvm-project/mlir/lib/Tools/mlir-translate/Translation.cpp (revision d30727fb6c15650dcd1432d5501e2d37f3fd5dda)
1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Tools/mlir-translate/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser/Parser.h"
19 #include "mlir/Tools/ParseUtilties.h"
20 #include "llvm/Support/SourceMgr.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Translation CommandLine Options
26 //===----------------------------------------------------------------------===//
27 
28 struct TranslationOptions {
29   llvm::cl::opt<bool> noImplicitModule{
30       "no-implicit-module",
31       llvm::cl::desc("Disable the parsing of an implicit top-level module op"),
32       llvm::cl::init(false)};
33 };
34 
35 static llvm::ManagedStatic<TranslationOptions> clOptions;
36 
37 void mlir::registerTranslationCLOptions() { *clOptions; }
38 
39 //===----------------------------------------------------------------------===//
40 // Translation Registry
41 //===----------------------------------------------------------------------===//
42 
43 struct TranslationBundle {
44   TranslateFunction translateFunction;
45   StringRef translateDescription;
46 };
47 
48 /// Get the mutable static map between registered file-to-file MLIR translations
49 /// and TranslateFunctions with its description that perform those translations.
50 static llvm::StringMap<TranslationBundle> &getTranslationRegistry() {
51   static llvm::StringMap<TranslationBundle> translationBundle;
52   return translationBundle;
53 }
54 
55 /// Register the given translation.
56 static void registerTranslation(StringRef name, StringRef description,
57                                 const TranslateFunction &function) {
58   auto &translationRegistry = getTranslationRegistry();
59   if (translationRegistry.find(name) != translationRegistry.end())
60     llvm::report_fatal_error(
61         "Attempting to overwrite an existing <file-to-file> function");
62   assert(function &&
63          "Attempting to register an empty translate <file-to-file> function");
64   translationRegistry[name].translateFunction = function;
65   translationRegistry[name].translateDescription = description;
66 }
67 
68 TranslateRegistration::TranslateRegistration(
69     StringRef name, StringRef description, const TranslateFunction &function) {
70   registerTranslation(name, description, function);
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // Translation to MLIR
75 //===----------------------------------------------------------------------===//
76 
77 // Puts `function` into the to-MLIR translation registry unless there is already
78 // a function registered for the same name.
79 static void registerTranslateToMLIRFunction(
80     StringRef name, StringRef description,
81     const TranslateSourceMgrToMLIRFunction &function) {
82   auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
83                               MLIRContext *context) {
84     OwningOpRef<Operation *> op = function(sourceMgr, context);
85     if (!op || failed(verify(*op)))
86       return failure();
87     op.get()->print(output);
88     return success();
89   };
90   registerTranslation(name, description, wrappedFn);
91 }
92 
93 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
94     StringRef name, StringRef description,
95     const TranslateSourceMgrToMLIRFunction &function) {
96   registerTranslateToMLIRFunction(name, description, function);
97 }
98 /// Wraps `function` with a lambda that extracts a StringRef from a source
99 /// manager and registers the wrapper lambda as a to-MLIR conversion.
100 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
101     StringRef name, StringRef description,
102     const TranslateStringRefToMLIRFunction &function) {
103   registerTranslateToMLIRFunction(
104       name, description,
105       [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
106         const llvm::MemoryBuffer *buffer =
107             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
108         return function(buffer->getBuffer(), ctx);
109       });
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Translation from MLIR
114 //===----------------------------------------------------------------------===//
115 
116 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
117     StringRef name, StringRef description,
118     const TranslateFromMLIRFunction &function,
119     const std::function<void(DialectRegistry &)> &dialectRegistration) {
120 
121   registerTranslation(
122       name, description,
123       [function, dialectRegistration](llvm::SourceMgr &sourceMgr,
124                                       raw_ostream &output,
125                                       MLIRContext *context) {
126         DialectRegistry registry;
127         dialectRegistration(registry);
128         context->appendDialectRegistry(registry);
129         bool implicitModule =
130             (!clOptions.isConstructed() || !clOptions->noImplicitModule);
131         OwningOpRef<Operation *> op =
132             parseSourceFileForTool(sourceMgr, context, implicitModule);
133         if (!op || failed(verify(*op)))
134           return failure();
135         return function(op.get(), output);
136       });
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Translation Parser
141 //===----------------------------------------------------------------------===//
142 
143 TranslationParser::TranslationParser(llvm::cl::Option &opt)
144     : llvm::cl::parser<const TranslateFunction *>(opt) {
145   for (const auto &kv : getTranslationRegistry()) {
146     addLiteralOption(kv.first(), &kv.second.translateFunction,
147                      kv.second.translateDescription);
148   }
149 }
150 
151 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
152                                         size_t globalWidth) const {
153   TranslationParser *tp = const_cast<TranslationParser *>(this);
154   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
155                        [](const TranslationParser::OptionInfo *lhs,
156                           const TranslationParser::OptionInfo *rhs) {
157                          return lhs->Name.compare(rhs->Name);
158                        });
159   llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
160 }
161