xref: /llvm-project/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1c312ce0aSLei Zhang //===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
2c312ce0aSLei Zhang //
3c312ce0aSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c312ce0aSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
5c312ce0aSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c312ce0aSLei Zhang //
7c312ce0aSLei Zhang //===----------------------------------------------------------------------===//
8c312ce0aSLei Zhang //
9c312ce0aSLei Zhang // This file implements patterns to convert Complex dialect to SPIR-V dialect.
10c312ce0aSLei Zhang //
11c312ce0aSLei Zhang //===----------------------------------------------------------------------===//
12c312ce0aSLei Zhang 
13c312ce0aSLei Zhang #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
14c312ce0aSLei Zhang #include "mlir/Dialect/Complex/IR/Complex.h"
15c312ce0aSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16c312ce0aSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17c312ce0aSLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18c312ce0aSLei Zhang #include "mlir/Transforms/DialectConversion.h"
19c312ce0aSLei Zhang #include "llvm/Support/Debug.h"
20c312ce0aSLei Zhang 
21c312ce0aSLei Zhang #define DEBUG_TYPE "complex-to-spirv-pattern"
22c312ce0aSLei Zhang 
23c312ce0aSLei Zhang using namespace mlir;
24c312ce0aSLei Zhang 
25c312ce0aSLei Zhang //===----------------------------------------------------------------------===//
26c312ce0aSLei Zhang // Operation conversion
27c312ce0aSLei Zhang //===----------------------------------------------------------------------===//
28c312ce0aSLei Zhang 
29c312ce0aSLei Zhang namespace {
30c312ce0aSLei Zhang 
3152aaac63SLei Zhang struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
3252aaac63SLei Zhang   using OpConversionPattern::OpConversionPattern;
3352aaac63SLei Zhang 
3452aaac63SLei Zhang   LogicalResult
3552aaac63SLei Zhang   matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
3652aaac63SLei Zhang                   ConversionPatternRewriter &rewriter) const override {
3752aaac63SLei Zhang     auto spirvType =
3852aaac63SLei Zhang         getTypeConverter()->convertType<ShapedType>(constOp.getType());
3952aaac63SLei Zhang     if (!spirvType)
4052aaac63SLei Zhang       return rewriter.notifyMatchFailure(constOp,
4152aaac63SLei Zhang                                          "unable to convert result type");
4252aaac63SLei Zhang 
4352aaac63SLei Zhang     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
4452aaac63SLei Zhang         constOp, spirvType,
4552aaac63SLei Zhang         DenseElementsAttr::get(spirvType, constOp.getValue().getValue()));
4652aaac63SLei Zhang     return success();
4752aaac63SLei Zhang   }
4852aaac63SLei Zhang };
4952aaac63SLei Zhang 
50c312ce0aSLei Zhang struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
51c312ce0aSLei Zhang   using OpConversionPattern::OpConversionPattern;
52c312ce0aSLei Zhang 
53c312ce0aSLei Zhang   LogicalResult
54c312ce0aSLei Zhang   matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
55c312ce0aSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
56c312ce0aSLei Zhang     Type spirvType = getTypeConverter()->convertType(createOp.getType());
57c312ce0aSLei Zhang     if (!spirvType)
58c312ce0aSLei Zhang       return rewriter.notifyMatchFailure(createOp,
59c312ce0aSLei Zhang                                          "unable to convert result type");
60c312ce0aSLei Zhang 
61c312ce0aSLei Zhang     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
62c312ce0aSLei Zhang         createOp, spirvType, adaptor.getOperands());
63c312ce0aSLei Zhang     return success();
64c312ce0aSLei Zhang   }
65c312ce0aSLei Zhang };
66c312ce0aSLei Zhang 
67c312ce0aSLei Zhang struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
68c312ce0aSLei Zhang   using OpConversionPattern::OpConversionPattern;
69c312ce0aSLei Zhang 
70c312ce0aSLei Zhang   LogicalResult
71c312ce0aSLei Zhang   matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
72c312ce0aSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
73c312ce0aSLei Zhang     Type spirvType = getTypeConverter()->convertType(reOp.getType());
74c312ce0aSLei Zhang     if (!spirvType)
75c312ce0aSLei Zhang       return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
76c312ce0aSLei Zhang 
77c312ce0aSLei Zhang     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
78c312ce0aSLei Zhang         reOp, adaptor.getComplex(), llvm::ArrayRef(0));
79c312ce0aSLei Zhang     return success();
80c312ce0aSLei Zhang   }
81c312ce0aSLei Zhang };
82c312ce0aSLei Zhang 
83c312ce0aSLei Zhang struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
84c312ce0aSLei Zhang   using OpConversionPattern::OpConversionPattern;
85c312ce0aSLei Zhang 
86c312ce0aSLei Zhang   LogicalResult
87c312ce0aSLei Zhang   matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
88c312ce0aSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
89c312ce0aSLei Zhang     Type spirvType = getTypeConverter()->convertType(imOp.getType());
90c312ce0aSLei Zhang     if (!spirvType)
91c312ce0aSLei Zhang       return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
92c312ce0aSLei Zhang 
93c312ce0aSLei Zhang     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
94c312ce0aSLei Zhang         imOp, adaptor.getComplex(), llvm::ArrayRef(1));
95c312ce0aSLei Zhang     return success();
96c312ce0aSLei Zhang   }
97c312ce0aSLei Zhang };
98c312ce0aSLei Zhang 
99c312ce0aSLei Zhang } // namespace
100c312ce0aSLei Zhang 
101c312ce0aSLei Zhang //===----------------------------------------------------------------------===//
102c312ce0aSLei Zhang // Pattern population
103c312ce0aSLei Zhang //===----------------------------------------------------------------------===//
104c312ce0aSLei Zhang 
105*206fad0eSMatthias Springer void mlir::populateComplexToSPIRVPatterns(
106*206fad0eSMatthias Springer     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
107c312ce0aSLei Zhang   MLIRContext *context = patterns.getContext();
108c312ce0aSLei Zhang 
10952aaac63SLei Zhang   patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
11052aaac63SLei Zhang       typeConverter, context);
111c312ce0aSLei Zhang }
112