1c312ce0aSLei Zhang //===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===// 2c312ce0aSLei Zhang // 3c312ce0aSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4c312ce0aSLei Zhang // See https://llvm.org/LICENSE.txt for license information. 5c312ce0aSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6c312ce0aSLei Zhang // 7c312ce0aSLei Zhang //===----------------------------------------------------------------------===// 8c312ce0aSLei Zhang // 9c312ce0aSLei Zhang // This file implements patterns to convert Complex dialect to SPIR-V dialect. 10c312ce0aSLei Zhang // 11c312ce0aSLei Zhang //===----------------------------------------------------------------------===// 12c312ce0aSLei Zhang 13c312ce0aSLei Zhang #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h" 14c312ce0aSLei Zhang #include "mlir/Dialect/Complex/IR/Complex.h" 15c312ce0aSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16c312ce0aSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17c312ce0aSLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18c312ce0aSLei Zhang #include "mlir/Transforms/DialectConversion.h" 19c312ce0aSLei Zhang #include "llvm/Support/Debug.h" 20c312ce0aSLei Zhang 21c312ce0aSLei Zhang #define DEBUG_TYPE "complex-to-spirv-pattern" 22c312ce0aSLei Zhang 23c312ce0aSLei Zhang using namespace mlir; 24c312ce0aSLei Zhang 25c312ce0aSLei Zhang //===----------------------------------------------------------------------===// 26c312ce0aSLei Zhang // Operation conversion 27c312ce0aSLei Zhang //===----------------------------------------------------------------------===// 28c312ce0aSLei Zhang 29c312ce0aSLei Zhang namespace { 30c312ce0aSLei Zhang 3152aaac63SLei Zhang struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> { 3252aaac63SLei Zhang using OpConversionPattern::OpConversionPattern; 3352aaac63SLei Zhang 3452aaac63SLei Zhang LogicalResult 3552aaac63SLei Zhang matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor, 3652aaac63SLei Zhang ConversionPatternRewriter &rewriter) const override { 3752aaac63SLei Zhang auto spirvType = 3852aaac63SLei Zhang getTypeConverter()->convertType<ShapedType>(constOp.getType()); 3952aaac63SLei Zhang if (!spirvType) 4052aaac63SLei Zhang return rewriter.notifyMatchFailure(constOp, 4152aaac63SLei Zhang "unable to convert result type"); 4252aaac63SLei Zhang 4352aaac63SLei Zhang rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 4452aaac63SLei Zhang constOp, spirvType, 4552aaac63SLei Zhang DenseElementsAttr::get(spirvType, constOp.getValue().getValue())); 4652aaac63SLei Zhang return success(); 4752aaac63SLei Zhang } 4852aaac63SLei Zhang }; 4952aaac63SLei Zhang 50c312ce0aSLei Zhang struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> { 51c312ce0aSLei Zhang using OpConversionPattern::OpConversionPattern; 52c312ce0aSLei Zhang 53c312ce0aSLei Zhang LogicalResult 54c312ce0aSLei Zhang matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor, 55c312ce0aSLei Zhang ConversionPatternRewriter &rewriter) const override { 56c312ce0aSLei Zhang Type spirvType = getTypeConverter()->convertType(createOp.getType()); 57c312ce0aSLei Zhang if (!spirvType) 58c312ce0aSLei Zhang return rewriter.notifyMatchFailure(createOp, 59c312ce0aSLei Zhang "unable to convert result type"); 60c312ce0aSLei Zhang 61c312ce0aSLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 62c312ce0aSLei Zhang createOp, spirvType, adaptor.getOperands()); 63c312ce0aSLei Zhang return success(); 64c312ce0aSLei Zhang } 65c312ce0aSLei Zhang }; 66c312ce0aSLei Zhang 67c312ce0aSLei Zhang struct ReOpPattern final : OpConversionPattern<complex::ReOp> { 68c312ce0aSLei Zhang using OpConversionPattern::OpConversionPattern; 69c312ce0aSLei Zhang 70c312ce0aSLei Zhang LogicalResult 71c312ce0aSLei Zhang matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor, 72c312ce0aSLei Zhang ConversionPatternRewriter &rewriter) const override { 73c312ce0aSLei Zhang Type spirvType = getTypeConverter()->convertType(reOp.getType()); 74c312ce0aSLei Zhang if (!spirvType) 75c312ce0aSLei Zhang return rewriter.notifyMatchFailure(reOp, "unable to convert result type"); 76c312ce0aSLei Zhang 77c312ce0aSLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 78c312ce0aSLei Zhang reOp, adaptor.getComplex(), llvm::ArrayRef(0)); 79c312ce0aSLei Zhang return success(); 80c312ce0aSLei Zhang } 81c312ce0aSLei Zhang }; 82c312ce0aSLei Zhang 83c312ce0aSLei Zhang struct ImOpPattern final : OpConversionPattern<complex::ImOp> { 84c312ce0aSLei Zhang using OpConversionPattern::OpConversionPattern; 85c312ce0aSLei Zhang 86c312ce0aSLei Zhang LogicalResult 87c312ce0aSLei Zhang matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor, 88c312ce0aSLei Zhang ConversionPatternRewriter &rewriter) const override { 89c312ce0aSLei Zhang Type spirvType = getTypeConverter()->convertType(imOp.getType()); 90c312ce0aSLei Zhang if (!spirvType) 91c312ce0aSLei Zhang return rewriter.notifyMatchFailure(imOp, "unable to convert result type"); 92c312ce0aSLei Zhang 93c312ce0aSLei Zhang rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 94c312ce0aSLei Zhang imOp, adaptor.getComplex(), llvm::ArrayRef(1)); 95c312ce0aSLei Zhang return success(); 96c312ce0aSLei Zhang } 97c312ce0aSLei Zhang }; 98c312ce0aSLei Zhang 99c312ce0aSLei Zhang } // namespace 100c312ce0aSLei Zhang 101c312ce0aSLei Zhang //===----------------------------------------------------------------------===// 102c312ce0aSLei Zhang // Pattern population 103c312ce0aSLei Zhang //===----------------------------------------------------------------------===// 104c312ce0aSLei Zhang 105*206fad0eSMatthias Springer void mlir::populateComplexToSPIRVPatterns( 106*206fad0eSMatthias Springer const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 107c312ce0aSLei Zhang MLIRContext *context = patterns.getContext(); 108c312ce0aSLei Zhang 10952aaac63SLei Zhang patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>( 11052aaac63SLei Zhang typeConverter, context); 111c312ce0aSLei Zhang } 112