xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp (revision 1d5e3b2d6559a853c544099e4cf1d46f44f83368)
1 //===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops  ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the Integer Dot Product operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir::spirv::AttrNames;
21 
22 namespace mlir::spirv {
23 
24 //===----------------------------------------------------------------------===//
25 // Integer Dot Product ops
26 //===----------------------------------------------------------------------===//
27 
28 template <typename IntegerDotProductOpTy>
verifyIntegerDotProduct(Operation * op)29 static LogicalResult verifyIntegerDotProduct(Operation *op) {
30   assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
31          "Not an integer dot product op?");
32   assert(op->getNumResults() == 1 && "Expected a single result");
33 
34   // ODS enforces that vector 1 and vector 2, and result and the accumulator
35   // have the same types.
36   Type factorTy = op->getOperand(0).getType();
37   StringAttr packedVectorFormatAttrName =
38       IntegerDotProductOpTy::getFormatAttrName(op->getName());
39   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
40     auto packedVectorFormat =
41         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
42             op->getAttr(packedVectorFormatAttrName));
43     if (!packedVectorFormat)
44       return op->emitOpError("requires Packed Vector Format attribute for "
45                              "integer vector operands");
46 
47     assert(packedVectorFormat.getValue() ==
48                spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
49            "Unknown Packed Vector Format");
50     if (intTy.getWidth() != 32)
51       return op->emitOpError(
52           llvm::formatv("with specified Packed Vector Format ({0}) requires "
53                         "integer vector operands to be 32-bits wide",
54                         packedVectorFormat.getValue()));
55   } else {
56     if (op->hasAttr(packedVectorFormatAttrName))
57       return op->emitOpError(llvm::formatv(
58           "with invalid format attribute for vector operands of type '{0}'",
59           factorTy));
60   }
61 
62   Type resultTy = op->getResultTypes().front();
63   unsigned factorBitWidth = getBitWidth(factorTy);
64   unsigned resultBitWidth = getBitWidth(resultTy);
65   if (factorBitWidth > resultBitWidth)
66     return op->emitOpError(
67         llvm::formatv("result type has insufficient bit-width ({0} bits) "
68                       "for the specified vector operand type ({1} bits)",
69                       resultBitWidth, factorBitWidth));
70 
71   return success();
72 }
73 
getIntegerDotProductMinVersion()74 static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
75   return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
76 }
77 
getIntegerDotProductMaxVersion()78 static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
79   return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
80 }
81 
82 static SmallVector<ArrayRef<spirv::Extension>, 1>
getIntegerDotProductExtensions()83 getIntegerDotProductExtensions() {
84   // Requires the SPV_KHR_integer_dot_product extension, specified either
85   // explicitly or implied by target env's SPIR-V version >= 1.6.
86   static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
87   return {extension};
88 }
89 
90 template <typename IntegerDotProductOpTy>
91 static SmallVector<ArrayRef<spirv::Capability>, 1>
getIntegerDotProductCapabilities(Operation * op)92 getIntegerDotProductCapabilities(Operation *op) {
93   // Requires the the DotProduct capability and capabilities that depend on
94   // exact op types.
95   static const auto dotProductCap = spirv::Capability::DotProduct;
96   static const auto dotProductInput4x8BitPackedCap =
97       spirv::Capability::DotProductInput4x8BitPacked;
98   static const auto dotProductInput4x8BitCap =
99       spirv::Capability::DotProductInput4x8Bit;
100   static const auto dotProductInputAllCap =
101       spirv::Capability::DotProductInputAll;
102 
103   SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
104 
105   Type factorTy = op->getOperand(0).getType();
106   StringAttr packedVectorFormatAttrName =
107       IntegerDotProductOpTy::getFormatAttrName(op->getName());
108   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
109     auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
110         op->getAttr(packedVectorFormatAttrName));
111     if (formatAttr.getValue() ==
112         spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
113       capabilities.push_back(dotProductInput4x8BitPackedCap);
114 
115     return capabilities;
116   }
117 
118   auto vecTy = llvm::cast<VectorType>(factorTy);
119   if (vecTy.getElementTypeBitWidth() == 8) {
120     capabilities.push_back(dotProductInput4x8BitCap);
121     return capabilities;
122   }
123 
124   capabilities.push_back(dotProductInputAllCap);
125   return capabilities;
126 }
127 
128 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
129   LogicalResult OpName::verify() {                                             \
130     return verifyIntegerDotProduct<OpName>(*this);                             \
131   }                                                                            \
132   SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
133     return getIntegerDotProductExtensions();                                   \
134   }                                                                            \
135   SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() {      \
136     return getIntegerDotProductCapabilities<OpName>(*this);                    \
137   }                                                                            \
138   std::optional<spirv::Version> OpName::getMinVersion() {                      \
139     return getIntegerDotProductMinVersion();                                   \
140   }                                                                            \
141   std::optional<spirv::Version> OpName::getMaxVersion() {                      \
142     return getIntegerDotProductMaxVersion();                                   \
143   }
144 
145 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
146 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp)
147 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp)
148 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp)
149 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp)
150 SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp)
151 
152 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
153 
154 } // namespace mlir::spirv
155