xref: /llvm-project/mlir/lib/Tools/mlir-translate/Translation.cpp (revision a1fe1f5f77d48b03b76884a9b9b91a6795193ac1)
1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Tools/mlir-translate/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser/Parser.h"
19 #include "mlir/Tools/ParseUtilities.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include <optional>
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Translation CommandLine Options
27 //===----------------------------------------------------------------------===//
28 
29 struct TranslationOptions {
30   llvm::cl::opt<bool> noImplicitModule{
31       "no-implicit-module",
32       llvm::cl::desc("Disable the parsing of an implicit top-level module op"),
33       llvm::cl::init(false)};
34 };
35 
36 static llvm::ManagedStatic<TranslationOptions> clOptions;
37 
38 void mlir::registerTranslationCLOptions() { *clOptions; }
39 
40 //===----------------------------------------------------------------------===//
41 // Translation Registry
42 //===----------------------------------------------------------------------===//
43 
44 /// Get the mutable static map between registered file-to-file MLIR
45 /// translations.
46 static llvm::StringMap<Translation> &getTranslationRegistry() {
47   static llvm::StringMap<Translation> translationBundle;
48   return translationBundle;
49 }
50 
51 /// Register the given translation.
52 static void registerTranslation(StringRef name, StringRef description,
53                                 Optional<llvm::Align> inputAlignment,
54                                 const TranslateFunction &function) {
55   auto &registry = getTranslationRegistry();
56   if (registry.count(name))
57     llvm::report_fatal_error(
58         "Attempting to overwrite an existing <file-to-file> function");
59   assert(function &&
60          "Attempting to register an empty translate <file-to-file> function");
61   registry[name] = Translation(function, description, inputAlignment);
62 }
63 
64 TranslateRegistration::TranslateRegistration(
65     StringRef name, StringRef description, const TranslateFunction &function) {
66   registerTranslation(name, description, /*inputAlignment=*/std::nullopt,
67                       function);
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // Translation to MLIR
72 //===----------------------------------------------------------------------===//
73 
74 // Puts `function` into the to-MLIR translation registry unless there is already
75 // a function registered for the same name.
76 static void registerTranslateToMLIRFunction(
77     StringRef name, StringRef description,
78     const DialectRegistrationFunction &dialectRegistration,
79     Optional<llvm::Align> inputAlignment,
80     const TranslateSourceMgrToMLIRFunction &function) {
81   auto wrappedFn = [function, dialectRegistration](
82                        const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
83                        raw_ostream &output, MLIRContext *context) {
84     DialectRegistry registry;
85     dialectRegistration(registry);
86     context->appendDialectRegistry(registry);
87     OwningOpRef<Operation *> op = function(sourceMgr, context);
88     if (!op || failed(verify(*op)))
89       return failure();
90     op.get()->print(output);
91     return success();
92   };
93   registerTranslation(name, description, inputAlignment, wrappedFn);
94 }
95 
96 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
97     StringRef name, StringRef description,
98     const TranslateSourceMgrToMLIRFunction &function,
99     const DialectRegistrationFunction &dialectRegistration,
100     Optional<llvm::Align> inputAlignment) {
101   registerTranslateToMLIRFunction(name, description, dialectRegistration,
102                                   inputAlignment, function);
103 }
104 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
105     StringRef name, StringRef description,
106     const TranslateRawSourceMgrToMLIRFunction &function,
107     const DialectRegistrationFunction &dialectRegistration,
108     Optional<llvm::Align> inputAlignment) {
109   registerTranslateToMLIRFunction(
110       name, description, dialectRegistration, inputAlignment,
111       [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
112                  MLIRContext *ctx) { return function(*sourceMgr, ctx); });
113 }
114 /// Wraps `function` with a lambda that extracts a StringRef from a source
115 /// manager and registers the wrapper lambda as a to-MLIR conversion.
116 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
117     StringRef name, StringRef description,
118     const TranslateStringRefToMLIRFunction &function,
119     const DialectRegistrationFunction &dialectRegistration,
120     Optional<llvm::Align> inputAlignment) {
121   registerTranslateToMLIRFunction(
122       name, description, dialectRegistration, inputAlignment,
123       [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
124                  MLIRContext *ctx) {
125         const llvm::MemoryBuffer *buffer =
126             sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
127         return function(buffer->getBuffer(), ctx);
128       });
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // Translation from MLIR
133 //===----------------------------------------------------------------------===//
134 
135 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
136     StringRef name, StringRef description,
137     const TranslateFromMLIRFunction &function,
138     const DialectRegistrationFunction &dialectRegistration) {
139   registerTranslation(
140       name, description, /*inputAlignment=*/std::nullopt,
141       [function,
142        dialectRegistration](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
143                             raw_ostream &output, MLIRContext *context) {
144         DialectRegistry registry;
145         dialectRegistration(registry);
146         context->appendDialectRegistry(registry);
147         bool implicitModule =
148             (!clOptions.isConstructed() || !clOptions->noImplicitModule);
149         OwningOpRef<Operation *> op =
150             parseSourceFileForTool(sourceMgr, context, implicitModule);
151         if (!op || failed(verify(*op)))
152           return failure();
153         return function(op.get(), output);
154       });
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // Translation Parser
159 //===----------------------------------------------------------------------===//
160 
161 TranslationParser::TranslationParser(llvm::cl::Option &opt)
162     : llvm::cl::parser<const Translation *>(opt) {
163   for (const auto &kv : getTranslationRegistry())
164     addLiteralOption(kv.first(), &kv.second, kv.second.getDescription());
165 }
166 
167 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
168                                         size_t globalWidth) const {
169   TranslationParser *tp = const_cast<TranslationParser *>(this);
170   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
171                        [](const TranslationParser::OptionInfo *lhs,
172                           const TranslationParser::OptionInfo *rhs) {
173                          return lhs->Name.compare(rhs->Name);
174                        });
175   llvm::cl::parser<const Translation *>::printOptionInfo(o, globalWidth);
176 }
177