15669660fSChao Chen //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===// 25669660fSChao Chen // 35669660fSChao Chen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45669660fSChao Chen // See https://llvm.org/LICENSE.txt for license information. 55669660fSChao Chen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65669660fSChao Chen // 75669660fSChao Chen //===----------------------------------------------------------------------===// 85669660fSChao Chen 961b24c61SChao Chen #include "mlir/Dialect/XeGPU/IR/XeGPU.h" 1061b24c61SChao Chen #include "mlir/IR/Builders.h" 1161b24c61SChao Chen #include "mlir/IR/DialectImplementation.h" 1261b24c61SChao Chen #include "llvm/ADT/TypeSwitch.h" 135669660fSChao Chen 145669660fSChao Chen namespace mlir { 155669660fSChao Chen namespace xegpu { 165669660fSChao Chen 175669660fSChao Chen void XeGPUDialect::initialize() { 185669660fSChao Chen addTypes< 195669660fSChao Chen #define GET_TYPEDEF_LIST 205669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc> 215669660fSChao Chen >(); 225669660fSChao Chen addOperations< 235669660fSChao Chen #define GET_OP_LIST 245669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> 255669660fSChao Chen >(); 265669660fSChao Chen addAttributes< 275669660fSChao Chen #define GET_ATTRDEF_LIST 285669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> 295669660fSChao Chen >(); 305669660fSChao Chen } 315669660fSChao Chen 3261b24c61SChao Chen //===----------------------------------------------------------------------===// 338b5e8414SChao Chen // XeGPU_BlockTensorDescAttr 3461b24c61SChao Chen //===----------------------------------------------------------------------===// 358b5e8414SChao Chen BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context, 368b5e8414SChao Chen xegpu::MemorySpace memory_space, 378b5e8414SChao Chen int array_length, 388b5e8414SChao Chen bool boundary_check) { 398b5e8414SChao Chen auto scopeAttr = MemorySpaceAttr::get(context, memory_space); 40b01879ecSChao Chen auto lengthAttr = 41b01879ecSChao Chen IntegerAttr::get(IntegerType::get(context, 64), array_length); 42b01879ecSChao Chen auto boundaryAttr = BoolAttr::get(context, boundary_check); 438b5e8414SChao Chen return Base::get(context, scopeAttr, lengthAttr, boundaryAttr); 448b5e8414SChao Chen } 458b5e8414SChao Chen 468b5e8414SChao Chen //===----------------------------------------------------------------------===// 478b5e8414SChao Chen // XeGPU_ScatterTensorDescAttr 488b5e8414SChao Chen //===----------------------------------------------------------------------===// 498b5e8414SChao Chen ScatterTensorDescAttr 508b5e8414SChao Chen ScatterTensorDescAttr::get(mlir::MLIRContext *context, 518b5e8414SChao Chen xegpu::MemorySpace memory_space, int chunk_size) { 528b5e8414SChao Chen auto scopeAttr = MemorySpaceAttr::get(context, memory_space); 538b5e8414SChao Chen auto chunkSizeAttr = 548b5e8414SChao Chen IntegerAttr::get(IntegerType::get(context, 64), chunk_size); 558b5e8414SChao Chen return Base::get(context, scopeAttr, chunkSizeAttr); 56b01879ecSChao Chen } 5761b24c61SChao Chen 5861b24c61SChao Chen //===----------------------------------------------------------------------===// 59*9fa55ec3SPetr Kurapov // XeGPU_SGMapAttr 60*9fa55ec3SPetr Kurapov //===----------------------------------------------------------------------===// 61*9fa55ec3SPetr Kurapov namespace { 62*9fa55ec3SPetr Kurapov template <typename T, unsigned N> 63*9fa55ec3SPetr Kurapov LogicalResult parseIntArrayField(::mlir::AsmParser &parser, 64*9fa55ec3SPetr Kurapov llvm::SmallVector<T, N> &result, 65*9fa55ec3SPetr Kurapov llvm::StringRef fieldName) { 66*9fa55ec3SPetr Kurapov if (failed(parser.parseKeyword(fieldName))) { 67*9fa55ec3SPetr Kurapov parser.emitError(parser.getCurrentLocation(), 68*9fa55ec3SPetr Kurapov "unexpected field name. Expected " + fieldName + "."); 69*9fa55ec3SPetr Kurapov return failure(); 70*9fa55ec3SPetr Kurapov } 71*9fa55ec3SPetr Kurapov 72*9fa55ec3SPetr Kurapov if (failed(parser.parseEqual())) { 73*9fa55ec3SPetr Kurapov parser.emitError(parser.getCurrentLocation(), "expected '=' sign."); 74*9fa55ec3SPetr Kurapov return failure(); 75*9fa55ec3SPetr Kurapov } 76*9fa55ec3SPetr Kurapov 77*9fa55ec3SPetr Kurapov auto elemParser = [&]() -> llvm::ParseResult { 78*9fa55ec3SPetr Kurapov uint32_t elem = 0; 79*9fa55ec3SPetr Kurapov auto res = parser.parseInteger(elem); 80*9fa55ec3SPetr Kurapov result.push_back(elem); 81*9fa55ec3SPetr Kurapov return res; 82*9fa55ec3SPetr Kurapov }; 83*9fa55ec3SPetr Kurapov 84*9fa55ec3SPetr Kurapov return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, 85*9fa55ec3SPetr Kurapov elemParser, fieldName); 86*9fa55ec3SPetr Kurapov } 87*9fa55ec3SPetr Kurapov } // namespace 88*9fa55ec3SPetr Kurapov 89*9fa55ec3SPetr Kurapov mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser, 90*9fa55ec3SPetr Kurapov ::mlir::Type attrType) { 91*9fa55ec3SPetr Kurapov if (failed(parser.parseLess())) 92*9fa55ec3SPetr Kurapov return {}; 93*9fa55ec3SPetr Kurapov 94*9fa55ec3SPetr Kurapov llvm::SmallVector<uint32_t, 2> wi_layout, wi_data; 95*9fa55ec3SPetr Kurapov if (failed(parseIntArrayField(parser, wi_layout, "wi_layout"))) 96*9fa55ec3SPetr Kurapov return {}; 97*9fa55ec3SPetr Kurapov 98*9fa55ec3SPetr Kurapov if (failed(parser.parseComma())) 99*9fa55ec3SPetr Kurapov return {}; 100*9fa55ec3SPetr Kurapov 101*9fa55ec3SPetr Kurapov if (failed(parseIntArrayField(parser, wi_data, "wi_data"))) 102*9fa55ec3SPetr Kurapov return {}; 103*9fa55ec3SPetr Kurapov 104*9fa55ec3SPetr Kurapov return SGMapAttr::getChecked( 105*9fa55ec3SPetr Kurapov [&]() { return parser.emitError(parser.getNameLoc()); }, 106*9fa55ec3SPetr Kurapov parser.getContext(), wi_layout, wi_data); 107*9fa55ec3SPetr Kurapov } 108*9fa55ec3SPetr Kurapov 109*9fa55ec3SPetr Kurapov void SGMapAttr::print(::mlir::AsmPrinter &printer) const { 110*9fa55ec3SPetr Kurapov printer << "<"; 111*9fa55ec3SPetr Kurapov printer.printKeywordOrString("wi_layout"); 112*9fa55ec3SPetr Kurapov printer << " = [" << getWiLayout() << "], "; 113*9fa55ec3SPetr Kurapov printer.printKeywordOrString("wi_data"); 114*9fa55ec3SPetr Kurapov printer << " = [" << getWiData() << "]"; 115*9fa55ec3SPetr Kurapov printer << ">"; 116*9fa55ec3SPetr Kurapov } 117*9fa55ec3SPetr Kurapov 118*9fa55ec3SPetr Kurapov LogicalResult 119*9fa55ec3SPetr Kurapov SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 120*9fa55ec3SPetr Kurapov llvm::ArrayRef<uint32_t> wi_layout, 121*9fa55ec3SPetr Kurapov llvm::ArrayRef<uint32_t> wi_data) { 122*9fa55ec3SPetr Kurapov if (wi_layout.size() != 2) 123*9fa55ec3SPetr Kurapov return emitError() << "expected wi_layout of size 2"; 124*9fa55ec3SPetr Kurapov if (wi_data.size() != 2) 125*9fa55ec3SPetr Kurapov return emitError() << "expected wi_data of size 2"; 126*9fa55ec3SPetr Kurapov return success(); 127*9fa55ec3SPetr Kurapov } 128*9fa55ec3SPetr Kurapov 129*9fa55ec3SPetr Kurapov //===----------------------------------------------------------------------===// 13061b24c61SChao Chen // XeGPU_TensorDescType 13161b24c61SChao Chen //===----------------------------------------------------------------------===// 1328b5e8414SChao Chen 13361b24c61SChao Chen mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { 13461b24c61SChao Chen llvm::SmallVector<int64_t> shape; 13561b24c61SChao Chen mlir::Type elementType; 13661b24c61SChao Chen mlir::FailureOr<mlir::Attribute> encoding; 137*9fa55ec3SPetr Kurapov mlir::FailureOr<mlir::Attribute> sg_map; 13861b24c61SChao Chen 13961b24c61SChao Chen // Parse literal '<' 14061b24c61SChao Chen if (parser.parseLess()) 14161b24c61SChao Chen return {}; 14261b24c61SChao Chen 14361b24c61SChao Chen auto shapeLoc = parser.getCurrentLocation(); 14461b24c61SChao Chen if (mlir::failed(parser.parseDimensionList(shape))) { 14561b24c61SChao Chen parser.emitError(shapeLoc, "failed to parse parameter 'shape'"); 14661b24c61SChao Chen return {}; 14761b24c61SChao Chen } 14861b24c61SChao Chen 14961b24c61SChao Chen auto elemTypeLoc = parser.getCurrentLocation(); 15061b24c61SChao Chen if (mlir::failed(parser.parseType(elementType))) { 15161b24c61SChao Chen parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'"); 15261b24c61SChao Chen return {}; 15361b24c61SChao Chen } 15461b24c61SChao Chen 15561b24c61SChao Chen // parse optional attributes 156*9fa55ec3SPetr Kurapov while (mlir::succeeded(parser.parseOptionalComma())) { 157*9fa55ec3SPetr Kurapov mlir::Attribute attr; 158*9fa55ec3SPetr Kurapov ParseResult res = parser.parseAttribute(attr); 159*9fa55ec3SPetr Kurapov if (mlir::succeeded(res)) { 160*9fa55ec3SPetr Kurapov if (mlir::isa<SGMapAttr>(attr)) { 161*9fa55ec3SPetr Kurapov sg_map = attr; 162*9fa55ec3SPetr Kurapov continue; 16361b24c61SChao Chen } 164*9fa55ec3SPetr Kurapov if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) { 165*9fa55ec3SPetr Kurapov encoding = attr; 166*9fa55ec3SPetr Kurapov continue; 167*9fa55ec3SPetr Kurapov } 168*9fa55ec3SPetr Kurapov } 169*9fa55ec3SPetr Kurapov parser.emitError(parser.getCurrentLocation(), 170*9fa55ec3SPetr Kurapov "Failed to parse the attribute.\n"); 171*9fa55ec3SPetr Kurapov return {}; 172a7b6fdafSChao Chen } 17361b24c61SChao Chen 17461b24c61SChao Chen // Parse literal '>' 17561b24c61SChao Chen if (parser.parseGreater()) 17661b24c61SChao Chen return {}; 17761b24c61SChao Chen 17861b24c61SChao Chen return TensorDescType::get(parser.getContext(), shape, elementType, 179*9fa55ec3SPetr Kurapov encoding.value_or(mlir::Attribute()), 180*9fa55ec3SPetr Kurapov sg_map.value_or(mlir::Attribute())); 18161b24c61SChao Chen } 18261b24c61SChao Chen 18361b24c61SChao Chen void TensorDescType::print(::mlir::AsmPrinter &printer) const { 18461b24c61SChao Chen printer << "<"; 18561b24c61SChao Chen 18661b24c61SChao Chen auto shape = getShape(); 18761b24c61SChao Chen for (int64_t dim : shape) { 18861b24c61SChao Chen if (mlir::ShapedType::isDynamic(dim)) 18961b24c61SChao Chen printer << '?'; 19061b24c61SChao Chen else 19161b24c61SChao Chen printer << dim; 19261b24c61SChao Chen printer << 'x'; 19361b24c61SChao Chen } 19461b24c61SChao Chen 19561b24c61SChao Chen printer << getElementType(); 19661b24c61SChao Chen 19761b24c61SChao Chen if (auto encoding = getEncoding()) 19861b24c61SChao Chen printer << ", " << encoding; 19961b24c61SChao Chen 200*9fa55ec3SPetr Kurapov if (auto sg_map = getSgMap()) 201*9fa55ec3SPetr Kurapov printer << ", " << sg_map; 202*9fa55ec3SPetr Kurapov 20361b24c61SChao Chen printer << ">"; 20461b24c61SChao Chen } 2055669660fSChao Chen 206b01879ecSChao Chen TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, 2078b5e8414SChao Chen mlir::Type elementType, int array_length, 2088b5e8414SChao Chen bool boundary_check, 209*9fa55ec3SPetr Kurapov MemorySpace memory_space, 210*9fa55ec3SPetr Kurapov mlir::Attribute sg_map) { 211b01879ecSChao Chen auto context = elementType.getContext(); 2128b5e8414SChao Chen auto attr = BlockTensorDescAttr::get(context, memory_space, array_length, 2138b5e8414SChao Chen boundary_check); 214*9fa55ec3SPetr Kurapov return Base::get(context, shape, elementType, attr, sg_map); 2158b5e8414SChao Chen } 2168b5e8414SChao Chen 2178b5e8414SChao Chen TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, 2188b5e8414SChao Chen mlir::Type elementType, int chunk_size, 219*9fa55ec3SPetr Kurapov MemorySpace memory_space, 220*9fa55ec3SPetr Kurapov mlir::Attribute sg_map) { 2218b5e8414SChao Chen auto context = elementType.getContext(); 2228b5e8414SChao Chen auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size); 223*9fa55ec3SPetr Kurapov return Base::get(context, shape, elementType, attr, sg_map); 224b01879ecSChao Chen } 225b01879ecSChao Chen 2265669660fSChao Chen } // namespace xegpu 2275669660fSChao Chen } // namespace mlir 2285669660fSChao Chen 2295669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc> 2305669660fSChao Chen #define GET_ATTRDEF_CLASSES 2315669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> 2325669660fSChao Chen #define GET_TYPEDEF_CLASSES 2335669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc> 234