xref: /llvm-project/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- UBToSPIRV.cpp - UB to SPIRV-V dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
10 
11 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/Dialect/UB/IR/UBOps.h"
14 #include "mlir/Pass/Pass.h"
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_UBTOSPIRVCONVERSIONPASS
18 #include "mlir/Conversion/Passes.h.inc"
19 } // namespace mlir
20 
21 using namespace mlir;
22 
23 namespace {
24 
25 struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
26   using OpConversionPattern::OpConversionPattern;
27 
28   LogicalResult
29   matchAndRewrite(ub::PoisonOp op, OpAdaptor,
30                   ConversionPatternRewriter &rewriter) const override {
31     Type origType = op.getType();
32     if (!origType.isIntOrIndexOrFloat())
33       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
34         diag << "unsupported type " << origType;
35       });
36 
37     Type resType = getTypeConverter()->convertType(origType);
38     if (!resType)
39       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
40         diag << "failed to convert result type " << origType;
41       });
42 
43     rewriter.replaceOpWithNewOp<spirv::UndefOp>(op, resType);
44     return success();
45   }
46 };
47 
48 } // namespace
49 
50 //===----------------------------------------------------------------------===//
51 // Pass Definition
52 //===----------------------------------------------------------------------===//
53 
54 namespace {
55 struct UBToSPIRVConversionPass final
56     : impl::UBToSPIRVConversionPassBase<UBToSPIRVConversionPass> {
57   using Base::Base;
58 
59   void runOnOperation() override {
60     Operation *op = getOperation();
61     spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
62     std::unique_ptr<SPIRVConversionTarget> target =
63         SPIRVConversionTarget::get(targetAttr);
64 
65     SPIRVConversionOptions options;
66     SPIRVTypeConverter typeConverter(targetAttr, options);
67 
68     RewritePatternSet patterns(&getContext());
69     ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
70 
71     if (failed(applyPartialConversion(op, *target, std::move(patterns))))
72       signalPassFailure();
73   }
74 };
75 } // namespace
76 
77 //===----------------------------------------------------------------------===//
78 // Pattern Population
79 //===----------------------------------------------------------------------===//
80 
81 void mlir::ub::populateUBToSPIRVConversionPatterns(
82     const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
83   patterns.add<PoisonOpLowering>(converter, patterns.getContext());
84 }
85