14529797aSMehdi Amini //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===// 24529797aSMehdi Amini // 34529797aSMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44529797aSMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 54529797aSMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64529797aSMehdi Amini // 74529797aSMehdi Amini //===----------------------------------------------------------------------===// 84529797aSMehdi Amini 9*7498eaa9SFabian Mora #include "mlir/Analysis/DataLayoutAnalysis.h" 104529797aSMehdi Amini #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 114529797aSMehdi Amini #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" 124529797aSMehdi Amini #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13876a480cSMatthias Springer #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 144529797aSMehdi Amini #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 154529797aSMehdi Amini #include "mlir/IR/PatternMatch.h" 164529797aSMehdi Amini #include "mlir/Pass/Pass.h" 174529797aSMehdi Amini #include "mlir/Rewrite/FrozenRewritePatternSet.h" 184529797aSMehdi Amini #include "mlir/Transforms/DialectConversion.h" 194529797aSMehdi Amini #include <memory> 204529797aSMehdi Amini 214529797aSMehdi Amini #define DEBUG_TYPE "convert-to-llvm" 224529797aSMehdi Amini 234529797aSMehdi Amini namespace mlir { 244529797aSMehdi Amini #define GEN_PASS_DEF_CONVERTTOLLVMPASS 254529797aSMehdi Amini #include "mlir/Conversion/Passes.h.inc" 264529797aSMehdi Amini } // namespace mlir 274529797aSMehdi Amini 284529797aSMehdi Amini using namespace mlir; 294529797aSMehdi Amini 304529797aSMehdi Amini namespace { 31*7498eaa9SFabian Mora /// Base class for creating the internal implementation of `convert-to-llvm` 32*7498eaa9SFabian Mora /// passes. 33*7498eaa9SFabian Mora class ConvertToLLVMPassInterface { 34*7498eaa9SFabian Mora public: 35*7498eaa9SFabian Mora ConvertToLLVMPassInterface(MLIRContext *context, 36*7498eaa9SFabian Mora ArrayRef<std::string> filterDialects); 37*7498eaa9SFabian Mora virtual ~ConvertToLLVMPassInterface() = default; 38*7498eaa9SFabian Mora 39*7498eaa9SFabian Mora /// Get the dependent dialects used by `convert-to-llvm`. 40*7498eaa9SFabian Mora static void getDependentDialects(DialectRegistry ®istry); 41*7498eaa9SFabian Mora 42*7498eaa9SFabian Mora /// Initialize the internal state of the `convert-to-llvm` pass 43*7498eaa9SFabian Mora /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`. 44*7498eaa9SFabian Mora /// This method returns whether the initialization process failed. 45*7498eaa9SFabian Mora virtual LogicalResult initialize() = 0; 46*7498eaa9SFabian Mora 47*7498eaa9SFabian Mora /// Transform `op` to LLVM with the conversions available in the pass. The 48*7498eaa9SFabian Mora /// analysis manager can be used to query analyzes like `DataLayoutAnalysis` 49*7498eaa9SFabian Mora /// to further configure the conversion process. This method is invoked by 50*7498eaa9SFabian Mora /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the 51*7498eaa9SFabian Mora /// transformation process failed. 52*7498eaa9SFabian Mora virtual LogicalResult transform(Operation *op, 53*7498eaa9SFabian Mora AnalysisManager manager) const = 0; 54*7498eaa9SFabian Mora 55*7498eaa9SFabian Mora protected: 56*7498eaa9SFabian Mora /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call 57*7498eaa9SFabian Mora /// `visitor` with each of the interfaces. If `filterDialects` is non-empty, 58*7498eaa9SFabian Mora /// then `visitor` is invoked only with the dialects in the `filterDialects` 59*7498eaa9SFabian Mora /// list. 60*7498eaa9SFabian Mora LogicalResult visitInterfaces( 61*7498eaa9SFabian Mora llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor); 62*7498eaa9SFabian Mora MLIRContext *context; 63*7498eaa9SFabian Mora /// List of dialects names to use as filters. 64*7498eaa9SFabian Mora ArrayRef<std::string> filterDialects; 65*7498eaa9SFabian Mora }; 664529797aSMehdi Amini 674529797aSMehdi Amini /// This DialectExtension can be attached to the context, which will invoke the 684529797aSMehdi Amini /// `apply()` method for every loaded dialect. If a dialect implements the 694529797aSMehdi Amini /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects 704529797aSMehdi Amini /// through the interface. This extension is loaded in the context before 714529797aSMehdi Amini /// starting a pass pipeline that involves dialect conversion to LLVM. 724529797aSMehdi Amini class LoadDependentDialectExtension : public DialectExtensionBase { 734529797aSMehdi Amini public: 7484cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension) 7584cc1865SNikhil Kalra 764529797aSMehdi Amini LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {} 774529797aSMehdi Amini 784529797aSMehdi Amini void apply(MLIRContext *context, 794529797aSMehdi Amini MutableArrayRef<Dialect *> dialects) const final { 804529797aSMehdi Amini LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n"); 814529797aSMehdi Amini for (Dialect *dialect : dialects) { 828b51b625SMehdi Amini auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 834529797aSMehdi Amini if (!iface) 844529797aSMehdi Amini continue; 854529797aSMehdi Amini LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for " 864529797aSMehdi Amini << dialect->getNamespace() << "\n"); 874529797aSMehdi Amini iface->loadDependentDialects(context); 884529797aSMehdi Amini } 894529797aSMehdi Amini } 904529797aSMehdi Amini 914529797aSMehdi Amini /// Return a copy of this extension. 92d4f47f29SAdrian Kuegel std::unique_ptr<DialectExtensionBase> clone() const final { 934529797aSMehdi Amini return std::make_unique<LoadDependentDialectExtension>(*this); 944529797aSMehdi Amini } 954529797aSMehdi Amini }; 964529797aSMehdi Amini 97*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 98*7498eaa9SFabian Mora // StaticConvertToLLVM 99*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 100*7498eaa9SFabian Mora 101*7498eaa9SFabian Mora /// Static implementation of the `convert-to-llvm` pass. This version only looks 102*7498eaa9SFabian Mora /// at dialect interfaces to configure the conversion process. 103*7498eaa9SFabian Mora struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { 104*7498eaa9SFabian Mora /// Pattern set with conversions to LLVM. 105*7498eaa9SFabian Mora std::shared_ptr<const FrozenRewritePatternSet> patterns; 106*7498eaa9SFabian Mora /// The conversion target. 107*7498eaa9SFabian Mora std::shared_ptr<const ConversionTarget> target; 108*7498eaa9SFabian Mora /// The LLVM type converter. 109*7498eaa9SFabian Mora std::shared_ptr<const LLVMTypeConverter> typeConverter; 110*7498eaa9SFabian Mora using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface; 111*7498eaa9SFabian Mora 112*7498eaa9SFabian Mora /// Configure the conversion to LLVM at pass initialization. 113*7498eaa9SFabian Mora LogicalResult initialize() final { 114*7498eaa9SFabian Mora auto target = std::make_shared<ConversionTarget>(*context); 115*7498eaa9SFabian Mora auto typeConverter = std::make_shared<LLVMTypeConverter>(context); 116*7498eaa9SFabian Mora RewritePatternSet tempPatterns(context); 117*7498eaa9SFabian Mora target->addLegalDialect<LLVM::LLVMDialect>(); 118*7498eaa9SFabian Mora // Populate the patterns with the dialect interface. 119*7498eaa9SFabian Mora if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) { 120*7498eaa9SFabian Mora iface->populateConvertToLLVMConversionPatterns( 121*7498eaa9SFabian Mora *target, *typeConverter, tempPatterns); 122*7498eaa9SFabian Mora }))) 123*7498eaa9SFabian Mora return failure(); 124*7498eaa9SFabian Mora this->patterns = 125*7498eaa9SFabian Mora std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns)); 126*7498eaa9SFabian Mora this->target = target; 127*7498eaa9SFabian Mora this->typeConverter = typeConverter; 128*7498eaa9SFabian Mora return success(); 129*7498eaa9SFabian Mora } 130*7498eaa9SFabian Mora 131*7498eaa9SFabian Mora /// Apply the conversion driver. 132*7498eaa9SFabian Mora LogicalResult transform(Operation *op, AnalysisManager manager) const final { 133*7498eaa9SFabian Mora if (failed(applyPartialConversion(op, *target, *patterns))) 134*7498eaa9SFabian Mora return failure(); 135*7498eaa9SFabian Mora return success(); 136*7498eaa9SFabian Mora } 137*7498eaa9SFabian Mora }; 138*7498eaa9SFabian Mora 139*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 140*7498eaa9SFabian Mora // DynamicConvertToLLVM 141*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 142*7498eaa9SFabian Mora 143*7498eaa9SFabian Mora /// Dynamic implementation of the `convert-to-llvm` pass. This version inspects 144*7498eaa9SFabian Mora /// the IR to configure the conversion to LLVM. 145*7498eaa9SFabian Mora struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { 146*7498eaa9SFabian Mora /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used 147*7498eaa9SFabian Mora /// to partially configure the conversion process. 148*7498eaa9SFabian Mora std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>> 149*7498eaa9SFabian Mora interfaces; 150*7498eaa9SFabian Mora using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface; 151*7498eaa9SFabian Mora 152*7498eaa9SFabian Mora /// Collect the dialect interfaces used to configure the conversion process. 153*7498eaa9SFabian Mora LogicalResult initialize() final { 154*7498eaa9SFabian Mora auto interfaces = 155*7498eaa9SFabian Mora std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>(); 156*7498eaa9SFabian Mora // Collect the interfaces. 157*7498eaa9SFabian Mora if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) { 158*7498eaa9SFabian Mora interfaces->push_back(iface); 159*7498eaa9SFabian Mora }))) 160*7498eaa9SFabian Mora return failure(); 161*7498eaa9SFabian Mora this->interfaces = interfaces; 162*7498eaa9SFabian Mora return success(); 163*7498eaa9SFabian Mora } 164*7498eaa9SFabian Mora 165*7498eaa9SFabian Mora /// Configure the conversion process and apply the conversion driver. 166*7498eaa9SFabian Mora LogicalResult transform(Operation *op, AnalysisManager manager) const final { 167*7498eaa9SFabian Mora RewritePatternSet patterns(context); 168*7498eaa9SFabian Mora ConversionTarget target(*context); 169*7498eaa9SFabian Mora target.addLegalDialect<LLVM::LLVMDialect>(); 170*7498eaa9SFabian Mora // Get the data layout analysis. 171*7498eaa9SFabian Mora const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>(); 172*7498eaa9SFabian Mora LLVMTypeConverter typeConverter(context, &dlAnalysis); 173*7498eaa9SFabian Mora 174*7498eaa9SFabian Mora // Configure the conversion with dialect level interfaces. 175*7498eaa9SFabian Mora for (ConvertToLLVMPatternInterface *iface : *interfaces) 176*7498eaa9SFabian Mora iface->populateConvertToLLVMConversionPatterns(target, typeConverter, 177*7498eaa9SFabian Mora patterns); 178*7498eaa9SFabian Mora 179*7498eaa9SFabian Mora // Configure the conversion attribute interfaces. 180*7498eaa9SFabian Mora populateOpConvertToLLVMConversionPatterns(op, target, typeConverter, 181*7498eaa9SFabian Mora patterns); 182*7498eaa9SFabian Mora 183*7498eaa9SFabian Mora // Apply the conversion. 184*7498eaa9SFabian Mora if (failed(applyPartialConversion(op, target, std::move(patterns)))) 185*7498eaa9SFabian Mora return failure(); 186*7498eaa9SFabian Mora return success(); 187*7498eaa9SFabian Mora } 188*7498eaa9SFabian Mora }; 189*7498eaa9SFabian Mora 190*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 191*7498eaa9SFabian Mora // ConvertToLLVMPass 192*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 193*7498eaa9SFabian Mora 1944529797aSMehdi Amini /// This is a generic pass to convert to LLVM, it uses the 1954529797aSMehdi Amini /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects 1964529797aSMehdi Amini /// the injection of conversion patterns. 1974529797aSMehdi Amini class ConvertToLLVMPass 1984529797aSMehdi Amini : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> { 199*7498eaa9SFabian Mora std::shared_ptr<const ConvertToLLVMPassInterface> impl; 2004529797aSMehdi Amini 2014529797aSMehdi Amini public: 2024529797aSMehdi Amini using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase; 2034529797aSMehdi Amini void getDependentDialects(DialectRegistry ®istry) const final { 204*7498eaa9SFabian Mora ConvertToLLVMPassInterface::getDependentDialects(registry); 205*7498eaa9SFabian Mora } 206*7498eaa9SFabian Mora 207*7498eaa9SFabian Mora LogicalResult initialize(MLIRContext *context) final { 208*7498eaa9SFabian Mora std::shared_ptr<ConvertToLLVMPassInterface> impl; 209*7498eaa9SFabian Mora // Choose the pass implementation. 210*7498eaa9SFabian Mora if (useDynamic) 211*7498eaa9SFabian Mora impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects); 212*7498eaa9SFabian Mora else 213*7498eaa9SFabian Mora impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects); 214*7498eaa9SFabian Mora if (failed(impl->initialize())) 215*7498eaa9SFabian Mora return failure(); 216*7498eaa9SFabian Mora this->impl = impl; 217*7498eaa9SFabian Mora return success(); 218*7498eaa9SFabian Mora } 219*7498eaa9SFabian Mora 220*7498eaa9SFabian Mora void runOnOperation() final { 221*7498eaa9SFabian Mora if (failed(impl->transform(getOperation(), getAnalysisManager()))) 222*7498eaa9SFabian Mora return signalPassFailure(); 223*7498eaa9SFabian Mora } 224*7498eaa9SFabian Mora }; 225*7498eaa9SFabian Mora 226*7498eaa9SFabian Mora } // namespace 227*7498eaa9SFabian Mora 228*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 229*7498eaa9SFabian Mora // ConvertToLLVMPassInterface 230*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 231*7498eaa9SFabian Mora 232*7498eaa9SFabian Mora ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( 233*7498eaa9SFabian Mora MLIRContext *context, ArrayRef<std::string> filterDialects) 234*7498eaa9SFabian Mora : context(context), filterDialects(filterDialects) {} 235*7498eaa9SFabian Mora 236*7498eaa9SFabian Mora void ConvertToLLVMPassInterface::getDependentDialects( 237*7498eaa9SFabian Mora DialectRegistry ®istry) { 2384529797aSMehdi Amini registry.insert<LLVM::LLVMDialect>(); 2394529797aSMehdi Amini registry.addExtensions<LoadDependentDialectExtension>(); 2404529797aSMehdi Amini } 2414529797aSMehdi Amini 242*7498eaa9SFabian Mora LogicalResult ConvertToLLVMPassInterface::visitInterfaces( 243*7498eaa9SFabian Mora llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) { 2441fdbbd15SMatthias Springer if (!filterDialects.empty()) { 2451fdbbd15SMatthias Springer // Test mode: Populate only patterns from the specified dialects. Produce 2461fdbbd15SMatthias Springer // an error if the dialect is not loaded or does not implement the 2471fdbbd15SMatthias Springer // interface. 248*7498eaa9SFabian Mora for (StringRef dialectName : filterDialects) { 2491fdbbd15SMatthias Springer Dialect *dialect = context->getLoadedDialect(dialectName); 2501fdbbd15SMatthias Springer if (!dialect) 2511fdbbd15SMatthias Springer return emitError(UnknownLoc::get(context)) 2521fdbbd15SMatthias Springer << "dialect not loaded: " << dialectName << "\n"; 2538b51b625SMehdi Amini auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 2541fdbbd15SMatthias Springer if (!iface) 2551fdbbd15SMatthias Springer return emitError(UnknownLoc::get(context)) 2561fdbbd15SMatthias Springer << "dialect does not implement ConvertToLLVMPatternInterface: " 2571fdbbd15SMatthias Springer << dialectName << "\n"; 258*7498eaa9SFabian Mora visitor(iface); 2591fdbbd15SMatthias Springer } 2601fdbbd15SMatthias Springer } else { 2611fdbbd15SMatthias Springer // Normal mode: Populate all patterns from all dialects that implement the 2621fdbbd15SMatthias Springer // interface. 2634529797aSMehdi Amini for (Dialect *dialect : context->getLoadedDialects()) { 2644529797aSMehdi Amini // First time we encounter this dialect: if it implements the interface, 2654529797aSMehdi Amini // let's populate patterns ! 2668b51b625SMehdi Amini auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 2674529797aSMehdi Amini if (!iface) 2684529797aSMehdi Amini continue; 269*7498eaa9SFabian Mora visitor(iface); 2704529797aSMehdi Amini } 2711fdbbd15SMatthias Springer } 2724529797aSMehdi Amini return success(); 2734529797aSMehdi Amini } 2744529797aSMehdi Amini 275*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 276*7498eaa9SFabian Mora // API 277*7498eaa9SFabian Mora //===----------------------------------------------------------------------===// 2781fdbbd15SMatthias Springer 2799e7b6f46SMehdi Amini void mlir::registerConvertToLLVMDependentDialectLoading( 2809e7b6f46SMehdi Amini DialectRegistry ®istry) { 2819e7b6f46SMehdi Amini registry.addExtensions<LoadDependentDialectExtension>(); 2829e7b6f46SMehdi Amini } 2839e7b6f46SMehdi Amini 2841fdbbd15SMatthias Springer std::unique_ptr<Pass> mlir::createConvertToLLVMPass() { 2851fdbbd15SMatthias Springer return std::make_unique<ConvertToLLVMPass>(); 2861fdbbd15SMatthias Springer } 287