xref: /llvm-project/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (revision 2f743ac52e945e155ff3cb1f8ca5287b306b831e)
1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AMX dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/AMX/AMXDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace mlir;
23 
24 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
25 
26 void amx::AMXDialect::initialize() {
27   addTypes<
28 #define GET_TYPEDEF_LIST
29 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
30       >();
31 
32   addOperations<
33 #define GET_OP_LIST
34 #include "mlir/Dialect/AMX/AMX.cpp.inc"
35       >();
36 }
37 
38 /// Verify that AMX supports the implied tile shape.
39 static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
40   const unsigned kMaxRows = 16;
41   const unsigned kBitsPerRow = 64 * 8;
42   unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
43   if (tp.getDimSize(0) > kMaxRows)
44     return op->emitOpError("bad row height: ") << tp.getDimSize(0);
45   if (col > kBitsPerRow || col & 0x1f)
46     return op->emitOpError("bad column width: ") << (col >> 3);
47   return success();
48 }
49 
50 /// Verify that AMX supports the multiplication.
51 static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
52                                      amx::TileType btp, amx::TileType ctp,
53                                      unsigned scale) {
54   unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
55   unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
56   unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
57   if (cm != am || cn != bn || ak != bk)
58     return op->emitOpError("bad mult shape: ")
59            << cm << " x " << cn << " x " << ak;
60   return success();
61 }
62 
63 LogicalResult amx::TileZeroOp::verify() {
64   return verifyTileSize(*this, getTileType());
65 }
66 
67 LogicalResult amx::TileLoadOp::verify() {
68   unsigned rank = getMemRefType().getRank();
69   if (getIndices().size() != rank)
70     return emitOpError("requires ") << rank << " indices";
71   return verifyTileSize(*this, getTileType());
72 }
73 
74 LogicalResult amx::TileStoreOp::verify() {
75   unsigned rank = getMemRefType().getRank();
76   if (getIndices().size() != rank)
77     return emitOpError("requires ") << rank << " indices";
78   return verifyTileSize(*this, getTileType());
79 }
80 
81 LogicalResult amx::TileMulFOp::verify() {
82   amx::TileType aType = getLhsTileType();
83   amx::TileType bType = getRhsTileType();
84   amx::TileType cType = getTileType();
85   if (failed(verifyTileSize(*this, aType)) ||
86       failed(verifyTileSize(*this, bType)) ||
87       failed(verifyTileSize(*this, cType)) ||
88       failed(verifyMultShape(*this, aType, bType, cType, 1)))
89     return failure();
90   Type ta = aType.getElementType();
91   Type tb = bType.getElementType();
92   Type tc = cType.getElementType();
93   if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
94     return emitOpError("unsupported type combination");
95   return success();
96 }
97 
98 LogicalResult amx::TileMulIOp::verify() {
99   amx::TileType aType = getLhsTileType();
100   amx::TileType bType = getRhsTileType();
101   amx::TileType cType = getTileType();
102   if (failed(verifyTileSize(*this, aType)) ||
103       failed(verifyTileSize(*this, bType)) ||
104       failed(verifyTileSize(*this, cType)) ||
105       failed(verifyMultShape(*this, aType, bType, cType, 2)))
106     return failure();
107   Type ta = aType.getElementType();
108   Type tb = bType.getElementType();
109   Type tc = cType.getElementType();
110   if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
111     return emitOpError("unsupported type combination");
112   return success();
113 }
114 
115 Type amx::TileType::parse(AsmParser &parser) {
116   if (parser.parseLess())
117     return nullptr;
118 
119   SmallVector<int64_t, 2> shape;
120   if (parser.parseDimensionList(shape, false, true))
121     return nullptr;
122 
123   Type elementType;
124   if (parser.parseType(elementType))
125     return nullptr;
126 
127   if (parser.parseGreater())
128     return nullptr;
129 
130   return TileType::get(shape, elementType);
131 }
132 
133 void amx::TileType::print(AsmPrinter &os) const {
134   os << "<";
135   os.printDimensionList(getShape());
136   os << 'x';
137   os.printType(getElementType());
138   os << '>';
139 }
140 
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/AMX/AMX.cpp.inc"
143 
144 #define GET_TYPEDEF_CLASSES
145 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
146