xref: /llvm-project/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp (revision 7498eaa9abf2e4ac0c10fa9a02576d708cc1b624)
1 //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataLayoutAnalysis.h"
10 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include <memory>
20 
21 #define DEBUG_TYPE "convert-to-llvm"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 /// Base class for creating the internal implementation of `convert-to-llvm`
32 /// passes.
33 class ConvertToLLVMPassInterface {
34 public:
35   ConvertToLLVMPassInterface(MLIRContext *context,
36                              ArrayRef<std::string> filterDialects);
37   virtual ~ConvertToLLVMPassInterface() = default;
38 
39   /// Get the dependent dialects used by `convert-to-llvm`.
40   static void getDependentDialects(DialectRegistry &registry);
41 
42   /// Initialize the internal state of the `convert-to-llvm` pass
43   /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
44   /// This method returns whether the initialization process failed.
45   virtual LogicalResult initialize() = 0;
46 
47   /// Transform `op` to LLVM with the conversions available in the pass. The
48   /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
49   /// to further configure the conversion process. This method is invoked by
50   /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
51   /// transformation process failed.
52   virtual LogicalResult transform(Operation *op,
53                                   AnalysisManager manager) const = 0;
54 
55 protected:
56   /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
57   /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58   /// then `visitor` is invoked only with the dialects in the `filterDialects`
59   /// list.
60   LogicalResult visitInterfaces(
61       llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor);
62   MLIRContext *context;
63   /// List of dialects names to use as filters.
64   ArrayRef<std::string> filterDialects;
65 };
66 
67 /// This DialectExtension can be attached to the context, which will invoke the
68 /// `apply()` method for every loaded dialect. If a dialect implements the
69 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
70 /// through the interface. This extension is loaded in the context before
71 /// starting a pass pipeline that involves dialect conversion to LLVM.
72 class LoadDependentDialectExtension : public DialectExtensionBase {
73 public:
74   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
75 
76   LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
77 
78   void apply(MLIRContext *context,
79              MutableArrayRef<Dialect *> dialects) const final {
80     LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
81     for (Dialect *dialect : dialects) {
82       auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
83       if (!iface)
84         continue;
85       LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
86                               << dialect->getNamespace() << "\n");
87       iface->loadDependentDialects(context);
88     }
89   }
90 
91   /// Return a copy of this extension.
92   std::unique_ptr<DialectExtensionBase> clone() const final {
93     return std::make_unique<LoadDependentDialectExtension>(*this);
94   }
95 };
96 
97 //===----------------------------------------------------------------------===//
98 // StaticConvertToLLVM
99 //===----------------------------------------------------------------------===//
100 
101 /// Static implementation of the `convert-to-llvm` pass. This version only looks
102 /// at dialect interfaces to configure the conversion process.
103 struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
104   /// Pattern set with conversions to LLVM.
105   std::shared_ptr<const FrozenRewritePatternSet> patterns;
106   /// The conversion target.
107   std::shared_ptr<const ConversionTarget> target;
108   /// The LLVM type converter.
109   std::shared_ptr<const LLVMTypeConverter> typeConverter;
110   using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
111 
112   /// Configure the conversion to LLVM at pass initialization.
113   LogicalResult initialize() final {
114     auto target = std::make_shared<ConversionTarget>(*context);
115     auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
116     RewritePatternSet tempPatterns(context);
117     target->addLegalDialect<LLVM::LLVMDialect>();
118     // Populate the patterns with the dialect interface.
119     if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
120           iface->populateConvertToLLVMConversionPatterns(
121               *target, *typeConverter, tempPatterns);
122         })))
123       return failure();
124     this->patterns =
125         std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
126     this->target = target;
127     this->typeConverter = typeConverter;
128     return success();
129   }
130 
131   /// Apply the conversion driver.
132   LogicalResult transform(Operation *op, AnalysisManager manager) const final {
133     if (failed(applyPartialConversion(op, *target, *patterns)))
134       return failure();
135     return success();
136   }
137 };
138 
139 //===----------------------------------------------------------------------===//
140 // DynamicConvertToLLVM
141 //===----------------------------------------------------------------------===//
142 
143 /// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
144 /// the IR to configure the conversion to LLVM.
145 struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
146   /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
147   /// to partially configure the conversion process.
148   std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
149       interfaces;
150   using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
151 
152   /// Collect the dialect interfaces used to configure the conversion process.
153   LogicalResult initialize() final {
154     auto interfaces =
155         std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
156     // Collect the interfaces.
157     if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
158           interfaces->push_back(iface);
159         })))
160       return failure();
161     this->interfaces = interfaces;
162     return success();
163   }
164 
165   /// Configure the conversion process and apply the conversion driver.
166   LogicalResult transform(Operation *op, AnalysisManager manager) const final {
167     RewritePatternSet patterns(context);
168     ConversionTarget target(*context);
169     target.addLegalDialect<LLVM::LLVMDialect>();
170     // Get the data layout analysis.
171     const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
172     LLVMTypeConverter typeConverter(context, &dlAnalysis);
173 
174     // Configure the conversion with dialect level interfaces.
175     for (ConvertToLLVMPatternInterface *iface : *interfaces)
176       iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
177                                                      patterns);
178 
179     // Configure the conversion attribute interfaces.
180     populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
181                                               patterns);
182 
183     // Apply the conversion.
184     if (failed(applyPartialConversion(op, target, std::move(patterns))))
185       return failure();
186     return success();
187   }
188 };
189 
190 //===----------------------------------------------------------------------===//
191 // ConvertToLLVMPass
192 //===----------------------------------------------------------------------===//
193 
194 /// This is a generic pass to convert to LLVM, it uses the
195 /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
196 /// the injection of conversion patterns.
197 class ConvertToLLVMPass
198     : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
199   std::shared_ptr<const ConvertToLLVMPassInterface> impl;
200 
201 public:
202   using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
203   void getDependentDialects(DialectRegistry &registry) const final {
204     ConvertToLLVMPassInterface::getDependentDialects(registry);
205   }
206 
207   LogicalResult initialize(MLIRContext *context) final {
208     std::shared_ptr<ConvertToLLVMPassInterface> impl;
209     // Choose the pass implementation.
210     if (useDynamic)
211       impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
212     else
213       impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
214     if (failed(impl->initialize()))
215       return failure();
216     this->impl = impl;
217     return success();
218   }
219 
220   void runOnOperation() final {
221     if (failed(impl->transform(getOperation(), getAnalysisManager())))
222       return signalPassFailure();
223   }
224 };
225 
226 } // namespace
227 
228 //===----------------------------------------------------------------------===//
229 // ConvertToLLVMPassInterface
230 //===----------------------------------------------------------------------===//
231 
232 ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
233     MLIRContext *context, ArrayRef<std::string> filterDialects)
234     : context(context), filterDialects(filterDialects) {}
235 
236 void ConvertToLLVMPassInterface::getDependentDialects(
237     DialectRegistry &registry) {
238   registry.insert<LLVM::LLVMDialect>();
239   registry.addExtensions<LoadDependentDialectExtension>();
240 }
241 
242 LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
243     llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) {
244   if (!filterDialects.empty()) {
245     // Test mode: Populate only patterns from the specified dialects. Produce
246     // an error if the dialect is not loaded or does not implement the
247     // interface.
248     for (StringRef dialectName : filterDialects) {
249       Dialect *dialect = context->getLoadedDialect(dialectName);
250       if (!dialect)
251         return emitError(UnknownLoc::get(context))
252                << "dialect not loaded: " << dialectName << "\n";
253       auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
254       if (!iface)
255         return emitError(UnknownLoc::get(context))
256                << "dialect does not implement ConvertToLLVMPatternInterface: "
257                << dialectName << "\n";
258       visitor(iface);
259     }
260   } else {
261     // Normal mode: Populate all patterns from all dialects that implement the
262     // interface.
263     for (Dialect *dialect : context->getLoadedDialects()) {
264       // First time we encounter this dialect: if it implements the interface,
265       // let's populate patterns !
266       auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
267       if (!iface)
268         continue;
269       visitor(iface);
270     }
271   }
272   return success();
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // API
277 //===----------------------------------------------------------------------===//
278 
279 void mlir::registerConvertToLLVMDependentDialectLoading(
280     DialectRegistry &registry) {
281   registry.addExtensions<LoadDependentDialectExtension>();
282 }
283 
284 std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
285   return std::make_unique<ConvertToLLVMPass>();
286 }
287