1 //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/Utils/Utils.h" 10 #include "mlir/Dialect/Utils/StaticValueUtils.h" 11 #include "mlir/Dialect/XeGPU/IR/XeGPU.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/TypeUtilities.h" 14 15 #include "llvm/Support/Debug.h" 16 17 #define DEBUG_TYPE "xegpu" 18 19 namespace mlir { 20 namespace xegpu { 21 22 static void transpose(llvm::ArrayRef<int64_t> trans, 23 SmallVector<int64_t> &shape) { 24 SmallVector<int64_t> old = shape; 25 for (size_t i = 0; i < trans.size(); i++) 26 shape[i] = old[trans[i]]; 27 } 28 29 template <typename T> 30 static std::string makeString(T array, bool breakline = false) { 31 std::string buf; 32 buf.clear(); 33 llvm::raw_string_ostream os(buf); 34 os << "["; 35 for (size_t i = 1; i < array.size(); i++) { 36 os << array[i - 1] << ", "; 37 if (breakline) 38 os << "\n\t\t"; 39 } 40 os << array.back() << "]"; 41 return buf; 42 } 43 44 static SmallVector<int64_t> getShapeOf(Type type) { 45 SmallVector<int64_t> shape; 46 if (auto ty = llvm::dyn_cast<ShapedType>(type)) 47 shape = SmallVector<int64_t>(ty.getShape()); 48 else 49 shape.push_back(1); 50 return shape; 51 } 52 53 static int64_t getRankOf(Value val) { 54 auto type = val.getType(); 55 if (auto ty = llvm::dyn_cast<ShapedType>(type)) 56 return ty.getRank(); 57 return 0; 58 } 59 60 static bool isReadHintOrNone(const CachePolicyAttr &attr) { 61 if (!attr) 62 return true; 63 auto kind = attr.getValue(); 64 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || 65 kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE; 66 } 67 68 static bool isWriteHintOrNone(const CachePolicyAttr &attr) { 69 if (!attr) 70 return true; 71 auto kind = attr.getValue(); 72 return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || 73 kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; 74 } 75 76 // Validations for nd instruction arguments is successful if any of these are 77 // true: 78 // - tensor descriptor and the output vector shapes exactly match. 79 // - tensor descriptor has a sg_map attribute and the distributed vector shape 80 // matches the tensor descriptor shape when scaled using sg_map factors on 81 // each dimension. 82 static bool isArgShapesValid(ArrayRef<int64_t> descShape, 83 ArrayRef<int64_t> valShape, SGMapAttr sgMap) { 84 if (descShape == valShape) { 85 if (!sgMap) 86 return true; 87 88 // this can be relaxed if necessary by supporting non-2d shapes distribution 89 // until the constraints are defined this lives here instead of the tensor 90 // descriptor type. 91 return valShape.size() == sgMap.getWiLayout().size(); 92 } 93 94 if (!sgMap) 95 return false; 96 97 if (valShape.size() != descShape.size()) 98 return false; 99 100 for (const auto &[factor, dim, expected] : 101 llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) { 102 if (factor * dim != expected) 103 return false; 104 } 105 106 return true; 107 } 108 109 //===----------------------------------------------------------------------===// 110 // XeGPU_CreateNdDescOp 111 //===----------------------------------------------------------------------===// 112 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, 113 Type tdesc, TypedValue<MemRefType> source, 114 llvm::ArrayRef<OpFoldResult> offsets) { 115 [[maybe_unused]] auto ty = source.getType(); 116 assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()); 117 118 llvm::SmallVector<int64_t> staticOffsets; 119 llvm::SmallVector<Value> dynamicOffsets; 120 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 121 122 build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, 123 ValueRange({}) /* empty dynamic shape */, 124 ValueRange({}) /* empty dynamic strides */, 125 staticOffsets /* const offsets */, {} /* empty const shape*/, 126 {} /* empty const strides*/); 127 } 128 129 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, 130 Type tdesc, TypedValue<MemRefType> source, 131 llvm::ArrayRef<OpFoldResult> offsets, 132 llvm::ArrayRef<OpFoldResult> shape, 133 llvm::ArrayRef<OpFoldResult> strides) { 134 assert(shape.size() && offsets.size() && strides.size() && 135 shape.size() == strides.size() && shape.size() == offsets.size()); 136 137 llvm::SmallVector<int64_t> staticOffsets; 138 llvm::SmallVector<int64_t> staticShape; 139 llvm::SmallVector<int64_t> staticStrides; 140 llvm::SmallVector<Value> dynamicOffsets; 141 llvm::SmallVector<Value> dynamicShape; 142 llvm::SmallVector<Value> dynamicStrides; 143 144 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 145 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); 146 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 147 148 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); 149 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); 150 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); 151 152 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, 153 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); 154 } 155 156 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, 157 Type tdesc, TypedValue<IntegerType> source, 158 llvm::ArrayRef<OpFoldResult> offsets, 159 llvm::ArrayRef<OpFoldResult> shape, 160 llvm::ArrayRef<OpFoldResult> strides) { 161 assert(shape.size() && offsets.size() && strides.size() && 162 shape.size() == strides.size() && shape.size() == offsets.size()); 163 164 llvm::SmallVector<int64_t> staticOffsets; 165 llvm::SmallVector<int64_t> staticShape; 166 llvm::SmallVector<int64_t> staticStrides; 167 llvm::SmallVector<Value> dynamicOffsets; 168 llvm::SmallVector<Value> dynamicShape; 169 llvm::SmallVector<Value> dynamicStrides; 170 171 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 172 dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); 173 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 174 175 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); 176 auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); 177 auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); 178 179 build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, 180 dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); 181 } 182 183 LogicalResult CreateNdDescOp::verify() { 184 auto rank = (int64_t)getMixedOffsets().size(); 185 bool invalidRank = false; 186 bool invalidElemTy = false; 187 188 // Memory space of created TensorDesc should match with the source. 189 // Both source and TensorDesc are considered for global memory by default, 190 // if the memory scope attr is not specified. If source is an integer, 191 // it is considered as ptr to global memory. 192 auto srcMemorySpace = getSourceMemorySpace(); 193 auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace()); 194 if (srcMemorySpace != tdescMemorySpace) 195 return emitOpError("Memory space mismatch.") 196 << " Source: " << srcMemorySpace 197 << ", TensorDesc: " << tdescMemorySpace; 198 199 // check source type matches the rank if it is a memref. 200 // It also should have the same ElementType as TensorDesc. 201 auto memrefTy = dyn_cast<MemRefType>(getSourceType()); 202 if (memrefTy) { 203 invalidRank |= (memrefTy.getRank() != rank); 204 invalidElemTy |= memrefTy.getElementType() != getElementType(); 205 } 206 207 // mismatches among shape, strides, and offsets are 208 // already handeled by OffsetSizeAndStrideOpInterface. 209 // So they are not check here. 210 if (invalidRank) 211 return emitOpError( 212 "Expecting the rank of shape, strides, offsets, and source (if source " 213 "is a memref) should match with each other."); 214 215 // check result TensorDesc rank 216 invalidRank = (getType().getRank() > 2 || getType().getRank() > rank); 217 218 if (invalidRank) 219 return emitOpError( 220 "Expecting the TensorDesc rank is up to 2 and not greater than the " 221 "ranks of shape, strides, offsets or the memref source."); 222 223 if (invalidElemTy) 224 return emitOpError("TensorDesc should have the same element " 225 "type with the source if it is a memref.\n"); 226 227 if (getType().isScattered()) 228 return emitOpError("Expects a non-scattered TensorDesc.\n"); 229 230 if (getType().getRank() == 2 && 231 tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM)) 232 return emitOpError("SLM is not supported for 2D Block TensorDesc.\n"); 233 234 return success(); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // XeGPU_PrefetchNdOp 239 //===----------------------------------------------------------------------===// 240 LogicalResult PrefetchNdOp::verify() { 241 auto tdescTy = getTensorDescType(); 242 if (tdescTy.isScattered()) 243 return emitOpError("Expects a non-scattered TensorDesc.\n"); 244 245 if (!isReadHintOrNone(getL1HintAttr())) 246 return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 247 248 if (!isReadHintOrNone(getL2HintAttr())) 249 return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 250 251 if (!isReadHintOrNone(getL3HintAttr())) 252 return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 253 254 return success(); 255 } 256 257 //===----------------------------------------------------------------------===// 258 // XeGPU_LoadNdOp 259 //===----------------------------------------------------------------------===// 260 LogicalResult LoadNdOp::verify() { 261 auto tdescTy = getTensorDescType(); 262 auto valueTy = getType(); 263 264 if (tdescTy.getRank() > 2) 265 return emitOpError("Expecting a 1D/2D TensorDesc.\n"); 266 267 if (tdescTy.isScattered()) 268 return emitOpError("Expects a non-scattered TensorDesc.\n"); 269 270 if (!valueTy) 271 return emitOpError("Invalid result, it should be a VectorType.\n"); 272 273 if (!isReadHintOrNone(getL1HintAttr())) 274 return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 275 276 if (!isReadHintOrNone(getL2HintAttr())) 277 return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 278 279 if (!isReadHintOrNone(getL3HintAttr())) 280 return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 281 282 auto array_len = tdescTy.getArrayLength(); 283 auto tdescShape = getShapeOf(tdescTy); 284 auto valueShape = getShapeOf(valueTy); 285 286 if (getTranspose()) { 287 auto trans = getTranspose().value(); 288 289 // Make sure the transpose value is valid. 290 bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) { 291 return t >= 0 && t < tdescTy.getRank(); 292 }); 293 294 if (valid) 295 transpose(trans, tdescShape); 296 else 297 mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored."; 298 } 299 300 if (getPacked()) { 301 if (tdescTy.getRank() == 2) { 302 const int axis = 0; 303 auto vnni_factor = valueShape.back(); 304 tdescShape[axis] /= vnni_factor; 305 tdescShape.push_back(vnni_factor); 306 } else { 307 mlir::emitWarning(getLoc()) 308 << "Invalid Packed Attr. It is ignored (available for 2D " 309 "TensorDesc only)."; 310 } 311 } 312 313 if (array_len > 1) { 314 auto it = tdescShape.begin(); 315 tdescShape.insert(it, array_len); 316 } 317 auto sgMap = tdescTy.getSGMapAttr(); 318 319 if (!isArgShapesValid(tdescShape, valueShape, sgMap)) 320 return emitOpError() << "Result shape doesn't match TensorDesc shape." 321 << "The expected shape is " << makeString(tdescShape) 322 << ". But the given shape is " 323 << makeString(valueShape) << ".\n"; 324 return success(); 325 } 326 327 //===----------------------------------------------------------------------===// 328 // XeGPU_StoreNdOp 329 //===----------------------------------------------------------------------===// 330 LogicalResult StoreNdOp::verify() { 331 auto dstTy = getTensorDescType(); // Tile 332 auto valTy = getValueType(); // Vector 333 334 if (dstTy.getRank() > 2) 335 return emitOpError("Expecting a 1D/2D TensorDesc.\n"); 336 337 if (dstTy.isScattered()) 338 return emitOpError("Expects a non-scattered TensorDesc.\n"); 339 340 if (!valTy) 341 return emitOpError("Expecting a VectorType result.\n"); 342 343 if (!isWriteHintOrNone(getL1HintAttr())) 344 return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 345 346 if (!isWriteHintOrNone(getL2HintAttr())) 347 return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 348 349 if (!isWriteHintOrNone(getL3HintAttr())) 350 return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 351 352 auto tdescShape = getShapeOf(dstTy); 353 auto valueShape = getShapeOf(valTy); 354 auto sgMap = dstTy.getSGMapAttr(); 355 356 if (!isArgShapesValid(tdescShape, valueShape, sgMap)) 357 return emitOpError() << "Result shape doesn't match TensorDesc shape." 358 << "The expected shape is " << makeString(tdescShape) 359 << ". But the given shape is " 360 << makeString(valueShape) << ".\n"; 361 return success(); 362 } 363 364 //===----------------------------------------------------------------------===// 365 // XeGPU_UpdateNDOffsetOp 366 //===----------------------------------------------------------------------===// 367 LogicalResult UpdateNdOffsetOp::verify() { 368 auto ty = getTensorDescType(); 369 if (ty.isScattered()) 370 return emitOpError("Expects a non-scattered TensorDesc.\n"); 371 372 // number of offsets specified must match the rank of the tensor descriptor 373 if (ty.getRank() != (int64_t)getNumOffsets()) { 374 return emitOpError("Invalid number of offsets."); 375 } 376 return success(); 377 } 378 379 //===----------------------------------------------------------------------===// 380 // XeGPU_CreateDescOp 381 //===----------------------------------------------------------------------===// 382 383 void CreateDescOp::build(OpBuilder &builder, OperationState &state, 384 TensorDescType TensorDesc, Value source, 385 llvm::ArrayRef<OpFoldResult> offsets) { 386 auto loc = source.getLoc(); 387 int64_t size = static_cast<int64_t>(offsets.size()); 388 auto type = VectorType::get(size, builder.getIndexType()); 389 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); 390 auto offset = builder.create<vector::FromElementsOp>(loc, type, values); 391 build(builder, state, TensorDesc, source, offset); 392 } 393 394 void CreateDescOp::build(OpBuilder &builder, OperationState &state, 395 TensorDescType TensorDesc, Value source, 396 llvm::ArrayRef<int64_t> offsets) { 397 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); 398 build(builder, state, TensorDesc, source, ofrs); 399 } 400 401 LogicalResult CreateDescOp::verify() { 402 auto tdescTy = getTensorDescType(); 403 404 if (getRankOf(getSource()) > 1) 405 return emitOpError( 406 "Expecting the source is a 1D memref or pointer (uint64_t)."); 407 408 if (!tdescTy.isScattered()) 409 return emitOpError("Expects a scattered TensorDesc.\n"); 410 411 // Memory space of created TensorDesc should match with the source. 412 // Both source and TensorDesc are considered for global memory by default, 413 // if the memory scope attr is not specified. If source is an integer, 414 // it is considered as ptr to global memory. 415 auto srcMemorySpace = getSourceMemorySpace(); 416 auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace()); 417 if (srcMemorySpace != tdescMemorySpace) 418 return emitOpError("Memory space mismatch.") 419 << " Source: " << srcMemorySpace 420 << ", TensorDesc: " << tdescMemorySpace; 421 422 auto chunkSize = tdescTy.getChunkSize(); 423 424 // check chunk_size 425 llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8, 426 16, 32, 64, 128, 256}; 427 if (!llvm::is_contained(supportedChunkSizes, chunkSize)) 428 return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, " 429 "8, 16, 32, 64, 128, or 256."); 430 431 // check total size 432 auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth(); 433 auto bitsPerLane = elemBits * chunkSize; 434 if (chunkSize > 1 && bitsPerLane % 32) { 435 // For 8-bit and 16-bit data, the hardware only supports chunk size of 1. 436 // For 32-bit data, the hardware can support larger larger chunk size. So 437 // we can bitcast 8-bit/16-bit data to 32-bit data for better performance. 438 // But this requires the total size is 32 bit aligned to make the 439 // optimization work. 440 return emitOpError( 441 "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned."); 442 } 443 444 auto lscConstraints = 512 * 8; // each access is upto 512 bytes. 445 if (elemBits * tdescTy.getNumElements() > lscConstraints) 446 return emitOpError("total access size (simd_lanes * chunk_size * " 447 "sizeof(elemTy)) is upto 512 bytes."); 448 449 SmallVector<int64_t> shape({(int64_t)getNumOffsets()}); 450 if (chunkSize != 1) 451 shape.push_back(chunkSize); 452 453 auto tdescShape = getShapeOf(tdescTy); 454 if (shape != tdescShape) 455 return emitOpError("Incorrect TensorDesc shape. ") 456 << "Expected is " << makeString(shape) << "\n"; 457 458 return success(); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // XeGPU_PrefetchOp 463 //===----------------------------------------------------------------------===// 464 LogicalResult PrefetchOp::verify() { 465 auto tdescTy = getTensorDescType(); 466 if (!tdescTy.isScattered()) 467 return emitOpError("Expects a scattered TensorDesc.\n"); 468 469 if (!isReadHintOrNone(getL1HintAttr())) 470 return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 471 472 if (!isReadHintOrNone(getL2HintAttr())) 473 return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 474 475 if (!isReadHintOrNone(getL3HintAttr())) 476 return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 477 478 return success(); 479 } 480 481 //===----------------------------------------------------------------------===// 482 // XeGPU_LoadGatherOp 483 //===----------------------------------------------------------------------===// 484 LogicalResult LoadGatherOp::verify() { 485 auto tdescTy = getTensorDescType(); 486 auto maskTy = getMaskType(); 487 auto valueTy = getValueType(); 488 489 if (!tdescTy.isScattered()) 490 return emitOpError("Expects a scattered TensorDesc.\n"); 491 492 if (!isReadHintOrNone(getL1HintAttr())) 493 return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 494 495 if (!isReadHintOrNone(getL2HintAttr())) 496 return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 497 498 if (!isReadHintOrNone(getL3HintAttr())) 499 return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 500 501 auto tdescElemTy = tdescTy.getElementType(); 502 auto valueElemTy = getElementType(); 503 if (tdescElemTy != valueElemTy) 504 return emitOpError( 505 "Value should have the same element type as TensorDesc."); 506 507 auto maskShape = getShapeOf(maskTy); 508 auto valueShape = getShapeOf(valueTy); 509 auto tdescShape = getShapeOf(tdescTy); 510 511 if (tdescShape[0] != maskShape[0]) 512 return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); 513 514 if (tdescTy.getRank() == 2) { 515 if (!getTransposeAttr()) 516 return emitOpError("load_gather has to be transposed."); 517 transpose({1, 0}, tdescShape); 518 } 519 520 if (valueShape != tdescShape) 521 return emitOpError("Unexpected result shape") 522 << "(Expected shape: " << makeString(tdescShape) 523 << ", Given shape: " << makeString(valueShape) << ").\n"; 524 525 return success(); 526 } 527 528 //===----------------------------------------------------------------------===// 529 // XeGPU_StoreScatterOp 530 //===----------------------------------------------------------------------===// 531 LogicalResult StoreScatterOp::verify() { 532 auto tdescTy = getTensorDescType(); 533 if (!tdescTy.isScattered()) 534 return emitOpError("Expects a scattered TensorDesc.\n"); 535 536 if (!isWriteHintOrNone(getL1HintAttr())) 537 return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 538 539 if (!isWriteHintOrNone(getL2HintAttr())) 540 return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 541 542 if (!isWriteHintOrNone(getL3HintAttr())) 543 return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 544 545 auto maskTy = getMaskType(); 546 auto valueTy = getValueType(); 547 auto maskShape = getShapeOf(maskTy); 548 auto tdescShape = getShapeOf(tdescTy); 549 auto valueShape = getShapeOf(valueTy); 550 if (tdescShape[0] != maskShape[0]) 551 return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); 552 553 if (tdescTy.getRank() == 2) { 554 if (!getTransposeAttr()) 555 return emitOpError("load_gather has to be transposed."); 556 transpose({1, 0}, tdescShape); 557 } 558 559 if (valueShape != tdescShape) 560 return emitOpError("Unexpected value shape") 561 << "(Expected shape: " << makeString(tdescShape) 562 << ", Given shape: " << makeString(valueShape) << ").\n"; 563 564 return success(); 565 } 566 567 //===----------------------------------------------------------------------===// 568 // XeGPU_UpdateOffsetOp 569 //===----------------------------------------------------------------------===// 570 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, 571 mlir::Value tensorDesc, 572 llvm::ArrayRef<OpFoldResult> offsets) { 573 auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType()); 574 assert(tdescTy && "Expecting the source is a TensorDescType value."); 575 auto loc = tensorDesc.getLoc(); 576 int64_t size = static_cast<int64_t>(offsets.size()); 577 auto type = VectorType::get({size}, builder.getIndexType()); 578 auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); 579 auto offset = builder.create<vector::FromElementsOp>(loc, type, values); 580 build(builder, state, tdescTy, tensorDesc, offset); 581 } 582 583 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, 584 Value tensorDesc, llvm::ArrayRef<int64_t> offsets) { 585 auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); 586 build(builder, state, tensorDesc, ofrs); 587 } 588 589 //===----------------------------------------------------------------------===// 590 // XeGPU_DpasOp 591 //===----------------------------------------------------------------------===// 592 LogicalResult DpasOp::verify() { 593 int64_t lhsRank = getLhsType().getRank(); 594 int64_t rhsRank = getRhsType().getRank(); 595 596 if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3)) 597 return emitOpError("expecting lhs to be a 2D vector, and rhs to be either " 598 "2D or 3D (packed) vector."); 599 600 auto lhsShape = getLhsType().getShape(); 601 auto rhsShape = getRhsType().getShape(); 602 auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0]; 603 if (bK != lhsShape[1]) 604 return emitOpError("K-dimension mismatch."); 605 606 return success(); 607 } 608 609 } // namespace xegpu 610 } // namespace mlir 611 612 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc> 613 #define GET_OP_CLASSES 614 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> 615