13ba66435SRiver Riddle //===- FuncToSPIRV.cpp - Func to SPIR-V Patterns ------------------===// 23ba66435SRiver Riddle // 33ba66435SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43ba66435SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53ba66435SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63ba66435SRiver Riddle // 73ba66435SRiver Riddle //===----------------------------------------------------------------------===// 83ba66435SRiver Riddle // 93ba66435SRiver Riddle // This file implements patterns to convert Func dialect to SPIR-V dialect. 103ba66435SRiver Riddle // 113ba66435SRiver Riddle //===----------------------------------------------------------------------===// 123ba66435SRiver Riddle 133ba66435SRiver Riddle #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 143ba66435SRiver Riddle #include "../SPIRVCommon/Pattern.h" 153ba66435SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 163ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 173ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 183ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 193ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 203ba66435SRiver Riddle #include "mlir/IR/AffineMap.h" 213ba66435SRiver Riddle #include "llvm/Support/Debug.h" 223ba66435SRiver Riddle 233ba66435SRiver Riddle #define DEBUG_TYPE "func-to-spirv-pattern" 243ba66435SRiver Riddle 253ba66435SRiver Riddle using namespace mlir; 263ba66435SRiver Riddle 273ba66435SRiver Riddle //===----------------------------------------------------------------------===// 283ba66435SRiver Riddle // Operation conversion 293ba66435SRiver Riddle //===----------------------------------------------------------------------===// 303ba66435SRiver Riddle 313ba66435SRiver Riddle // Note that DRR cannot be used for the patterns in this file: we may need to 323ba66435SRiver Riddle // convert type along the way, which requires ConversionPattern. DRR generates 333ba66435SRiver Riddle // normal RewritePattern. 343ba66435SRiver Riddle 353ba66435SRiver Riddle namespace { 363ba66435SRiver Riddle 375ab6ef75SJakub Kuderski /// Converts func.return to spirv.Return. 383ba66435SRiver Riddle class ReturnOpPattern final : public OpConversionPattern<func::ReturnOp> { 393ba66435SRiver Riddle public: 403ba66435SRiver Riddle using OpConversionPattern<func::ReturnOp>::OpConversionPattern; 413ba66435SRiver Riddle 423ba66435SRiver Riddle LogicalResult 433ba66435SRiver Riddle matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor, 443ba66435SRiver Riddle ConversionPatternRewriter &rewriter) const override { 453ba66435SRiver Riddle if (returnOp.getNumOperands() > 1) 463ba66435SRiver Riddle return failure(); 473ba66435SRiver Riddle 483ba66435SRiver Riddle if (returnOp.getNumOperands() == 1) { 493ba66435SRiver Riddle rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>( 503ba66435SRiver Riddle returnOp, adaptor.getOperands()[0]); 513ba66435SRiver Riddle } else { 523ba66435SRiver Riddle rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp); 533ba66435SRiver Riddle } 543ba66435SRiver Riddle return success(); 553ba66435SRiver Riddle } 563ba66435SRiver Riddle }; 573ba66435SRiver Riddle 585ab6ef75SJakub Kuderski /// Converts func.call to spirv.FunctionCall. 59c0ccb692Sxndcn class CallOpPattern final : public OpConversionPattern<func::CallOp> { 60c0ccb692Sxndcn public: 61c0ccb692Sxndcn using OpConversionPattern<func::CallOp>::OpConversionPattern; 62c0ccb692Sxndcn 63c0ccb692Sxndcn LogicalResult 64c0ccb692Sxndcn matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, 65c0ccb692Sxndcn ConversionPatternRewriter &rewriter) const override { 665ab6ef75SJakub Kuderski // multiple results func was not converted to spirv.func 67c0ccb692Sxndcn if (callOp.getNumResults() > 1) 68c0ccb692Sxndcn return failure(); 69c0ccb692Sxndcn if (callOp.getNumResults() == 1) { 70c0ccb692Sxndcn auto resultType = 71c0ccb692Sxndcn getTypeConverter()->convertType(callOp.getResult(0).getType()); 72c0ccb692Sxndcn if (!resultType) 73c0ccb692Sxndcn return failure(); 74c0ccb692Sxndcn rewriter.replaceOpWithNewOp<spirv::FunctionCallOp>( 75c0ccb692Sxndcn callOp, resultType, adaptor.getOperands(), callOp->getAttrs()); 76c0ccb692Sxndcn } else { 77c0ccb692Sxndcn rewriter.replaceOpWithNewOp<spirv::FunctionCallOp>( 78c0ccb692Sxndcn callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs()); 79c0ccb692Sxndcn } 80c0ccb692Sxndcn return success(); 81c0ccb692Sxndcn } 82c0ccb692Sxndcn }; 83c0ccb692Sxndcn 843ba66435SRiver Riddle } // namespace 853ba66435SRiver Riddle 863ba66435SRiver Riddle //===----------------------------------------------------------------------===// 873ba66435SRiver Riddle // Pattern population 883ba66435SRiver Riddle //===----------------------------------------------------------------------===// 893ba66435SRiver Riddle 90*206fad0eSMatthias Springer void mlir::populateFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, 913ba66435SRiver Riddle RewritePatternSet &patterns) { 923ba66435SRiver Riddle MLIRContext *context = patterns.getContext(); 933ba66435SRiver Riddle 94c0ccb692Sxndcn patterns.add<ReturnOpPattern, CallOpPattern>(typeConverter, context); 953ba66435SRiver Riddle } 96