xref: /llvm-project/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- NVVMToLLVM.cpp - NVVM to LLVM dialect conversion -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation NVVM ops which is not supported in LLVM
10 // core.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
15 
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
21 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Value.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Support/LLVM.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #define DEBUG_TYPE "nvvm-to-llvm"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
36 #include "mlir/Conversion/Passes.h.inc"
37 } // namespace mlir
38 
39 using namespace mlir;
40 using namespace NVVM;
41 
42 namespace {
43 
44 struct PtxLowering
45     : public OpInterfaceRewritePattern<BasicPtxBuilderInterface> {
46   using OpInterfaceRewritePattern<
47       BasicPtxBuilderInterface>::OpInterfaceRewritePattern;
48 
PtxLowering__anon308f0b0a0111::PtxLowering49   PtxLowering(MLIRContext *context, PatternBenefit benefit = 2)
50       : OpInterfaceRewritePattern(context, benefit) {}
51 
matchAndRewrite__anon308f0b0a0111::PtxLowering52   LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
53                                 PatternRewriter &rewriter) const override {
54     if (op.hasIntrinsic()) {
55       LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
56       return failure();
57     }
58 
59     SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
60     LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
61     PtxBuilder generator(op, rewriter);
62 
63     op.getAsmValues(rewriter, asmValues);
64     for (auto &[asmValue, modifier] : asmValues) {
65       LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
66       generator.insertValue(asmValue, modifier);
67     }
68 
69     generator.buildAndReplaceOp();
70     return success();
71   }
72 };
73 
74 struct ConvertNVVMToLLVMPass
75     : public impl::ConvertNVVMToLLVMPassBase<ConvertNVVMToLLVMPass> {
76   using Base::Base;
77 
getDependentDialects__anon308f0b0a0111::ConvertNVVMToLLVMPass78   void getDependentDialects(DialectRegistry &registry) const override {
79     registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
80   }
81 
runOnOperation__anon308f0b0a0111::ConvertNVVMToLLVMPass82   void runOnOperation() override {
83     ConversionTarget target(getContext());
84     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
85     RewritePatternSet pattern(&getContext());
86     mlir::populateNVVMToLLVMConversionPatterns(pattern);
87     if (failed(
88             applyPartialConversion(getOperation(), target, std::move(pattern))))
89       signalPassFailure();
90   }
91 };
92 
93 /// Implement the interface to convert NVVM to LLVM.
94 struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
95   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
loadDependentDialects__anon308f0b0a0111::NVVMToLLVMDialectInterface96   void loadDependentDialects(MLIRContext *context) const final {
97     context->loadDialect<NVVMDialect>();
98   }
99 
100   /// Hook for derived dialect interface to provide conversion patterns
101   /// and mark dialect legal for the conversion target.
populateConvertToLLVMConversionPatterns__anon308f0b0a0111::NVVMToLLVMDialectInterface102   void populateConvertToLLVMConversionPatterns(
103       ConversionTarget &target, LLVMTypeConverter &typeConverter,
104       RewritePatternSet &patterns) const final {
105     populateNVVMToLLVMConversionPatterns(patterns);
106   }
107 };
108 
109 } // namespace
110 
populateNVVMToLLVMConversionPatterns(RewritePatternSet & patterns)111 void mlir::populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
112   patterns.add<PtxLowering>(patterns.getContext());
113 }
114 
registerConvertNVVMToLLVMInterface(DialectRegistry & registry)115 void mlir::registerConvertNVVMToLLVMInterface(DialectRegistry &registry) {
116   registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
117     dialect->addInterfaces<NVVMToLLVMDialectInterface>();
118   });
119 }
120