xref: /llvm-project/mlir/lib/Tools/mlir-translate/Translation.cpp (revision cf487cce6f64fe2a397d43108bfdbad01a8754fb)
1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Tools/mlir-translate/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Parser/Parser.h"
19 #include "mlir/Tools/ParseUtilities.h"
20 #include "llvm/Support/SourceMgr.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Translation CommandLine Options
26 //===----------------------------------------------------------------------===//
27 
28 struct TranslationOptions {
29   llvm::cl::opt<bool> noImplicitModule{
30       "no-implicit-module",
31       llvm::cl::desc("Disable the parsing of an implicit top-level module op"),
32       llvm::cl::init(false)};
33 };
34 
35 static llvm::ManagedStatic<TranslationOptions> clOptions;
36 
37 void mlir::registerTranslationCLOptions() { *clOptions; }
38 
39 //===----------------------------------------------------------------------===//
40 // Translation Registry
41 //===----------------------------------------------------------------------===//
42 
43 /// Get the mutable static map between registered file-to-file MLIR
44 /// translations.
45 static llvm::StringMap<Translation> &getTranslationRegistry() {
46   static llvm::StringMap<Translation> translationBundle;
47   return translationBundle;
48 }
49 
50 /// Register the given translation.
51 static void registerTranslation(StringRef name, StringRef description,
52                                 Optional<llvm::Align> inputAlignment,
53                                 const TranslateFunction &function) {
54   auto &registry = getTranslationRegistry();
55   if (registry.count(name))
56     llvm::report_fatal_error(
57         "Attempting to overwrite an existing <file-to-file> function");
58   assert(function &&
59          "Attempting to register an empty translate <file-to-file> function");
60   registry[name] = Translation(function, description, inputAlignment);
61 }
62 
63 TranslateRegistration::TranslateRegistration(
64     StringRef name, StringRef description, const TranslateFunction &function) {
65   registerTranslation(name, description, /*inputAlignment=*/std::nullopt,
66                       function);
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Translation to MLIR
71 //===----------------------------------------------------------------------===//
72 
73 // Puts `function` into the to-MLIR translation registry unless there is already
74 // a function registered for the same name.
75 static void registerTranslateToMLIRFunction(
76     StringRef name, StringRef description,
77     const DialectRegistrationFunction &dialectRegistration,
78     Optional<llvm::Align> inputAlignment,
79     const TranslateSourceMgrToMLIRFunction &function) {
80   auto wrappedFn = [function, dialectRegistration](
81                        const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
82                        raw_ostream &output, MLIRContext *context) {
83     DialectRegistry registry;
84     dialectRegistration(registry);
85     context->appendDialectRegistry(registry);
86     OwningOpRef<Operation *> op = function(sourceMgr, context);
87     if (!op || failed(verify(*op)))
88       return failure();
89     op.get()->print(output);
90     return success();
91   };
92   registerTranslation(name, description, inputAlignment, wrappedFn);
93 }
94 
95 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
96     StringRef name, StringRef description,
97     const TranslateSourceMgrToMLIRFunction &function,
98     const DialectRegistrationFunction &dialectRegistration,
99     Optional<llvm::Align> inputAlignment) {
100   registerTranslateToMLIRFunction(name, description, dialectRegistration,
101                                   inputAlignment, function);
102 }
103 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
104     StringRef name, StringRef description,
105     const TranslateRawSourceMgrToMLIRFunction &function,
106     const DialectRegistrationFunction &dialectRegistration,
107     Optional<llvm::Align> inputAlignment) {
108   registerTranslateToMLIRFunction(
109       name, description, dialectRegistration, inputAlignment,
110       [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
111                  MLIRContext *ctx) { return function(*sourceMgr, ctx); });
112 }
113 /// Wraps `function` with a lambda that extracts a StringRef from a source
114 /// manager and registers the wrapper lambda as a to-MLIR conversion.
115 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
116     StringRef name, StringRef description,
117     const TranslateStringRefToMLIRFunction &function,
118     const DialectRegistrationFunction &dialectRegistration,
119     Optional<llvm::Align> inputAlignment) {
120   registerTranslateToMLIRFunction(
121       name, description, dialectRegistration, inputAlignment,
122       [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
123                  MLIRContext *ctx) {
124         const llvm::MemoryBuffer *buffer =
125             sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
126         return function(buffer->getBuffer(), ctx);
127       });
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // Translation from MLIR
132 //===----------------------------------------------------------------------===//
133 
134 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
135     StringRef name, StringRef description,
136     const TranslateFromMLIRFunction &function,
137     const DialectRegistrationFunction &dialectRegistration) {
138   registerTranslation(
139       name, description, /*inputAlignment=*/std::nullopt,
140       [function,
141        dialectRegistration](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
142                             raw_ostream &output, MLIRContext *context) {
143         DialectRegistry registry;
144         dialectRegistration(registry);
145         context->appendDialectRegistry(registry);
146         bool implicitModule =
147             (!clOptions.isConstructed() || !clOptions->noImplicitModule);
148         OwningOpRef<Operation *> op =
149             parseSourceFileForTool(sourceMgr, context, implicitModule);
150         if (!op || failed(verify(*op)))
151           return failure();
152         return function(op.get(), output);
153       });
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // Translation Parser
158 //===----------------------------------------------------------------------===//
159 
160 TranslationParser::TranslationParser(llvm::cl::Option &opt)
161     : llvm::cl::parser<const Translation *>(opt) {
162   for (const auto &kv : getTranslationRegistry())
163     addLiteralOption(kv.first(), &kv.second, kv.second.getDescription());
164 }
165 
166 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
167                                         size_t globalWidth) const {
168   TranslationParser *tp = const_cast<TranslationParser *>(this);
169   llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
170                        [](const TranslationParser::OptionInfo *lhs,
171                           const TranslationParser::OptionInfo *rhs) {
172                          return lhs->Name.compare(rhs->Name);
173                        });
174   llvm::cl::parser<const Translation *>::printOptionInfo(o, globalWidth);
175 }
176