xref: /llvm-project/mlir/lib/Conversion/SPIRVCommon/Pattern.h (revision 72e8b286f03c7f6bacbec10ca9883f77d482284c)
1a54f4eaeSMogball //===- Pattern.h - SPIRV Common Conversion Patterns -----------------------===//
2a54f4eaeSMogball //
3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information.
5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a54f4eaeSMogball //
7a54f4eaeSMogball //===----------------------------------------------------------------------===//
8a54f4eaeSMogball 
9a54f4eaeSMogball #ifndef MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H
10a54f4eaeSMogball #define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H
11a54f4eaeSMogball 
12a54f4eaeSMogball #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
13*72e8b286SQuinn Dawkins #include "mlir/IR/TypeUtilities.h"
14a54f4eaeSMogball #include "mlir/Transforms/DialectConversion.h"
15179978d7SJakub Kuderski #include "llvm/Support/FormatVariadic.h"
16a54f4eaeSMogball 
17a54f4eaeSMogball namespace mlir {
18a54f4eaeSMogball namespace spirv {
19a54f4eaeSMogball 
20d9edc1a5SThomas Raoux /// Converts elementwise unary, binary and ternary standard operations to SPIR-V
21d9edc1a5SThomas Raoux /// operations.
22a54f4eaeSMogball template <typename Op, typename SPIRVOp>
237f7e33c2SJakub Kuderski struct ElementwiseOpPattern : public OpConversionPattern<Op> {
24a54f4eaeSMogball   using OpConversionPattern<Op>::OpConversionPattern;
25a54f4eaeSMogball 
26a54f4eaeSMogball   LogicalResult
matchAndRewriteElementwiseOpPattern27a54f4eaeSMogball   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
28a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override {
29d9edc1a5SThomas Raoux     assert(adaptor.getOperands().size() <= 3);
30179978d7SJakub Kuderski     Type dstType = this->getTypeConverter()->convertType(op.getType());
31179978d7SJakub Kuderski     if (!dstType) {
32179978d7SJakub Kuderski       return rewriter.notifyMatchFailure(
33179978d7SJakub Kuderski           op->getLoc(),
34179978d7SJakub Kuderski           llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
35179978d7SJakub Kuderski     }
36179978d7SJakub Kuderski 
37a54f4eaeSMogball     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
38*72e8b286SQuinn Dawkins         !getElementTypeOrSelf(op.getType()).isIndex() &&
39*72e8b286SQuinn Dawkins         dstType != op.getType()) {
40*72e8b286SQuinn Dawkins       op.dump();
41*72e8b286SQuinn Dawkins       return op.emitError("bitwidth emulation is not implemented yet on "
42*72e8b286SQuinn Dawkins                           "unsigned op pattern version");
43a54f4eaeSMogball     }
44a54f4eaeSMogball     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
45a54f4eaeSMogball                                                   adaptor.getOperands());
46a54f4eaeSMogball     return success();
47a54f4eaeSMogball   }
48a54f4eaeSMogball };
49a54f4eaeSMogball 
50be0a7e9fSMehdi Amini } // namespace spirv
51be0a7e9fSMehdi Amini } // namespace mlir
52a54f4eaeSMogball 
53a54f4eaeSMogball #endif // MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H
54