xref: /llvm-project/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (revision ba6774f997ee28157b0a3b8816cc76b94ed1da17)
1 //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/Utils/Utils.h"
10 #include "mlir/Dialect/Utils/StaticValueUtils.h"
11 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/TypeUtilities.h"
14 
15 #include "llvm/Support/Debug.h"
16 
17 #define DEBUG_TYPE "xegpu"
18 
19 namespace mlir {
20 namespace xegpu {
21 
22 static void transpose(llvm::ArrayRef<int64_t> trans,
23                       SmallVector<int64_t> &shape) {
24   SmallVector<int64_t> old = shape;
25   for (size_t i = 0; i < trans.size(); i++)
26     shape[i] = old[trans[i]];
27 }
28 
29 template <typename T>
30 static std::string makeString(T array, bool breakline = false) {
31   std::string buf;
32   buf.clear();
33   llvm::raw_string_ostream os(buf);
34   os << "[";
35   for (size_t i = 1; i < array.size(); i++) {
36     os << array[i - 1] << ", ";
37     if (breakline)
38       os << "\n\t\t";
39   }
40   os << array.back() << "]";
41   return buf;
42 }
43 
44 static SmallVector<int64_t> getShapeOf(Type type) {
45   SmallVector<int64_t> shape;
46   if (auto ty = llvm::dyn_cast<ShapedType>(type))
47     shape = SmallVector<int64_t>(ty.getShape());
48   else
49     shape.push_back(1);
50   return shape;
51 }
52 
53 static int64_t getRankOf(Value val) {
54   auto type = val.getType();
55   if (auto ty = llvm::dyn_cast<ShapedType>(type))
56     return ty.getRank();
57   return 0;
58 }
59 
60 static bool isReadHintOrNone(const CachePolicyAttr &attr) {
61   if (!attr)
62     return true;
63   auto kind = attr.getValue();
64   return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65          kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
66 }
67 
68 static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
69   if (!attr)
70     return true;
71   auto kind = attr.getValue();
72   return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
74 }
75 
76 // Validations for nd instruction arguments is successful if any of these are
77 // true:
78 // - tensor descriptor and the output vector shapes exactly match.
79 // - tensor descriptor has a sg_map attribute and the distributed vector shape
80 //   matches the tensor descriptor shape when scaled using sg_map factors on
81 //   each dimension.
82 static bool isArgShapesValid(ArrayRef<int64_t> descShape,
83                              ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
84   if (descShape == valShape) {
85     if (!sgMap)
86       return true;
87 
88     // this can be relaxed if necessary by supporting non-2d shapes distribution
89     // until the constraints are defined this lives here instead of the tensor
90     // descriptor type.
91     return valShape.size() == sgMap.getWiLayout().size();
92   }
93 
94   if (!sgMap)
95     return false;
96 
97   if (valShape.size() != descShape.size())
98     return false;
99 
100   for (const auto &[factor, dim, expected] :
101        llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
102     if (factor * dim != expected)
103       return false;
104   }
105 
106   return true;
107 }
108 
109 //===----------------------------------------------------------------------===//
110 // XeGPU_CreateNdDescOp
111 //===----------------------------------------------------------------------===//
112 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
113                            Type tdesc, TypedValue<MemRefType> source,
114                            llvm::ArrayRef<OpFoldResult> offsets) {
115   [[maybe_unused]] auto ty = source.getType();
116   assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
117 
118   llvm::SmallVector<int64_t> staticOffsets;
119   llvm::SmallVector<Value> dynamicOffsets;
120   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
121 
122   build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
123         ValueRange({}) /* empty dynamic shape */,
124         ValueRange({}) /* empty dynamic strides */,
125         staticOffsets /* const offsets */, {} /* empty const shape*/,
126         {} /* empty const strides*/);
127 }
128 
129 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
130                            Type tdesc, TypedValue<MemRefType> source,
131                            llvm::ArrayRef<OpFoldResult> offsets,
132                            llvm::ArrayRef<OpFoldResult> shape,
133                            llvm::ArrayRef<OpFoldResult> strides) {
134   assert(shape.size() && offsets.size() && strides.size() &&
135          shape.size() == strides.size() && shape.size() == offsets.size());
136 
137   llvm::SmallVector<int64_t> staticOffsets;
138   llvm::SmallVector<int64_t> staticShape;
139   llvm::SmallVector<int64_t> staticStrides;
140   llvm::SmallVector<Value> dynamicOffsets;
141   llvm::SmallVector<Value> dynamicShape;
142   llvm::SmallVector<Value> dynamicStrides;
143 
144   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
145   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
146   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
147 
148   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
149   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
150   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
151 
152   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
153         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
154 }
155 
156 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
157                            Type tdesc, TypedValue<IntegerType> source,
158                            llvm::ArrayRef<OpFoldResult> offsets,
159                            llvm::ArrayRef<OpFoldResult> shape,
160                            llvm::ArrayRef<OpFoldResult> strides) {
161   assert(shape.size() && offsets.size() && strides.size() &&
162          shape.size() == strides.size() && shape.size() == offsets.size());
163 
164   llvm::SmallVector<int64_t> staticOffsets;
165   llvm::SmallVector<int64_t> staticShape;
166   llvm::SmallVector<int64_t> staticStrides;
167   llvm::SmallVector<Value> dynamicOffsets;
168   llvm::SmallVector<Value> dynamicShape;
169   llvm::SmallVector<Value> dynamicStrides;
170 
171   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
172   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
173   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
174 
175   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
176   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
177   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
178 
179   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
180         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
181 }
182 
183 LogicalResult CreateNdDescOp::verify() {
184   auto rank = (int64_t)getMixedOffsets().size();
185   bool invalidRank = false;
186   bool invalidElemTy = false;
187 
188   // Memory space of created TensorDesc should match with the source.
189   // Both source and TensorDesc are considered for global memory by default,
190   // if the memory scope attr is not specified. If source is an integer,
191   // it is considered as ptr to global memory.
192   auto srcMemorySpace = getSourceMemorySpace();
193   auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
194   if (srcMemorySpace != tdescMemorySpace)
195     return emitOpError("Memory space mismatch.")
196            << " Source: " << srcMemorySpace
197            << ", TensorDesc: " << tdescMemorySpace;
198 
199   // check source type matches the rank if it is a memref.
200   // It also should have the same ElementType as TensorDesc.
201   auto memrefTy = dyn_cast<MemRefType>(getSourceType());
202   if (memrefTy) {
203     invalidRank |= (memrefTy.getRank() != rank);
204     invalidElemTy |= memrefTy.getElementType() != getElementType();
205   }
206 
207   // mismatches among shape, strides, and offsets are
208   // already handeled by OffsetSizeAndStrideOpInterface.
209   // So they are not check here.
210   if (invalidRank)
211     return emitOpError(
212         "Expecting the rank of shape, strides, offsets, and source (if source "
213         "is a memref) should match with each other.");
214 
215   // check result TensorDesc rank
216   invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
217 
218   if (invalidRank)
219     return emitOpError(
220         "Expecting the TensorDesc rank is up to 2 and not greater than the "
221         "ranks of shape, strides, offsets or the memref source.");
222 
223   if (invalidElemTy)
224     return emitOpError("TensorDesc should have the same element "
225                        "type with the source if it is a memref.\n");
226 
227   if (getType().isScattered())
228     return emitOpError("Expects a non-scattered TensorDesc.\n");
229 
230   if (getType().getRank() == 2 &&
231       tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
232     return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
233 
234   return success();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // XeGPU_PrefetchNdOp
239 //===----------------------------------------------------------------------===//
240 LogicalResult PrefetchNdOp::verify() {
241   auto tdescTy = getTensorDescType();
242   if (tdescTy.isScattered())
243     return emitOpError("Expects a non-scattered TensorDesc.\n");
244 
245   if (!isReadHintOrNone(getL1HintAttr()))
246     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
247 
248   if (!isReadHintOrNone(getL2HintAttr()))
249     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
250 
251   if (!isReadHintOrNone(getL3HintAttr()))
252     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
253 
254   return success();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // XeGPU_LoadNdOp
259 //===----------------------------------------------------------------------===//
260 LogicalResult LoadNdOp::verify() {
261   auto tdescTy = getTensorDescType();
262   auto valueTy = getType();
263 
264   if (tdescTy.getRank() > 2)
265     return emitOpError("Expecting a 1D/2D TensorDesc.\n");
266 
267   if (tdescTy.isScattered())
268     return emitOpError("Expects a non-scattered TensorDesc.\n");
269 
270   if (!valueTy)
271     return emitOpError("Invalid result, it should be a VectorType.\n");
272 
273   if (!isReadHintOrNone(getL1HintAttr()))
274     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
275 
276   if (!isReadHintOrNone(getL2HintAttr()))
277     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
278 
279   if (!isReadHintOrNone(getL3HintAttr()))
280     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
281 
282   auto array_len = tdescTy.getArrayLength();
283   auto tdescShape = getShapeOf(tdescTy);
284   auto valueShape = getShapeOf(valueTy);
285 
286   if (getTranspose()) {
287     auto trans = getTranspose().value();
288 
289     // Make sure the transpose value is valid.
290     bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
291       return t >= 0 && t < tdescTy.getRank();
292     });
293 
294     if (valid)
295       transpose(trans, tdescShape);
296     else
297       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
298   }
299 
300   if (getPacked()) {
301     if (tdescTy.getRank() == 2) {
302       const int axis = 0;
303       auto vnni_factor = valueShape.back();
304       tdescShape[axis] /= vnni_factor;
305       tdescShape.push_back(vnni_factor);
306     } else {
307       mlir::emitWarning(getLoc())
308           << "Invalid Packed Attr. It is ignored (available for 2D "
309              "TensorDesc only).";
310     }
311   }
312 
313   if (array_len > 1) {
314     auto it = tdescShape.begin();
315     tdescShape.insert(it, array_len);
316   }
317   auto sgMap = tdescTy.getSGMapAttr();
318 
319   if (!isArgShapesValid(tdescShape, valueShape, sgMap))
320     return emitOpError() << "Result shape doesn't match TensorDesc shape."
321                          << "The expected shape is " << makeString(tdescShape)
322                          << ". But the given shape is "
323                          << makeString(valueShape) << ".\n";
324   return success();
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // XeGPU_StoreNdOp
329 //===----------------------------------------------------------------------===//
330 LogicalResult StoreNdOp::verify() {
331   auto dstTy = getTensorDescType(); // Tile
332   auto valTy = getValueType();      // Vector
333 
334   if (dstTy.getRank() > 2)
335     return emitOpError("Expecting a 1D/2D TensorDesc.\n");
336 
337   if (dstTy.isScattered())
338     return emitOpError("Expects a non-scattered TensorDesc.\n");
339 
340   if (!valTy)
341     return emitOpError("Expecting a VectorType result.\n");
342 
343   if (!isWriteHintOrNone(getL1HintAttr()))
344     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
345 
346   if (!isWriteHintOrNone(getL2HintAttr()))
347     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
348 
349   if (!isWriteHintOrNone(getL3HintAttr()))
350     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
351 
352   auto tdescShape = getShapeOf(dstTy);
353   auto valueShape = getShapeOf(valTy);
354   auto sgMap = dstTy.getSGMapAttr();
355 
356   if (!isArgShapesValid(tdescShape, valueShape, sgMap))
357     return emitOpError() << "Result shape doesn't match TensorDesc shape."
358                          << "The expected shape is " << makeString(tdescShape)
359                          << ". But the given shape is "
360                          << makeString(valueShape) << ".\n";
361   return success();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // XeGPU_UpdateNDOffsetOp
366 //===----------------------------------------------------------------------===//
367 LogicalResult UpdateNdOffsetOp::verify() {
368   auto ty = getTensorDescType();
369   if (ty.isScattered())
370     return emitOpError("Expects a non-scattered TensorDesc.\n");
371 
372   // number of offsets specified must match the rank of the tensor descriptor
373   if (ty.getRank() != (int64_t)getNumOffsets()) {
374     return emitOpError("Invalid number of offsets.");
375   }
376   return success();
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // XeGPU_CreateDescOp
381 //===----------------------------------------------------------------------===//
382 
383 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
384                          TensorDescType TensorDesc, Value source,
385                          llvm::ArrayRef<OpFoldResult> offsets) {
386   auto loc = source.getLoc();
387   int64_t size = static_cast<int64_t>(offsets.size());
388   auto type = VectorType::get(size, builder.getIndexType());
389   auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
390   auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
391   build(builder, state, TensorDesc, source, offset);
392 }
393 
394 void CreateDescOp::build(OpBuilder &builder, OperationState &state,
395                          TensorDescType TensorDesc, Value source,
396                          llvm::ArrayRef<int64_t> offsets) {
397   auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
398   build(builder, state, TensorDesc, source, ofrs);
399 }
400 
401 LogicalResult CreateDescOp::verify() {
402   auto tdescTy = getTensorDescType();
403 
404   if (getRankOf(getSource()) > 1)
405     return emitOpError(
406         "Expecting the source is a 1D memref or pointer (uint64_t).");
407 
408   if (!tdescTy.isScattered())
409     return emitOpError("Expects a scattered TensorDesc.\n");
410 
411   // Memory space of created TensorDesc should match with the source.
412   // Both source and TensorDesc are considered for global memory by default,
413   // if the memory scope attr is not specified. If source is an integer,
414   // it is considered as ptr to global memory.
415   auto srcMemorySpace = getSourceMemorySpace();
416   auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
417   if (srcMemorySpace != tdescMemorySpace)
418     return emitOpError("Memory space mismatch.")
419            << " Source: " << srcMemorySpace
420            << ", TensorDesc: " << tdescMemorySpace;
421 
422   auto chunkSize = tdescTy.getChunkSize();
423 
424   // check chunk_size
425   llvm::SmallVector<int64_t> supportedChunkSizes = {1,  2,  3,  4,   8,
426                                                     16, 32, 64, 128, 256};
427   if (!llvm::is_contained(supportedChunkSizes, chunkSize))
428     return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
429                        "8, 16, 32, 64, 128, or 256.");
430 
431   // check total size
432   auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
433   auto bitsPerLane = elemBits * chunkSize;
434   if (chunkSize > 1 && bitsPerLane % 32) {
435     // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
436     // For 32-bit data, the hardware can support larger larger chunk size. So
437     // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
438     // But this requires the total size is 32 bit aligned to make the
439     // optimization work.
440     return emitOpError(
441         "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
442   }
443 
444   auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
445   if (elemBits * tdescTy.getNumElements() > lscConstraints)
446     return emitOpError("total access size (simd_lanes * chunk_size * "
447                        "sizeof(elemTy)) is upto 512 bytes.");
448 
449   SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
450   if (chunkSize != 1)
451     shape.push_back(chunkSize);
452 
453   auto tdescShape = getShapeOf(tdescTy);
454   if (shape != tdescShape)
455     return emitOpError("Incorrect TensorDesc shape. ")
456            << "Expected is " << makeString(shape) << "\n";
457 
458   return success();
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // XeGPU_PrefetchOp
463 //===----------------------------------------------------------------------===//
464 LogicalResult PrefetchOp::verify() {
465   auto tdescTy = getTensorDescType();
466   if (!tdescTy.isScattered())
467     return emitOpError("Expects a scattered TensorDesc.\n");
468 
469   if (!isReadHintOrNone(getL1HintAttr()))
470     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
471 
472   if (!isReadHintOrNone(getL2HintAttr()))
473     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
474 
475   if (!isReadHintOrNone(getL3HintAttr()))
476     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
477 
478   return success();
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // XeGPU_LoadGatherOp
483 //===----------------------------------------------------------------------===//
484 LogicalResult LoadGatherOp::verify() {
485   auto tdescTy = getTensorDescType();
486   auto maskTy = getMaskType();
487   auto valueTy = getValueType();
488 
489   if (!tdescTy.isScattered())
490     return emitOpError("Expects a scattered TensorDesc.\n");
491 
492   if (!isReadHintOrNone(getL1HintAttr()))
493     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
494 
495   if (!isReadHintOrNone(getL2HintAttr()))
496     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
497 
498   if (!isReadHintOrNone(getL3HintAttr()))
499     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
500 
501   auto tdescElemTy = tdescTy.getElementType();
502   auto valueElemTy = getElementType();
503   if (tdescElemTy != valueElemTy)
504     return emitOpError(
505         "Value should have the same element type as TensorDesc.");
506 
507   auto maskShape = getShapeOf(maskTy);
508   auto valueShape = getShapeOf(valueTy);
509   auto tdescShape = getShapeOf(tdescTy);
510 
511   if (tdescShape[0] != maskShape[0])
512     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
513 
514   if (tdescTy.getRank() == 2) {
515     if (!getTransposeAttr())
516       return emitOpError("load_gather has to be transposed.");
517     transpose({1, 0}, tdescShape);
518   }
519 
520   if (valueShape != tdescShape)
521     return emitOpError("Unexpected result shape")
522            << "(Expected shape: " << makeString(tdescShape)
523            << ", Given shape: " << makeString(valueShape) << ").\n";
524 
525   return success();
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // XeGPU_StoreScatterOp
530 //===----------------------------------------------------------------------===//
531 LogicalResult StoreScatterOp::verify() {
532   auto tdescTy = getTensorDescType();
533   if (!tdescTy.isScattered())
534     return emitOpError("Expects a scattered TensorDesc.\n");
535 
536   if (!isWriteHintOrNone(getL1HintAttr()))
537     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
538 
539   if (!isWriteHintOrNone(getL2HintAttr()))
540     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
541 
542   if (!isWriteHintOrNone(getL3HintAttr()))
543     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
544 
545   auto maskTy = getMaskType();
546   auto valueTy = getValueType();
547   auto maskShape = getShapeOf(maskTy);
548   auto tdescShape = getShapeOf(tdescTy);
549   auto valueShape = getShapeOf(valueTy);
550   if (tdescShape[0] != maskShape[0])
551     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
552 
553   if (tdescTy.getRank() == 2) {
554     if (!getTransposeAttr())
555       return emitOpError("load_gather has to be transposed.");
556     transpose({1, 0}, tdescShape);
557   }
558 
559   if (valueShape != tdescShape)
560     return emitOpError("Unexpected value shape")
561            << "(Expected shape: " << makeString(tdescShape)
562            << ", Given shape: " << makeString(valueShape) << ").\n";
563 
564   return success();
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // XeGPU_UpdateOffsetOp
569 //===----------------------------------------------------------------------===//
570 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
571                            mlir::Value tensorDesc,
572                            llvm::ArrayRef<OpFoldResult> offsets) {
573   auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
574   assert(tdescTy && "Expecting the source is a TensorDescType value.");
575   auto loc = tensorDesc.getLoc();
576   int64_t size = static_cast<int64_t>(offsets.size());
577   auto type = VectorType::get({size}, builder.getIndexType());
578   auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
579   auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
580   build(builder, state, tdescTy, tensorDesc, offset);
581 }
582 
583 void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
584                            Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
585   auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
586   build(builder, state, tensorDesc, ofrs);
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // XeGPU_DpasOp
591 //===----------------------------------------------------------------------===//
592 LogicalResult DpasOp::verify() {
593   int64_t lhsRank = getLhsType().getRank();
594   int64_t rhsRank = getRhsType().getRank();
595 
596   if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
597     return emitOpError("expecting lhs to be a 2D vector, and rhs to be either "
598                        "2D or 3D (packed) vector.");
599 
600   auto lhsShape = getLhsType().getShape();
601   auto rhsShape = getRhsType().getShape();
602   auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
603   if (bK != lhsShape[1])
604     return emitOpError("K-dimension mismatch.");
605 
606   return success();
607 }
608 
609 } // namespace xegpu
610 } // namespace mlir
611 
612 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
613 #define GET_OP_CLASSES
614 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
615