xref: /llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1f1f05a91SKrzysztof Drewniak //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2f1f05a91SKrzysztof Drewniak //
3f1f05a91SKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f1f05a91SKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information.
5f1f05a91SKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f1f05a91SKrzysztof Drewniak //
7f1f05a91SKrzysztof Drewniak //===----------------------------------------------------------------------===//
8f1f05a91SKrzysztof Drewniak //
9f1f05a91SKrzysztof Drewniak // This file implements the AMDGPU dialect and its operations.
10f1f05a91SKrzysztof Drewniak //
11f1f05a91SKrzysztof Drewniak //===----------------------------------------------------------------------===//
12f1f05a91SKrzysztof Drewniak 
13cc470374SKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14c55b41d5SKrzysztof Drewniak 
15d6abdf46SKrzysztof Drewniak #include "mlir/Dialect/Arith/IR/Arith.h"
16499abb24SKrzysztof Drewniak #include "mlir/Dialect/GPU/IR/GPUDialect.h"
171387ba48SGiuseppe Rossini #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18f1f05a91SKrzysztof Drewniak #include "mlir/IR/Builders.h"
19c55b41d5SKrzysztof Drewniak #include "mlir/IR/BuiltinTypes.h"
20c55b41d5SKrzysztof Drewniak #include "mlir/IR/Diagnostics.h"
21c55b41d5SKrzysztof Drewniak #include "mlir/IR/DialectImplementation.h"
22d6abdf46SKrzysztof Drewniak #include "mlir/IR/Matchers.h"
23f1f05a91SKrzysztof Drewniak #include "mlir/IR/OpImplementation.h"
24d6abdf46SKrzysztof Drewniak #include "mlir/IR/PatternMatch.h"
25f1f05a91SKrzysztof Drewniak #include "mlir/IR/TypeUtilities.h"
26c55b41d5SKrzysztof Drewniak #include "llvm/ADT/TypeSwitch.h"
27f1f05a91SKrzysztof Drewniak 
28d6abdf46SKrzysztof Drewniak #include <limits>
29a1fe1f5fSKazu Hirata #include <optional>
30d6abdf46SKrzysztof Drewniak 
31f1f05a91SKrzysztof Drewniak using namespace mlir;
32c55b41d5SKrzysztof Drewniak using namespace mlir::amdgpu;
33f1f05a91SKrzysztof Drewniak 
34cc470374SKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
35f1f05a91SKrzysztof Drewniak 
36c55b41d5SKrzysztof Drewniak void AMDGPUDialect::initialize() {
37f1f05a91SKrzysztof Drewniak   addOperations<
38f1f05a91SKrzysztof Drewniak #define GET_OP_LIST
39cc470374SKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
40f1f05a91SKrzysztof Drewniak       >();
41c55b41d5SKrzysztof Drewniak   addAttributes<
42c55b41d5SKrzysztof Drewniak #define GET_ATTRDEF_LIST
43cc470374SKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
44c55b41d5SKrzysztof Drewniak       >();
45f1f05a91SKrzysztof Drewniak }
46f1f05a91SKrzysztof Drewniak 
47f1f05a91SKrzysztof Drewniak //===----------------------------------------------------------------------===//
482ebd633fSKrzysztof Drewniak // 8-bit float ops
492ebd633fSKrzysztof Drewniak //===----------------------------------------------------------------------===//
502ebd633fSKrzysztof Drewniak LogicalResult PackedTrunc2xFp8Op::verify() {
512ebd633fSKrzysztof Drewniak   if (getExisting() && getExisting().getType() != getResult().getType())
522ebd633fSKrzysztof Drewniak     return emitOpError("existing values must have same type as result");
532ebd633fSKrzysztof Drewniak   return success();
542ebd633fSKrzysztof Drewniak }
552ebd633fSKrzysztof Drewniak 
562ebd633fSKrzysztof Drewniak LogicalResult PackedStochRoundFp8Op::verify() {
572ebd633fSKrzysztof Drewniak   if (getExisting() && getExisting().getType() != getResult().getType())
582ebd633fSKrzysztof Drewniak     return emitOpError("existing values must have same type as result");
592ebd633fSKrzysztof Drewniak   return success();
602ebd633fSKrzysztof Drewniak }
612ebd633fSKrzysztof Drewniak 
622ebd633fSKrzysztof Drewniak //===----------------------------------------------------------------------===//
63f1f05a91SKrzysztof Drewniak // RawBuffer*Op
64f1f05a91SKrzysztof Drewniak //===----------------------------------------------------------------------===//
65f1f05a91SKrzysztof Drewniak template <typename T>
66f1f05a91SKrzysztof Drewniak static LogicalResult verifyRawBufferOp(T &op) {
67c1fa60b4STres Popp   MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
68499abb24SKrzysztof Drewniak   Attribute memorySpace = bufferType.getMemorySpace();
69499abb24SKrzysztof Drewniak   bool isGlobal = false;
70499abb24SKrzysztof Drewniak   if (!memorySpace)
71499abb24SKrzysztof Drewniak     isGlobal = true;
72c1fa60b4STres Popp   else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
73499abb24SKrzysztof Drewniak     isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
74c1fa60b4STres Popp   else if (auto gpuMemorySpace =
75c1fa60b4STres Popp                llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
76499abb24SKrzysztof Drewniak     isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
77499abb24SKrzysztof Drewniak 
78499abb24SKrzysztof Drewniak   if (!isGlobal)
79f1f05a91SKrzysztof Drewniak     return op.emitOpError(
80f1f05a91SKrzysztof Drewniak         "Buffer ops must operate on a memref in global memory");
81f1f05a91SKrzysztof Drewniak   if (!bufferType.hasRank())
82f1f05a91SKrzysztof Drewniak     return op.emitOpError(
83f1f05a91SKrzysztof Drewniak         "Cannot meaningfully buffer_store to an unranked memref");
848df54a6aSJacques Pienaar   if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
85f1f05a91SKrzysztof Drewniak     return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
86f1f05a91SKrzysztof Drewniak                           " indices to memref");
87f1f05a91SKrzysztof Drewniak   return success();
88f1f05a91SKrzysztof Drewniak }
89f1f05a91SKrzysztof Drewniak 
90c55b41d5SKrzysztof Drewniak LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
91c55b41d5SKrzysztof Drewniak 
92c55b41d5SKrzysztof Drewniak LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
93c55b41d5SKrzysztof Drewniak 
94c55b41d5SKrzysztof Drewniak LogicalResult RawBufferAtomicFaddOp::verify() {
95f1f05a91SKrzysztof Drewniak   return verifyRawBufferOp(*this);
96f1f05a91SKrzysztof Drewniak }
97f1f05a91SKrzysztof Drewniak 
98584f6436SManupa Karunaratne LogicalResult RawBufferAtomicFmaxOp::verify() {
99584f6436SManupa Karunaratne   return verifyRawBufferOp(*this);
100584f6436SManupa Karunaratne }
101584f6436SManupa Karunaratne 
102584f6436SManupa Karunaratne LogicalResult RawBufferAtomicSmaxOp::verify() {
103584f6436SManupa Karunaratne   return verifyRawBufferOp(*this);
104584f6436SManupa Karunaratne }
105584f6436SManupa Karunaratne 
106584f6436SManupa Karunaratne LogicalResult RawBufferAtomicUminOp::verify() {
107584f6436SManupa Karunaratne   return verifyRawBufferOp(*this);
108584f6436SManupa Karunaratne }
109584f6436SManupa Karunaratne 
11098c1104dSKrzysztof Drewniak LogicalResult RawBufferAtomicCmpswapOp::verify() {
11198c1104dSKrzysztof Drewniak   return verifyRawBufferOp(*this);
11298c1104dSKrzysztof Drewniak }
11398c1104dSKrzysztof Drewniak 
1140a81ace0SKazu Hirata static std::optional<uint32_t> getConstantUint32(Value v) {
115d6abdf46SKrzysztof Drewniak   APInt cst;
116d6abdf46SKrzysztof Drewniak   if (!v.getType().isInteger(32))
1171a36588eSKazu Hirata     return std::nullopt;
118d6abdf46SKrzysztof Drewniak   if (matchPattern(v, m_ConstantInt(&cst)))
119d6abdf46SKrzysztof Drewniak     return cst.getZExtValue();
1201a36588eSKazu Hirata   return std::nullopt;
121d6abdf46SKrzysztof Drewniak }
122d6abdf46SKrzysztof Drewniak 
123d6abdf46SKrzysztof Drewniak template <typename OpType>
124d6abdf46SKrzysztof Drewniak static bool staticallyOutOfBounds(OpType op) {
125d6abdf46SKrzysztof Drewniak   if (!op.getBoundsCheck())
126d6abdf46SKrzysztof Drewniak     return false;
127d6abdf46SKrzysztof Drewniak   MemRefType bufferType = op.getMemref().getType();
128d6abdf46SKrzysztof Drewniak   if (!bufferType.hasStaticShape())
129d6abdf46SKrzysztof Drewniak     return false;
130d6abdf46SKrzysztof Drewniak   int64_t offset;
131d6abdf46SKrzysztof Drewniak   SmallVector<int64_t> strides;
132*6aaa8f25SMatthias Springer   if (failed(bufferType.getStridesAndOffset(strides, offset)))
133d6abdf46SKrzysztof Drewniak     return false;
134d6abdf46SKrzysztof Drewniak   int64_t result = offset + op.getIndexOffset().value_or(0);
135d6abdf46SKrzysztof Drewniak   if (op.getSgprOffset()) {
1360a81ace0SKazu Hirata     std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
137d6abdf46SKrzysztof Drewniak     if (!sgprOffset)
138d6abdf46SKrzysztof Drewniak       return false;
139d6abdf46SKrzysztof Drewniak     result += *sgprOffset;
140d6abdf46SKrzysztof Drewniak   }
141d6abdf46SKrzysztof Drewniak   if (strides.size() != op.getIndices().size())
142d6abdf46SKrzysztof Drewniak     return false;
143d6abdf46SKrzysztof Drewniak   int64_t indexVal = 0;
144d6abdf46SKrzysztof Drewniak   for (auto pair : llvm::zip(strides, op.getIndices())) {
145d6abdf46SKrzysztof Drewniak     int64_t stride = std::get<0>(pair);
146d6abdf46SKrzysztof Drewniak     Value idx = std::get<1>(pair);
1470a81ace0SKazu Hirata     std::optional<uint32_t> idxVal = getConstantUint32(idx);
148d6abdf46SKrzysztof Drewniak     if (!idxVal)
149d6abdf46SKrzysztof Drewniak       return false;
150cbb09813SFangrui Song     indexVal += stride * *idxVal;
151d6abdf46SKrzysztof Drewniak   }
152d6abdf46SKrzysztof Drewniak   result += indexVal;
153d6abdf46SKrzysztof Drewniak   if (result > std::numeric_limits<uint32_t>::max())
154d6abdf46SKrzysztof Drewniak     // Overflow means don't drop
155d6abdf46SKrzysztof Drewniak     return false;
156d6abdf46SKrzysztof Drewniak   return result >= bufferType.getNumElements();
157d6abdf46SKrzysztof Drewniak }
158d6abdf46SKrzysztof Drewniak 
159d6abdf46SKrzysztof Drewniak namespace {
16098c1104dSKrzysztof Drewniak template <typename OpType>
16198c1104dSKrzysztof Drewniak struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
16298c1104dSKrzysztof Drewniak   using OpRewritePattern<OpType>::OpRewritePattern;
163d6abdf46SKrzysztof Drewniak 
16498c1104dSKrzysztof Drewniak   LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
165d6abdf46SKrzysztof Drewniak     if (!staticallyOutOfBounds(op))
166d6abdf46SKrzysztof Drewniak       return failure();
167d6abdf46SKrzysztof Drewniak     Type loadType = op.getResult().getType();
168d6abdf46SKrzysztof Drewniak     rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
169d6abdf46SKrzysztof Drewniak                                              rw.getZeroAttr(loadType));
170d6abdf46SKrzysztof Drewniak     return success();
171d6abdf46SKrzysztof Drewniak   }
172d6abdf46SKrzysztof Drewniak };
173d6abdf46SKrzysztof Drewniak 
174d6abdf46SKrzysztof Drewniak template <typename OpType>
175d6abdf46SKrzysztof Drewniak struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
176d6abdf46SKrzysztof Drewniak   using OpRewritePattern<OpType>::OpRewritePattern;
177d6abdf46SKrzysztof Drewniak 
178d6abdf46SKrzysztof Drewniak   LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
179d6abdf46SKrzysztof Drewniak     if (!staticallyOutOfBounds(op))
180d6abdf46SKrzysztof Drewniak       return failure();
181d6abdf46SKrzysztof Drewniak 
182d6abdf46SKrzysztof Drewniak     rw.eraseOp(op);
183d6abdf46SKrzysztof Drewniak     return success();
184d6abdf46SKrzysztof Drewniak   }
185d6abdf46SKrzysztof Drewniak };
186d6abdf46SKrzysztof Drewniak } // end namespace
187d6abdf46SKrzysztof Drewniak 
188d6abdf46SKrzysztof Drewniak void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
189d6abdf46SKrzysztof Drewniak                                                   MLIRContext *context) {
19098c1104dSKrzysztof Drewniak   results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
191d6abdf46SKrzysztof Drewniak }
192d6abdf46SKrzysztof Drewniak 
193d6abdf46SKrzysztof Drewniak void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
194d6abdf46SKrzysztof Drewniak                                                    MLIRContext *context) {
195d6abdf46SKrzysztof Drewniak   results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
196d6abdf46SKrzysztof Drewniak }
197d6abdf46SKrzysztof Drewniak 
198d6abdf46SKrzysztof Drewniak void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
199d6abdf46SKrzysztof Drewniak     RewritePatternSet &results, MLIRContext *context) {
200d6abdf46SKrzysztof Drewniak   results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
201d6abdf46SKrzysztof Drewniak }
202d6abdf46SKrzysztof Drewniak 
203584f6436SManupa Karunaratne void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
204584f6436SManupa Karunaratne     RewritePatternSet &results, MLIRContext *context) {
205584f6436SManupa Karunaratne   results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
206584f6436SManupa Karunaratne }
207584f6436SManupa Karunaratne 
208584f6436SManupa Karunaratne void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
209584f6436SManupa Karunaratne     RewritePatternSet &results, MLIRContext *context) {
210584f6436SManupa Karunaratne   results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
211584f6436SManupa Karunaratne }
212584f6436SManupa Karunaratne 
213584f6436SManupa Karunaratne void RawBufferAtomicUminOp::getCanonicalizationPatterns(
214584f6436SManupa Karunaratne     RewritePatternSet &results, MLIRContext *context) {
215584f6436SManupa Karunaratne   results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
216584f6436SManupa Karunaratne }
217584f6436SManupa Karunaratne 
21898c1104dSKrzysztof Drewniak void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
21998c1104dSKrzysztof Drewniak     RewritePatternSet &results, MLIRContext *context) {
22098c1104dSKrzysztof Drewniak   results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
22198c1104dSKrzysztof Drewniak       context);
22298c1104dSKrzysztof Drewniak }
22398c1104dSKrzysztof Drewniak 
224c55b41d5SKrzysztof Drewniak //===----------------------------------------------------------------------===//
2254b3eaee2SGiuseppe Rossini // WMMAOp
2264b3eaee2SGiuseppe Rossini //===----------------------------------------------------------------------===//
2274b3eaee2SGiuseppe Rossini LogicalResult WMMAOp::verify() {
2284b3eaee2SGiuseppe Rossini   Type sourceAType = getSourceA().getType();
2294b3eaee2SGiuseppe Rossini   Type destType = getDestC().getType();
2304b3eaee2SGiuseppe Rossini 
231a5757c5bSChristian Sigg   VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
232a5757c5bSChristian Sigg   VectorType destVectorType = dyn_cast<VectorType>(destType);
2334b3eaee2SGiuseppe Rossini 
2344b3eaee2SGiuseppe Rossini   Type sourceAElemType = sourceVectorAType.getElementType();
2354b3eaee2SGiuseppe Rossini   Type destElemType = destVectorType.getElementType();
2364b3eaee2SGiuseppe Rossini 
237a8e1c6f9SGiuseppe Rossini   bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
238a8e1c6f9SGiuseppe Rossini   bool isSrcFloat =
239a8e1c6f9SGiuseppe Rossini       isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
240a8e1c6f9SGiuseppe Rossini           sourceAElemType);
2414b3eaee2SGiuseppe Rossini 
2424b3eaee2SGiuseppe Rossini   if (isDestFloat && !isSrcFloat) {
2434b3eaee2SGiuseppe Rossini     return emitOpError("Expected float sources with float destination");
2444b3eaee2SGiuseppe Rossini   }
2454b3eaee2SGiuseppe Rossini 
2464b3eaee2SGiuseppe Rossini   if (!isDestFloat && isSrcFloat) {
2474b3eaee2SGiuseppe Rossini     return emitOpError("Expected int sources with int destination");
2484b3eaee2SGiuseppe Rossini   }
2494b3eaee2SGiuseppe Rossini 
2504b3eaee2SGiuseppe Rossini   return success();
2514b3eaee2SGiuseppe Rossini }
2524b3eaee2SGiuseppe Rossini 
2534b3eaee2SGiuseppe Rossini //===----------------------------------------------------------------------===//
254c55b41d5SKrzysztof Drewniak // MFMAOp
255c55b41d5SKrzysztof Drewniak //===----------------------------------------------------------------------===//
256c55b41d5SKrzysztof Drewniak LogicalResult MFMAOp::verify() {
257c55b41d5SKrzysztof Drewniak   constexpr uint32_t waveSize = 64;
258c55b41d5SKrzysztof Drewniak   Builder b(getContext());
259c55b41d5SKrzysztof Drewniak 
260c55b41d5SKrzysztof Drewniak   Type sourceType = getSourceA().getType();
261c55b41d5SKrzysztof Drewniak   Type destType = getDestC().getType();
262c55b41d5SKrzysztof Drewniak 
263c55b41d5SKrzysztof Drewniak   Type sourceElem = sourceType, destElem = destType;
264c55b41d5SKrzysztof Drewniak   uint32_t sourceLen = 1, destLen = 1;
265c1fa60b4STres Popp   if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
266c55b41d5SKrzysztof Drewniak     sourceLen = sourceVector.getNumElements();
267c55b41d5SKrzysztof Drewniak     sourceElem = sourceVector.getElementType();
268c55b41d5SKrzysztof Drewniak   }
269c1fa60b4STres Popp   if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
270c55b41d5SKrzysztof Drewniak     destLen = destVector.getNumElements();
271c55b41d5SKrzysztof Drewniak     destElem = destVector.getElementType();
272f1f05a91SKrzysztof Drewniak   }
273f1f05a91SKrzysztof Drewniak 
27422f0c7a4SKrzysztof Drewniak   Type sourceBType = getSourceB().getType();
2757a77f14cSMatthias Springer   if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
27622f0c7a4SKrzysztof Drewniak     int64_t sourceBLen = 1;
27722f0c7a4SKrzysztof Drewniak     Type sourceBElem = sourceBType;
278c1fa60b4STres Popp     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
27922f0c7a4SKrzysztof Drewniak       sourceBLen = sourceBVector.getNumElements();
28022f0c7a4SKrzysztof Drewniak       sourceBElem = sourceBVector.getElementType();
28122f0c7a4SKrzysztof Drewniak     }
2827a77f14cSMatthias Springer     if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
28322f0c7a4SKrzysztof Drewniak       return emitOpError("expected both source operands to have f8 elements");
28422f0c7a4SKrzysztof Drewniak     if (sourceLen != sourceBLen)
28522f0c7a4SKrzysztof Drewniak       return emitOpError(
28622f0c7a4SKrzysztof Drewniak           "expected both f8 source vectors to have the same length");
28722f0c7a4SKrzysztof Drewniak   } else {
28822f0c7a4SKrzysztof Drewniak     if (sourceType != sourceBType)
28922f0c7a4SKrzysztof Drewniak       return emitOpError(
29022f0c7a4SKrzysztof Drewniak           "expected both non-f8 source operand types to match exactly");
29122f0c7a4SKrzysztof Drewniak   }
292c55b41d5SKrzysztof Drewniak   // Normalize the wider integer types the compiler expects to i8
293c55b41d5SKrzysztof Drewniak   if (sourceElem.isInteger(32)) {
294c55b41d5SKrzysztof Drewniak     sourceLen *= 4;
295c55b41d5SKrzysztof Drewniak     sourceElem = b.getI8Type();
296f1f05a91SKrzysztof Drewniak   }
297c55b41d5SKrzysztof Drewniak   if (sourceElem.isInteger(64)) {
298c55b41d5SKrzysztof Drewniak     sourceLen *= 8;
299c55b41d5SKrzysztof Drewniak     sourceElem = b.getI8Type();
300c55b41d5SKrzysztof Drewniak   }
301c55b41d5SKrzysztof Drewniak 
302c55b41d5SKrzysztof Drewniak   int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
303c55b41d5SKrzysztof Drewniak   if (sourceLen != numSourceElems)
304c55b41d5SKrzysztof Drewniak     return emitOpError("expected " + Twine(numSourceElems) +
305c55b41d5SKrzysztof Drewniak                        " source values for this operation but got " +
306c55b41d5SKrzysztof Drewniak                        Twine(sourceLen));
307c55b41d5SKrzysztof Drewniak 
308c55b41d5SKrzysztof Drewniak   int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
309c55b41d5SKrzysztof Drewniak   if (destLen != numDestElems)
310c55b41d5SKrzysztof Drewniak     return emitOpError("expected " + Twine(numDestElems) +
311c55b41d5SKrzysztof Drewniak                        " result values for this operation but got " +
312c55b41d5SKrzysztof Drewniak                        Twine(destLen));
313c55b41d5SKrzysztof Drewniak 
314c55b41d5SKrzysztof Drewniak   if (destElem.isF64() && getBlgp() != MFMAPermB::none)
315c55b41d5SKrzysztof Drewniak     return emitOpError(
316c55b41d5SKrzysztof Drewniak         "double-precision ops do not support permuting lanes of B");
317c55b41d5SKrzysztof Drewniak   if (destElem.isF64() && getCbsz() != 0)
318c55b41d5SKrzysztof Drewniak     return emitOpError(
319c55b41d5SKrzysztof Drewniak         "double-precision ops do not support permuting lanes of A");
3207e52e0fcSRob Suderman   if (getAbid() >= (1u << getCbsz()))
321c55b41d5SKrzysztof Drewniak     return emitOpError(
322c55b41d5SKrzysztof Drewniak         "block ID for permuting A (abid) must be below 2 ** cbsz");
323c55b41d5SKrzysztof Drewniak 
324c55b41d5SKrzysztof Drewniak   if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
325c55b41d5SKrzysztof Drewniak     return emitOpError(
326c55b41d5SKrzysztof Drewniak         "negation flags only available for double-precision operations");
327c55b41d5SKrzysztof Drewniak 
328c55b41d5SKrzysztof Drewniak   return success();
329c55b41d5SKrzysztof Drewniak }
330c55b41d5SKrzysztof Drewniak 
3311164e4aeSstefankoncarevic //===----------------------------------------------------------------------===//
3321164e4aeSstefankoncarevic // DPPOp
3331164e4aeSstefankoncarevic //===----------------------------------------------------------------------===//
3341164e4aeSstefankoncarevic LogicalResult DPPOp::verify() {
3351164e4aeSstefankoncarevic   Type srcType = getSrc().getType();
3361164e4aeSstefankoncarevic   if (srcType.getIntOrFloatBitWidth() > 64) {
3371164e4aeSstefankoncarevic     return emitOpError("integer and floating point types larger than 64 bits "
3381164e4aeSstefankoncarevic                        "are not supported");
3391164e4aeSstefankoncarevic   }
3401164e4aeSstefankoncarevic 
3411164e4aeSstefankoncarevic   DPPPerm kind = getKind();
3421164e4aeSstefankoncarevic   Attribute permArgument = getPermArgument().value_or(Attribute{});
3431164e4aeSstefankoncarevic 
3441164e4aeSstefankoncarevic   switch (kind) {
3451164e4aeSstefankoncarevic 
3461164e4aeSstefankoncarevic   case DPPPerm::quad_perm: {
3471164e4aeSstefankoncarevic     auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
3481164e4aeSstefankoncarevic     if (!quadPermAttr || quadPermAttr.size() != 4) {
3491164e4aeSstefankoncarevic       return emitOpError("quad_perm attribute must have exactly 4 elements");
3501164e4aeSstefankoncarevic     }
3511164e4aeSstefankoncarevic     for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
352d5746d73SFrank Schlimbach       int32_t num = elem.getInt();
3531164e4aeSstefankoncarevic       if (num < 0 || num > 3) {
3541164e4aeSstefankoncarevic         return emitOpError(
3551164e4aeSstefankoncarevic             "Each element of quad_perm must be in the range [0, 3]");
3561164e4aeSstefankoncarevic       }
3571164e4aeSstefankoncarevic     }
3581164e4aeSstefankoncarevic   } break;
3591164e4aeSstefankoncarevic 
3601164e4aeSstefankoncarevic   case DPPPerm::row_shl:
3611164e4aeSstefankoncarevic   case DPPPerm::row_shr:
3621164e4aeSstefankoncarevic   case DPPPerm::row_ror: {
3631164e4aeSstefankoncarevic     if (!permArgument) {
3641164e4aeSstefankoncarevic       return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
3651164e4aeSstefankoncarevic                          "' value not specified");
3661164e4aeSstefankoncarevic     }
3671164e4aeSstefankoncarevic     if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
3681164e4aeSstefankoncarevic       uint32_t attrValue = intAttr.getInt();
3691164e4aeSstefankoncarevic       if (attrValue < 1 || attrValue > 15) {
3701164e4aeSstefankoncarevic         return emitOpError("Attribute value must be between 1 and 15");
3711164e4aeSstefankoncarevic       }
3721164e4aeSstefankoncarevic     }
3731164e4aeSstefankoncarevic   } break;
3741164e4aeSstefankoncarevic 
3751164e4aeSstefankoncarevic   case DPPPerm::wave_shl:
3761164e4aeSstefankoncarevic   case DPPPerm::wave_shr:
3771164e4aeSstefankoncarevic   case DPPPerm::wave_rol:
3781164e4aeSstefankoncarevic   case DPPPerm::wave_ror:
3791164e4aeSstefankoncarevic   case DPPPerm::row_mirror:
3801164e4aeSstefankoncarevic   case DPPPerm::row_half_mirror:
3811164e4aeSstefankoncarevic   case DPPPerm::row_bcast_15:
3821164e4aeSstefankoncarevic   case DPPPerm::row_bcast_31: {
3831164e4aeSstefankoncarevic     if (permArgument && !isa<UnitAttr>(permArgument)) {
3841164e4aeSstefankoncarevic       return emitOpError("Expected unit attribute for permArgument, but found "
3851164e4aeSstefankoncarevic                          "non-trivial argument");
3861164e4aeSstefankoncarevic     }
3871164e4aeSstefankoncarevic     break;
3881164e4aeSstefankoncarevic   }
3891164e4aeSstefankoncarevic   }
3901164e4aeSstefankoncarevic   return success();
3911164e4aeSstefankoncarevic }
3921164e4aeSstefankoncarevic 
393cc470374SKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
394c55b41d5SKrzysztof Drewniak 
395c55b41d5SKrzysztof Drewniak #define GET_ATTRDEF_CLASSES
396cc470374SKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
397f1f05a91SKrzysztof Drewniak 
398f1f05a91SKrzysztof Drewniak #define GET_OP_CLASSES
399cc470374SKrzysztof Drewniak #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
400