1 //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/XeGPU/IR/XeGPU.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 namespace mlir { 15 namespace xegpu { 16 17 void XeGPUDialect::initialize() { 18 addTypes< 19 #define GET_TYPEDEF_LIST 20 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc> 21 >(); 22 addOperations< 23 #define GET_OP_LIST 24 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> 25 >(); 26 addAttributes< 27 #define GET_ATTRDEF_LIST 28 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> 29 >(); 30 } 31 32 //===----------------------------------------------------------------------===// 33 // XeGPU_BlockTensorDescAttr 34 //===----------------------------------------------------------------------===// 35 BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context, 36 xegpu::MemorySpace memory_space, 37 int array_length, 38 bool boundary_check) { 39 auto scopeAttr = MemorySpaceAttr::get(context, memory_space); 40 auto lengthAttr = 41 IntegerAttr::get(IntegerType::get(context, 64), array_length); 42 auto boundaryAttr = BoolAttr::get(context, boundary_check); 43 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // XeGPU_ScatterTensorDescAttr 48 //===----------------------------------------------------------------------===// 49 ScatterTensorDescAttr 50 ScatterTensorDescAttr::get(mlir::MLIRContext *context, 51 xegpu::MemorySpace memory_space, int chunk_size) { 52 auto scopeAttr = MemorySpaceAttr::get(context, memory_space); 53 auto chunkSizeAttr = 54 IntegerAttr::get(IntegerType::get(context, 64), chunk_size); 55 return Base::get(context, scopeAttr, chunkSizeAttr); 56 } 57 58 //===----------------------------------------------------------------------===// 59 // XeGPU_SGMapAttr 60 //===----------------------------------------------------------------------===// 61 namespace { 62 template <typename T, unsigned N> 63 LogicalResult parseIntArrayField(::mlir::AsmParser &parser, 64 llvm::SmallVector<T, N> &result, 65 llvm::StringRef fieldName) { 66 if (failed(parser.parseKeyword(fieldName))) { 67 parser.emitError(parser.getCurrentLocation(), 68 "unexpected field name. Expected " + fieldName + "."); 69 return failure(); 70 } 71 72 if (failed(parser.parseEqual())) { 73 parser.emitError(parser.getCurrentLocation(), "expected '=' sign."); 74 return failure(); 75 } 76 77 auto elemParser = [&]() -> llvm::ParseResult { 78 uint32_t elem = 0; 79 auto res = parser.parseInteger(elem); 80 result.push_back(elem); 81 return res; 82 }; 83 84 return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, 85 elemParser, fieldName); 86 } 87 } // namespace 88 89 mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser, 90 ::mlir::Type attrType) { 91 if (failed(parser.parseLess())) 92 return {}; 93 94 llvm::SmallVector<uint32_t, 2> wi_layout, wi_data; 95 if (failed(parseIntArrayField(parser, wi_layout, "wi_layout"))) 96 return {}; 97 98 if (failed(parser.parseComma())) 99 return {}; 100 101 if (failed(parseIntArrayField(parser, wi_data, "wi_data"))) 102 return {}; 103 104 return SGMapAttr::getChecked( 105 [&]() { return parser.emitError(parser.getNameLoc()); }, 106 parser.getContext(), wi_layout, wi_data); 107 } 108 109 void SGMapAttr::print(::mlir::AsmPrinter &printer) const { 110 printer << "<"; 111 printer.printKeywordOrString("wi_layout"); 112 printer << " = [" << getWiLayout() << "], "; 113 printer.printKeywordOrString("wi_data"); 114 printer << " = [" << getWiData() << "]"; 115 printer << ">"; 116 } 117 118 LogicalResult 119 SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, 120 llvm::ArrayRef<uint32_t> wi_layout, 121 llvm::ArrayRef<uint32_t> wi_data) { 122 if (wi_layout.size() != 2) 123 return emitError() << "expected wi_layout of size 2"; 124 if (wi_data.size() != 2) 125 return emitError() << "expected wi_data of size 2"; 126 return success(); 127 } 128 129 //===----------------------------------------------------------------------===// 130 // XeGPU_TensorDescType 131 //===----------------------------------------------------------------------===// 132 133 mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { 134 llvm::SmallVector<int64_t> shape; 135 mlir::Type elementType; 136 mlir::FailureOr<mlir::Attribute> encoding; 137 mlir::FailureOr<mlir::Attribute> sg_map; 138 139 // Parse literal '<' 140 if (parser.parseLess()) 141 return {}; 142 143 auto shapeLoc = parser.getCurrentLocation(); 144 if (mlir::failed(parser.parseDimensionList(shape))) { 145 parser.emitError(shapeLoc, "failed to parse parameter 'shape'"); 146 return {}; 147 } 148 149 auto elemTypeLoc = parser.getCurrentLocation(); 150 if (mlir::failed(parser.parseType(elementType))) { 151 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'"); 152 return {}; 153 } 154 155 // parse optional attributes 156 while (mlir::succeeded(parser.parseOptionalComma())) { 157 mlir::Attribute attr; 158 ParseResult res = parser.parseAttribute(attr); 159 if (mlir::succeeded(res)) { 160 if (mlir::isa<SGMapAttr>(attr)) { 161 sg_map = attr; 162 continue; 163 } 164 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) { 165 encoding = attr; 166 continue; 167 } 168 } 169 parser.emitError(parser.getCurrentLocation(), 170 "Failed to parse the attribute.\n"); 171 return {}; 172 } 173 174 // Parse literal '>' 175 if (parser.parseGreater()) 176 return {}; 177 178 return TensorDescType::get(parser.getContext(), shape, elementType, 179 encoding.value_or(mlir::Attribute()), 180 sg_map.value_or(mlir::Attribute())); 181 } 182 183 void TensorDescType::print(::mlir::AsmPrinter &printer) const { 184 printer << "<"; 185 186 auto shape = getShape(); 187 for (int64_t dim : shape) { 188 if (mlir::ShapedType::isDynamic(dim)) 189 printer << '?'; 190 else 191 printer << dim; 192 printer << 'x'; 193 } 194 195 printer << getElementType(); 196 197 if (auto encoding = getEncoding()) 198 printer << ", " << encoding; 199 200 if (auto sg_map = getSgMap()) 201 printer << ", " << sg_map; 202 203 printer << ">"; 204 } 205 206 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, 207 mlir::Type elementType, int array_length, 208 bool boundary_check, 209 MemorySpace memory_space, 210 mlir::Attribute sg_map) { 211 auto context = elementType.getContext(); 212 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length, 213 boundary_check); 214 return Base::get(context, shape, elementType, attr, sg_map); 215 } 216 217 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, 218 mlir::Type elementType, int chunk_size, 219 MemorySpace memory_space, 220 mlir::Attribute sg_map) { 221 auto context = elementType.getContext(); 222 auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size); 223 return Base::get(context, shape, elementType, attr, sg_map); 224 } 225 226 } // namespace xegpu 227 } // namespace mlir 228 229 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc> 230 #define GET_ATTRDEF_CLASSES 231 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> 232 #define GET_TYPEDEF_CLASSES 233 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc> 234