xref: /llvm-project/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Complex dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
14 #include "mlir/Dialect/Complex/IR/Complex.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "complex-to-spirv-pattern"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Operation conversion
27 //===----------------------------------------------------------------------===//
28 
29 namespace {
30 
31 struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
32   using OpConversionPattern::OpConversionPattern;
33 
34   LogicalResult
35   matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
36                   ConversionPatternRewriter &rewriter) const override {
37     auto spirvType =
38         getTypeConverter()->convertType<ShapedType>(constOp.getType());
39     if (!spirvType)
40       return rewriter.notifyMatchFailure(constOp,
41                                          "unable to convert result type");
42 
43     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
44         constOp, spirvType,
45         DenseElementsAttr::get(spirvType, constOp.getValue().getValue()));
46     return success();
47   }
48 };
49 
50 struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
51   using OpConversionPattern::OpConversionPattern;
52 
53   LogicalResult
54   matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
55                   ConversionPatternRewriter &rewriter) const override {
56     Type spirvType = getTypeConverter()->convertType(createOp.getType());
57     if (!spirvType)
58       return rewriter.notifyMatchFailure(createOp,
59                                          "unable to convert result type");
60 
61     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
62         createOp, spirvType, adaptor.getOperands());
63     return success();
64   }
65 };
66 
67 struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
68   using OpConversionPattern::OpConversionPattern;
69 
70   LogicalResult
71   matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
72                   ConversionPatternRewriter &rewriter) const override {
73     Type spirvType = getTypeConverter()->convertType(reOp.getType());
74     if (!spirvType)
75       return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
76 
77     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
78         reOp, adaptor.getComplex(), llvm::ArrayRef(0));
79     return success();
80   }
81 };
82 
83 struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
84   using OpConversionPattern::OpConversionPattern;
85 
86   LogicalResult
87   matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const override {
89     Type spirvType = getTypeConverter()->convertType(imOp.getType());
90     if (!spirvType)
91       return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
92 
93     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
94         imOp, adaptor.getComplex(), llvm::ArrayRef(1));
95     return success();
96   }
97 };
98 
99 } // namespace
100 
101 //===----------------------------------------------------------------------===//
102 // Pattern population
103 //===----------------------------------------------------------------------===//
104 
105 void mlir::populateComplexToSPIRVPatterns(
106     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
107   MLIRContext *context = patterns.getContext();
108 
109   patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
110       typeConverter, context);
111 }
112