1 //===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Complex dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h" 14 #include "mlir/Dialect/Complex/IR/Complex.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "complex-to-spirv-pattern" 22 23 using namespace mlir; 24 25 //===----------------------------------------------------------------------===// 26 // Operation conversion 27 //===----------------------------------------------------------------------===// 28 29 namespace { 30 31 struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> { 32 using OpConversionPattern::OpConversionPattern; 33 34 LogicalResult 35 matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor, 36 ConversionPatternRewriter &rewriter) const override { 37 auto spirvType = 38 getTypeConverter()->convertType<ShapedType>(constOp.getType()); 39 if (!spirvType) 40 return rewriter.notifyMatchFailure(constOp, 41 "unable to convert result type"); 42 43 rewriter.replaceOpWithNewOp<spirv::ConstantOp>( 44 constOp, spirvType, 45 DenseElementsAttr::get(spirvType, constOp.getValue().getValue())); 46 return success(); 47 } 48 }; 49 50 struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> { 51 using OpConversionPattern::OpConversionPattern; 52 53 LogicalResult 54 matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor, 55 ConversionPatternRewriter &rewriter) const override { 56 Type spirvType = getTypeConverter()->convertType(createOp.getType()); 57 if (!spirvType) 58 return rewriter.notifyMatchFailure(createOp, 59 "unable to convert result type"); 60 61 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 62 createOp, spirvType, adaptor.getOperands()); 63 return success(); 64 } 65 }; 66 67 struct ReOpPattern final : OpConversionPattern<complex::ReOp> { 68 using OpConversionPattern::OpConversionPattern; 69 70 LogicalResult 71 matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor, 72 ConversionPatternRewriter &rewriter) const override { 73 Type spirvType = getTypeConverter()->convertType(reOp.getType()); 74 if (!spirvType) 75 return rewriter.notifyMatchFailure(reOp, "unable to convert result type"); 76 77 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 78 reOp, adaptor.getComplex(), llvm::ArrayRef(0)); 79 return success(); 80 } 81 }; 82 83 struct ImOpPattern final : OpConversionPattern<complex::ImOp> { 84 using OpConversionPattern::OpConversionPattern; 85 86 LogicalResult 87 matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor, 88 ConversionPatternRewriter &rewriter) const override { 89 Type spirvType = getTypeConverter()->convertType(imOp.getType()); 90 if (!spirvType) 91 return rewriter.notifyMatchFailure(imOp, "unable to convert result type"); 92 93 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 94 imOp, adaptor.getComplex(), llvm::ArrayRef(1)); 95 return success(); 96 } 97 }; 98 99 } // namespace 100 101 //===----------------------------------------------------------------------===// 102 // Pattern population 103 //===----------------------------------------------------------------------===// 104 105 void mlir::populateComplexToSPIRVPatterns( 106 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 107 MLIRContext *context = patterns.getContext(); 108 109 patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>( 110 typeConverter, context); 111 } 112