1//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// This file defines the basic operations for the AMX dialect. 10// 11// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix 12// multiply unit (TMUL), a tile control register (TILECFG), and eight 13// tile registers TMM0 through TMM7 (TILEDATA). 14// 15// The AMX dialect provides a bridge between MLIR concepts, such as 16// 2-d vector, operations, and memrefs, and the lower level details 17// of Intel AMX, such as configuration setup, tile sizes, instructions, 18// and tile release. 19// 20// Note that since configuration changes (implicit at dialect level) are 21// costly, it is highly recommended to use the AMX dialect on same-shaped 22// vectors, at least within a single method. 23// 24// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html 25// 26//===----------------------------------------------------------------------===// 27 28#ifndef AMX 29#define AMX 30 31include "mlir/Dialect/LLVMIR/LLVMOpBase.td" 32include "mlir/Interfaces/SideEffectInterfaces.td" 33include "mlir/IR/AttrTypeBase.td" 34include "mlir/IR/BuiltinTypes.td" 35 36//===----------------------------------------------------------------------===// 37// AMX dialect definition. 38//===----------------------------------------------------------------------===// 39 40def AMX_Dialect : Dialect { 41 let name = "amx"; 42 let cppNamespace = "::mlir::amx"; 43 let description = [{ 44 The Intel Advanced Matrix Extensions (AMX) provide a tile matrix 45 multiply unit (TMUL), a tile control register (TILECFG), and eight 46 tile registers TMM0 through TMM7 (TILEDATA). 47 48 This `AMX` dialect provides a bridge between MLIR concepts such as 49 vectors and memrefs and the lower level LLVM IR support of AMX. 50 The dialect is split into user-facing AMX ops (AMX_Op) and 51 backend-facing intrinsic ops (AMX_IntrOp). 52 53 Note that since configuration changes (implicit at dialect level) are 54 costly, it is highly recommended to use the AMX dialect on same-shaped 55 vectors, at least within a single method. 56 57 For details, see the Intel documentation: 58 https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html 59 }]; 60 let useDefaultTypePrinterParser = 1; 61} 62 63//===----------------------------------------------------------------------===// 64// AMX Tile definition. 65//===----------------------------------------------------------------------===// 66 67class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []> 68 : TypeDef<AMX_Dialect, typeName, traits> { 69 let mnemonic = typeMnemonic; 70} 71 72def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> { 73 let cppFunctionName = "isValidTileTypeElementType"; 74} 75 76def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> { 77 let summary = "AMX 2D tile to be used by AMX opertaions."; 78 79 let description = [{ 80 This type is used to represent values in AMX tile registers. All AMX operations 81 work on AMX tiles and these tiles cannot be used in other operations directly. 82 LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and 83 element type for IR verification and lowering to LLVMIR dialect. 84 }]; 85 86 let parameters = (ins 87 ArrayRefParameter<"int64_t">:$shape, 88 AMX_TileTypeElementType:$elementType 89 ); 90 91 let builders = [ 92 TypeBuilderWithInferredContext<(ins 93 "ArrayRef<int64_t>":$shape, "Type":$elementType), [{ 94 return $_get(elementType.getContext(), shape, elementType); 95 }]> 96 ]; 97 98 let extraClassDeclaration = [{ 99 /// Returns if this type is ranked (always true). 100 bool hasRank() const { return true; } 101 102 /// Clone this tile type with the given shape and element type. If the 103 /// provided shape is `std::nullopt`, the current shape of the type is used. 104 TileType cloneWith(std::optional<ArrayRef<int64_t>> shape, 105 Type elementType) const { 106 return get(shape.value_or(getShape()), elementType); 107 } 108 }]; 109 110 let hasCustomAssemblyFormat = 1; 111 let skipDefaultBuilders = 1; 112} 113 114def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">, 115 CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>; 116 117class AMXTileOf<list<Type> allowedTypes> : 118 ShapedContainerType<allowedTypes, IsAMXTilePred, "tile", 119 "::mlir::amx::TileType">; 120 121def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>; 122 123def AMXTileF32 : AMXTileOf<[F32]>; 124 125def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>; 126 127def AMXTileI32 : AMXTileOf<[I32]>; 128 129def AMXTileI8 : AMXTileOf<[I8]>; 130 131//===----------------------------------------------------------------------===// 132// AMX Op and IntrOp definitions. 133//===----------------------------------------------------------------------===// 134 135class AMX_Op<string mnemonic, list<Trait> traits = []> : 136 Op<AMX_Dialect, mnemonic, traits> {} 137 138// The "internal" intrinsics are meant for compiler usage. 139class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> : 140 LLVM_IntrOpBase<AMX_Dialect, mnemonic, 141 "x86_" # !subst(".", "_", mnemonic) # "_internal", 142 [], [], traits, numResults>; 143 144//===----------------------------------------------------------------------===// 145// AMX Op definitions (user facing). 146//===----------------------------------------------------------------------===// 147 148// 149// Tile reset. 150// 151 152def TileZeroOp : AMX_Op<"tile_zero", [Pure]> { 153 let summary = "tile zero operation"; 154 let description = [{ 155 Zeroes the destination tile, with the shape defined by the 2-dim 156 vector type of the result. This is eventually lowered into the 157 "tilezero" instruction with the corresponding tile configuration. 158 159 Example: 160 161 ```mlir 162 %0 = amx.tile_zero : !amx.tile<16x16xbf16> 163 ``` 164 }]; 165 let results = (outs AnyAMXTile:$res); 166 let extraClassDeclaration = [{ 167 TileType getTileType() { 168 return ::llvm::cast<TileType>(getRes().getType()); 169 } 170 }]; 171 let assemblyFormat = "attr-dict `:` qualified(type($res))"; 172 let hasVerifier = 1; 173} 174 175// 176// Tile memory operations. 177// 178 179def TileLoadOp : AMX_Op<"tile_load", [Pure]> { 180 let summary = "tile load operation"; 181 let description = [{ 182 Loads a tile from memory defined by a base and indices, with the 183 shape defined by the 2-dim vector type of the result. This is 184 eventually lowered into the "tileloadd" instruction with the 185 corresponding tile configuration. 186 187 Example: 188 189 ```mlir 190 %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8> 191 ``` 192 }]; 193 let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base, 194 Variadic<Index>:$indices); 195 let results = (outs AnyAMXTile:$res); 196 let extraClassDeclaration = [{ 197 MemRefType getMemRefType() { 198 return ::llvm::cast<MemRefType>(getBase().getType()); 199 } 200 TileType getTileType() { 201 return ::llvm::cast<TileType>(getRes().getType()); 202 } 203 }]; 204 let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " 205 "type($base) `into` qualified(type($res))"; 206 let hasVerifier = 1; 207} 208 209def TileStoreOp : AMX_Op<"tile_store"> { 210 let summary = "tile store operation"; 211 let description = [{ 212 Stores a tile to memory defined by a base and indices, with the 213 shape defined by the 2-dim vector type of the value. This is 214 eventually lowered into the "tilestored" instruction with the 215 corresponding tile configuration. 216 217 Example: 218 219 ```mlir 220 amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8> 221 ``` 222 }]; 223 let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base, 224 Variadic<Index>:$indices, 225 AnyAMXTile:$val); 226 let extraClassDeclaration = [{ 227 MemRefType getMemRefType() { 228 return ::llvm::cast<MemRefType>(getBase().getType()); 229 } 230 TileType getTileType() { 231 return ::llvm::cast<TileType>(getVal().getType()); 232 } 233 }]; 234 let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " 235 "type($base) `,` qualified(type($val))"; 236 let hasVerifier = 1; 237} 238 239// 240// Tile arithmetic operations. 241// 242 243def TileMulFOp : AMX_Op<"tile_mulf", [ 244 Pure, AllTypesMatch<["acc", "res"]>]> { 245 let summary = "tile multiplication operation (floating-point)"; 246 let description = [{ 247 Multiplies a "m x k" tile with a "k x n" tile and accumulates the results 248 into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with 249 pairs of "bf16"). The operation is eventually lowered into the 250 "tdpbf16ps" instruction with the corresponding tile configuration. 251 252 Example: 253 254 ```mlir 255 %0 = amx.tile_mulf %a, %b, %c 256 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> 257 ``` 258 }]; 259 let arguments = (ins AMXTileF16OrBF16:$lhs, 260 AMXTileF16OrBF16:$rhs, 261 AMXTileF32:$acc); 262 let results = (outs AMXTileF32:$res); 263 let extraClassDeclaration = [{ 264 TileType getLhsTileType() { 265 return ::llvm::cast<TileType>(getLhs().getType()); 266 } 267 TileType getRhsTileType() { 268 return ::llvm::cast<TileType>(getRhs().getType()); 269 } 270 TileType getTileType() { 271 return ::llvm::cast<TileType>(getRes().getType()); 272 } 273 }]; 274 let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " 275 "qualified(type($lhs)) `,` qualified(type($rhs))" 276 " `,` qualified(type($acc)) "; 277 let hasVerifier = 1; 278} 279 280def TileMulIOp : AMX_Op<"tile_muli", [ 281 Pure, AllTypesMatch<["acc", "res"]>]> { 282 let summary = "tile multiplication operation (integer)"; 283 let description = [{ 284 Multiplies a "m x k" tile with a "k x n" tile and accumulates the results 285 into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8" 286 combinations (4 bytes packed into dwords in the columns of both the 287 source operand tiles; the zero or sign extension is specified with 288 the attributes and default to sign extended). The operation is eventually 289 lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" 290 instructions with the corresponding tile configuration. 291 292 Example: 293 294 ```mlir 295 %0 = amx.tile_muli %a zext, %b zext, %c 296 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> 297 ``` 298 }]; 299 let arguments = (ins AMXTileI8:$lhs, 300 AMXTileI8:$rhs, 301 AMXTileI32:$acc, 302 UnitAttr:$isZextLhs, 303 UnitAttr:$isZextRhs 304 ); 305 let results = (outs AMXTileI32:$res); 306 let extraClassDeclaration = [{ 307 TileType getLhsTileType() { 308 return ::llvm::cast<TileType>(getLhs().getType()); 309 } 310 TileType getRhsTileType() { 311 return ::llvm::cast<TileType>(getRhs().getType()); 312 } 313 TileType getTileType() { 314 return ::llvm::cast<TileType>(getRes().getType()); 315 } 316 }]; 317 let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " 318 "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) "; 319 let hasVerifier = 1; 320} 321 322//===----------------------------------------------------------------------===// 323// AMX IntrOp definitions (LLVM compiler facing). 324//===----------------------------------------------------------------------===// 325 326// 327// Tile reset. Parameters define the tile size. 328// 329 330def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>, 331 Arguments<(ins AnyInteger, AnyInteger)>; 332 333// 334// Tile memory operations. Parameters define the tile size, 335// base address, and stride between consecutive rows for the 336// memory operation. 337// 338 339def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>, 340 Arguments<(ins AnyInteger, 341 AnyInteger, LLVM_AnyPointer, AnyInteger)>; 342 343def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>, 344 Arguments<(ins AnyInteger, 345 AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>; 346 347// 348// Tile multiplication operations (series of dot products). Parameters 349// define the tile sizes and source and destination tiles for the 350// operation. Note that the prefix "tdp" stands for tile dot product. 351// 352 353// Dot product of bf16 tiles into f32 tile. 354def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>, 355 Arguments<(ins AnyInteger, 356 AnyInteger, 357 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; 358 359// Dot product of f16 tiles into f32 tile. 360def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>, 361 Arguments<(ins AnyInteger, 362 AnyInteger, 363 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; 364 365// Dot product of i8 tiles into i32 tile (with sign/sign extension). 366def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>, 367 Arguments<(ins AnyInteger, 368 AnyInteger, 369 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; 370 371// Dot product of i8 tiles into i32 tile (with sign/zero extension). 372def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>, 373 Arguments<(ins AnyInteger, 374 AnyInteger, 375 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; 376 377// Dot product of i8 tiles into i32 tile (with zero/sign extension). 378def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>, 379 Arguments<(ins AnyInteger, 380 AnyInteger, 381 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; 382 383// Dot product of i8 tiles into i32 tile (with zero/zero extension). 384def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>, 385 Arguments<(ins AnyInteger, 386 AnyInteger, 387 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; 388 389#endif // AMX 390