xref: /llvm-project/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (revision ba6774f997ee28157b0a3b8816cc76b94ed1da17)
15669660fSChao Chen //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===//
25669660fSChao Chen //
35669660fSChao Chen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45669660fSChao Chen // See https://llvm.org/LICENSE.txt for license information.
55669660fSChao Chen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65669660fSChao Chen //
75669660fSChao Chen //===----------------------------------------------------------------------===//
85669660fSChao Chen 
99c697b3aSChao Chen #include "mlir/Dialect/Arith/Utils/Utils.h"
1061b24c61SChao Chen #include "mlir/Dialect/Utils/StaticValueUtils.h"
1161b24c61SChao Chen #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1261b24c61SChao Chen #include "mlir/IR/Builders.h"
13b01879ecSChao Chen #include "mlir/IR/TypeUtilities.h"
14b01879ecSChao Chen 
15b01879ecSChao Chen #include "llvm/Support/Debug.h"
165669660fSChao Chen 
175669660fSChao Chen #define DEBUG_TYPE "xegpu"
185669660fSChao Chen 
195669660fSChao Chen namespace mlir {
205669660fSChao Chen namespace xegpu {
21daebe5c4SChao Chen 
2261b24c61SChao Chen static void transpose(llvm::ArrayRef<int64_t> trans,
23b01879ecSChao Chen                       SmallVector<int64_t> &shape) {
24b01879ecSChao Chen   SmallVector<int64_t> old = shape;
2561b24c61SChao Chen   for (size_t i = 0; i < trans.size(); i++)
2661b24c61SChao Chen     shape[i] = old[trans[i]];
2761b24c61SChao Chen }
2861b24c61SChao Chen 
2961b24c61SChao Chen template <typename T>
3061b24c61SChao Chen static std::string makeString(T array, bool breakline = false) {
3161b24c61SChao Chen   std::string buf;
3261b24c61SChao Chen   buf.clear();
3361b24c61SChao Chen   llvm::raw_string_ostream os(buf);
3461b24c61SChao Chen   os << "[";
3561b24c61SChao Chen   for (size_t i = 1; i < array.size(); i++) {
3661b24c61SChao Chen     os << array[i - 1] << ", ";
3761b24c61SChao Chen     if (breakline)
3861b24c61SChao Chen       os << "\n\t\t";
3961b24c61SChao Chen   }
4061b24c61SChao Chen   os << array.back() << "]";
4161b24c61SChao Chen   return buf;
4261b24c61SChao Chen }
4361b24c61SChao Chen 
44b01879ecSChao Chen static SmallVector<int64_t> getShapeOf(Type type) {
45b01879ecSChao Chen   SmallVector<int64_t> shape;
46b01879ecSChao Chen   if (auto ty = llvm::dyn_cast<ShapedType>(type))
47b01879ecSChao Chen     shape = SmallVector<int64_t>(ty.getShape());
48b01879ecSChao Chen   else
49b01879ecSChao Chen     shape.push_back(1);
50b01879ecSChao Chen   return shape;
51b01879ecSChao Chen }
52b01879ecSChao Chen 
53b01879ecSChao Chen static int64_t getRankOf(Value val) {
54b01879ecSChao Chen   auto type = val.getType();
55b01879ecSChao Chen   if (auto ty = llvm::dyn_cast<ShapedType>(type))
56b01879ecSChao Chen     return ty.getRank();
57b01879ecSChao Chen   return 0;
58c9731a3dSKazu Hirata }
59b01879ecSChao Chen 
60b01879ecSChao Chen static bool isReadHintOrNone(const CachePolicyAttr &attr) {
61b01879ecSChao Chen   if (!attr)
62b01879ecSChao Chen     return true;
63b01879ecSChao Chen   auto kind = attr.getValue();
64b01879ecSChao Chen   return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
65b01879ecSChao Chen          kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
66b01879ecSChao Chen }
67b01879ecSChao Chen 
68b01879ecSChao Chen static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
69b01879ecSChao Chen   if (!attr)
70b01879ecSChao Chen     return true;
71b01879ecSChao Chen   auto kind = attr.getValue();
72b01879ecSChao Chen   return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
73b01879ecSChao Chen          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
74b01879ecSChao Chen }
75b01879ecSChao Chen 
76fa6f88afSPetr Kurapov // Validations for nd instruction arguments is successful if any of these are
77fa6f88afSPetr Kurapov // true:
78fa6f88afSPetr Kurapov // - tensor descriptor and the output vector shapes exactly match.
79fa6f88afSPetr Kurapov // - tensor descriptor has a sg_map attribute and the distributed vector shape
80fa6f88afSPetr Kurapov //   matches the tensor descriptor shape when scaled using sg_map factors on
81fa6f88afSPetr Kurapov //   each dimension.
82fa6f88afSPetr Kurapov static bool isArgShapesValid(ArrayRef<int64_t> descShape,
83fa6f88afSPetr Kurapov                              ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
84fa6f88afSPetr Kurapov   if (descShape == valShape) {
85fa6f88afSPetr Kurapov     if (!sgMap)
86fa6f88afSPetr Kurapov       return true;
87fa6f88afSPetr Kurapov 
88fa6f88afSPetr Kurapov     // this can be relaxed if necessary by supporting non-2d shapes distribution
89fa6f88afSPetr Kurapov     // until the constraints are defined this lives here instead of the tensor
90fa6f88afSPetr Kurapov     // descriptor type.
91fa6f88afSPetr Kurapov     return valShape.size() == sgMap.getWiLayout().size();
92fa6f88afSPetr Kurapov   }
93fa6f88afSPetr Kurapov 
94fa6f88afSPetr Kurapov   if (!sgMap)
95fa6f88afSPetr Kurapov     return false;
96fa6f88afSPetr Kurapov 
97fa6f88afSPetr Kurapov   if (valShape.size() != descShape.size())
98fa6f88afSPetr Kurapov     return false;
99fa6f88afSPetr Kurapov 
100fa6f88afSPetr Kurapov   for (const auto &[factor, dim, expected] :
101fa6f88afSPetr Kurapov        llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
102fa6f88afSPetr Kurapov     if (factor * dim != expected)
103fa6f88afSPetr Kurapov       return false;
104fa6f88afSPetr Kurapov   }
105fa6f88afSPetr Kurapov 
106fa6f88afSPetr Kurapov   return true;
107fa6f88afSPetr Kurapov }
108fa6f88afSPetr Kurapov 
10961b24c61SChao Chen //===----------------------------------------------------------------------===//
11061b24c61SChao Chen // XeGPU_CreateNdDescOp
11161b24c61SChao Chen //===----------------------------------------------------------------------===//
11261b24c61SChao Chen void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
11361b24c61SChao Chen                            Type tdesc, TypedValue<MemRefType> source,
11461b24c61SChao Chen                            llvm::ArrayRef<OpFoldResult> offsets) {
115258091e7SJie Fu   [[maybe_unused]] auto ty = source.getType();
11661b24c61SChao Chen   assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank());
11761b24c61SChao Chen 
11861b24c61SChao Chen   llvm::SmallVector<int64_t> staticOffsets;
11961b24c61SChao Chen   llvm::SmallVector<Value> dynamicOffsets;
12061b24c61SChao Chen   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
12161b24c61SChao Chen 
12261b24c61SChao Chen   build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
12361b24c61SChao Chen         ValueRange({}) /* empty dynamic shape */,
12461b24c61SChao Chen         ValueRange({}) /* empty dynamic strides */,
12561b24c61SChao Chen         staticOffsets /* const offsets */, {} /* empty const shape*/,
12661b24c61SChao Chen         {} /* empty const strides*/);
12761b24c61SChao Chen }
12861b24c61SChao Chen 
12961b24c61SChao Chen void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
13027046badSMd Abdullah Shahneous Bari                            Type tdesc, TypedValue<MemRefType> source,
13127046badSMd Abdullah Shahneous Bari                            llvm::ArrayRef<OpFoldResult> offsets,
13227046badSMd Abdullah Shahneous Bari                            llvm::ArrayRef<OpFoldResult> shape,
13327046badSMd Abdullah Shahneous Bari                            llvm::ArrayRef<OpFoldResult> strides) {
13427046badSMd Abdullah Shahneous Bari   assert(shape.size() && offsets.size() && strides.size() &&
13527046badSMd Abdullah Shahneous Bari          shape.size() == strides.size() && shape.size() == offsets.size());
13627046badSMd Abdullah Shahneous Bari 
13727046badSMd Abdullah Shahneous Bari   llvm::SmallVector<int64_t> staticOffsets;
13827046badSMd Abdullah Shahneous Bari   llvm::SmallVector<int64_t> staticShape;
13927046badSMd Abdullah Shahneous Bari   llvm::SmallVector<int64_t> staticStrides;
14027046badSMd Abdullah Shahneous Bari   llvm::SmallVector<Value> dynamicOffsets;
14127046badSMd Abdullah Shahneous Bari   llvm::SmallVector<Value> dynamicShape;
14227046badSMd Abdullah Shahneous Bari   llvm::SmallVector<Value> dynamicStrides;
14327046badSMd Abdullah Shahneous Bari 
14427046badSMd Abdullah Shahneous Bari   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
14527046badSMd Abdullah Shahneous Bari   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
14627046badSMd Abdullah Shahneous Bari   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
14727046badSMd Abdullah Shahneous Bari 
14827046badSMd Abdullah Shahneous Bari   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
14927046badSMd Abdullah Shahneous Bari   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
15027046badSMd Abdullah Shahneous Bari   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
15127046badSMd Abdullah Shahneous Bari 
15227046badSMd Abdullah Shahneous Bari   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
15327046badSMd Abdullah Shahneous Bari         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
15427046badSMd Abdullah Shahneous Bari }
15527046badSMd Abdullah Shahneous Bari 
15627046badSMd Abdullah Shahneous Bari void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
15761b24c61SChao Chen                            Type tdesc, TypedValue<IntegerType> source,
15861b24c61SChao Chen                            llvm::ArrayRef<OpFoldResult> offsets,
15961b24c61SChao Chen                            llvm::ArrayRef<OpFoldResult> shape,
16061b24c61SChao Chen                            llvm::ArrayRef<OpFoldResult> strides) {
16161b24c61SChao Chen   assert(shape.size() && offsets.size() && strides.size() &&
16261b24c61SChao Chen          shape.size() == strides.size() && shape.size() == offsets.size());
16361b24c61SChao Chen 
16461b24c61SChao Chen   llvm::SmallVector<int64_t> staticOffsets;
16561b24c61SChao Chen   llvm::SmallVector<int64_t> staticShape;
16661b24c61SChao Chen   llvm::SmallVector<int64_t> staticStrides;
16761b24c61SChao Chen   llvm::SmallVector<Value> dynamicOffsets;
16861b24c61SChao Chen   llvm::SmallVector<Value> dynamicShape;
16961b24c61SChao Chen   llvm::SmallVector<Value> dynamicStrides;
17061b24c61SChao Chen 
17161b24c61SChao Chen   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
17261b24c61SChao Chen   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
17373a2fd47SArtem Kroviakov   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
17461b24c61SChao Chen 
17561b24c61SChao Chen   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
17661b24c61SChao Chen   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
17761b24c61SChao Chen   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
17861b24c61SChao Chen 
17961b24c61SChao Chen   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
18061b24c61SChao Chen         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
18161b24c61SChao Chen }
18261b24c61SChao Chen 
18361b24c61SChao Chen LogicalResult CreateNdDescOp::verify() {
18461b24c61SChao Chen   auto rank = (int64_t)getMixedOffsets().size();
1856c783e19SChao Chen   bool invalidRank = false;
18661b24c61SChao Chen   bool invalidElemTy = false;
18761b24c61SChao Chen 
1888b5e8414SChao Chen   // Memory space of created TensorDesc should match with the source.
1898b5e8414SChao Chen   // Both source and TensorDesc are considered for global memory by default,
1908b5e8414SChao Chen   // if the memory scope attr is not specified. If source is an integer,
1918b5e8414SChao Chen   // it is considered as ptr to global memory.
1928b5e8414SChao Chen   auto srcMemorySpace = getSourceMemorySpace();
1938b5e8414SChao Chen   auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace());
1948b5e8414SChao Chen   if (srcMemorySpace != tdescMemorySpace)
1958b5e8414SChao Chen     return emitOpError("Memory space mismatch.")
1968b5e8414SChao Chen            << " Source: " << srcMemorySpace
1978b5e8414SChao Chen            << ", TensorDesc: " << tdescMemorySpace;
1988b5e8414SChao Chen 
19961b24c61SChao Chen   // check source type matches the rank if it is a memref.
20061b24c61SChao Chen   // It also should have the same ElementType as TensorDesc.
201a5757c5bSChristian Sigg   auto memrefTy = dyn_cast<MemRefType>(getSourceType());
20261b24c61SChao Chen   if (memrefTy) {
20361b24c61SChao Chen     invalidRank |= (memrefTy.getRank() != rank);
20461b24c61SChao Chen     invalidElemTy |= memrefTy.getElementType() != getElementType();
20561b24c61SChao Chen   }
20661b24c61SChao Chen 
20761b24c61SChao Chen   // mismatches among shape, strides, and offsets are
20861b24c61SChao Chen   // already handeled by OffsetSizeAndStrideOpInterface.
20961b24c61SChao Chen   // So they are not check here.
21061b24c61SChao Chen   if (invalidRank)
21161b24c61SChao Chen     return emitOpError(
2126c783e19SChao Chen         "Expecting the rank of shape, strides, offsets, and source (if source "
2136c783e19SChao Chen         "is a memref) should match with each other.");
2146c783e19SChao Chen 
2156c783e19SChao Chen   // check result TensorDesc rank
2166c783e19SChao Chen   invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
2176c783e19SChao Chen 
2186c783e19SChao Chen   if (invalidRank)
2196c783e19SChao Chen     return emitOpError(
2206c783e19SChao Chen         "Expecting the TensorDesc rank is up to 2 and not greater than the "
2216c783e19SChao Chen         "ranks of shape, strides, offsets or the memref source.");
22261b24c61SChao Chen 
22361b24c61SChao Chen   if (invalidElemTy)
22461b24c61SChao Chen     return emitOpError("TensorDesc should have the same element "
22561b24c61SChao Chen                        "type with the source if it is a memref.\n");
22661b24c61SChao Chen 
2278b5e8414SChao Chen   if (getType().isScattered())
228b01879ecSChao Chen     return emitOpError("Expects a non-scattered TensorDesc.\n");
229b01879ecSChao Chen 
2308b5e8414SChao Chen   if (getType().getRank() == 2 &&
2318b5e8414SChao Chen       tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM))
2328b5e8414SChao Chen     return emitOpError("SLM is not supported for 2D Block TensorDesc.\n");
2338b5e8414SChao Chen 
234b01879ecSChao Chen   return success();
235b01879ecSChao Chen }
236b01879ecSChao Chen 
237b01879ecSChao Chen //===----------------------------------------------------------------------===//
238b01879ecSChao Chen // XeGPU_PrefetchNdOp
239b01879ecSChao Chen //===----------------------------------------------------------------------===//
240b01879ecSChao Chen LogicalResult PrefetchNdOp::verify() {
241b01879ecSChao Chen   auto tdescTy = getTensorDescType();
2428b5e8414SChao Chen   if (tdescTy.isScattered())
243b01879ecSChao Chen     return emitOpError("Expects a non-scattered TensorDesc.\n");
244b01879ecSChao Chen 
245b01879ecSChao Chen   if (!isReadHintOrNone(getL1HintAttr()))
246fa6f88afSPetr Kurapov     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
247b01879ecSChao Chen 
248b01879ecSChao Chen   if (!isReadHintOrNone(getL2HintAttr()))
249fa6f88afSPetr Kurapov     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
250b01879ecSChao Chen 
251b01879ecSChao Chen   if (!isReadHintOrNone(getL3HintAttr()))
252fa6f88afSPetr Kurapov     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
253b01879ecSChao Chen 
25461b24c61SChao Chen   return success();
25561b24c61SChao Chen }
25661b24c61SChao Chen 
25761b24c61SChao Chen //===----------------------------------------------------------------------===//
25861b24c61SChao Chen // XeGPU_LoadNdOp
25961b24c61SChao Chen //===----------------------------------------------------------------------===//
26061b24c61SChao Chen LogicalResult LoadNdOp::verify() {
26161b24c61SChao Chen   auto tdescTy = getTensorDescType();
26261b24c61SChao Chen   auto valueTy = getType();
26361b24c61SChao Chen 
2646c783e19SChao Chen   if (tdescTy.getRank() > 2)
2656c783e19SChao Chen     return emitOpError("Expecting a 1D/2D TensorDesc.\n");
266b01879ecSChao Chen 
2678b5e8414SChao Chen   if (tdescTy.isScattered())
268b01879ecSChao Chen     return emitOpError("Expects a non-scattered TensorDesc.\n");
26961b24c61SChao Chen 
27061b24c61SChao Chen   if (!valueTy)
27161b24c61SChao Chen     return emitOpError("Invalid result, it should be a VectorType.\n");
27261b24c61SChao Chen 
273b01879ecSChao Chen   if (!isReadHintOrNone(getL1HintAttr()))
274fa6f88afSPetr Kurapov     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
27561b24c61SChao Chen 
276b01879ecSChao Chen   if (!isReadHintOrNone(getL2HintAttr()))
277fa6f88afSPetr Kurapov     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
278b01879ecSChao Chen 
279b01879ecSChao Chen   if (!isReadHintOrNone(getL3HintAttr()))
280fa6f88afSPetr Kurapov     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
28161b24c61SChao Chen 
28261b24c61SChao Chen   auto array_len = tdescTy.getArrayLength();
283b01879ecSChao Chen   auto tdescShape = getShapeOf(tdescTy);
284b01879ecSChao Chen   auto valueShape = getShapeOf(valueTy);
28561b24c61SChao Chen 
28661b24c61SChao Chen   if (getTranspose()) {
28761b24c61SChao Chen     auto trans = getTranspose().value();
2886c783e19SChao Chen 
2896c783e19SChao Chen     // Make sure the transpose value is valid.
2906c783e19SChao Chen     bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) {
2916c783e19SChao Chen       return t >= 0 && t < tdescTy.getRank();
2926c783e19SChao Chen     });
2936c783e19SChao Chen 
2946c783e19SChao Chen     if (valid)
29561b24c61SChao Chen       transpose(trans, tdescShape);
29661b24c61SChao Chen     else
297*ba6774f9SAdam Siemieniuk       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
29861b24c61SChao Chen   }
29961b24c61SChao Chen 
3006c783e19SChao Chen   if (getPacked()) {
3016c783e19SChao Chen     if (tdescTy.getRank() == 2) {
3026c783e19SChao Chen       const int axis = 0;
30361b24c61SChao Chen       auto vnni_factor = valueShape.back();
30461b24c61SChao Chen       tdescShape[axis] /= vnni_factor;
30561b24c61SChao Chen       tdescShape.push_back(vnni_factor);
3066c783e19SChao Chen     } else {
307*ba6774f9SAdam Siemieniuk       mlir::emitWarning(getLoc())
308*ba6774f9SAdam Siemieniuk           << "Invalid Packed Attr. It is ignored (available for 2D "
309*ba6774f9SAdam Siemieniuk              "TensorDesc only).";
3106c783e19SChao Chen     }
31161b24c61SChao Chen   }
31261b24c61SChao Chen 
31361b24c61SChao Chen   if (array_len > 1) {
31461b24c61SChao Chen     auto it = tdescShape.begin();
31561b24c61SChao Chen     tdescShape.insert(it, array_len);
31661b24c61SChao Chen   }
317fa6f88afSPetr Kurapov   auto sgMap = tdescTy.getSGMapAttr();
31861b24c61SChao Chen 
319fa6f88afSPetr Kurapov   if (!isArgShapesValid(tdescShape, valueShape, sgMap))
32061b24c61SChao Chen     return emitOpError() << "Result shape doesn't match TensorDesc shape."
32161b24c61SChao Chen                          << "The expected shape is " << makeString(tdescShape)
32261b24c61SChao Chen                          << ". But the given shape is "
32361b24c61SChao Chen                          << makeString(valueShape) << ".\n";
32461b24c61SChao Chen   return success();
32561b24c61SChao Chen }
32661b24c61SChao Chen 
32761b24c61SChao Chen //===----------------------------------------------------------------------===//
32861b24c61SChao Chen // XeGPU_StoreNdOp
32961b24c61SChao Chen //===----------------------------------------------------------------------===//
33061b24c61SChao Chen LogicalResult StoreNdOp::verify() {
331b01879ecSChao Chen   auto dstTy = getTensorDescType(); // Tile
332b01879ecSChao Chen   auto valTy = getValueType();      // Vector
33361b24c61SChao Chen 
3346c783e19SChao Chen   if (dstTy.getRank() > 2)
3356c783e19SChao Chen     return emitOpError("Expecting a 1D/2D TensorDesc.\n");
336b01879ecSChao Chen 
3378b5e8414SChao Chen   if (dstTy.isScattered())
338b01879ecSChao Chen     return emitOpError("Expects a non-scattered TensorDesc.\n");
33961b24c61SChao Chen 
34061b24c61SChao Chen   if (!valTy)
341fa6f88afSPetr Kurapov     return emitOpError("Expecting a VectorType result.\n");
34261b24c61SChao Chen 
343b01879ecSChao Chen   if (!isWriteHintOrNone(getL1HintAttr()))
344fa6f88afSPetr Kurapov     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
34561b24c61SChao Chen 
346b01879ecSChao Chen   if (!isWriteHintOrNone(getL2HintAttr()))
347fa6f88afSPetr Kurapov     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
348b01879ecSChao Chen 
349b01879ecSChao Chen   if (!isWriteHintOrNone(getL3HintAttr()))
350fa6f88afSPetr Kurapov     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
351b01879ecSChao Chen 
352fa6f88afSPetr Kurapov   auto tdescShape = getShapeOf(dstTy);
353fa6f88afSPetr Kurapov   auto valueShape = getShapeOf(valTy);
354fa6f88afSPetr Kurapov   auto sgMap = dstTy.getSGMapAttr();
355fa6f88afSPetr Kurapov 
356fa6f88afSPetr Kurapov   if (!isArgShapesValid(tdescShape, valueShape, sgMap))
357fa6f88afSPetr Kurapov     return emitOpError() << "Result shape doesn't match TensorDesc shape."
358fa6f88afSPetr Kurapov                          << "The expected shape is " << makeString(tdescShape)
359fa6f88afSPetr Kurapov                          << ". But the given shape is "
360fa6f88afSPetr Kurapov                          << makeString(valueShape) << ".\n";
361b01879ecSChao Chen   return success();
36261b24c61SChao Chen }
36361b24c61SChao Chen 
364b01879ecSChao Chen //===----------------------------------------------------------------------===//
365b01879ecSChao Chen // XeGPU_UpdateNDOffsetOp
366b01879ecSChao Chen //===----------------------------------------------------------------------===//
367b01879ecSChao Chen LogicalResult UpdateNdOffsetOp::verify() {
368b01879ecSChao Chen   auto ty = getTensorDescType();
3698b5e8414SChao Chen   if (ty.isScattered())
370b01879ecSChao Chen     return emitOpError("Expects a non-scattered TensorDesc.\n");
371b01879ecSChao Chen 
372b01879ecSChao Chen   // number of offsets specified must match the rank of the tensor descriptor
373b01879ecSChao Chen   if (ty.getRank() != (int64_t)getNumOffsets()) {
374b01879ecSChao Chen     return emitOpError("Invalid number of offsets.");
375b01879ecSChao Chen   }
376b01879ecSChao Chen   return success();
377b01879ecSChao Chen }
378b01879ecSChao Chen 
379b01879ecSChao Chen //===----------------------------------------------------------------------===//
380b01879ecSChao Chen // XeGPU_CreateDescOp
381b01879ecSChao Chen //===----------------------------------------------------------------------===//
382b01879ecSChao Chen 
3839c697b3aSChao Chen void CreateDescOp::build(OpBuilder &builder, OperationState &state,
3849c697b3aSChao Chen                          TensorDescType TensorDesc, Value source,
3859c697b3aSChao Chen                          llvm::ArrayRef<OpFoldResult> offsets) {
3869c697b3aSChao Chen   auto loc = source.getLoc();
3879c697b3aSChao Chen   int64_t size = static_cast<int64_t>(offsets.size());
3889c697b3aSChao Chen   auto type = VectorType::get(size, builder.getIndexType());
3899c697b3aSChao Chen   auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
3909c697b3aSChao Chen   auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
3919c697b3aSChao Chen   build(builder, state, TensorDesc, source, offset);
3929c697b3aSChao Chen }
3939c697b3aSChao Chen 
3949c697b3aSChao Chen void CreateDescOp::build(OpBuilder &builder, OperationState &state,
3959c697b3aSChao Chen                          TensorDescType TensorDesc, Value source,
3969c697b3aSChao Chen                          llvm::ArrayRef<int64_t> offsets) {
3979c697b3aSChao Chen   auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
3989c697b3aSChao Chen   build(builder, state, TensorDesc, source, ofrs);
3999c697b3aSChao Chen }
4009c697b3aSChao Chen 
401b01879ecSChao Chen LogicalResult CreateDescOp::verify() {
402b01879ecSChao Chen   auto tdescTy = getTensorDescType();
403b01879ecSChao Chen 
404b01879ecSChao Chen   if (getRankOf(getSource()) > 1)
405b01879ecSChao Chen     return emitOpError(
406b01879ecSChao Chen         "Expecting the source is a 1D memref or pointer (uint64_t).");
407b01879ecSChao Chen 
4088b5e8414SChao Chen   if (!tdescTy.isScattered())
409b01879ecSChao Chen     return emitOpError("Expects a scattered TensorDesc.\n");
410b01879ecSChao Chen 
4118b5e8414SChao Chen   // Memory space of created TensorDesc should match with the source.
4128b5e8414SChao Chen   // Both source and TensorDesc are considered for global memory by default,
4138b5e8414SChao Chen   // if the memory scope attr is not specified. If source is an integer,
4148b5e8414SChao Chen   // it is considered as ptr to global memory.
4158b5e8414SChao Chen   auto srcMemorySpace = getSourceMemorySpace();
4168b5e8414SChao Chen   auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace());
4178b5e8414SChao Chen   if (srcMemorySpace != tdescMemorySpace)
4188b5e8414SChao Chen     return emitOpError("Memory space mismatch.")
4198b5e8414SChao Chen            << " Source: " << srcMemorySpace
4208b5e8414SChao Chen            << ", TensorDesc: " << tdescMemorySpace;
4218b5e8414SChao Chen 
4228b5e8414SChao Chen   auto chunkSize = tdescTy.getChunkSize();
4238b5e8414SChao Chen 
4248b5e8414SChao Chen   // check chunk_size
4258b5e8414SChao Chen   llvm::SmallVector<int64_t> supportedChunkSizes = {1,  2,  3,  4,   8,
4268b5e8414SChao Chen                                                     16, 32, 64, 128, 256};
4278b5e8414SChao Chen   if (!llvm::is_contained(supportedChunkSizes, chunkSize))
4288b5e8414SChao Chen     return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
4298b5e8414SChao Chen                        "8, 16, 32, 64, 128, or 256.");
4308b5e8414SChao Chen 
4318b5e8414SChao Chen   // check total size
4328b5e8414SChao Chen   auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
4338b5e8414SChao Chen   auto bitsPerLane = elemBits * chunkSize;
4348b5e8414SChao Chen   if (chunkSize > 1 && bitsPerLane % 32) {
4358b5e8414SChao Chen     // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
4368b5e8414SChao Chen     // For 32-bit data, the hardware can support larger larger chunk size. So
4378b5e8414SChao Chen     // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
4388b5e8414SChao Chen     // But this requires the total size is 32 bit aligned to make the
4398b5e8414SChao Chen     // optimization work.
4408b5e8414SChao Chen     return emitOpError(
4418b5e8414SChao Chen         "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
4428b5e8414SChao Chen   }
4438b5e8414SChao Chen 
4448b5e8414SChao Chen   auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
4458b5e8414SChao Chen   if (elemBits * tdescTy.getNumElements() > lscConstraints)
4468b5e8414SChao Chen     return emitOpError("total access size (simd_lanes * chunk_size * "
4478b5e8414SChao Chen                        "sizeof(elemTy)) is upto 512 bytes.");
4488b5e8414SChao Chen 
449b01879ecSChao Chen   SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
450b01879ecSChao Chen   if (chunkSize != 1)
451b01879ecSChao Chen     shape.push_back(chunkSize);
452b01879ecSChao Chen 
453b01879ecSChao Chen   auto tdescShape = getShapeOf(tdescTy);
454b01879ecSChao Chen   if (shape != tdescShape)
455b01879ecSChao Chen     return emitOpError("Incorrect TensorDesc shape. ")
456b01879ecSChao Chen            << "Expected is " << makeString(shape) << "\n";
457b01879ecSChao Chen 
458b01879ecSChao Chen   return success();
459b01879ecSChao Chen }
460b01879ecSChao Chen 
461b01879ecSChao Chen //===----------------------------------------------------------------------===//
462b01879ecSChao Chen // XeGPU_PrefetchOp
463b01879ecSChao Chen //===----------------------------------------------------------------------===//
464b01879ecSChao Chen LogicalResult PrefetchOp::verify() {
465b01879ecSChao Chen   auto tdescTy = getTensorDescType();
4668b5e8414SChao Chen   if (!tdescTy.isScattered())
467b01879ecSChao Chen     return emitOpError("Expects a scattered TensorDesc.\n");
468b01879ecSChao Chen 
469b01879ecSChao Chen   if (!isReadHintOrNone(getL1HintAttr()))
470fa6f88afSPetr Kurapov     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
471b01879ecSChao Chen 
472b01879ecSChao Chen   if (!isReadHintOrNone(getL2HintAttr()))
473fa6f88afSPetr Kurapov     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
474b01879ecSChao Chen 
475b01879ecSChao Chen   if (!isReadHintOrNone(getL3HintAttr()))
476fa6f88afSPetr Kurapov     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
477b01879ecSChao Chen 
478b01879ecSChao Chen   return success();
479b01879ecSChao Chen }
480b01879ecSChao Chen 
481b01879ecSChao Chen //===----------------------------------------------------------------------===//
482b01879ecSChao Chen // XeGPU_LoadGatherOp
483b01879ecSChao Chen //===----------------------------------------------------------------------===//
484b01879ecSChao Chen LogicalResult LoadGatherOp::verify() {
485b01879ecSChao Chen   auto tdescTy = getTensorDescType();
486b01879ecSChao Chen   auto maskTy = getMaskType();
487b01879ecSChao Chen   auto valueTy = getValueType();
488b01879ecSChao Chen 
4898b5e8414SChao Chen   if (!tdescTy.isScattered())
490b01879ecSChao Chen     return emitOpError("Expects a scattered TensorDesc.\n");
491b01879ecSChao Chen 
492b01879ecSChao Chen   if (!isReadHintOrNone(getL1HintAttr()))
493fa6f88afSPetr Kurapov     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
494b01879ecSChao Chen 
495b01879ecSChao Chen   if (!isReadHintOrNone(getL2HintAttr()))
496fa6f88afSPetr Kurapov     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
497b01879ecSChao Chen 
498b01879ecSChao Chen   if (!isReadHintOrNone(getL3HintAttr()))
499fa6f88afSPetr Kurapov     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
500b01879ecSChao Chen 
501b01879ecSChao Chen   auto tdescElemTy = tdescTy.getElementType();
502b01879ecSChao Chen   auto valueElemTy = getElementType();
503b01879ecSChao Chen   if (tdescElemTy != valueElemTy)
504b01879ecSChao Chen     return emitOpError(
505b01879ecSChao Chen         "Value should have the same element type as TensorDesc.");
506b01879ecSChao Chen 
507b01879ecSChao Chen   auto maskShape = getShapeOf(maskTy);
508b01879ecSChao Chen   auto valueShape = getShapeOf(valueTy);
509b01879ecSChao Chen   auto tdescShape = getShapeOf(tdescTy);
510b01879ecSChao Chen 
511b01879ecSChao Chen   if (tdescShape[0] != maskShape[0])
512b01879ecSChao Chen     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
513b01879ecSChao Chen 
5148b5e8414SChao Chen   if (tdescTy.getRank() == 2) {
5158b5e8414SChao Chen     if (!getTransposeAttr())
5168b5e8414SChao Chen       return emitOpError("load_gather has to be transposed.");
5178b5e8414SChao Chen     transpose({1, 0}, tdescShape);
518b01879ecSChao Chen   }
519b01879ecSChao Chen 
520b01879ecSChao Chen   if (valueShape != tdescShape)
521b01879ecSChao Chen     return emitOpError("Unexpected result shape")
522b01879ecSChao Chen            << "(Expected shape: " << makeString(tdescShape)
523b01879ecSChao Chen            << ", Given shape: " << makeString(valueShape) << ").\n";
524b01879ecSChao Chen 
525b01879ecSChao Chen   return success();
526b01879ecSChao Chen }
527b01879ecSChao Chen 
528b01879ecSChao Chen //===----------------------------------------------------------------------===//
529b01879ecSChao Chen // XeGPU_StoreScatterOp
530b01879ecSChao Chen //===----------------------------------------------------------------------===//
531b01879ecSChao Chen LogicalResult StoreScatterOp::verify() {
532b01879ecSChao Chen   auto tdescTy = getTensorDescType();
5338b5e8414SChao Chen   if (!tdescTy.isScattered())
534b01879ecSChao Chen     return emitOpError("Expects a scattered TensorDesc.\n");
535b01879ecSChao Chen 
536b01879ecSChao Chen   if (!isWriteHintOrNone(getL1HintAttr()))
537fa6f88afSPetr Kurapov     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
538b01879ecSChao Chen 
539b01879ecSChao Chen   if (!isWriteHintOrNone(getL2HintAttr()))
540fa6f88afSPetr Kurapov     return emitOpError("invalid l2_hint: ") << getL2HintAttr();
541b01879ecSChao Chen 
542b01879ecSChao Chen   if (!isWriteHintOrNone(getL3HintAttr()))
543fa6f88afSPetr Kurapov     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
544b01879ecSChao Chen 
545b01879ecSChao Chen   auto maskTy = getMaskType();
5468b5e8414SChao Chen   auto valueTy = getValueType();
547b01879ecSChao Chen   auto maskShape = getShapeOf(maskTy);
548b01879ecSChao Chen   auto tdescShape = getShapeOf(tdescTy);
5498b5e8414SChao Chen   auto valueShape = getShapeOf(valueTy);
550b01879ecSChao Chen   if (tdescShape[0] != maskShape[0])
551b01879ecSChao Chen     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
552b01879ecSChao Chen 
5538b5e8414SChao Chen   if (tdescTy.getRank() == 2) {
5548b5e8414SChao Chen     if (!getTransposeAttr())
5558b5e8414SChao Chen       return emitOpError("load_gather has to be transposed.");
5568b5e8414SChao Chen     transpose({1, 0}, tdescShape);
5578b5e8414SChao Chen   }
5588b5e8414SChao Chen 
5598b5e8414SChao Chen   if (valueShape != tdescShape)
5608b5e8414SChao Chen     return emitOpError("Unexpected value shape")
5618b5e8414SChao Chen            << "(Expected shape: " << makeString(tdescShape)
5628b5e8414SChao Chen            << ", Given shape: " << makeString(valueShape) << ").\n";
5638b5e8414SChao Chen 
56461b24c61SChao Chen   return success();
56561b24c61SChao Chen }
5669c697b3aSChao Chen 
5679c697b3aSChao Chen //===----------------------------------------------------------------------===//
5689c697b3aSChao Chen // XeGPU_UpdateOffsetOp
5699c697b3aSChao Chen //===----------------------------------------------------------------------===//
5709c697b3aSChao Chen void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
5719c697b3aSChao Chen                            mlir::Value tensorDesc,
5729c697b3aSChao Chen                            llvm::ArrayRef<OpFoldResult> offsets) {
5739c697b3aSChao Chen   auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
5749c697b3aSChao Chen   assert(tdescTy && "Expecting the source is a TensorDescType value.");
5759c697b3aSChao Chen   auto loc = tensorDesc.getLoc();
5769c697b3aSChao Chen   int64_t size = static_cast<int64_t>(offsets.size());
5779c697b3aSChao Chen   auto type = VectorType::get({size}, builder.getIndexType());
5789c697b3aSChao Chen   auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
5799c697b3aSChao Chen   auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
5809c697b3aSChao Chen   build(builder, state, tdescTy, tensorDesc, offset);
5819c697b3aSChao Chen }
5829c697b3aSChao Chen 
5839c697b3aSChao Chen void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
5849c697b3aSChao Chen                            Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
5859c697b3aSChao Chen   auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
5869c697b3aSChao Chen   build(builder, state, tensorDesc, ofrs);
5879c697b3aSChao Chen }
5889c697b3aSChao Chen 
58903bb10dfSChao Chen //===----------------------------------------------------------------------===//
59003bb10dfSChao Chen // XeGPU_DpasOp
59103bb10dfSChao Chen //===----------------------------------------------------------------------===//
59203bb10dfSChao Chen LogicalResult DpasOp::verify() {
59303bb10dfSChao Chen   int64_t lhsRank = getLhsType().getRank();
59403bb10dfSChao Chen   int64_t rhsRank = getRhsType().getRank();
59503bb10dfSChao Chen 
5966c783e19SChao Chen   if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3))
5976c783e19SChao Chen     return emitOpError("expecting lhs to be a 2D vector, and rhs to be either "
5986c783e19SChao Chen                        "2D or 3D (packed) vector.");
59903bb10dfSChao Chen 
60003bb10dfSChao Chen   auto lhsShape = getLhsType().getShape();
60103bb10dfSChao Chen   auto rhsShape = getRhsType().getShape();
6026c783e19SChao Chen   auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
6036c783e19SChao Chen   if (bK != lhsShape[1])
6046c783e19SChao Chen     return emitOpError("K-dimension mismatch.");
60503bb10dfSChao Chen 
60603bb10dfSChao Chen   return success();
60703bb10dfSChao Chen }
6085669660fSChao Chen 
6095669660fSChao Chen } // namespace xegpu
6105669660fSChao Chen } // namespace mlir
6115669660fSChao Chen 
6125669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
6135669660fSChao Chen #define GET_OP_CLASSES
6145669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
615