15669660fSChao Chen //===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===// 25669660fSChao Chen // 35669660fSChao Chen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45669660fSChao Chen // See https://llvm.org/LICENSE.txt for license information. 55669660fSChao Chen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65669660fSChao Chen // 75669660fSChao Chen //===----------------------------------------------------------------------===// 85669660fSChao Chen 99c697b3aSChao Chen #include "mlir/Dialect/Arith/Utils/Utils.h" 1061b24c61SChao Chen #include "mlir/Dialect/Utils/StaticValueUtils.h" 1161b24c61SChao Chen #include "mlir/Dialect/XeGPU/IR/XeGPU.h" 1261b24c61SChao Chen #include "mlir/IR/Builders.h" 13b01879ecSChao Chen #include "mlir/IR/TypeUtilities.h" 14b01879ecSChao Chen 15b01879ecSChao Chen #include "llvm/Support/Debug.h" 165669660fSChao Chen 175669660fSChao Chen #define DEBUG_TYPE "xegpu" 185669660fSChao Chen 195669660fSChao Chen namespace mlir { 205669660fSChao Chen namespace xegpu { 21daebe5c4SChao Chen 2261b24c61SChao Chen static void transpose(llvm::ArrayRef<int64_t> trans, 23b01879ecSChao Chen SmallVector<int64_t> &shape) { 24b01879ecSChao Chen SmallVector<int64_t> old = shape; 2561b24c61SChao Chen for (size_t i = 0; i < trans.size(); i++) 2661b24c61SChao Chen shape[i] = old[trans[i]]; 2761b24c61SChao Chen } 2861b24c61SChao Chen 2961b24c61SChao Chen template <typename T> 3061b24c61SChao Chen static std::string makeString(T array, bool breakline = false) { 3161b24c61SChao Chen std::string buf; 3261b24c61SChao Chen buf.clear(); 3361b24c61SChao Chen llvm::raw_string_ostream os(buf); 3461b24c61SChao Chen os << "["; 3561b24c61SChao Chen for (size_t i = 1; i < array.size(); i++) { 3661b24c61SChao Chen os << array[i - 1] << ", "; 3761b24c61SChao Chen if (breakline) 3861b24c61SChao Chen os << "\n\t\t"; 3961b24c61SChao Chen } 4061b24c61SChao Chen os << array.back() << "]"; 4161b24c61SChao Chen return buf; 4261b24c61SChao Chen } 4361b24c61SChao Chen 44b01879ecSChao Chen static SmallVector<int64_t> getShapeOf(Type type) { 45b01879ecSChao Chen SmallVector<int64_t> shape; 46b01879ecSChao Chen if (auto ty = llvm::dyn_cast<ShapedType>(type)) 47b01879ecSChao Chen shape = SmallVector<int64_t>(ty.getShape()); 48b01879ecSChao Chen else 49b01879ecSChao Chen shape.push_back(1); 50b01879ecSChao Chen return shape; 51b01879ecSChao Chen } 52b01879ecSChao Chen 53b01879ecSChao Chen static int64_t getRankOf(Value val) { 54b01879ecSChao Chen auto type = val.getType(); 55b01879ecSChao Chen if (auto ty = llvm::dyn_cast<ShapedType>(type)) 56b01879ecSChao Chen return ty.getRank(); 57b01879ecSChao Chen return 0; 58c9731a3dSKazu Hirata } 59b01879ecSChao Chen 60b01879ecSChao Chen static bool isReadHintOrNone(const CachePolicyAttr &attr) { 61b01879ecSChao Chen if (!attr) 62b01879ecSChao Chen return true; 63b01879ecSChao Chen auto kind = attr.getValue(); 64b01879ecSChao Chen return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || 65b01879ecSChao Chen kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE; 66b01879ecSChao Chen } 67b01879ecSChao Chen 68b01879ecSChao Chen static bool isWriteHintOrNone(const CachePolicyAttr &attr) { 69b01879ecSChao Chen if (!attr) 70b01879ecSChao Chen return true; 71b01879ecSChao Chen auto kind = attr.getValue(); 72b01879ecSChao Chen return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || 73b01879ecSChao Chen kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; 74b01879ecSChao Chen } 75b01879ecSChao Chen 76fa6f88afSPetr Kurapov // Validations for nd instruction arguments is successful if any of these are 77fa6f88afSPetr Kurapov // true: 78fa6f88afSPetr Kurapov // - tensor descriptor and the output vector shapes exactly match. 79fa6f88afSPetr Kurapov // - tensor descriptor has a sg_map attribute and the distributed vector shape 80fa6f88afSPetr Kurapov // matches the tensor descriptor shape when scaled using sg_map factors on 81fa6f88afSPetr Kurapov // each dimension. 82fa6f88afSPetr Kurapov static bool isArgShapesValid(ArrayRef<int64_t> descShape, 83fa6f88afSPetr Kurapov ArrayRef<int64_t> valShape, SGMapAttr sgMap) { 84fa6f88afSPetr Kurapov if (descShape == valShape) { 85fa6f88afSPetr Kurapov if (!sgMap) 86fa6f88afSPetr Kurapov return true; 87fa6f88afSPetr Kurapov 88fa6f88afSPetr Kurapov // this can be relaxed if necessary by supporting non-2d shapes distribution 89fa6f88afSPetr Kurapov // until the constraints are defined this lives here instead of the tensor 90fa6f88afSPetr Kurapov // descriptor type. 91fa6f88afSPetr Kurapov return valShape.size() == sgMap.getWiLayout().size(); 92fa6f88afSPetr Kurapov } 93fa6f88afSPetr Kurapov 94fa6f88afSPetr Kurapov if (!sgMap) 95fa6f88afSPetr Kurapov return false; 96fa6f88afSPetr Kurapov 97fa6f88afSPetr Kurapov if (valShape.size() != descShape.size()) 98fa6f88afSPetr Kurapov return false; 99fa6f88afSPetr Kurapov 100fa6f88afSPetr Kurapov for (const auto &[factor, dim, expected] : 101fa6f88afSPetr Kurapov llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) { 102fa6f88afSPetr Kurapov if (factor * dim != expected) 103fa6f88afSPetr Kurapov return false; 104fa6f88afSPetr Kurapov } 105fa6f88afSPetr Kurapov 106fa6f88afSPetr Kurapov return true; 107fa6f88afSPetr Kurapov } 108fa6f88afSPetr Kurapov 10961b24c61SChao Chen //===----------------------------------------------------------------------===// 11061b24c61SChao Chen // XeGPU_CreateNdDescOp 11161b24c61SChao Chen //===----------------------------------------------------------------------===// 11261b24c61SChao Chen void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, 11361b24c61SChao Chen Type tdesc, TypedValue<MemRefType> source, 11461b24c61SChao Chen llvm::ArrayRef<OpFoldResult> offsets) { 115258091e7SJie Fu [[maybe_unused]] auto ty = source.getType(); 11661b24c61SChao Chen assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()); 11761b24c61SChao Chen 11861b24c61SChao Chen llvm::SmallVector<int64_t> staticOffsets; 11961b24c61SChao Chen llvm::SmallVector<Value> dynamicOffsets; 12061b24c61SChao Chen dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 12161b24c61SChao Chen 12261b24c61SChao Chen build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, 12361b24c61SChao Chen ValueRange({}) /* empty dynamic shape */, 12461b24c61SChao Chen ValueRange({}) /* empty dynamic strides */, 12561b24c61SChao Chen staticOffsets /* const offsets */, {} /* empty const shape*/, 12661b24c61SChao Chen {} /* empty const strides*/); 12761b24c61SChao Chen } 12861b24c61SChao Chen 12961b24c61SChao Chen void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, 13027046badSMd Abdullah Shahneous Bari Type tdesc, TypedValue<MemRefType> source, 13127046badSMd Abdullah Shahneous Bari llvm::ArrayRef<OpFoldResult> offsets, 13227046badSMd Abdullah Shahneous Bari llvm::ArrayRef<OpFoldResult> shape, 13327046badSMd Abdullah Shahneous Bari llvm::ArrayRef<OpFoldResult> strides) { 13427046badSMd Abdullah Shahneous Bari assert(shape.size() && offsets.size() && strides.size() && 13527046badSMd Abdullah Shahneous Bari shape.size() == strides.size() && shape.size() == offsets.size()); 13627046badSMd Abdullah Shahneous Bari 13727046badSMd Abdullah Shahneous Bari llvm::SmallVector<int64_t> staticOffsets; 13827046badSMd Abdullah Shahneous Bari llvm::SmallVector<int64_t> staticShape; 13927046badSMd Abdullah Shahneous Bari llvm::SmallVector<int64_t> staticStrides; 14027046badSMd Abdullah Shahneous Bari llvm::SmallVector<Value> dynamicOffsets; 14127046badSMd Abdullah Shahneous Bari llvm::SmallVector<Value> dynamicShape; 14227046badSMd Abdullah Shahneous Bari llvm::SmallVector<Value> dynamicStrides; 14327046badSMd Abdullah Shahneous Bari 14427046badSMd Abdullah Shahneous Bari dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 14527046badSMd Abdullah Shahneous Bari dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); 14627046badSMd Abdullah Shahneous Bari dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 14727046badSMd Abdullah Shahneous Bari 14827046badSMd Abdullah Shahneous Bari auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); 14927046badSMd Abdullah Shahneous Bari auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); 15027046badSMd Abdullah Shahneous Bari auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); 15127046badSMd Abdullah Shahneous Bari 15227046badSMd Abdullah Shahneous Bari build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, 15327046badSMd Abdullah Shahneous Bari dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); 15427046badSMd Abdullah Shahneous Bari } 15527046badSMd Abdullah Shahneous Bari 15627046badSMd Abdullah Shahneous Bari void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, 15761b24c61SChao Chen Type tdesc, TypedValue<IntegerType> source, 15861b24c61SChao Chen llvm::ArrayRef<OpFoldResult> offsets, 15961b24c61SChao Chen llvm::ArrayRef<OpFoldResult> shape, 16061b24c61SChao Chen llvm::ArrayRef<OpFoldResult> strides) { 16161b24c61SChao Chen assert(shape.size() && offsets.size() && strides.size() && 16261b24c61SChao Chen shape.size() == strides.size() && shape.size() == offsets.size()); 16361b24c61SChao Chen 16461b24c61SChao Chen llvm::SmallVector<int64_t> staticOffsets; 16561b24c61SChao Chen llvm::SmallVector<int64_t> staticShape; 16661b24c61SChao Chen llvm::SmallVector<int64_t> staticStrides; 16761b24c61SChao Chen llvm::SmallVector<Value> dynamicOffsets; 16861b24c61SChao Chen llvm::SmallVector<Value> dynamicShape; 16961b24c61SChao Chen llvm::SmallVector<Value> dynamicStrides; 17061b24c61SChao Chen 17161b24c61SChao Chen dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); 17261b24c61SChao Chen dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); 17373a2fd47SArtem Kroviakov dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); 17461b24c61SChao Chen 17561b24c61SChao Chen auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); 17661b24c61SChao Chen auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); 17761b24c61SChao Chen auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); 17861b24c61SChao Chen 17961b24c61SChao Chen build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, 18061b24c61SChao Chen dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); 18161b24c61SChao Chen } 18261b24c61SChao Chen 18361b24c61SChao Chen LogicalResult CreateNdDescOp::verify() { 18461b24c61SChao Chen auto rank = (int64_t)getMixedOffsets().size(); 1856c783e19SChao Chen bool invalidRank = false; 18661b24c61SChao Chen bool invalidElemTy = false; 18761b24c61SChao Chen 1888b5e8414SChao Chen // Memory space of created TensorDesc should match with the source. 1898b5e8414SChao Chen // Both source and TensorDesc are considered for global memory by default, 1908b5e8414SChao Chen // if the memory scope attr is not specified. If source is an integer, 1918b5e8414SChao Chen // it is considered as ptr to global memory. 1928b5e8414SChao Chen auto srcMemorySpace = getSourceMemorySpace(); 1938b5e8414SChao Chen auto tdescMemorySpace = static_cast<unsigned>(getType().getMemorySpace()); 1948b5e8414SChao Chen if (srcMemorySpace != tdescMemorySpace) 1958b5e8414SChao Chen return emitOpError("Memory space mismatch.") 1968b5e8414SChao Chen << " Source: " << srcMemorySpace 1978b5e8414SChao Chen << ", TensorDesc: " << tdescMemorySpace; 1988b5e8414SChao Chen 19961b24c61SChao Chen // check source type matches the rank if it is a memref. 20061b24c61SChao Chen // It also should have the same ElementType as TensorDesc. 201a5757c5bSChristian Sigg auto memrefTy = dyn_cast<MemRefType>(getSourceType()); 20261b24c61SChao Chen if (memrefTy) { 20361b24c61SChao Chen invalidRank |= (memrefTy.getRank() != rank); 20461b24c61SChao Chen invalidElemTy |= memrefTy.getElementType() != getElementType(); 20561b24c61SChao Chen } 20661b24c61SChao Chen 20761b24c61SChao Chen // mismatches among shape, strides, and offsets are 20861b24c61SChao Chen // already handeled by OffsetSizeAndStrideOpInterface. 20961b24c61SChao Chen // So they are not check here. 21061b24c61SChao Chen if (invalidRank) 21161b24c61SChao Chen return emitOpError( 2126c783e19SChao Chen "Expecting the rank of shape, strides, offsets, and source (if source " 2136c783e19SChao Chen "is a memref) should match with each other."); 2146c783e19SChao Chen 2156c783e19SChao Chen // check result TensorDesc rank 2166c783e19SChao Chen invalidRank = (getType().getRank() > 2 || getType().getRank() > rank); 2176c783e19SChao Chen 2186c783e19SChao Chen if (invalidRank) 2196c783e19SChao Chen return emitOpError( 2206c783e19SChao Chen "Expecting the TensorDesc rank is up to 2 and not greater than the " 2216c783e19SChao Chen "ranks of shape, strides, offsets or the memref source."); 22261b24c61SChao Chen 22361b24c61SChao Chen if (invalidElemTy) 22461b24c61SChao Chen return emitOpError("TensorDesc should have the same element " 22561b24c61SChao Chen "type with the source if it is a memref.\n"); 22661b24c61SChao Chen 2278b5e8414SChao Chen if (getType().isScattered()) 228b01879ecSChao Chen return emitOpError("Expects a non-scattered TensorDesc.\n"); 229b01879ecSChao Chen 2308b5e8414SChao Chen if (getType().getRank() == 2 && 2318b5e8414SChao Chen tdescMemorySpace == static_cast<unsigned>(MemorySpace::SLM)) 2328b5e8414SChao Chen return emitOpError("SLM is not supported for 2D Block TensorDesc.\n"); 2338b5e8414SChao Chen 234b01879ecSChao Chen return success(); 235b01879ecSChao Chen } 236b01879ecSChao Chen 237b01879ecSChao Chen //===----------------------------------------------------------------------===// 238b01879ecSChao Chen // XeGPU_PrefetchNdOp 239b01879ecSChao Chen //===----------------------------------------------------------------------===// 240b01879ecSChao Chen LogicalResult PrefetchNdOp::verify() { 241b01879ecSChao Chen auto tdescTy = getTensorDescType(); 2428b5e8414SChao Chen if (tdescTy.isScattered()) 243b01879ecSChao Chen return emitOpError("Expects a non-scattered TensorDesc.\n"); 244b01879ecSChao Chen 245b01879ecSChao Chen if (!isReadHintOrNone(getL1HintAttr())) 246fa6f88afSPetr Kurapov return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 247b01879ecSChao Chen 248b01879ecSChao Chen if (!isReadHintOrNone(getL2HintAttr())) 249fa6f88afSPetr Kurapov return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 250b01879ecSChao Chen 251b01879ecSChao Chen if (!isReadHintOrNone(getL3HintAttr())) 252fa6f88afSPetr Kurapov return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 253b01879ecSChao Chen 25461b24c61SChao Chen return success(); 25561b24c61SChao Chen } 25661b24c61SChao Chen 25761b24c61SChao Chen //===----------------------------------------------------------------------===// 25861b24c61SChao Chen // XeGPU_LoadNdOp 25961b24c61SChao Chen //===----------------------------------------------------------------------===// 26061b24c61SChao Chen LogicalResult LoadNdOp::verify() { 26161b24c61SChao Chen auto tdescTy = getTensorDescType(); 26261b24c61SChao Chen auto valueTy = getType(); 26361b24c61SChao Chen 2646c783e19SChao Chen if (tdescTy.getRank() > 2) 2656c783e19SChao Chen return emitOpError("Expecting a 1D/2D TensorDesc.\n"); 266b01879ecSChao Chen 2678b5e8414SChao Chen if (tdescTy.isScattered()) 268b01879ecSChao Chen return emitOpError("Expects a non-scattered TensorDesc.\n"); 26961b24c61SChao Chen 27061b24c61SChao Chen if (!valueTy) 27161b24c61SChao Chen return emitOpError("Invalid result, it should be a VectorType.\n"); 27261b24c61SChao Chen 273b01879ecSChao Chen if (!isReadHintOrNone(getL1HintAttr())) 274fa6f88afSPetr Kurapov return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 27561b24c61SChao Chen 276b01879ecSChao Chen if (!isReadHintOrNone(getL2HintAttr())) 277fa6f88afSPetr Kurapov return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 278b01879ecSChao Chen 279b01879ecSChao Chen if (!isReadHintOrNone(getL3HintAttr())) 280fa6f88afSPetr Kurapov return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 28161b24c61SChao Chen 28261b24c61SChao Chen auto array_len = tdescTy.getArrayLength(); 283b01879ecSChao Chen auto tdescShape = getShapeOf(tdescTy); 284b01879ecSChao Chen auto valueShape = getShapeOf(valueTy); 28561b24c61SChao Chen 28661b24c61SChao Chen if (getTranspose()) { 28761b24c61SChao Chen auto trans = getTranspose().value(); 2886c783e19SChao Chen 2896c783e19SChao Chen // Make sure the transpose value is valid. 2906c783e19SChao Chen bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) { 2916c783e19SChao Chen return t >= 0 && t < tdescTy.getRank(); 2926c783e19SChao Chen }); 2936c783e19SChao Chen 2946c783e19SChao Chen if (valid) 29561b24c61SChao Chen transpose(trans, tdescShape); 29661b24c61SChao Chen else 297*ba6774f9SAdam Siemieniuk mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored."; 29861b24c61SChao Chen } 29961b24c61SChao Chen 3006c783e19SChao Chen if (getPacked()) { 3016c783e19SChao Chen if (tdescTy.getRank() == 2) { 3026c783e19SChao Chen const int axis = 0; 30361b24c61SChao Chen auto vnni_factor = valueShape.back(); 30461b24c61SChao Chen tdescShape[axis] /= vnni_factor; 30561b24c61SChao Chen tdescShape.push_back(vnni_factor); 3066c783e19SChao Chen } else { 307*ba6774f9SAdam Siemieniuk mlir::emitWarning(getLoc()) 308*ba6774f9SAdam Siemieniuk << "Invalid Packed Attr. It is ignored (available for 2D " 309*ba6774f9SAdam Siemieniuk "TensorDesc only)."; 3106c783e19SChao Chen } 31161b24c61SChao Chen } 31261b24c61SChao Chen 31361b24c61SChao Chen if (array_len > 1) { 31461b24c61SChao Chen auto it = tdescShape.begin(); 31561b24c61SChao Chen tdescShape.insert(it, array_len); 31661b24c61SChao Chen } 317fa6f88afSPetr Kurapov auto sgMap = tdescTy.getSGMapAttr(); 31861b24c61SChao Chen 319fa6f88afSPetr Kurapov if (!isArgShapesValid(tdescShape, valueShape, sgMap)) 32061b24c61SChao Chen return emitOpError() << "Result shape doesn't match TensorDesc shape." 32161b24c61SChao Chen << "The expected shape is " << makeString(tdescShape) 32261b24c61SChao Chen << ". But the given shape is " 32361b24c61SChao Chen << makeString(valueShape) << ".\n"; 32461b24c61SChao Chen return success(); 32561b24c61SChao Chen } 32661b24c61SChao Chen 32761b24c61SChao Chen //===----------------------------------------------------------------------===// 32861b24c61SChao Chen // XeGPU_StoreNdOp 32961b24c61SChao Chen //===----------------------------------------------------------------------===// 33061b24c61SChao Chen LogicalResult StoreNdOp::verify() { 331b01879ecSChao Chen auto dstTy = getTensorDescType(); // Tile 332b01879ecSChao Chen auto valTy = getValueType(); // Vector 33361b24c61SChao Chen 3346c783e19SChao Chen if (dstTy.getRank() > 2) 3356c783e19SChao Chen return emitOpError("Expecting a 1D/2D TensorDesc.\n"); 336b01879ecSChao Chen 3378b5e8414SChao Chen if (dstTy.isScattered()) 338b01879ecSChao Chen return emitOpError("Expects a non-scattered TensorDesc.\n"); 33961b24c61SChao Chen 34061b24c61SChao Chen if (!valTy) 341fa6f88afSPetr Kurapov return emitOpError("Expecting a VectorType result.\n"); 34261b24c61SChao Chen 343b01879ecSChao Chen if (!isWriteHintOrNone(getL1HintAttr())) 344fa6f88afSPetr Kurapov return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 34561b24c61SChao Chen 346b01879ecSChao Chen if (!isWriteHintOrNone(getL2HintAttr())) 347fa6f88afSPetr Kurapov return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 348b01879ecSChao Chen 349b01879ecSChao Chen if (!isWriteHintOrNone(getL3HintAttr())) 350fa6f88afSPetr Kurapov return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 351b01879ecSChao Chen 352fa6f88afSPetr Kurapov auto tdescShape = getShapeOf(dstTy); 353fa6f88afSPetr Kurapov auto valueShape = getShapeOf(valTy); 354fa6f88afSPetr Kurapov auto sgMap = dstTy.getSGMapAttr(); 355fa6f88afSPetr Kurapov 356fa6f88afSPetr Kurapov if (!isArgShapesValid(tdescShape, valueShape, sgMap)) 357fa6f88afSPetr Kurapov return emitOpError() << "Result shape doesn't match TensorDesc shape." 358fa6f88afSPetr Kurapov << "The expected shape is " << makeString(tdescShape) 359fa6f88afSPetr Kurapov << ". But the given shape is " 360fa6f88afSPetr Kurapov << makeString(valueShape) << ".\n"; 361b01879ecSChao Chen return success(); 36261b24c61SChao Chen } 36361b24c61SChao Chen 364b01879ecSChao Chen //===----------------------------------------------------------------------===// 365b01879ecSChao Chen // XeGPU_UpdateNDOffsetOp 366b01879ecSChao Chen //===----------------------------------------------------------------------===// 367b01879ecSChao Chen LogicalResult UpdateNdOffsetOp::verify() { 368b01879ecSChao Chen auto ty = getTensorDescType(); 3698b5e8414SChao Chen if (ty.isScattered()) 370b01879ecSChao Chen return emitOpError("Expects a non-scattered TensorDesc.\n"); 371b01879ecSChao Chen 372b01879ecSChao Chen // number of offsets specified must match the rank of the tensor descriptor 373b01879ecSChao Chen if (ty.getRank() != (int64_t)getNumOffsets()) { 374b01879ecSChao Chen return emitOpError("Invalid number of offsets."); 375b01879ecSChao Chen } 376b01879ecSChao Chen return success(); 377b01879ecSChao Chen } 378b01879ecSChao Chen 379b01879ecSChao Chen //===----------------------------------------------------------------------===// 380b01879ecSChao Chen // XeGPU_CreateDescOp 381b01879ecSChao Chen //===----------------------------------------------------------------------===// 382b01879ecSChao Chen 3839c697b3aSChao Chen void CreateDescOp::build(OpBuilder &builder, OperationState &state, 3849c697b3aSChao Chen TensorDescType TensorDesc, Value source, 3859c697b3aSChao Chen llvm::ArrayRef<OpFoldResult> offsets) { 3869c697b3aSChao Chen auto loc = source.getLoc(); 3879c697b3aSChao Chen int64_t size = static_cast<int64_t>(offsets.size()); 3889c697b3aSChao Chen auto type = VectorType::get(size, builder.getIndexType()); 3899c697b3aSChao Chen auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); 3909c697b3aSChao Chen auto offset = builder.create<vector::FromElementsOp>(loc, type, values); 3919c697b3aSChao Chen build(builder, state, TensorDesc, source, offset); 3929c697b3aSChao Chen } 3939c697b3aSChao Chen 3949c697b3aSChao Chen void CreateDescOp::build(OpBuilder &builder, OperationState &state, 3959c697b3aSChao Chen TensorDescType TensorDesc, Value source, 3969c697b3aSChao Chen llvm::ArrayRef<int64_t> offsets) { 3979c697b3aSChao Chen auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); 3989c697b3aSChao Chen build(builder, state, TensorDesc, source, ofrs); 3999c697b3aSChao Chen } 4009c697b3aSChao Chen 401b01879ecSChao Chen LogicalResult CreateDescOp::verify() { 402b01879ecSChao Chen auto tdescTy = getTensorDescType(); 403b01879ecSChao Chen 404b01879ecSChao Chen if (getRankOf(getSource()) > 1) 405b01879ecSChao Chen return emitOpError( 406b01879ecSChao Chen "Expecting the source is a 1D memref or pointer (uint64_t)."); 407b01879ecSChao Chen 4088b5e8414SChao Chen if (!tdescTy.isScattered()) 409b01879ecSChao Chen return emitOpError("Expects a scattered TensorDesc.\n"); 410b01879ecSChao Chen 4118b5e8414SChao Chen // Memory space of created TensorDesc should match with the source. 4128b5e8414SChao Chen // Both source and TensorDesc are considered for global memory by default, 4138b5e8414SChao Chen // if the memory scope attr is not specified. If source is an integer, 4148b5e8414SChao Chen // it is considered as ptr to global memory. 4158b5e8414SChao Chen auto srcMemorySpace = getSourceMemorySpace(); 4168b5e8414SChao Chen auto tdescMemorySpace = static_cast<unsigned>(tdescTy.getMemorySpace()); 4178b5e8414SChao Chen if (srcMemorySpace != tdescMemorySpace) 4188b5e8414SChao Chen return emitOpError("Memory space mismatch.") 4198b5e8414SChao Chen << " Source: " << srcMemorySpace 4208b5e8414SChao Chen << ", TensorDesc: " << tdescMemorySpace; 4218b5e8414SChao Chen 4228b5e8414SChao Chen auto chunkSize = tdescTy.getChunkSize(); 4238b5e8414SChao Chen 4248b5e8414SChao Chen // check chunk_size 4258b5e8414SChao Chen llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8, 4268b5e8414SChao Chen 16, 32, 64, 128, 256}; 4278b5e8414SChao Chen if (!llvm::is_contained(supportedChunkSizes, chunkSize)) 4288b5e8414SChao Chen return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, " 4298b5e8414SChao Chen "8, 16, 32, 64, 128, or 256."); 4308b5e8414SChao Chen 4318b5e8414SChao Chen // check total size 4328b5e8414SChao Chen auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth(); 4338b5e8414SChao Chen auto bitsPerLane = elemBits * chunkSize; 4348b5e8414SChao Chen if (chunkSize > 1 && bitsPerLane % 32) { 4358b5e8414SChao Chen // For 8-bit and 16-bit data, the hardware only supports chunk size of 1. 4368b5e8414SChao Chen // For 32-bit data, the hardware can support larger larger chunk size. So 4378b5e8414SChao Chen // we can bitcast 8-bit/16-bit data to 32-bit data for better performance. 4388b5e8414SChao Chen // But this requires the total size is 32 bit aligned to make the 4398b5e8414SChao Chen // optimization work. 4408b5e8414SChao Chen return emitOpError( 4418b5e8414SChao Chen "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned."); 4428b5e8414SChao Chen } 4438b5e8414SChao Chen 4448b5e8414SChao Chen auto lscConstraints = 512 * 8; // each access is upto 512 bytes. 4458b5e8414SChao Chen if (elemBits * tdescTy.getNumElements() > lscConstraints) 4468b5e8414SChao Chen return emitOpError("total access size (simd_lanes * chunk_size * " 4478b5e8414SChao Chen "sizeof(elemTy)) is upto 512 bytes."); 4488b5e8414SChao Chen 449b01879ecSChao Chen SmallVector<int64_t> shape({(int64_t)getNumOffsets()}); 450b01879ecSChao Chen if (chunkSize != 1) 451b01879ecSChao Chen shape.push_back(chunkSize); 452b01879ecSChao Chen 453b01879ecSChao Chen auto tdescShape = getShapeOf(tdescTy); 454b01879ecSChao Chen if (shape != tdescShape) 455b01879ecSChao Chen return emitOpError("Incorrect TensorDesc shape. ") 456b01879ecSChao Chen << "Expected is " << makeString(shape) << "\n"; 457b01879ecSChao Chen 458b01879ecSChao Chen return success(); 459b01879ecSChao Chen } 460b01879ecSChao Chen 461b01879ecSChao Chen //===----------------------------------------------------------------------===// 462b01879ecSChao Chen // XeGPU_PrefetchOp 463b01879ecSChao Chen //===----------------------------------------------------------------------===// 464b01879ecSChao Chen LogicalResult PrefetchOp::verify() { 465b01879ecSChao Chen auto tdescTy = getTensorDescType(); 4668b5e8414SChao Chen if (!tdescTy.isScattered()) 467b01879ecSChao Chen return emitOpError("Expects a scattered TensorDesc.\n"); 468b01879ecSChao Chen 469b01879ecSChao Chen if (!isReadHintOrNone(getL1HintAttr())) 470fa6f88afSPetr Kurapov return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 471b01879ecSChao Chen 472b01879ecSChao Chen if (!isReadHintOrNone(getL2HintAttr())) 473fa6f88afSPetr Kurapov return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 474b01879ecSChao Chen 475b01879ecSChao Chen if (!isReadHintOrNone(getL3HintAttr())) 476fa6f88afSPetr Kurapov return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 477b01879ecSChao Chen 478b01879ecSChao Chen return success(); 479b01879ecSChao Chen } 480b01879ecSChao Chen 481b01879ecSChao Chen //===----------------------------------------------------------------------===// 482b01879ecSChao Chen // XeGPU_LoadGatherOp 483b01879ecSChao Chen //===----------------------------------------------------------------------===// 484b01879ecSChao Chen LogicalResult LoadGatherOp::verify() { 485b01879ecSChao Chen auto tdescTy = getTensorDescType(); 486b01879ecSChao Chen auto maskTy = getMaskType(); 487b01879ecSChao Chen auto valueTy = getValueType(); 488b01879ecSChao Chen 4898b5e8414SChao Chen if (!tdescTy.isScattered()) 490b01879ecSChao Chen return emitOpError("Expects a scattered TensorDesc.\n"); 491b01879ecSChao Chen 492b01879ecSChao Chen if (!isReadHintOrNone(getL1HintAttr())) 493fa6f88afSPetr Kurapov return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 494b01879ecSChao Chen 495b01879ecSChao Chen if (!isReadHintOrNone(getL2HintAttr())) 496fa6f88afSPetr Kurapov return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 497b01879ecSChao Chen 498b01879ecSChao Chen if (!isReadHintOrNone(getL3HintAttr())) 499fa6f88afSPetr Kurapov return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 500b01879ecSChao Chen 501b01879ecSChao Chen auto tdescElemTy = tdescTy.getElementType(); 502b01879ecSChao Chen auto valueElemTy = getElementType(); 503b01879ecSChao Chen if (tdescElemTy != valueElemTy) 504b01879ecSChao Chen return emitOpError( 505b01879ecSChao Chen "Value should have the same element type as TensorDesc."); 506b01879ecSChao Chen 507b01879ecSChao Chen auto maskShape = getShapeOf(maskTy); 508b01879ecSChao Chen auto valueShape = getShapeOf(valueTy); 509b01879ecSChao Chen auto tdescShape = getShapeOf(tdescTy); 510b01879ecSChao Chen 511b01879ecSChao Chen if (tdescShape[0] != maskShape[0]) 512b01879ecSChao Chen return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); 513b01879ecSChao Chen 5148b5e8414SChao Chen if (tdescTy.getRank() == 2) { 5158b5e8414SChao Chen if (!getTransposeAttr()) 5168b5e8414SChao Chen return emitOpError("load_gather has to be transposed."); 5178b5e8414SChao Chen transpose({1, 0}, tdescShape); 518b01879ecSChao Chen } 519b01879ecSChao Chen 520b01879ecSChao Chen if (valueShape != tdescShape) 521b01879ecSChao Chen return emitOpError("Unexpected result shape") 522b01879ecSChao Chen << "(Expected shape: " << makeString(tdescShape) 523b01879ecSChao Chen << ", Given shape: " << makeString(valueShape) << ").\n"; 524b01879ecSChao Chen 525b01879ecSChao Chen return success(); 526b01879ecSChao Chen } 527b01879ecSChao Chen 528b01879ecSChao Chen //===----------------------------------------------------------------------===// 529b01879ecSChao Chen // XeGPU_StoreScatterOp 530b01879ecSChao Chen //===----------------------------------------------------------------------===// 531b01879ecSChao Chen LogicalResult StoreScatterOp::verify() { 532b01879ecSChao Chen auto tdescTy = getTensorDescType(); 5338b5e8414SChao Chen if (!tdescTy.isScattered()) 534b01879ecSChao Chen return emitOpError("Expects a scattered TensorDesc.\n"); 535b01879ecSChao Chen 536b01879ecSChao Chen if (!isWriteHintOrNone(getL1HintAttr())) 537fa6f88afSPetr Kurapov return emitOpError("invalid l1_hint: ") << getL1HintAttr(); 538b01879ecSChao Chen 539b01879ecSChao Chen if (!isWriteHintOrNone(getL2HintAttr())) 540fa6f88afSPetr Kurapov return emitOpError("invalid l2_hint: ") << getL2HintAttr(); 541b01879ecSChao Chen 542b01879ecSChao Chen if (!isWriteHintOrNone(getL3HintAttr())) 543fa6f88afSPetr Kurapov return emitOpError("invalid l3_hint: ") << getL3HintAttr(); 544b01879ecSChao Chen 545b01879ecSChao Chen auto maskTy = getMaskType(); 5468b5e8414SChao Chen auto valueTy = getValueType(); 547b01879ecSChao Chen auto maskShape = getShapeOf(maskTy); 548b01879ecSChao Chen auto tdescShape = getShapeOf(tdescTy); 5498b5e8414SChao Chen auto valueShape = getShapeOf(valueTy); 550b01879ecSChao Chen if (tdescShape[0] != maskShape[0]) 551b01879ecSChao Chen return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); 552b01879ecSChao Chen 5538b5e8414SChao Chen if (tdescTy.getRank() == 2) { 5548b5e8414SChao Chen if (!getTransposeAttr()) 5558b5e8414SChao Chen return emitOpError("load_gather has to be transposed."); 5568b5e8414SChao Chen transpose({1, 0}, tdescShape); 5578b5e8414SChao Chen } 5588b5e8414SChao Chen 5598b5e8414SChao Chen if (valueShape != tdescShape) 5608b5e8414SChao Chen return emitOpError("Unexpected value shape") 5618b5e8414SChao Chen << "(Expected shape: " << makeString(tdescShape) 5628b5e8414SChao Chen << ", Given shape: " << makeString(valueShape) << ").\n"; 5638b5e8414SChao Chen 56461b24c61SChao Chen return success(); 56561b24c61SChao Chen } 5669c697b3aSChao Chen 5679c697b3aSChao Chen //===----------------------------------------------------------------------===// 5689c697b3aSChao Chen // XeGPU_UpdateOffsetOp 5699c697b3aSChao Chen //===----------------------------------------------------------------------===// 5709c697b3aSChao Chen void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, 5719c697b3aSChao Chen mlir::Value tensorDesc, 5729c697b3aSChao Chen llvm::ArrayRef<OpFoldResult> offsets) { 5739c697b3aSChao Chen auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType()); 5749c697b3aSChao Chen assert(tdescTy && "Expecting the source is a TensorDescType value."); 5759c697b3aSChao Chen auto loc = tensorDesc.getLoc(); 5769c697b3aSChao Chen int64_t size = static_cast<int64_t>(offsets.size()); 5779c697b3aSChao Chen auto type = VectorType::get({size}, builder.getIndexType()); 5789c697b3aSChao Chen auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); 5799c697b3aSChao Chen auto offset = builder.create<vector::FromElementsOp>(loc, type, values); 5809c697b3aSChao Chen build(builder, state, tdescTy, tensorDesc, offset); 5819c697b3aSChao Chen } 5829c697b3aSChao Chen 5839c697b3aSChao Chen void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, 5849c697b3aSChao Chen Value tensorDesc, llvm::ArrayRef<int64_t> offsets) { 5859c697b3aSChao Chen auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); 5869c697b3aSChao Chen build(builder, state, tensorDesc, ofrs); 5879c697b3aSChao Chen } 5889c697b3aSChao Chen 58903bb10dfSChao Chen //===----------------------------------------------------------------------===// 59003bb10dfSChao Chen // XeGPU_DpasOp 59103bb10dfSChao Chen //===----------------------------------------------------------------------===// 59203bb10dfSChao Chen LogicalResult DpasOp::verify() { 59303bb10dfSChao Chen int64_t lhsRank = getLhsType().getRank(); 59403bb10dfSChao Chen int64_t rhsRank = getRhsType().getRank(); 59503bb10dfSChao Chen 5966c783e19SChao Chen if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3)) 5976c783e19SChao Chen return emitOpError("expecting lhs to be a 2D vector, and rhs to be either " 5986c783e19SChao Chen "2D or 3D (packed) vector."); 59903bb10dfSChao Chen 60003bb10dfSChao Chen auto lhsShape = getLhsType().getShape(); 60103bb10dfSChao Chen auto rhsShape = getRhsType().getShape(); 6026c783e19SChao Chen auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0]; 6036c783e19SChao Chen if (bK != lhsShape[1]) 6046c783e19SChao Chen return emitOpError("K-dimension mismatch."); 60503bb10dfSChao Chen 60603bb10dfSChao Chen return success(); 60703bb10dfSChao Chen } 6085669660fSChao Chen 6095669660fSChao Chen } // namespace xegpu 6105669660fSChao Chen } // namespace mlir 6115669660fSChao Chen 6125669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc> 6135669660fSChao Chen #define GET_OP_CLASSES 6145669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> 615