1 //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataLayoutAnalysis.h" 10 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include <memory> 20 21 #define DEBUG_TYPE "convert-to-llvm" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_CONVERTTOLLVMPASS 25 #include "mlir/Conversion/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 30 namespace { 31 /// Base class for creating the internal implementation of `convert-to-llvm` 32 /// passes. 33 class ConvertToLLVMPassInterface { 34 public: 35 ConvertToLLVMPassInterface(MLIRContext *context, 36 ArrayRef<std::string> filterDialects); 37 virtual ~ConvertToLLVMPassInterface() = default; 38 39 /// Get the dependent dialects used by `convert-to-llvm`. 40 static void getDependentDialects(DialectRegistry ®istry); 41 42 /// Initialize the internal state of the `convert-to-llvm` pass 43 /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`. 44 /// This method returns whether the initialization process failed. 45 virtual LogicalResult initialize() = 0; 46 47 /// Transform `op` to LLVM with the conversions available in the pass. The 48 /// analysis manager can be used to query analyzes like `DataLayoutAnalysis` 49 /// to further configure the conversion process. This method is invoked by 50 /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the 51 /// transformation process failed. 52 virtual LogicalResult transform(Operation *op, 53 AnalysisManager manager) const = 0; 54 55 protected: 56 /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call 57 /// `visitor` with each of the interfaces. If `filterDialects` is non-empty, 58 /// then `visitor` is invoked only with the dialects in the `filterDialects` 59 /// list. 60 LogicalResult visitInterfaces( 61 llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor); 62 MLIRContext *context; 63 /// List of dialects names to use as filters. 64 ArrayRef<std::string> filterDialects; 65 }; 66 67 /// This DialectExtension can be attached to the context, which will invoke the 68 /// `apply()` method for every loaded dialect. If a dialect implements the 69 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects 70 /// through the interface. This extension is loaded in the context before 71 /// starting a pass pipeline that involves dialect conversion to LLVM. 72 class LoadDependentDialectExtension : public DialectExtensionBase { 73 public: 74 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension) 75 76 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {} 77 78 void apply(MLIRContext *context, 79 MutableArrayRef<Dialect *> dialects) const final { 80 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n"); 81 for (Dialect *dialect : dialects) { 82 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 83 if (!iface) 84 continue; 85 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for " 86 << dialect->getNamespace() << "\n"); 87 iface->loadDependentDialects(context); 88 } 89 } 90 91 /// Return a copy of this extension. 92 std::unique_ptr<DialectExtensionBase> clone() const final { 93 return std::make_unique<LoadDependentDialectExtension>(*this); 94 } 95 }; 96 97 //===----------------------------------------------------------------------===// 98 // StaticConvertToLLVM 99 //===----------------------------------------------------------------------===// 100 101 /// Static implementation of the `convert-to-llvm` pass. This version only looks 102 /// at dialect interfaces to configure the conversion process. 103 struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { 104 /// Pattern set with conversions to LLVM. 105 std::shared_ptr<const FrozenRewritePatternSet> patterns; 106 /// The conversion target. 107 std::shared_ptr<const ConversionTarget> target; 108 /// The LLVM type converter. 109 std::shared_ptr<const LLVMTypeConverter> typeConverter; 110 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface; 111 112 /// Configure the conversion to LLVM at pass initialization. 113 LogicalResult initialize() final { 114 auto target = std::make_shared<ConversionTarget>(*context); 115 auto typeConverter = std::make_shared<LLVMTypeConverter>(context); 116 RewritePatternSet tempPatterns(context); 117 target->addLegalDialect<LLVM::LLVMDialect>(); 118 // Populate the patterns with the dialect interface. 119 if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) { 120 iface->populateConvertToLLVMConversionPatterns( 121 *target, *typeConverter, tempPatterns); 122 }))) 123 return failure(); 124 this->patterns = 125 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns)); 126 this->target = target; 127 this->typeConverter = typeConverter; 128 return success(); 129 } 130 131 /// Apply the conversion driver. 132 LogicalResult transform(Operation *op, AnalysisManager manager) const final { 133 if (failed(applyPartialConversion(op, *target, *patterns))) 134 return failure(); 135 return success(); 136 } 137 }; 138 139 //===----------------------------------------------------------------------===// 140 // DynamicConvertToLLVM 141 //===----------------------------------------------------------------------===// 142 143 /// Dynamic implementation of the `convert-to-llvm` pass. This version inspects 144 /// the IR to configure the conversion to LLVM. 145 struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { 146 /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used 147 /// to partially configure the conversion process. 148 std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>> 149 interfaces; 150 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface; 151 152 /// Collect the dialect interfaces used to configure the conversion process. 153 LogicalResult initialize() final { 154 auto interfaces = 155 std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>(); 156 // Collect the interfaces. 157 if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) { 158 interfaces->push_back(iface); 159 }))) 160 return failure(); 161 this->interfaces = interfaces; 162 return success(); 163 } 164 165 /// Configure the conversion process and apply the conversion driver. 166 LogicalResult transform(Operation *op, AnalysisManager manager) const final { 167 RewritePatternSet patterns(context); 168 ConversionTarget target(*context); 169 target.addLegalDialect<LLVM::LLVMDialect>(); 170 // Get the data layout analysis. 171 const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>(); 172 LLVMTypeConverter typeConverter(context, &dlAnalysis); 173 174 // Configure the conversion with dialect level interfaces. 175 for (ConvertToLLVMPatternInterface *iface : *interfaces) 176 iface->populateConvertToLLVMConversionPatterns(target, typeConverter, 177 patterns); 178 179 // Configure the conversion attribute interfaces. 180 populateOpConvertToLLVMConversionPatterns(op, target, typeConverter, 181 patterns); 182 183 // Apply the conversion. 184 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 185 return failure(); 186 return success(); 187 } 188 }; 189 190 //===----------------------------------------------------------------------===// 191 // ConvertToLLVMPass 192 //===----------------------------------------------------------------------===// 193 194 /// This is a generic pass to convert to LLVM, it uses the 195 /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects 196 /// the injection of conversion patterns. 197 class ConvertToLLVMPass 198 : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> { 199 std::shared_ptr<const ConvertToLLVMPassInterface> impl; 200 201 public: 202 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase; 203 void getDependentDialects(DialectRegistry ®istry) const final { 204 ConvertToLLVMPassInterface::getDependentDialects(registry); 205 } 206 207 LogicalResult initialize(MLIRContext *context) final { 208 std::shared_ptr<ConvertToLLVMPassInterface> impl; 209 // Choose the pass implementation. 210 if (useDynamic) 211 impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects); 212 else 213 impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects); 214 if (failed(impl->initialize())) 215 return failure(); 216 this->impl = impl; 217 return success(); 218 } 219 220 void runOnOperation() final { 221 if (failed(impl->transform(getOperation(), getAnalysisManager()))) 222 return signalPassFailure(); 223 } 224 }; 225 226 } // namespace 227 228 //===----------------------------------------------------------------------===// 229 // ConvertToLLVMPassInterface 230 //===----------------------------------------------------------------------===// 231 232 ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( 233 MLIRContext *context, ArrayRef<std::string> filterDialects) 234 : context(context), filterDialects(filterDialects) {} 235 236 void ConvertToLLVMPassInterface::getDependentDialects( 237 DialectRegistry ®istry) { 238 registry.insert<LLVM::LLVMDialect>(); 239 registry.addExtensions<LoadDependentDialectExtension>(); 240 } 241 242 LogicalResult ConvertToLLVMPassInterface::visitInterfaces( 243 llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) { 244 if (!filterDialects.empty()) { 245 // Test mode: Populate only patterns from the specified dialects. Produce 246 // an error if the dialect is not loaded or does not implement the 247 // interface. 248 for (StringRef dialectName : filterDialects) { 249 Dialect *dialect = context->getLoadedDialect(dialectName); 250 if (!dialect) 251 return emitError(UnknownLoc::get(context)) 252 << "dialect not loaded: " << dialectName << "\n"; 253 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 254 if (!iface) 255 return emitError(UnknownLoc::get(context)) 256 << "dialect does not implement ConvertToLLVMPatternInterface: " 257 << dialectName << "\n"; 258 visitor(iface); 259 } 260 } else { 261 // Normal mode: Populate all patterns from all dialects that implement the 262 // interface. 263 for (Dialect *dialect : context->getLoadedDialects()) { 264 // First time we encounter this dialect: if it implements the interface, 265 // let's populate patterns ! 266 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); 267 if (!iface) 268 continue; 269 visitor(iface); 270 } 271 } 272 return success(); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // API 277 //===----------------------------------------------------------------------===// 278 279 void mlir::registerConvertToLLVMDependentDialectLoading( 280 DialectRegistry ®istry) { 281 registry.addExtensions<LoadDependentDialectExtension>(); 282 } 283 284 std::unique_ptr<Pass> mlir::createConvertToLLVMPass() { 285 return std::make_unique<ConvertToLLVMPass>(); 286 } 287