xref: /llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMDGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 #include <limits>
29 #include <optional>
30 
31 using namespace mlir;
32 using namespace mlir::amdgpu;
33 
34 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
35 
36 void AMDGPUDialect::initialize() {
37   addOperations<
38 #define GET_OP_LIST
39 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
40       >();
41   addAttributes<
42 #define GET_ATTRDEF_LIST
43 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
44       >();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // 8-bit float ops
49 //===----------------------------------------------------------------------===//
50 LogicalResult PackedTrunc2xFp8Op::verify() {
51   if (getExisting() && getExisting().getType() != getResult().getType())
52     return emitOpError("existing values must have same type as result");
53   return success();
54 }
55 
56 LogicalResult PackedStochRoundFp8Op::verify() {
57   if (getExisting() && getExisting().getType() != getResult().getType())
58     return emitOpError("existing values must have same type as result");
59   return success();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // RawBuffer*Op
64 //===----------------------------------------------------------------------===//
65 template <typename T>
66 static LogicalResult verifyRawBufferOp(T &op) {
67   MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
68   Attribute memorySpace = bufferType.getMemorySpace();
69   bool isGlobal = false;
70   if (!memorySpace)
71     isGlobal = true;
72   else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
73     isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
74   else if (auto gpuMemorySpace =
75                llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
76     isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
77 
78   if (!isGlobal)
79     return op.emitOpError(
80         "Buffer ops must operate on a memref in global memory");
81   if (!bufferType.hasRank())
82     return op.emitOpError(
83         "Cannot meaningfully buffer_store to an unranked memref");
84   if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
85     return op.emitOpError("Expected " + Twine(bufferType.getRank()) +
86                           " indices to memref");
87   return success();
88 }
89 
90 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); }
91 
92 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); }
93 
94 LogicalResult RawBufferAtomicFaddOp::verify() {
95   return verifyRawBufferOp(*this);
96 }
97 
98 LogicalResult RawBufferAtomicFmaxOp::verify() {
99   return verifyRawBufferOp(*this);
100 }
101 
102 LogicalResult RawBufferAtomicSmaxOp::verify() {
103   return verifyRawBufferOp(*this);
104 }
105 
106 LogicalResult RawBufferAtomicUminOp::verify() {
107   return verifyRawBufferOp(*this);
108 }
109 
110 LogicalResult RawBufferAtomicCmpswapOp::verify() {
111   return verifyRawBufferOp(*this);
112 }
113 
114 static std::optional<uint32_t> getConstantUint32(Value v) {
115   APInt cst;
116   if (!v.getType().isInteger(32))
117     return std::nullopt;
118   if (matchPattern(v, m_ConstantInt(&cst)))
119     return cst.getZExtValue();
120   return std::nullopt;
121 }
122 
123 template <typename OpType>
124 static bool staticallyOutOfBounds(OpType op) {
125   if (!op.getBoundsCheck())
126     return false;
127   MemRefType bufferType = op.getMemref().getType();
128   if (!bufferType.hasStaticShape())
129     return false;
130   int64_t offset;
131   SmallVector<int64_t> strides;
132   if (failed(bufferType.getStridesAndOffset(strides, offset)))
133     return false;
134   int64_t result = offset + op.getIndexOffset().value_or(0);
135   if (op.getSgprOffset()) {
136     std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset());
137     if (!sgprOffset)
138       return false;
139     result += *sgprOffset;
140   }
141   if (strides.size() != op.getIndices().size())
142     return false;
143   int64_t indexVal = 0;
144   for (auto pair : llvm::zip(strides, op.getIndices())) {
145     int64_t stride = std::get<0>(pair);
146     Value idx = std::get<1>(pair);
147     std::optional<uint32_t> idxVal = getConstantUint32(idx);
148     if (!idxVal)
149       return false;
150     indexVal += stride * *idxVal;
151   }
152   result += indexVal;
153   if (result > std::numeric_limits<uint32_t>::max())
154     // Overflow means don't drop
155     return false;
156   return result >= bufferType.getNumElements();
157 }
158 
159 namespace {
160 template <typename OpType>
161 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> {
162   using OpRewritePattern<OpType>::OpRewritePattern;
163 
164   LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
165     if (!staticallyOutOfBounds(op))
166       return failure();
167     Type loadType = op.getResult().getType();
168     rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType,
169                                              rw.getZeroAttr(loadType));
170     return success();
171   }
172 };
173 
174 template <typename OpType>
175 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> {
176   using OpRewritePattern<OpType>::OpRewritePattern;
177 
178   LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override {
179     if (!staticallyOutOfBounds(op))
180       return failure();
181 
182     rw.eraseOp(op);
183     return success();
184   }
185 };
186 } // end namespace
187 
188 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
189                                                   MLIRContext *context) {
190   results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
191 }
192 
193 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
194                                                    MLIRContext *context) {
195   results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
196 }
197 
198 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
199     RewritePatternSet &results, MLIRContext *context) {
200   results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
201 }
202 
203 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
204     RewritePatternSet &results, MLIRContext *context) {
205   results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
206 }
207 
208 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
209     RewritePatternSet &results, MLIRContext *context) {
210   results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
211 }
212 
213 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
214     RewritePatternSet &results, MLIRContext *context) {
215   results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
216 }
217 
218 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
219     RewritePatternSet &results, MLIRContext *context) {
220   results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
221       context);
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // WMMAOp
226 //===----------------------------------------------------------------------===//
227 LogicalResult WMMAOp::verify() {
228   Type sourceAType = getSourceA().getType();
229   Type destType = getDestC().getType();
230 
231   VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
232   VectorType destVectorType = dyn_cast<VectorType>(destType);
233 
234   Type sourceAElemType = sourceVectorAType.getElementType();
235   Type destElemType = destVectorType.getElementType();
236 
237   bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
238   bool isSrcFloat =
239       isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
240           sourceAElemType);
241 
242   if (isDestFloat && !isSrcFloat) {
243     return emitOpError("Expected float sources with float destination");
244   }
245 
246   if (!isDestFloat && isSrcFloat) {
247     return emitOpError("Expected int sources with int destination");
248   }
249 
250   return success();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // MFMAOp
255 //===----------------------------------------------------------------------===//
256 LogicalResult MFMAOp::verify() {
257   constexpr uint32_t waveSize = 64;
258   Builder b(getContext());
259 
260   Type sourceType = getSourceA().getType();
261   Type destType = getDestC().getType();
262 
263   Type sourceElem = sourceType, destElem = destType;
264   uint32_t sourceLen = 1, destLen = 1;
265   if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
266     sourceLen = sourceVector.getNumElements();
267     sourceElem = sourceVector.getElementType();
268   }
269   if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
270     destLen = destVector.getNumElements();
271     destElem = destVector.getElementType();
272   }
273 
274   Type sourceBType = getSourceB().getType();
275   if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
276     int64_t sourceBLen = 1;
277     Type sourceBElem = sourceBType;
278     if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
279       sourceBLen = sourceBVector.getNumElements();
280       sourceBElem = sourceBVector.getElementType();
281     }
282     if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
283       return emitOpError("expected both source operands to have f8 elements");
284     if (sourceLen != sourceBLen)
285       return emitOpError(
286           "expected both f8 source vectors to have the same length");
287   } else {
288     if (sourceType != sourceBType)
289       return emitOpError(
290           "expected both non-f8 source operand types to match exactly");
291   }
292   // Normalize the wider integer types the compiler expects to i8
293   if (sourceElem.isInteger(32)) {
294     sourceLen *= 4;
295     sourceElem = b.getI8Type();
296   }
297   if (sourceElem.isInteger(64)) {
298     sourceLen *= 8;
299     sourceElem = b.getI8Type();
300   }
301 
302   int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize;
303   if (sourceLen != numSourceElems)
304     return emitOpError("expected " + Twine(numSourceElems) +
305                        " source values for this operation but got " +
306                        Twine(sourceLen));
307 
308   int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize;
309   if (destLen != numDestElems)
310     return emitOpError("expected " + Twine(numDestElems) +
311                        " result values for this operation but got " +
312                        Twine(destLen));
313 
314   if (destElem.isF64() && getBlgp() != MFMAPermB::none)
315     return emitOpError(
316         "double-precision ops do not support permuting lanes of B");
317   if (destElem.isF64() && getCbsz() != 0)
318     return emitOpError(
319         "double-precision ops do not support permuting lanes of A");
320   if (getAbid() >= (1u << getCbsz()))
321     return emitOpError(
322         "block ID for permuting A (abid) must be below 2 ** cbsz");
323 
324   if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
325     return emitOpError(
326         "negation flags only available for double-precision operations");
327 
328   return success();
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // DPPOp
333 //===----------------------------------------------------------------------===//
334 LogicalResult DPPOp::verify() {
335   Type srcType = getSrc().getType();
336   if (srcType.getIntOrFloatBitWidth() > 64) {
337     return emitOpError("integer and floating point types larger than 64 bits "
338                        "are not supported");
339   }
340 
341   DPPPerm kind = getKind();
342   Attribute permArgument = getPermArgument().value_or(Attribute{});
343 
344   switch (kind) {
345 
346   case DPPPerm::quad_perm: {
347     auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
348     if (!quadPermAttr || quadPermAttr.size() != 4) {
349       return emitOpError("quad_perm attribute must have exactly 4 elements");
350     }
351     for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
352       int32_t num = elem.getInt();
353       if (num < 0 || num > 3) {
354         return emitOpError(
355             "Each element of quad_perm must be in the range [0, 3]");
356       }
357     }
358   } break;
359 
360   case DPPPerm::row_shl:
361   case DPPPerm::row_shr:
362   case DPPPerm::row_ror: {
363     if (!permArgument) {
364       return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) +
365                          "' value not specified");
366     }
367     if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
368       uint32_t attrValue = intAttr.getInt();
369       if (attrValue < 1 || attrValue > 15) {
370         return emitOpError("Attribute value must be between 1 and 15");
371       }
372     }
373   } break;
374 
375   case DPPPerm::wave_shl:
376   case DPPPerm::wave_shr:
377   case DPPPerm::wave_rol:
378   case DPPPerm::wave_ror:
379   case DPPPerm::row_mirror:
380   case DPPPerm::row_half_mirror:
381   case DPPPerm::row_bcast_15:
382   case DPPPerm::row_bcast_31: {
383     if (permArgument && !isa<UnitAttr>(permArgument)) {
384       return emitOpError("Expected unit attribute for permArgument, but found "
385                          "non-trivial argument");
386     }
387     break;
388   }
389   }
390   return success();
391 }
392 
393 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
394 
395 #define GET_ATTRDEF_CLASSES
396 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
397 
398 #define GET_OP_CLASSES
399 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
400