xref: /llvm-project/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (revision 9fa55ec3d93435a790f9990b1c6565e3ee689b2c)
15669660fSChao Chen //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
25669660fSChao Chen //
35669660fSChao Chen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45669660fSChao Chen // See https://llvm.org/LICENSE.txt for license information.
55669660fSChao Chen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65669660fSChao Chen //
75669660fSChao Chen //===----------------------------------------------------------------------===//
85669660fSChao Chen 
961b24c61SChao Chen #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1061b24c61SChao Chen #include "mlir/IR/Builders.h"
1161b24c61SChao Chen #include "mlir/IR/DialectImplementation.h"
1261b24c61SChao Chen #include "llvm/ADT/TypeSwitch.h"
135669660fSChao Chen 
145669660fSChao Chen namespace mlir {
155669660fSChao Chen namespace xegpu {
165669660fSChao Chen 
175669660fSChao Chen void XeGPUDialect::initialize() {
185669660fSChao Chen   addTypes<
195669660fSChao Chen #define GET_TYPEDEF_LIST
205669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
215669660fSChao Chen       >();
225669660fSChao Chen   addOperations<
235669660fSChao Chen #define GET_OP_LIST
245669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
255669660fSChao Chen       >();
265669660fSChao Chen   addAttributes<
275669660fSChao Chen #define GET_ATTRDEF_LIST
285669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
295669660fSChao Chen       >();
305669660fSChao Chen }
315669660fSChao Chen 
3261b24c61SChao Chen //===----------------------------------------------------------------------===//
338b5e8414SChao Chen // XeGPU_BlockTensorDescAttr
3461b24c61SChao Chen //===----------------------------------------------------------------------===//
358b5e8414SChao Chen BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
368b5e8414SChao Chen                                              xegpu::MemorySpace memory_space,
378b5e8414SChao Chen                                              int array_length,
388b5e8414SChao Chen                                              bool boundary_check) {
398b5e8414SChao Chen   auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
40b01879ecSChao Chen   auto lengthAttr =
41b01879ecSChao Chen       IntegerAttr::get(IntegerType::get(context, 64), array_length);
42b01879ecSChao Chen   auto boundaryAttr = BoolAttr::get(context, boundary_check);
438b5e8414SChao Chen   return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
448b5e8414SChao Chen }
458b5e8414SChao Chen 
468b5e8414SChao Chen //===----------------------------------------------------------------------===//
478b5e8414SChao Chen // XeGPU_ScatterTensorDescAttr
488b5e8414SChao Chen //===----------------------------------------------------------------------===//
498b5e8414SChao Chen ScatterTensorDescAttr
508b5e8414SChao Chen ScatterTensorDescAttr::get(mlir::MLIRContext *context,
518b5e8414SChao Chen                            xegpu::MemorySpace memory_space, int chunk_size) {
528b5e8414SChao Chen   auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
538b5e8414SChao Chen   auto chunkSizeAttr =
548b5e8414SChao Chen       IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
558b5e8414SChao Chen   return Base::get(context, scopeAttr, chunkSizeAttr);
56b01879ecSChao Chen }
5761b24c61SChao Chen 
5861b24c61SChao Chen //===----------------------------------------------------------------------===//
59*9fa55ec3SPetr Kurapov // XeGPU_SGMapAttr
60*9fa55ec3SPetr Kurapov //===----------------------------------------------------------------------===//
61*9fa55ec3SPetr Kurapov namespace {
62*9fa55ec3SPetr Kurapov template <typename T, unsigned N>
63*9fa55ec3SPetr Kurapov LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
64*9fa55ec3SPetr Kurapov                                  llvm::SmallVector<T, N> &result,
65*9fa55ec3SPetr Kurapov                                  llvm::StringRef fieldName) {
66*9fa55ec3SPetr Kurapov   if (failed(parser.parseKeyword(fieldName))) {
67*9fa55ec3SPetr Kurapov     parser.emitError(parser.getCurrentLocation(),
68*9fa55ec3SPetr Kurapov                      "unexpected field name. Expected " + fieldName + ".");
69*9fa55ec3SPetr Kurapov     return failure();
70*9fa55ec3SPetr Kurapov   }
71*9fa55ec3SPetr Kurapov 
72*9fa55ec3SPetr Kurapov   if (failed(parser.parseEqual())) {
73*9fa55ec3SPetr Kurapov     parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
74*9fa55ec3SPetr Kurapov     return failure();
75*9fa55ec3SPetr Kurapov   }
76*9fa55ec3SPetr Kurapov 
77*9fa55ec3SPetr Kurapov   auto elemParser = [&]() -> llvm::ParseResult {
78*9fa55ec3SPetr Kurapov     uint32_t elem = 0;
79*9fa55ec3SPetr Kurapov     auto res = parser.parseInteger(elem);
80*9fa55ec3SPetr Kurapov     result.push_back(elem);
81*9fa55ec3SPetr Kurapov     return res;
82*9fa55ec3SPetr Kurapov   };
83*9fa55ec3SPetr Kurapov 
84*9fa55ec3SPetr Kurapov   return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
85*9fa55ec3SPetr Kurapov                                         elemParser, fieldName);
86*9fa55ec3SPetr Kurapov }
87*9fa55ec3SPetr Kurapov } // namespace
88*9fa55ec3SPetr Kurapov 
89*9fa55ec3SPetr Kurapov mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
90*9fa55ec3SPetr Kurapov                                  ::mlir::Type attrType) {
91*9fa55ec3SPetr Kurapov   if (failed(parser.parseLess()))
92*9fa55ec3SPetr Kurapov     return {};
93*9fa55ec3SPetr Kurapov 
94*9fa55ec3SPetr Kurapov   llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
95*9fa55ec3SPetr Kurapov   if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
96*9fa55ec3SPetr Kurapov     return {};
97*9fa55ec3SPetr Kurapov 
98*9fa55ec3SPetr Kurapov   if (failed(parser.parseComma()))
99*9fa55ec3SPetr Kurapov     return {};
100*9fa55ec3SPetr Kurapov 
101*9fa55ec3SPetr Kurapov   if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
102*9fa55ec3SPetr Kurapov     return {};
103*9fa55ec3SPetr Kurapov 
104*9fa55ec3SPetr Kurapov   return SGMapAttr::getChecked(
105*9fa55ec3SPetr Kurapov       [&]() { return parser.emitError(parser.getNameLoc()); },
106*9fa55ec3SPetr Kurapov       parser.getContext(), wi_layout, wi_data);
107*9fa55ec3SPetr Kurapov }
108*9fa55ec3SPetr Kurapov 
109*9fa55ec3SPetr Kurapov void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
110*9fa55ec3SPetr Kurapov   printer << "<";
111*9fa55ec3SPetr Kurapov   printer.printKeywordOrString("wi_layout");
112*9fa55ec3SPetr Kurapov   printer << " = [" << getWiLayout() << "], ";
113*9fa55ec3SPetr Kurapov   printer.printKeywordOrString("wi_data");
114*9fa55ec3SPetr Kurapov   printer << " = [" << getWiData() << "]";
115*9fa55ec3SPetr Kurapov   printer << ">";
116*9fa55ec3SPetr Kurapov }
117*9fa55ec3SPetr Kurapov 
118*9fa55ec3SPetr Kurapov LogicalResult
119*9fa55ec3SPetr Kurapov SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120*9fa55ec3SPetr Kurapov                   llvm::ArrayRef<uint32_t> wi_layout,
121*9fa55ec3SPetr Kurapov                   llvm::ArrayRef<uint32_t> wi_data) {
122*9fa55ec3SPetr Kurapov   if (wi_layout.size() != 2)
123*9fa55ec3SPetr Kurapov     return emitError() << "expected wi_layout of size 2";
124*9fa55ec3SPetr Kurapov   if (wi_data.size() != 2)
125*9fa55ec3SPetr Kurapov     return emitError() << "expected wi_data of size 2";
126*9fa55ec3SPetr Kurapov   return success();
127*9fa55ec3SPetr Kurapov }
128*9fa55ec3SPetr Kurapov 
129*9fa55ec3SPetr Kurapov //===----------------------------------------------------------------------===//
13061b24c61SChao Chen // XeGPU_TensorDescType
13161b24c61SChao Chen //===----------------------------------------------------------------------===//
1328b5e8414SChao Chen 
13361b24c61SChao Chen mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
13461b24c61SChao Chen   llvm::SmallVector<int64_t> shape;
13561b24c61SChao Chen   mlir::Type elementType;
13661b24c61SChao Chen   mlir::FailureOr<mlir::Attribute> encoding;
137*9fa55ec3SPetr Kurapov   mlir::FailureOr<mlir::Attribute> sg_map;
13861b24c61SChao Chen 
13961b24c61SChao Chen   // Parse literal '<'
14061b24c61SChao Chen   if (parser.parseLess())
14161b24c61SChao Chen     return {};
14261b24c61SChao Chen 
14361b24c61SChao Chen   auto shapeLoc = parser.getCurrentLocation();
14461b24c61SChao Chen   if (mlir::failed(parser.parseDimensionList(shape))) {
14561b24c61SChao Chen     parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
14661b24c61SChao Chen     return {};
14761b24c61SChao Chen   }
14861b24c61SChao Chen 
14961b24c61SChao Chen   auto elemTypeLoc = parser.getCurrentLocation();
15061b24c61SChao Chen   if (mlir::failed(parser.parseType(elementType))) {
15161b24c61SChao Chen     parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
15261b24c61SChao Chen     return {};
15361b24c61SChao Chen   }
15461b24c61SChao Chen 
15561b24c61SChao Chen   // parse optional attributes
156*9fa55ec3SPetr Kurapov   while (mlir::succeeded(parser.parseOptionalComma())) {
157*9fa55ec3SPetr Kurapov     mlir::Attribute attr;
158*9fa55ec3SPetr Kurapov     ParseResult res = parser.parseAttribute(attr);
159*9fa55ec3SPetr Kurapov     if (mlir::succeeded(res)) {
160*9fa55ec3SPetr Kurapov       if (mlir::isa<SGMapAttr>(attr)) {
161*9fa55ec3SPetr Kurapov         sg_map = attr;
162*9fa55ec3SPetr Kurapov         continue;
16361b24c61SChao Chen       }
164*9fa55ec3SPetr Kurapov       if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165*9fa55ec3SPetr Kurapov         encoding = attr;
166*9fa55ec3SPetr Kurapov         continue;
167*9fa55ec3SPetr Kurapov       }
168*9fa55ec3SPetr Kurapov     }
169*9fa55ec3SPetr Kurapov     parser.emitError(parser.getCurrentLocation(),
170*9fa55ec3SPetr Kurapov                      "Failed to parse the attribute.\n");
171*9fa55ec3SPetr Kurapov     return {};
172a7b6fdafSChao Chen   }
17361b24c61SChao Chen 
17461b24c61SChao Chen   // Parse literal '>'
17561b24c61SChao Chen   if (parser.parseGreater())
17661b24c61SChao Chen     return {};
17761b24c61SChao Chen 
17861b24c61SChao Chen   return TensorDescType::get(parser.getContext(), shape, elementType,
179*9fa55ec3SPetr Kurapov                              encoding.value_or(mlir::Attribute()),
180*9fa55ec3SPetr Kurapov                              sg_map.value_or(mlir::Attribute()));
18161b24c61SChao Chen }
18261b24c61SChao Chen 
18361b24c61SChao Chen void TensorDescType::print(::mlir::AsmPrinter &printer) const {
18461b24c61SChao Chen   printer << "<";
18561b24c61SChao Chen 
18661b24c61SChao Chen   auto shape = getShape();
18761b24c61SChao Chen   for (int64_t dim : shape) {
18861b24c61SChao Chen     if (mlir::ShapedType::isDynamic(dim))
18961b24c61SChao Chen       printer << '?';
19061b24c61SChao Chen     else
19161b24c61SChao Chen       printer << dim;
19261b24c61SChao Chen     printer << 'x';
19361b24c61SChao Chen   }
19461b24c61SChao Chen 
19561b24c61SChao Chen   printer << getElementType();
19661b24c61SChao Chen 
19761b24c61SChao Chen   if (auto encoding = getEncoding())
19861b24c61SChao Chen     printer << ", " << encoding;
19961b24c61SChao Chen 
200*9fa55ec3SPetr Kurapov   if (auto sg_map = getSgMap())
201*9fa55ec3SPetr Kurapov     printer << ", " << sg_map;
202*9fa55ec3SPetr Kurapov 
20361b24c61SChao Chen   printer << ">";
20461b24c61SChao Chen }
2055669660fSChao Chen 
206b01879ecSChao Chen TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
2078b5e8414SChao Chen                                    mlir::Type elementType, int array_length,
2088b5e8414SChao Chen                                    bool boundary_check,
209*9fa55ec3SPetr Kurapov                                    MemorySpace memory_space,
210*9fa55ec3SPetr Kurapov                                    mlir::Attribute sg_map) {
211b01879ecSChao Chen   auto context = elementType.getContext();
2128b5e8414SChao Chen   auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
2138b5e8414SChao Chen                                        boundary_check);
214*9fa55ec3SPetr Kurapov   return Base::get(context, shape, elementType, attr, sg_map);
2158b5e8414SChao Chen }
2168b5e8414SChao Chen 
2178b5e8414SChao Chen TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
2188b5e8414SChao Chen                                    mlir::Type elementType, int chunk_size,
219*9fa55ec3SPetr Kurapov                                    MemorySpace memory_space,
220*9fa55ec3SPetr Kurapov                                    mlir::Attribute sg_map) {
2218b5e8414SChao Chen   auto context = elementType.getContext();
2228b5e8414SChao Chen   auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
223*9fa55ec3SPetr Kurapov   return Base::get(context, shape, elementType, attr, sg_map);
224b01879ecSChao Chen }
225b01879ecSChao Chen 
2265669660fSChao Chen } // namespace xegpu
2275669660fSChao Chen } // namespace mlir
2285669660fSChao Chen 
2295669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
2305669660fSChao Chen #define GET_ATTRDEF_CLASSES
2315669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
2325669660fSChao Chen #define GET_TYPEDEF_CLASSES
2335669660fSChao Chen #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
234