xref: /llvm-project/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (revision 9fa55ec3d93435a790f9990b1c6565e3ee689b2c)
1 //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 namespace mlir {
15 namespace xegpu {
16 
17 void XeGPUDialect::initialize() {
18   addTypes<
19 #define GET_TYPEDEF_LIST
20 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
21       >();
22   addOperations<
23 #define GET_OP_LIST
24 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
25       >();
26   addAttributes<
27 #define GET_ATTRDEF_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
29       >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // XeGPU_BlockTensorDescAttr
34 //===----------------------------------------------------------------------===//
35 BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
36                                              xegpu::MemorySpace memory_space,
37                                              int array_length,
38                                              bool boundary_check) {
39   auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
40   auto lengthAttr =
41       IntegerAttr::get(IntegerType::get(context, 64), array_length);
42   auto boundaryAttr = BoolAttr::get(context, boundary_check);
43   return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // XeGPU_ScatterTensorDescAttr
48 //===----------------------------------------------------------------------===//
49 ScatterTensorDescAttr
50 ScatterTensorDescAttr::get(mlir::MLIRContext *context,
51                            xegpu::MemorySpace memory_space, int chunk_size) {
52   auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
53   auto chunkSizeAttr =
54       IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
55   return Base::get(context, scopeAttr, chunkSizeAttr);
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // XeGPU_SGMapAttr
60 //===----------------------------------------------------------------------===//
61 namespace {
62 template <typename T, unsigned N>
63 LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
64                                  llvm::SmallVector<T, N> &result,
65                                  llvm::StringRef fieldName) {
66   if (failed(parser.parseKeyword(fieldName))) {
67     parser.emitError(parser.getCurrentLocation(),
68                      "unexpected field name. Expected " + fieldName + ".");
69     return failure();
70   }
71 
72   if (failed(parser.parseEqual())) {
73     parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
74     return failure();
75   }
76 
77   auto elemParser = [&]() -> llvm::ParseResult {
78     uint32_t elem = 0;
79     auto res = parser.parseInteger(elem);
80     result.push_back(elem);
81     return res;
82   };
83 
84   return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
85                                         elemParser, fieldName);
86 }
87 } // namespace
88 
89 mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
90                                  ::mlir::Type attrType) {
91   if (failed(parser.parseLess()))
92     return {};
93 
94   llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
95   if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
96     return {};
97 
98   if (failed(parser.parseComma()))
99     return {};
100 
101   if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
102     return {};
103 
104   return SGMapAttr::getChecked(
105       [&]() { return parser.emitError(parser.getNameLoc()); },
106       parser.getContext(), wi_layout, wi_data);
107 }
108 
109 void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
110   printer << "<";
111   printer.printKeywordOrString("wi_layout");
112   printer << " = [" << getWiLayout() << "], ";
113   printer.printKeywordOrString("wi_data");
114   printer << " = [" << getWiData() << "]";
115   printer << ">";
116 }
117 
118 LogicalResult
119 SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
120                   llvm::ArrayRef<uint32_t> wi_layout,
121                   llvm::ArrayRef<uint32_t> wi_data) {
122   if (wi_layout.size() != 2)
123     return emitError() << "expected wi_layout of size 2";
124   if (wi_data.size() != 2)
125     return emitError() << "expected wi_data of size 2";
126   return success();
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // XeGPU_TensorDescType
131 //===----------------------------------------------------------------------===//
132 
133 mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
134   llvm::SmallVector<int64_t> shape;
135   mlir::Type elementType;
136   mlir::FailureOr<mlir::Attribute> encoding;
137   mlir::FailureOr<mlir::Attribute> sg_map;
138 
139   // Parse literal '<'
140   if (parser.parseLess())
141     return {};
142 
143   auto shapeLoc = parser.getCurrentLocation();
144   if (mlir::failed(parser.parseDimensionList(shape))) {
145     parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
146     return {};
147   }
148 
149   auto elemTypeLoc = parser.getCurrentLocation();
150   if (mlir::failed(parser.parseType(elementType))) {
151     parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
152     return {};
153   }
154 
155   // parse optional attributes
156   while (mlir::succeeded(parser.parseOptionalComma())) {
157     mlir::Attribute attr;
158     ParseResult res = parser.parseAttribute(attr);
159     if (mlir::succeeded(res)) {
160       if (mlir::isa<SGMapAttr>(attr)) {
161         sg_map = attr;
162         continue;
163       }
164       if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165         encoding = attr;
166         continue;
167       }
168     }
169     parser.emitError(parser.getCurrentLocation(),
170                      "Failed to parse the attribute.\n");
171     return {};
172   }
173 
174   // Parse literal '>'
175   if (parser.parseGreater())
176     return {};
177 
178   return TensorDescType::get(parser.getContext(), shape, elementType,
179                              encoding.value_or(mlir::Attribute()),
180                              sg_map.value_or(mlir::Attribute()));
181 }
182 
183 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
184   printer << "<";
185 
186   auto shape = getShape();
187   for (int64_t dim : shape) {
188     if (mlir::ShapedType::isDynamic(dim))
189       printer << '?';
190     else
191       printer << dim;
192     printer << 'x';
193   }
194 
195   printer << getElementType();
196 
197   if (auto encoding = getEncoding())
198     printer << ", " << encoding;
199 
200   if (auto sg_map = getSgMap())
201     printer << ", " << sg_map;
202 
203   printer << ">";
204 }
205 
206 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
207                                    mlir::Type elementType, int array_length,
208                                    bool boundary_check,
209                                    MemorySpace memory_space,
210                                    mlir::Attribute sg_map) {
211   auto context = elementType.getContext();
212   auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
213                                        boundary_check);
214   return Base::get(context, shape, elementType, attr, sg_map);
215 }
216 
217 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
218                                    mlir::Type elementType, int chunk_size,
219                                    MemorySpace memory_space,
220                                    mlir::Attribute sg_map) {
221   auto context = elementType.getContext();
222   auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
223   return Base::get(context, shape, elementType, attr, sg_map);
224 }
225 
226 } // namespace xegpu
227 } // namespace mlir
228 
229 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
230 #define GET_ATTRDEF_CLASSES
231 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
232 #define GET_TYPEDEF_CLASSES
233 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
234