1 //===- AMDGPUDialect.cpp - MLIR AMDGPU dialect implementation --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the AMDGPU dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/Matchers.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/IR/TypeUtilities.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 28 #include <limits> 29 #include <optional> 30 31 using namespace mlir; 32 using namespace mlir::amdgpu; 33 34 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc" 35 36 void AMDGPUDialect::initialize() { 37 addOperations< 38 #define GET_OP_LIST 39 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" 40 >(); 41 addAttributes< 42 #define GET_ATTRDEF_LIST 43 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" 44 >(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 // 8-bit float ops 49 //===----------------------------------------------------------------------===// 50 LogicalResult PackedTrunc2xFp8Op::verify() { 51 if (getExisting() && getExisting().getType() != getResult().getType()) 52 return emitOpError("existing values must have same type as result"); 53 return success(); 54 } 55 56 LogicalResult PackedStochRoundFp8Op::verify() { 57 if (getExisting() && getExisting().getType() != getResult().getType()) 58 return emitOpError("existing values must have same type as result"); 59 return success(); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // RawBuffer*Op 64 //===----------------------------------------------------------------------===// 65 template <typename T> 66 static LogicalResult verifyRawBufferOp(T &op) { 67 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType()); 68 Attribute memorySpace = bufferType.getMemorySpace(); 69 bool isGlobal = false; 70 if (!memorySpace) 71 isGlobal = true; 72 else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace)) 73 isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; 74 else if (auto gpuMemorySpace = 75 llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) 76 isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global; 77 78 if (!isGlobal) 79 return op.emitOpError( 80 "Buffer ops must operate on a memref in global memory"); 81 if (!bufferType.hasRank()) 82 return op.emitOpError( 83 "Cannot meaningfully buffer_store to an unranked memref"); 84 if (static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank()) 85 return op.emitOpError("Expected " + Twine(bufferType.getRank()) + 86 " indices to memref"); 87 return success(); 88 } 89 90 LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); } 91 92 LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); } 93 94 LogicalResult RawBufferAtomicFaddOp::verify() { 95 return verifyRawBufferOp(*this); 96 } 97 98 LogicalResult RawBufferAtomicFmaxOp::verify() { 99 return verifyRawBufferOp(*this); 100 } 101 102 LogicalResult RawBufferAtomicSmaxOp::verify() { 103 return verifyRawBufferOp(*this); 104 } 105 106 LogicalResult RawBufferAtomicUminOp::verify() { 107 return verifyRawBufferOp(*this); 108 } 109 110 LogicalResult RawBufferAtomicCmpswapOp::verify() { 111 return verifyRawBufferOp(*this); 112 } 113 114 static std::optional<uint32_t> getConstantUint32(Value v) { 115 APInt cst; 116 if (!v.getType().isInteger(32)) 117 return std::nullopt; 118 if (matchPattern(v, m_ConstantInt(&cst))) 119 return cst.getZExtValue(); 120 return std::nullopt; 121 } 122 123 template <typename OpType> 124 static bool staticallyOutOfBounds(OpType op) { 125 if (!op.getBoundsCheck()) 126 return false; 127 MemRefType bufferType = op.getMemref().getType(); 128 if (!bufferType.hasStaticShape()) 129 return false; 130 int64_t offset; 131 SmallVector<int64_t> strides; 132 if (failed(bufferType.getStridesAndOffset(strides, offset))) 133 return false; 134 int64_t result = offset + op.getIndexOffset().value_or(0); 135 if (op.getSgprOffset()) { 136 std::optional<uint32_t> sgprOffset = getConstantUint32(op.getSgprOffset()); 137 if (!sgprOffset) 138 return false; 139 result += *sgprOffset; 140 } 141 if (strides.size() != op.getIndices().size()) 142 return false; 143 int64_t indexVal = 0; 144 for (auto pair : llvm::zip(strides, op.getIndices())) { 145 int64_t stride = std::get<0>(pair); 146 Value idx = std::get<1>(pair); 147 std::optional<uint32_t> idxVal = getConstantUint32(idx); 148 if (!idxVal) 149 return false; 150 indexVal += stride * *idxVal; 151 } 152 result += indexVal; 153 if (result > std::numeric_limits<uint32_t>::max()) 154 // Overflow means don't drop 155 return false; 156 return result >= bufferType.getNumElements(); 157 } 158 159 namespace { 160 template <typename OpType> 161 struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern<OpType> { 162 using OpRewritePattern<OpType>::OpRewritePattern; 163 164 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { 165 if (!staticallyOutOfBounds(op)) 166 return failure(); 167 Type loadType = op.getResult().getType(); 168 rw.replaceOpWithNewOp<arith::ConstantOp>(op, loadType, 169 rw.getZeroAttr(loadType)); 170 return success(); 171 } 172 }; 173 174 template <typename OpType> 175 struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern<OpType> { 176 using OpRewritePattern<OpType>::OpRewritePattern; 177 178 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { 179 if (!staticallyOutOfBounds(op)) 180 return failure(); 181 182 rw.eraseOp(op); 183 return success(); 184 } 185 }; 186 } // end namespace 187 188 void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, 189 MLIRContext *context) { 190 results.add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context); 191 } 192 193 void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, 194 MLIRContext *context) { 195 results.add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context); 196 } 197 198 void RawBufferAtomicFaddOp::getCanonicalizationPatterns( 199 RewritePatternSet &results, MLIRContext *context) { 200 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context); 201 } 202 203 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns( 204 RewritePatternSet &results, MLIRContext *context) { 205 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context); 206 } 207 208 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns( 209 RewritePatternSet &results, MLIRContext *context) { 210 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context); 211 } 212 213 void RawBufferAtomicUminOp::getCanonicalizationPatterns( 214 RewritePatternSet &results, MLIRContext *context) { 215 results.add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context); 216 } 217 218 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( 219 RewritePatternSet &results, MLIRContext *context) { 220 results.add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>( 221 context); 222 } 223 224 //===----------------------------------------------------------------------===// 225 // WMMAOp 226 //===----------------------------------------------------------------------===// 227 LogicalResult WMMAOp::verify() { 228 Type sourceAType = getSourceA().getType(); 229 Type destType = getDestC().getType(); 230 231 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType); 232 VectorType destVectorType = dyn_cast<VectorType>(destType); 233 234 Type sourceAElemType = sourceVectorAType.getElementType(); 235 Type destElemType = destVectorType.getElementType(); 236 237 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType); 238 bool isSrcFloat = 239 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>( 240 sourceAElemType); 241 242 if (isDestFloat && !isSrcFloat) { 243 return emitOpError("Expected float sources with float destination"); 244 } 245 246 if (!isDestFloat && isSrcFloat) { 247 return emitOpError("Expected int sources with int destination"); 248 } 249 250 return success(); 251 } 252 253 //===----------------------------------------------------------------------===// 254 // MFMAOp 255 //===----------------------------------------------------------------------===// 256 LogicalResult MFMAOp::verify() { 257 constexpr uint32_t waveSize = 64; 258 Builder b(getContext()); 259 260 Type sourceType = getSourceA().getType(); 261 Type destType = getDestC().getType(); 262 263 Type sourceElem = sourceType, destElem = destType; 264 uint32_t sourceLen = 1, destLen = 1; 265 if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) { 266 sourceLen = sourceVector.getNumElements(); 267 sourceElem = sourceVector.getElementType(); 268 } 269 if (auto destVector = llvm::dyn_cast<VectorType>(destType)) { 270 destLen = destVector.getNumElements(); 271 destElem = destVector.getElementType(); 272 } 273 274 Type sourceBType = getSourceB().getType(); 275 if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) { 276 int64_t sourceBLen = 1; 277 Type sourceBElem = sourceBType; 278 if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) { 279 sourceBLen = sourceBVector.getNumElements(); 280 sourceBElem = sourceBVector.getElementType(); 281 } 282 if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem)) 283 return emitOpError("expected both source operands to have f8 elements"); 284 if (sourceLen != sourceBLen) 285 return emitOpError( 286 "expected both f8 source vectors to have the same length"); 287 } else { 288 if (sourceType != sourceBType) 289 return emitOpError( 290 "expected both non-f8 source operand types to match exactly"); 291 } 292 // Normalize the wider integer types the compiler expects to i8 293 if (sourceElem.isInteger(32)) { 294 sourceLen *= 4; 295 sourceElem = b.getI8Type(); 296 } 297 if (sourceElem.isInteger(64)) { 298 sourceLen *= 8; 299 sourceElem = b.getI8Type(); 300 } 301 302 int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize; 303 if (sourceLen != numSourceElems) 304 return emitOpError("expected " + Twine(numSourceElems) + 305 " source values for this operation but got " + 306 Twine(sourceLen)); 307 308 int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize; 309 if (destLen != numDestElems) 310 return emitOpError("expected " + Twine(numDestElems) + 311 " result values for this operation but got " + 312 Twine(destLen)); 313 314 if (destElem.isF64() && getBlgp() != MFMAPermB::none) 315 return emitOpError( 316 "double-precision ops do not support permuting lanes of B"); 317 if (destElem.isF64() && getCbsz() != 0) 318 return emitOpError( 319 "double-precision ops do not support permuting lanes of A"); 320 if (getAbid() >= (1u << getCbsz())) 321 return emitOpError( 322 "block ID for permuting A (abid) must be below 2 ** cbsz"); 323 324 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64()) 325 return emitOpError( 326 "negation flags only available for double-precision operations"); 327 328 return success(); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // DPPOp 333 //===----------------------------------------------------------------------===// 334 LogicalResult DPPOp::verify() { 335 Type srcType = getSrc().getType(); 336 if (srcType.getIntOrFloatBitWidth() > 64) { 337 return emitOpError("integer and floating point types larger than 64 bits " 338 "are not supported"); 339 } 340 341 DPPPerm kind = getKind(); 342 Attribute permArgument = getPermArgument().value_or(Attribute{}); 343 344 switch (kind) { 345 346 case DPPPerm::quad_perm: { 347 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument); 348 if (!quadPermAttr || quadPermAttr.size() != 4) { 349 return emitOpError("quad_perm attribute must have exactly 4 elements"); 350 } 351 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) { 352 int32_t num = elem.getInt(); 353 if (num < 0 || num > 3) { 354 return emitOpError( 355 "Each element of quad_perm must be in the range [0, 3]"); 356 } 357 } 358 } break; 359 360 case DPPPerm::row_shl: 361 case DPPPerm::row_shr: 362 case DPPPerm::row_ror: { 363 if (!permArgument) { 364 return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) + 365 "' value not specified"); 366 } 367 if (auto intAttr = dyn_cast<IntegerAttr>(permArgument)) { 368 uint32_t attrValue = intAttr.getInt(); 369 if (attrValue < 1 || attrValue > 15) { 370 return emitOpError("Attribute value must be between 1 and 15"); 371 } 372 } 373 } break; 374 375 case DPPPerm::wave_shl: 376 case DPPPerm::wave_shr: 377 case DPPPerm::wave_rol: 378 case DPPPerm::wave_ror: 379 case DPPPerm::row_mirror: 380 case DPPPerm::row_half_mirror: 381 case DPPPerm::row_bcast_15: 382 case DPPPerm::row_bcast_31: { 383 if (permArgument && !isa<UnitAttr>(permArgument)) { 384 return emitOpError("Expected unit attribute for permArgument, but found " 385 "non-trivial argument"); 386 } 387 break; 388 } 389 } 390 return success(); 391 } 392 393 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" 394 395 #define GET_ATTRDEF_CLASSES 396 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" 397 398 #define GET_OP_CLASSES 399 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" 400