xref: /llvm-project/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (revision 2f743ac52e945e155ff3cb1f8ca5287b306b831e)
16ad7b97eSAart Bik //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
26ad7b97eSAart Bik //
36ad7b97eSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46ad7b97eSAart Bik // See https://llvm.org/LICENSE.txt for license information.
56ad7b97eSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66ad7b97eSAart Bik //
76ad7b97eSAart Bik //===----------------------------------------------------------------------===//
86ad7b97eSAart Bik //
96ad7b97eSAart Bik // This file implements the AMX dialect and its operations.
106ad7b97eSAart Bik //
116ad7b97eSAart Bik //===----------------------------------------------------------------------===//
126ad7b97eSAart Bik 
136ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMXDialect.h"
146ad7b97eSAart Bik #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
156ad7b97eSAart Bik #include "mlir/IR/Builders.h"
16*2f743ac5SIlya Enkovich #include "mlir/IR/DialectImplementation.h"
176ad7b97eSAart Bik #include "mlir/IR/OpImplementation.h"
186ad7b97eSAart Bik #include "mlir/IR/TypeUtilities.h"
196ad7b97eSAart Bik 
20*2f743ac5SIlya Enkovich #include "llvm/ADT/TypeSwitch.h"
21*2f743ac5SIlya Enkovich 
226ad7b97eSAart Bik using namespace mlir;
236ad7b97eSAart Bik 
24485cc55eSStella Laurenzo #include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
25485cc55eSStella Laurenzo 
266ad7b97eSAart Bik void amx::AMXDialect::initialize() {
27*2f743ac5SIlya Enkovich   addTypes<
28*2f743ac5SIlya Enkovich #define GET_TYPEDEF_LIST
29*2f743ac5SIlya Enkovich #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
30*2f743ac5SIlya Enkovich       >();
31*2f743ac5SIlya Enkovich 
326ad7b97eSAart Bik   addOperations<
336ad7b97eSAart Bik #define GET_OP_LIST
346ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMX.cpp.inc"
356ad7b97eSAart Bik       >();
366ad7b97eSAart Bik }
376ad7b97eSAart Bik 
386ad7b97eSAart Bik /// Verify that AMX supports the implied tile shape.
39*2f743ac5SIlya Enkovich static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
406ad7b97eSAart Bik   const unsigned kMaxRows = 16;
416ad7b97eSAart Bik   const unsigned kBitsPerRow = 64 * 8;
426ad7b97eSAart Bik   unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
436ad7b97eSAart Bik   if (tp.getDimSize(0) > kMaxRows)
446ad7b97eSAart Bik     return op->emitOpError("bad row height: ") << tp.getDimSize(0);
456ad7b97eSAart Bik   if (col > kBitsPerRow || col & 0x1f)
466ad7b97eSAart Bik     return op->emitOpError("bad column width: ") << (col >> 3);
476ad7b97eSAart Bik   return success();
486ad7b97eSAart Bik }
496ad7b97eSAart Bik 
506ad7b97eSAart Bik /// Verify that AMX supports the multiplication.
51*2f743ac5SIlya Enkovich static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
52*2f743ac5SIlya Enkovich                                      amx::TileType btp, amx::TileType ctp,
536ad7b97eSAart Bik                                      unsigned scale) {
546ad7b97eSAart Bik   unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
556ad7b97eSAart Bik   unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
566ad7b97eSAart Bik   unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
576ad7b97eSAart Bik   if (cm != am || cn != bn || ak != bk)
586ad7b97eSAart Bik     return op->emitOpError("bad mult shape: ")
596ad7b97eSAart Bik            << cm << " x " << cn << " x " << ak;
606ad7b97eSAart Bik   return success();
616ad7b97eSAart Bik }
626ad7b97eSAart Bik 
6338abdddfSRiver Riddle LogicalResult amx::TileZeroOp::verify() {
64*2f743ac5SIlya Enkovich   return verifyTileSize(*this, getTileType());
656ad7b97eSAart Bik }
666ad7b97eSAart Bik 
6738abdddfSRiver Riddle LogicalResult amx::TileLoadOp::verify() {
6838abdddfSRiver Riddle   unsigned rank = getMemRefType().getRank();
698df54a6aSJacques Pienaar   if (getIndices().size() != rank)
7038abdddfSRiver Riddle     return emitOpError("requires ") << rank << " indices";
71*2f743ac5SIlya Enkovich   return verifyTileSize(*this, getTileType());
726ad7b97eSAart Bik }
736ad7b97eSAart Bik 
7438abdddfSRiver Riddle LogicalResult amx::TileStoreOp::verify() {
7538abdddfSRiver Riddle   unsigned rank = getMemRefType().getRank();
768df54a6aSJacques Pienaar   if (getIndices().size() != rank)
7738abdddfSRiver Riddle     return emitOpError("requires ") << rank << " indices";
78*2f743ac5SIlya Enkovich   return verifyTileSize(*this, getTileType());
796ad7b97eSAart Bik }
806ad7b97eSAart Bik 
8138abdddfSRiver Riddle LogicalResult amx::TileMulFOp::verify() {
82*2f743ac5SIlya Enkovich   amx::TileType aType = getLhsTileType();
83*2f743ac5SIlya Enkovich   amx::TileType bType = getRhsTileType();
84*2f743ac5SIlya Enkovich   amx::TileType cType = getTileType();
8538abdddfSRiver Riddle   if (failed(verifyTileSize(*this, aType)) ||
8638abdddfSRiver Riddle       failed(verifyTileSize(*this, bType)) ||
8738abdddfSRiver Riddle       failed(verifyTileSize(*this, cType)) ||
8838abdddfSRiver Riddle       failed(verifyMultShape(*this, aType, bType, cType, 1)))
896ad7b97eSAart Bik     return failure();
906ad7b97eSAart Bik   Type ta = aType.getElementType();
916ad7b97eSAart Bik   Type tb = bType.getElementType();
926ad7b97eSAart Bik   Type tc = cType.getElementType();
93*2f743ac5SIlya Enkovich   if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
9438abdddfSRiver Riddle     return emitOpError("unsupported type combination");
956ad7b97eSAart Bik   return success();
966ad7b97eSAart Bik }
976ad7b97eSAart Bik 
9838abdddfSRiver Riddle LogicalResult amx::TileMulIOp::verify() {
99*2f743ac5SIlya Enkovich   amx::TileType aType = getLhsTileType();
100*2f743ac5SIlya Enkovich   amx::TileType bType = getRhsTileType();
101*2f743ac5SIlya Enkovich   amx::TileType cType = getTileType();
10238abdddfSRiver Riddle   if (failed(verifyTileSize(*this, aType)) ||
10338abdddfSRiver Riddle       failed(verifyTileSize(*this, bType)) ||
10438abdddfSRiver Riddle       failed(verifyTileSize(*this, cType)) ||
10538abdddfSRiver Riddle       failed(verifyMultShape(*this, aType, bType, cType, 2)))
1066ad7b97eSAart Bik     return failure();
1076ad7b97eSAart Bik   Type ta = aType.getElementType();
1086ad7b97eSAart Bik   Type tb = bType.getElementType();
1096ad7b97eSAart Bik   Type tc = cType.getElementType();
1106ad7b97eSAart Bik   if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
11138abdddfSRiver Riddle     return emitOpError("unsupported type combination");
1126ad7b97eSAart Bik   return success();
1136ad7b97eSAart Bik }
1146ad7b97eSAart Bik 
115*2f743ac5SIlya Enkovich Type amx::TileType::parse(AsmParser &parser) {
116*2f743ac5SIlya Enkovich   if (parser.parseLess())
117*2f743ac5SIlya Enkovich     return nullptr;
118*2f743ac5SIlya Enkovich 
119*2f743ac5SIlya Enkovich   SmallVector<int64_t, 2> shape;
120*2f743ac5SIlya Enkovich   if (parser.parseDimensionList(shape, false, true))
121*2f743ac5SIlya Enkovich     return nullptr;
122*2f743ac5SIlya Enkovich 
123*2f743ac5SIlya Enkovich   Type elementType;
124*2f743ac5SIlya Enkovich   if (parser.parseType(elementType))
125*2f743ac5SIlya Enkovich     return nullptr;
126*2f743ac5SIlya Enkovich 
127*2f743ac5SIlya Enkovich   if (parser.parseGreater())
128*2f743ac5SIlya Enkovich     return nullptr;
129*2f743ac5SIlya Enkovich 
130*2f743ac5SIlya Enkovich   return TileType::get(shape, elementType);
131*2f743ac5SIlya Enkovich }
132*2f743ac5SIlya Enkovich 
133*2f743ac5SIlya Enkovich void amx::TileType::print(AsmPrinter &os) const {
134*2f743ac5SIlya Enkovich   os << "<";
135*2f743ac5SIlya Enkovich   os.printDimensionList(getShape());
136*2f743ac5SIlya Enkovich   os << 'x';
137*2f743ac5SIlya Enkovich   os.printType(getElementType());
138*2f743ac5SIlya Enkovich   os << '>';
139*2f743ac5SIlya Enkovich }
140*2f743ac5SIlya Enkovich 
1416ad7b97eSAart Bik #define GET_OP_CLASSES
1426ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMX.cpp.inc"
143*2f743ac5SIlya Enkovich 
144*2f743ac5SIlya Enkovich #define GET_TYPEDEF_CLASSES
145*2f743ac5SIlya Enkovich #include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
146