//===- XeGPUOps.cpp - MLIR XeGPU ops implementation -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "xegpu" namespace mlir { namespace xegpu { static void transpose(llvm::ArrayRef trans, SmallVector &shape) { SmallVector old = shape; for (size_t i = 0; i < trans.size(); i++) shape[i] = old[trans[i]]; } template static std::string makeString(T array, bool breakline = false) { std::string buf; buf.clear(); llvm::raw_string_ostream os(buf); os << "["; for (size_t i = 1; i < array.size(); i++) { os << array[i - 1] << ", "; if (breakline) os << "\n\t\t"; } os << array.back() << "]"; return buf; } static SmallVector getShapeOf(Type type) { SmallVector shape; if (auto ty = llvm::dyn_cast(type)) shape = SmallVector(ty.getShape()); else shape.push_back(1); return shape; } static int64_t getRankOf(Value val) { auto type = val.getType(); if (auto ty = llvm::dyn_cast(type)) return ty.getRank(); return 0; } static bool isReadHintOrNone(const CachePolicyAttr &attr) { if (!attr) return true; auto kind = attr.getValue(); return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE; } static bool isWriteHintOrNone(const CachePolicyAttr &attr) { if (!attr) return true; auto kind = attr.getValue(); return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; } // Validations for nd instruction arguments is successful if any of these are // true: // - tensor descriptor and the output vector shapes exactly match. // - tensor descriptor has a sg_map attribute and the distributed vector shape // matches the tensor descriptor shape when scaled using sg_map factors on // each dimension. static bool isArgShapesValid(ArrayRef descShape, ArrayRef valShape, SGMapAttr sgMap) { if (descShape == valShape) { if (!sgMap) return true; // this can be relaxed if necessary by supporting non-2d shapes distribution // until the constraints are defined this lives here instead of the tensor // descriptor type. return valShape.size() == sgMap.getWiLayout().size(); } if (!sgMap) return false; if (valShape.size() != descShape.size()) return false; for (const auto &[factor, dim, expected] : llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) { if (factor * dim != expected) return false; } return true; } //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, TypedValue source, llvm::ArrayRef offsets) { [[maybe_unused]] auto ty = source.getType(); assert(ty.hasStaticShape() && offsets.size() == (size_t)ty.getRank()); llvm::SmallVector staticOffsets; llvm::SmallVector dynamicOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, ValueRange({}) /* empty dynamic shape */, ValueRange({}) /* empty dynamic strides */, staticOffsets /* const offsets */, {} /* empty const shape*/, {} /* empty const strides*/); } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, TypedValue source, llvm::ArrayRef offsets, llvm::ArrayRef shape, llvm::ArrayRef strides) { assert(shape.size() && offsets.size() && strides.size() && shape.size() == strides.size() && shape.size() == offsets.size()); llvm::SmallVector staticOffsets; llvm::SmallVector staticShape; llvm::SmallVector staticStrides; llvm::SmallVector dynamicOffsets; llvm::SmallVector dynamicShape; llvm::SmallVector dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, TypedValue source, llvm::ArrayRef offsets, llvm::ArrayRef shape, llvm::ArrayRef strides) { assert(shape.size() && offsets.size() && strides.size() && shape.size() == strides.size() && shape.size() == offsets.size()); llvm::SmallVector staticOffsets; llvm::SmallVector staticShape; llvm::SmallVector staticStrides; llvm::SmallVector dynamicOffsets; llvm::SmallVector dynamicShape; llvm::SmallVector dynamicStrides; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); build(builder, state, tdesc, source, dynamicOffsets, dynamicShape, dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr); } LogicalResult CreateNdDescOp::verify() { auto rank = (int64_t)getMixedOffsets().size(); bool invalidRank = false; bool invalidElemTy = false; // Memory space of created TensorDesc should match with the source. // Both source and TensorDesc are considered for global memory by default, // if the memory scope attr is not specified. If source is an integer, // it is considered as ptr to global memory. auto srcMemorySpace = getSourceMemorySpace(); auto tdescMemorySpace = static_cast(getType().getMemorySpace()); if (srcMemorySpace != tdescMemorySpace) return emitOpError("Memory space mismatch.") << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; // check source type matches the rank if it is a memref. // It also should have the same ElementType as TensorDesc. auto memrefTy = dyn_cast(getSourceType()); if (memrefTy) { invalidRank |= (memrefTy.getRank() != rank); invalidElemTy |= memrefTy.getElementType() != getElementType(); } // mismatches among shape, strides, and offsets are // already handeled by OffsetSizeAndStrideOpInterface. // So they are not check here. if (invalidRank) return emitOpError( "Expecting the rank of shape, strides, offsets, and source (if source " "is a memref) should match with each other."); // check result TensorDesc rank invalidRank = (getType().getRank() > 2 || getType().getRank() > rank); if (invalidRank) return emitOpError( "Expecting the TensorDesc rank is up to 2 and not greater than the " "ranks of shape, strides, offsets or the memref source."); if (invalidElemTy) return emitOpError("TensorDesc should have the same element " "type with the source if it is a memref.\n"); if (getType().isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); if (getType().getRank() == 2 && tdescMemorySpace == static_cast(MemorySpace::SLM)) return emitOpError("SLM is not supported for 2D Block TensorDesc.\n"); return success(); } //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp //===----------------------------------------------------------------------===// LogicalResult PrefetchNdOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); return success(); } //===----------------------------------------------------------------------===// // XeGPU_LoadNdOp //===----------------------------------------------------------------------===// LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); if (tdescTy.getRank() > 2) return emitOpError("Expecting a 1D/2D TensorDesc.\n"); if (tdescTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); if (!valueTy) return emitOpError("Invalid result, it should be a VectorType.\n"); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto array_len = tdescTy.getArrayLength(); auto tdescShape = getShapeOf(tdescTy); auto valueShape = getShapeOf(valueTy); if (getTranspose()) { auto trans = getTranspose().value(); // Make sure the transpose value is valid. bool valid = std::all_of(trans.begin(), trans.end(), [&](int t) { return t >= 0 && t < tdescTy.getRank(); }); if (valid) transpose(trans, tdescShape); else mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored."; } if (getPacked()) { if (tdescTy.getRank() == 2) { const int axis = 0; auto vnni_factor = valueShape.back(); tdescShape[axis] /= vnni_factor; tdescShape.push_back(vnni_factor); } else { mlir::emitWarning(getLoc()) << "Invalid Packed Attr. It is ignored (available for 2D " "TensorDesc only)."; } } if (array_len > 1) { auto it = tdescShape.begin(); tdescShape.insert(it, array_len); } auto sgMap = tdescTy.getSGMapAttr(); if (!isArgShapesValid(tdescShape, valueShape, sgMap)) return emitOpError() << "Result shape doesn't match TensorDesc shape." << "The expected shape is " << makeString(tdescShape) << ". But the given shape is " << makeString(valueShape) << ".\n"; return success(); } //===----------------------------------------------------------------------===// // XeGPU_StoreNdOp //===----------------------------------------------------------------------===// LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector if (dstTy.getRank() > 2) return emitOpError("Expecting a 1D/2D TensorDesc.\n"); if (dstTy.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); if (!valTy) return emitOpError("Expecting a VectorType result.\n"); if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isWriteHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto tdescShape = getShapeOf(dstTy); auto valueShape = getShapeOf(valTy); auto sgMap = dstTy.getSGMapAttr(); if (!isArgShapesValid(tdescShape, valueShape, sgMap)) return emitOpError() << "Result shape doesn't match TensorDesc shape." << "The expected shape is " << makeString(tdescShape) << ". But the given shape is " << makeString(valueShape) << ".\n"; return success(); } //===----------------------------------------------------------------------===// // XeGPU_UpdateNDOffsetOp //===----------------------------------------------------------------------===// LogicalResult UpdateNdOffsetOp::verify() { auto ty = getTensorDescType(); if (ty.isScattered()) return emitOpError("Expects a non-scattered TensorDesc.\n"); // number of offsets specified must match the rank of the tensor descriptor if (ty.getRank() != (int64_t)getNumOffsets()) { return emitOpError("Invalid number of offsets."); } return success(); } //===----------------------------------------------------------------------===// // XeGPU_CreateDescOp //===----------------------------------------------------------------------===// void CreateDescOp::build(OpBuilder &builder, OperationState &state, TensorDescType TensorDesc, Value source, llvm::ArrayRef offsets) { auto loc = source.getLoc(); int64_t size = static_cast(offsets.size()); auto type = VectorType::get(size, builder.getIndexType()); auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); auto offset = builder.create(loc, type, values); build(builder, state, TensorDesc, source, offset); } void CreateDescOp::build(OpBuilder &builder, OperationState &state, TensorDescType TensorDesc, Value source, llvm::ArrayRef offsets) { auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); build(builder, state, TensorDesc, source, ofrs); } LogicalResult CreateDescOp::verify() { auto tdescTy = getTensorDescType(); if (getRankOf(getSource()) > 1) return emitOpError( "Expecting the source is a 1D memref or pointer (uint64_t)."); if (!tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); // Memory space of created TensorDesc should match with the source. // Both source and TensorDesc are considered for global memory by default, // if the memory scope attr is not specified. If source is an integer, // it is considered as ptr to global memory. auto srcMemorySpace = getSourceMemorySpace(); auto tdescMemorySpace = static_cast(tdescTy.getMemorySpace()); if (srcMemorySpace != tdescMemorySpace) return emitOpError("Memory space mismatch.") << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; auto chunkSize = tdescTy.getChunkSize(); // check chunk_size llvm::SmallVector supportedChunkSizes = {1, 2, 3, 4, 8, 16, 32, 64, 128, 256}; if (!llvm::is_contained(supportedChunkSizes, chunkSize)) return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, " "8, 16, 32, 64, 128, or 256."); // check total size auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth(); auto bitsPerLane = elemBits * chunkSize; if (chunkSize > 1 && bitsPerLane % 32) { // For 8-bit and 16-bit data, the hardware only supports chunk size of 1. // For 32-bit data, the hardware can support larger larger chunk size. So // we can bitcast 8-bit/16-bit data to 32-bit data for better performance. // But this requires the total size is 32 bit aligned to make the // optimization work. return emitOpError( "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned."); } auto lscConstraints = 512 * 8; // each access is upto 512 bytes. if (elemBits * tdescTy.getNumElements() > lscConstraints) return emitOpError("total access size (simd_lanes * chunk_size * " "sizeof(elemTy)) is upto 512 bytes."); SmallVector shape({(int64_t)getNumOffsets()}); if (chunkSize != 1) shape.push_back(chunkSize); auto tdescShape = getShapeOf(tdescTy); if (shape != tdescShape) return emitOpError("Incorrect TensorDesc shape. ") << "Expected is " << makeString(shape) << "\n"; return success(); } //===----------------------------------------------------------------------===// // XeGPU_PrefetchOp //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); if (!tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); return success(); } //===----------------------------------------------------------------------===// // XeGPU_LoadGatherOp //===----------------------------------------------------------------------===// LogicalResult LoadGatherOp::verify() { auto tdescTy = getTensorDescType(); auto maskTy = getMaskType(); auto valueTy = getValueType(); if (!tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto tdescElemTy = tdescTy.getElementType(); auto valueElemTy = getElementType(); if (tdescElemTy != valueElemTy) return emitOpError( "Value should have the same element type as TensorDesc."); auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); auto tdescShape = getShapeOf(tdescTy); if (tdescShape[0] != maskShape[0]) return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); if (tdescTy.getRank() == 2) { if (!getTransposeAttr()) return emitOpError("load_gather has to be transposed."); transpose({1, 0}, tdescShape); } if (valueShape != tdescShape) return emitOpError("Unexpected result shape") << "(Expected shape: " << makeString(tdescShape) << ", Given shape: " << makeString(valueShape) << ").\n"; return success(); } //===----------------------------------------------------------------------===// // XeGPU_StoreScatterOp //===----------------------------------------------------------------------===// LogicalResult StoreScatterOp::verify() { auto tdescTy = getTensorDescType(); if (!tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isWriteHintOrNone(getL2HintAttr())) return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto maskTy = getMaskType(); auto valueTy = getValueType(); auto maskShape = getShapeOf(maskTy); auto tdescShape = getShapeOf(tdescTy); auto valueShape = getShapeOf(valueTy); if (tdescShape[0] != maskShape[0]) return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); if (tdescTy.getRank() == 2) { if (!getTransposeAttr()) return emitOpError("load_gather has to be transposed."); transpose({1, 0}, tdescShape); } if (valueShape != tdescShape) return emitOpError("Unexpected value shape") << "(Expected shape: " << makeString(tdescShape) << ", Given shape: " << makeString(valueShape) << ").\n"; return success(); } //===----------------------------------------------------------------------===// // XeGPU_UpdateOffsetOp //===----------------------------------------------------------------------===// void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, mlir::Value tensorDesc, llvm::ArrayRef offsets) { auto tdescTy = mlir::dyn_cast(tensorDesc.getType()); assert(tdescTy && "Expecting the source is a TensorDescType value."); auto loc = tensorDesc.getLoc(); int64_t size = static_cast(offsets.size()); auto type = VectorType::get({size}, builder.getIndexType()); auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); auto offset = builder.create(loc, type, values); build(builder, state, tdescTy, tensorDesc, offset); } void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, llvm::ArrayRef offsets) { auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets); build(builder, state, tensorDesc, ofrs); } //===----------------------------------------------------------------------===// // XeGPU_DpasOp //===----------------------------------------------------------------------===// LogicalResult DpasOp::verify() { int64_t lhsRank = getLhsType().getRank(); int64_t rhsRank = getRhsType().getRank(); if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3)) return emitOpError("expecting lhs to be a 2D vector, and rhs to be either " "2D or 3D (packed) vector."); auto lhsShape = getLhsType().getShape(); auto rhsShape = getRhsType().getShape(); auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0]; if (bK != lhsShape[1]) return emitOpError("K-dimension mismatch."); return success(); } } // namespace xegpu } // namespace mlir #include #define GET_OP_CLASSES #include