xref: /llvm-project/mlir/include/mlir/Dialect/AMX/AMX.td (revision 2f743ac52e945e155ff3cb1f8ca5287b306b831e)
1//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the basic operations for the AMX dialect.
10//
11// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
12// multiply unit (TMUL), a tile control register (TILECFG), and eight
13// tile registers TMM0 through TMM7 (TILEDATA).
14//
15// The AMX dialect provides a bridge between MLIR concepts, such as
16// 2-d vector, operations, and memrefs, and the lower level details
17// of Intel AMX, such as configuration setup, tile sizes, instructions,
18// and tile release.
19//
20// Note that since configuration changes (implicit at dialect level) are
21// costly, it is highly recommended to use the AMX dialect on same-shaped
22// vectors, at least within a single method.
23//
24// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
25//
26//===----------------------------------------------------------------------===//
27
28#ifndef AMX
29#define AMX
30
31include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
32include "mlir/Interfaces/SideEffectInterfaces.td"
33include "mlir/IR/AttrTypeBase.td"
34include "mlir/IR/BuiltinTypes.td"
35
36//===----------------------------------------------------------------------===//
37// AMX dialect definition.
38//===----------------------------------------------------------------------===//
39
40def AMX_Dialect : Dialect {
41  let name = "amx";
42  let cppNamespace = "::mlir::amx";
43  let description = [{
44    The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
45    multiply unit (TMUL), a tile control register (TILECFG), and eight
46    tile registers TMM0 through TMM7 (TILEDATA).
47
48    This `AMX` dialect provides a bridge between MLIR concepts such as
49    vectors and memrefs and the lower level LLVM IR support of AMX.
50    The dialect is split into user-facing AMX ops (AMX_Op) and
51    backend-facing intrinsic ops (AMX_IntrOp).
52
53    Note that since configuration changes (implicit at dialect level) are
54    costly, it is highly recommended to use the AMX dialect on same-shaped
55    vectors, at least within a single method.
56
57    For details, see the Intel documentation:
58    https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
59  }];
60  let useDefaultTypePrinterParser = 1;
61}
62
63//===----------------------------------------------------------------------===//
64// AMX Tile definition.
65//===----------------------------------------------------------------------===//
66
67class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
68    : TypeDef<AMX_Dialect, typeName, traits> {
69  let mnemonic = typeMnemonic;
70}
71
72def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
73  let cppFunctionName = "isValidTileTypeElementType";
74}
75
76def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
77  let summary = "AMX 2D tile to be used by AMX opertaions.";
78
79  let description = [{
80    This type is used to represent values in AMX tile registers. All AMX operations
81    work on AMX tiles and these tiles cannot be used in other operations directly.
82    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
83    element type for IR verification and lowering to LLVMIR dialect.
84  }];
85
86  let parameters = (ins
87    ArrayRefParameter<"int64_t">:$shape,
88    AMX_TileTypeElementType:$elementType
89  );
90
91  let builders = [
92    TypeBuilderWithInferredContext<(ins
93      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
94      return $_get(elementType.getContext(), shape, elementType);
95    }]>
96  ];
97
98  let extraClassDeclaration = [{
99    /// Returns if this type is ranked (always true).
100    bool hasRank() const { return true; }
101
102    /// Clone this tile type with the given shape and element type. If the
103    /// provided shape is `std::nullopt`, the current shape of the type is used.
104    TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
105                       Type elementType) const {
106      return get(shape.value_or(getShape()), elementType);
107    }
108  }];
109
110  let hasCustomAssemblyFormat = 1;
111  let skipDefaultBuilders = 1;
112}
113
114def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
115  CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
116
117class AMXTileOf<list<Type> allowedTypes> :
118  ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
119                      "::mlir::amx::TileType">;
120
121def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
122
123def AMXTileF32 : AMXTileOf<[F32]>;
124
125def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
126
127def AMXTileI32 : AMXTileOf<[I32]>;
128
129def AMXTileI8 : AMXTileOf<[I8]>;
130
131//===----------------------------------------------------------------------===//
132// AMX Op and IntrOp definitions.
133//===----------------------------------------------------------------------===//
134
135class AMX_Op<string mnemonic, list<Trait> traits = []> :
136  Op<AMX_Dialect, mnemonic, traits> {}
137
138// The "internal" intrinsics are meant for compiler usage.
139class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
140  LLVM_IntrOpBase<AMX_Dialect, mnemonic,
141                  "x86_" # !subst(".", "_", mnemonic) # "_internal",
142                  [], [], traits, numResults>;
143
144//===----------------------------------------------------------------------===//
145// AMX Op definitions (user facing).
146//===----------------------------------------------------------------------===//
147
148//
149// Tile reset.
150//
151
152def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
153  let summary = "tile zero operation";
154  let description = [{
155    Zeroes the destination tile, with the shape defined by the 2-dim
156    vector type of the result. This is eventually lowered into the
157    "tilezero" instruction with the corresponding tile configuration.
158
159    Example:
160
161    ```mlir
162      %0 = amx.tile_zero : !amx.tile<16x16xbf16>
163    ```
164  }];
165  let results = (outs AnyAMXTile:$res);
166  let extraClassDeclaration = [{
167    TileType getTileType() {
168      return ::llvm::cast<TileType>(getRes().getType());
169    }
170  }];
171  let assemblyFormat = "attr-dict `:` qualified(type($res))";
172  let hasVerifier = 1;
173}
174
175//
176// Tile memory operations.
177//
178
179def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
180  let summary = "tile load operation";
181  let description = [{
182    Loads a tile from memory defined by a base and indices, with the
183    shape defined by the 2-dim vector type of the result. This is
184    eventually lowered into the "tileloadd" instruction with the
185    corresponding tile configuration.
186
187    Example:
188
189    ```mlir
190      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
191    ```
192  }];
193  let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
194                   Variadic<Index>:$indices);
195  let results = (outs AnyAMXTile:$res);
196  let extraClassDeclaration = [{
197    MemRefType getMemRefType() {
198      return ::llvm::cast<MemRefType>(getBase().getType());
199    }
200    TileType getTileType() {
201      return ::llvm::cast<TileType>(getRes().getType());
202    }
203  }];
204  let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
205                       "type($base) `into` qualified(type($res))";
206  let hasVerifier = 1;
207}
208
209def TileStoreOp : AMX_Op<"tile_store"> {
210  let summary = "tile store operation";
211  let description = [{
212    Stores a tile to memory defined by a base and indices, with the
213    shape defined by the 2-dim vector type of the value. This is
214    eventually lowered into the "tilestored" instruction with the
215    corresponding tile configuration.
216
217    Example:
218
219    ```mlir
220      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
221    ```
222  }];
223  let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
224                   Variadic<Index>:$indices,
225                   AnyAMXTile:$val);
226  let extraClassDeclaration = [{
227    MemRefType getMemRefType() {
228      return ::llvm::cast<MemRefType>(getBase().getType());
229    }
230    TileType getTileType() {
231      return ::llvm::cast<TileType>(getVal().getType());
232    }
233  }];
234  let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
235                       "type($base) `,` qualified(type($val))";
236  let hasVerifier = 1;
237}
238
239//
240// Tile arithmetic operations.
241//
242
243def TileMulFOp : AMX_Op<"tile_mulf", [
244    Pure, AllTypesMatch<["acc", "res"]>]> {
245  let summary = "tile multiplication operation (floating-point)";
246  let description = [{
247    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
248    into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
249    pairs of "bf16"). The operation is eventually lowered into the
250    "tdpbf16ps" instruction with the corresponding tile configuration.
251
252    Example:
253
254    ```mlir
255      %0 = amx.tile_mulf %a, %b, %c
256        : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
257    ```
258  }];
259  let arguments = (ins AMXTileF16OrBF16:$lhs,
260                       AMXTileF16OrBF16:$rhs,
261                       AMXTileF32:$acc);
262  let results = (outs AMXTileF32:$res);
263  let extraClassDeclaration = [{
264    TileType getLhsTileType() {
265      return ::llvm::cast<TileType>(getLhs().getType());
266    }
267    TileType getRhsTileType() {
268      return ::llvm::cast<TileType>(getRhs().getType());
269    }
270    TileType getTileType() {
271      return ::llvm::cast<TileType>(getRes().getType());
272    }
273  }];
274  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
275                       "qualified(type($lhs)) `,` qualified(type($rhs))"
276                       " `,` qualified(type($acc)) ";
277  let hasVerifier = 1;
278}
279
280def TileMulIOp : AMX_Op<"tile_muli", [
281    Pure, AllTypesMatch<["acc", "res"]>]> {
282  let summary = "tile multiplication operation (integer)";
283  let description = [{
284    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
285    into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
286    combinations (4 bytes packed into dwords in the columns of both the
287    source operand tiles; the zero or sign extension is specified with
288    the attributes and default to sign extended). The operation is eventually
289    lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
290    instructions with the corresponding tile configuration.
291
292    Example:
293
294    ```mlir
295      %0 = amx.tile_muli %a zext, %b zext, %c
296        : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
297    ```
298  }];
299  let arguments = (ins AMXTileI8:$lhs,
300                       AMXTileI8:$rhs,
301                       AMXTileI32:$acc,
302                       UnitAttr:$isZextLhs,
303                       UnitAttr:$isZextRhs
304                       );
305  let results = (outs AMXTileI32:$res);
306  let extraClassDeclaration = [{
307    TileType getLhsTileType() {
308      return ::llvm::cast<TileType>(getLhs().getType());
309    }
310    TileType getRhsTileType() {
311      return ::llvm::cast<TileType>(getRhs().getType());
312    }
313    TileType getTileType() {
314      return ::llvm::cast<TileType>(getRes().getType());
315    }
316  }];
317  let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
318                       "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
319  let hasVerifier = 1;
320}
321
322//===----------------------------------------------------------------------===//
323// AMX IntrOp definitions (LLVM compiler facing).
324//===----------------------------------------------------------------------===//
325
326//
327// Tile reset. Parameters define the tile size.
328//
329
330def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
331  Arguments<(ins AnyInteger, AnyInteger)>;
332
333//
334// Tile memory operations. Parameters define the tile size,
335// base address, and stride between consecutive rows for the
336// memory operation.
337//
338
339def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
340  Arguments<(ins AnyInteger,
341                 AnyInteger, LLVM_AnyPointer, AnyInteger)>;
342
343def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
344  Arguments<(ins AnyInteger,
345                 AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
346
347//
348// Tile multiplication operations (series of dot products). Parameters
349// define the tile sizes and source and destination tiles for the
350// operation. Note that the prefix "tdp" stands for tile dot product.
351//
352
353// Dot product of bf16 tiles into f32 tile.
354def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
355  Arguments<(ins AnyInteger,
356                 AnyInteger,
357		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
358
359// Dot product of f16 tiles into f32 tile.
360def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
361  Arguments<(ins AnyInteger,
362                 AnyInteger,
363		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
364
365// Dot product of i8 tiles into i32 tile (with sign/sign extension).
366def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
367  Arguments<(ins AnyInteger,
368                 AnyInteger,
369		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
370
371// Dot product of i8 tiles into i32 tile (with sign/zero extension).
372def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
373  Arguments<(ins AnyInteger,
374                 AnyInteger,
375		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
376
377// Dot product of i8 tiles into i32 tile (with zero/sign extension).
378def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
379  Arguments<(ins AnyInteger,
380                 AnyInteger,
381		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
382
383// Dot product of i8 tiles into i32 tile (with zero/zero extension).
384def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
385  Arguments<(ins AnyInteger,
386                 AnyInteger,
387		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
388
389#endif // AMX
390