xref: /llvm-project/mlir/lib/Tools/mlir-translate/Translation.cpp (revision 0aea1f2f21b8b3984072dc2ea33857d077d91af2)
1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Tools/mlir-translate/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser/Parser.h"
19 #include "mlir/Tools/ParseUtilities.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include <optional>
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Translation CommandLine Options
28 //===----------------------------------------------------------------------===//
29 
30 struct TranslationOptions {
31   llvm::cl::opt<bool> noImplicitModule{
32       "no-implicit-module",
33       llvm::cl::desc("Disable the parsing of an implicit top-level module op"),
34       llvm::cl::init(false)};
35 };
36 
37 static llvm::ManagedStatic<TranslationOptions> clOptions;
38 
registerTranslationCLOptions()39 void mlir::registerTranslationCLOptions() { *clOptions; }
40 
41 //===----------------------------------------------------------------------===//
42 // Translation Registry
43 //===----------------------------------------------------------------------===//
44 
45 /// Get the mutable static map between registered file-to-file MLIR
46 /// translations.
getTranslationRegistry()47 static llvm::StringMap<Translation> &getTranslationRegistry() {
48   static llvm::StringMap<Translation> translationBundle;
49   return translationBundle;
50 }
51 
52 /// Register the given translation.
registerTranslation(StringRef name,StringRef description,std::optional<llvm::Align> inputAlignment,const TranslateFunction & function)53 static void registerTranslation(StringRef name, StringRef description,
54                                 std::optional<llvm::Align> inputAlignment,
55                                 const TranslateFunction &function) {
56   auto &registry = getTranslationRegistry();
57   if (registry.count(name))
58     llvm::report_fatal_error(
59         "Attempting to overwrite an existing <file-to-file> function");
60   assert(function &&
61          "Attempting to register an empty translate <file-to-file> function");
62   registry[name] = Translation(function, description, inputAlignment);
63 }
64 
TranslateRegistration(StringRef name,StringRef description,const TranslateFunction & function)65 TranslateRegistration::TranslateRegistration(
66     StringRef name, StringRef description, const TranslateFunction &function) {
67   registerTranslation(name, description, /*inputAlignment=*/std::nullopt,
68                       function);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Translation to MLIR
73 //===----------------------------------------------------------------------===//
74 
75 // Puts `function` into the to-MLIR translation registry unless there is already
76 // a function registered for the same name.
registerTranslateToMLIRFunction(StringRef name,StringRef description,const DialectRegistrationFunction & dialectRegistration,std::optional<llvm::Align> inputAlignment,const TranslateSourceMgrToMLIRFunction & function)77 static void registerTranslateToMLIRFunction(
78     StringRef name, StringRef description,
79     const DialectRegistrationFunction &dialectRegistration,
80     std::optional<llvm::Align> inputAlignment,
81     const TranslateSourceMgrToMLIRFunction &function) {
82   auto wrappedFn = [function, dialectRegistration](
83                        const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
84                        raw_ostream &output, MLIRContext *context) {
85     DialectRegistry registry;
86     dialectRegistration(registry);
87     context->appendDialectRegistry(registry);
88     OwningOpRef<Operation *> op = function(sourceMgr, context);
89     if (!op || failed(verify(*op)))
90       return failure();
91     op.get()->print(output);
92     return success();
93   };
94   registerTranslation(name, description, inputAlignment, wrappedFn);
95 }
96 
TranslateToMLIRRegistration(StringRef name,StringRef description,const TranslateSourceMgrToMLIRFunction & function,const DialectRegistrationFunction & dialectRegistration,std::optional<llvm::Align> inputAlignment)97 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
98     StringRef name, StringRef description,
99     const TranslateSourceMgrToMLIRFunction &function,
100     const DialectRegistrationFunction &dialectRegistration,
101     std::optional<llvm::Align> inputAlignment) {
102   registerTranslateToMLIRFunction(name, description, dialectRegistration,
103                                   inputAlignment, function);
104 }
TranslateToMLIRRegistration(StringRef name,StringRef description,const TranslateRawSourceMgrToMLIRFunction & function,const DialectRegistrationFunction & dialectRegistration,std::optional<llvm::Align> inputAlignment)105 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
106     StringRef name, StringRef description,
107     const TranslateRawSourceMgrToMLIRFunction &function,
108     const DialectRegistrationFunction &dialectRegistration,
109     std::optional<llvm::Align> inputAlignment) {
110   registerTranslateToMLIRFunction(
111       name, description, dialectRegistration, inputAlignment,
112       [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
113                  MLIRContext *ctx) { return function(*sourceMgr, ctx); });
114 }
115 /// Wraps `function` with a lambda that extracts a StringRef from a source
116 /// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration(StringRef name,StringRef description,const TranslateStringRefToMLIRFunction & function,const DialectRegistrationFunction & dialectRegistration,std::optional<llvm::Align> inputAlignment)117 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
118     StringRef name, StringRef description,
119     const TranslateStringRefToMLIRFunction &function,
120     const DialectRegistrationFunction &dialectRegistration,
121     std::optional<llvm::Align> inputAlignment) {
122   registerTranslateToMLIRFunction(
123       name, description, dialectRegistration, inputAlignment,
124       [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
125                  MLIRContext *ctx) {
126         const llvm::MemoryBuffer *buffer =
127             sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
128         return function(buffer->getBuffer(), ctx);
129       });
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // Translation from MLIR
134 //===----------------------------------------------------------------------===//
135 
TranslateFromMLIRRegistration(StringRef name,StringRef description,const TranslateFromMLIRFunction & function,const DialectRegistrationFunction & dialectRegistration)136 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
137     StringRef name, StringRef description,
138     const TranslateFromMLIRFunction &function,
139     const DialectRegistrationFunction &dialectRegistration) {
140   registerTranslation(
141       name, description, /*inputAlignment=*/std::nullopt,
142       [function,
143        dialectRegistration](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
144                             raw_ostream &output, MLIRContext *context) {
145         DialectRegistry registry;
146         dialectRegistration(registry);
147         context->appendDialectRegistry(registry);
148         bool implicitModule =
149             (!clOptions.isConstructed() || !clOptions->noImplicitModule);
150         OwningOpRef<Operation *> op =
151             parseSourceFileForTool(sourceMgr, context, implicitModule);
152         if (!op || failed(verify(*op)))
153           return failure();
154         return function(op.get(), output);
155       });
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Translation Parser
160 //===----------------------------------------------------------------------===//
161 
TranslationParser(llvm::cl::Option & opt)162 TranslationParser::TranslationParser(llvm::cl::Option &opt)
163     : llvm::cl::parser<const Translation *>(opt) {
164   for (const auto &kv : getTranslationRegistry())
165     addLiteralOption(kv.first(), &kv.second, kv.second.getDescription());
166 }
167 
printOptionInfo(const llvm::cl::Option & o,size_t globalWidth) const168 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
169                                         size_t globalWidth) const {
170   TranslationParser *tp = const_cast<TranslationParser *>(this);
171   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
172                        [](const TranslationParser::OptionInfo *lhs,
173                           const TranslationParser::OptionInfo *rhs) {
174                          return lhs->Name.compare(rhs->Name);
175                        });
176   llvm::cl::parser<const Translation *>::printOptionInfo(o, globalWidth);
177 }
178