xref: /llvm-project/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
13ba66435SRiver Riddle //===- FuncToSPIRV.cpp - Func to SPIR-V Patterns ------------------===//
23ba66435SRiver Riddle //
33ba66435SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43ba66435SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53ba66435SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63ba66435SRiver Riddle //
73ba66435SRiver Riddle //===----------------------------------------------------------------------===//
83ba66435SRiver Riddle //
93ba66435SRiver Riddle // This file implements patterns to convert Func dialect to SPIR-V dialect.
103ba66435SRiver Riddle //
113ba66435SRiver Riddle //===----------------------------------------------------------------------===//
123ba66435SRiver Riddle 
133ba66435SRiver Riddle #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
143ba66435SRiver Riddle #include "../SPIRVCommon/Pattern.h"
153ba66435SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
163ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
173ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
183ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
193ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
203ba66435SRiver Riddle #include "mlir/IR/AffineMap.h"
213ba66435SRiver Riddle #include "llvm/Support/Debug.h"
223ba66435SRiver Riddle 
233ba66435SRiver Riddle #define DEBUG_TYPE "func-to-spirv-pattern"
243ba66435SRiver Riddle 
253ba66435SRiver Riddle using namespace mlir;
263ba66435SRiver Riddle 
273ba66435SRiver Riddle //===----------------------------------------------------------------------===//
283ba66435SRiver Riddle // Operation conversion
293ba66435SRiver Riddle //===----------------------------------------------------------------------===//
303ba66435SRiver Riddle 
313ba66435SRiver Riddle // Note that DRR cannot be used for the patterns in this file: we may need to
323ba66435SRiver Riddle // convert type along the way, which requires ConversionPattern. DRR generates
333ba66435SRiver Riddle // normal RewritePattern.
343ba66435SRiver Riddle 
353ba66435SRiver Riddle namespace {
363ba66435SRiver Riddle 
375ab6ef75SJakub Kuderski /// Converts func.return to spirv.Return.
383ba66435SRiver Riddle class ReturnOpPattern final : public OpConversionPattern<func::ReturnOp> {
393ba66435SRiver Riddle public:
403ba66435SRiver Riddle   using OpConversionPattern<func::ReturnOp>::OpConversionPattern;
413ba66435SRiver Riddle 
423ba66435SRiver Riddle   LogicalResult
433ba66435SRiver Riddle   matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
443ba66435SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
453ba66435SRiver Riddle     if (returnOp.getNumOperands() > 1)
463ba66435SRiver Riddle       return failure();
473ba66435SRiver Riddle 
483ba66435SRiver Riddle     if (returnOp.getNumOperands() == 1) {
493ba66435SRiver Riddle       rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(
503ba66435SRiver Riddle           returnOp, adaptor.getOperands()[0]);
513ba66435SRiver Riddle     } else {
523ba66435SRiver Riddle       rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
533ba66435SRiver Riddle     }
543ba66435SRiver Riddle     return success();
553ba66435SRiver Riddle   }
563ba66435SRiver Riddle };
573ba66435SRiver Riddle 
585ab6ef75SJakub Kuderski /// Converts func.call to spirv.FunctionCall.
59c0ccb692Sxndcn class CallOpPattern final : public OpConversionPattern<func::CallOp> {
60c0ccb692Sxndcn public:
61c0ccb692Sxndcn   using OpConversionPattern<func::CallOp>::OpConversionPattern;
62c0ccb692Sxndcn 
63c0ccb692Sxndcn   LogicalResult
64c0ccb692Sxndcn   matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
65c0ccb692Sxndcn                   ConversionPatternRewriter &rewriter) const override {
665ab6ef75SJakub Kuderski     // multiple results func was not converted to spirv.func
67c0ccb692Sxndcn     if (callOp.getNumResults() > 1)
68c0ccb692Sxndcn       return failure();
69c0ccb692Sxndcn     if (callOp.getNumResults() == 1) {
70c0ccb692Sxndcn       auto resultType =
71c0ccb692Sxndcn           getTypeConverter()->convertType(callOp.getResult(0).getType());
72c0ccb692Sxndcn       if (!resultType)
73c0ccb692Sxndcn         return failure();
74c0ccb692Sxndcn       rewriter.replaceOpWithNewOp<spirv::FunctionCallOp>(
75c0ccb692Sxndcn           callOp, resultType, adaptor.getOperands(), callOp->getAttrs());
76c0ccb692Sxndcn     } else {
77c0ccb692Sxndcn       rewriter.replaceOpWithNewOp<spirv::FunctionCallOp>(
78c0ccb692Sxndcn           callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs());
79c0ccb692Sxndcn     }
80c0ccb692Sxndcn     return success();
81c0ccb692Sxndcn   }
82c0ccb692Sxndcn };
83c0ccb692Sxndcn 
843ba66435SRiver Riddle } // namespace
853ba66435SRiver Riddle 
863ba66435SRiver Riddle //===----------------------------------------------------------------------===//
873ba66435SRiver Riddle // Pattern population
883ba66435SRiver Riddle //===----------------------------------------------------------------------===//
893ba66435SRiver Riddle 
90*206fad0eSMatthias Springer void mlir::populateFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
913ba66435SRiver Riddle                                        RewritePatternSet &patterns) {
923ba66435SRiver Riddle   MLIRContext *context = patterns.getContext();
933ba66435SRiver Riddle 
94c0ccb692Sxndcn   patterns.add<ReturnOpPattern, CallOpPattern>(typeConverter, context);
953ba66435SRiver Riddle }
96