16ad7b97eSAart Bik //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// 26ad7b97eSAart Bik // 36ad7b97eSAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 46ad7b97eSAart Bik // See https://llvm.org/LICENSE.txt for license information. 56ad7b97eSAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 66ad7b97eSAart Bik // 76ad7b97eSAart Bik //===----------------------------------------------------------------------===// 86ad7b97eSAart Bik // 96ad7b97eSAart Bik // This file implements the AMX dialect and its operations. 106ad7b97eSAart Bik // 116ad7b97eSAart Bik //===----------------------------------------------------------------------===// 126ad7b97eSAart Bik 136ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMXDialect.h" 146ad7b97eSAart Bik #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 156ad7b97eSAart Bik #include "mlir/IR/Builders.h" 16*2f743ac5SIlya Enkovich #include "mlir/IR/DialectImplementation.h" 176ad7b97eSAart Bik #include "mlir/IR/OpImplementation.h" 186ad7b97eSAart Bik #include "mlir/IR/TypeUtilities.h" 196ad7b97eSAart Bik 20*2f743ac5SIlya Enkovich #include "llvm/ADT/TypeSwitch.h" 21*2f743ac5SIlya Enkovich 226ad7b97eSAart Bik using namespace mlir; 236ad7b97eSAart Bik 24485cc55eSStella Laurenzo #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" 25485cc55eSStella Laurenzo 266ad7b97eSAart Bik void amx::AMXDialect::initialize() { 27*2f743ac5SIlya Enkovich addTypes< 28*2f743ac5SIlya Enkovich #define GET_TYPEDEF_LIST 29*2f743ac5SIlya Enkovich #include "mlir/Dialect/AMX/AMXTypes.cpp.inc" 30*2f743ac5SIlya Enkovich >(); 31*2f743ac5SIlya Enkovich 326ad7b97eSAart Bik addOperations< 336ad7b97eSAart Bik #define GET_OP_LIST 346ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMX.cpp.inc" 356ad7b97eSAart Bik >(); 366ad7b97eSAart Bik } 376ad7b97eSAart Bik 386ad7b97eSAart Bik /// Verify that AMX supports the implied tile shape. 39*2f743ac5SIlya Enkovich static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) { 406ad7b97eSAart Bik const unsigned kMaxRows = 16; 416ad7b97eSAart Bik const unsigned kBitsPerRow = 64 * 8; 426ad7b97eSAart Bik unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); 436ad7b97eSAart Bik if (tp.getDimSize(0) > kMaxRows) 446ad7b97eSAart Bik return op->emitOpError("bad row height: ") << tp.getDimSize(0); 456ad7b97eSAart Bik if (col > kBitsPerRow || col & 0x1f) 466ad7b97eSAart Bik return op->emitOpError("bad column width: ") << (col >> 3); 476ad7b97eSAart Bik return success(); 486ad7b97eSAart Bik } 496ad7b97eSAart Bik 506ad7b97eSAart Bik /// Verify that AMX supports the multiplication. 51*2f743ac5SIlya Enkovich static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, 52*2f743ac5SIlya Enkovich amx::TileType btp, amx::TileType ctp, 536ad7b97eSAart Bik unsigned scale) { 546ad7b97eSAart Bik unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; 556ad7b97eSAart Bik unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; 566ad7b97eSAart Bik unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); 576ad7b97eSAart Bik if (cm != am || cn != bn || ak != bk) 586ad7b97eSAart Bik return op->emitOpError("bad mult shape: ") 596ad7b97eSAart Bik << cm << " x " << cn << " x " << ak; 606ad7b97eSAart Bik return success(); 616ad7b97eSAart Bik } 626ad7b97eSAart Bik 6338abdddfSRiver Riddle LogicalResult amx::TileZeroOp::verify() { 64*2f743ac5SIlya Enkovich return verifyTileSize(*this, getTileType()); 656ad7b97eSAart Bik } 666ad7b97eSAart Bik 6738abdddfSRiver Riddle LogicalResult amx::TileLoadOp::verify() { 6838abdddfSRiver Riddle unsigned rank = getMemRefType().getRank(); 698df54a6aSJacques Pienaar if (getIndices().size() != rank) 7038abdddfSRiver Riddle return emitOpError("requires ") << rank << " indices"; 71*2f743ac5SIlya Enkovich return verifyTileSize(*this, getTileType()); 726ad7b97eSAart Bik } 736ad7b97eSAart Bik 7438abdddfSRiver Riddle LogicalResult amx::TileStoreOp::verify() { 7538abdddfSRiver Riddle unsigned rank = getMemRefType().getRank(); 768df54a6aSJacques Pienaar if (getIndices().size() != rank) 7738abdddfSRiver Riddle return emitOpError("requires ") << rank << " indices"; 78*2f743ac5SIlya Enkovich return verifyTileSize(*this, getTileType()); 796ad7b97eSAart Bik } 806ad7b97eSAart Bik 8138abdddfSRiver Riddle LogicalResult amx::TileMulFOp::verify() { 82*2f743ac5SIlya Enkovich amx::TileType aType = getLhsTileType(); 83*2f743ac5SIlya Enkovich amx::TileType bType = getRhsTileType(); 84*2f743ac5SIlya Enkovich amx::TileType cType = getTileType(); 8538abdddfSRiver Riddle if (failed(verifyTileSize(*this, aType)) || 8638abdddfSRiver Riddle failed(verifyTileSize(*this, bType)) || 8738abdddfSRiver Riddle failed(verifyTileSize(*this, cType)) || 8838abdddfSRiver Riddle failed(verifyMultShape(*this, aType, bType, cType, 1))) 896ad7b97eSAart Bik return failure(); 906ad7b97eSAart Bik Type ta = aType.getElementType(); 916ad7b97eSAart Bik Type tb = bType.getElementType(); 926ad7b97eSAart Bik Type tc = cType.getElementType(); 93*2f743ac5SIlya Enkovich if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32()) 9438abdddfSRiver Riddle return emitOpError("unsupported type combination"); 956ad7b97eSAart Bik return success(); 966ad7b97eSAart Bik } 976ad7b97eSAart Bik 9838abdddfSRiver Riddle LogicalResult amx::TileMulIOp::verify() { 99*2f743ac5SIlya Enkovich amx::TileType aType = getLhsTileType(); 100*2f743ac5SIlya Enkovich amx::TileType bType = getRhsTileType(); 101*2f743ac5SIlya Enkovich amx::TileType cType = getTileType(); 10238abdddfSRiver Riddle if (failed(verifyTileSize(*this, aType)) || 10338abdddfSRiver Riddle failed(verifyTileSize(*this, bType)) || 10438abdddfSRiver Riddle failed(verifyTileSize(*this, cType)) || 10538abdddfSRiver Riddle failed(verifyMultShape(*this, aType, bType, cType, 2))) 1066ad7b97eSAart Bik return failure(); 1076ad7b97eSAart Bik Type ta = aType.getElementType(); 1086ad7b97eSAart Bik Type tb = bType.getElementType(); 1096ad7b97eSAart Bik Type tc = cType.getElementType(); 1106ad7b97eSAart Bik if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) 11138abdddfSRiver Riddle return emitOpError("unsupported type combination"); 1126ad7b97eSAart Bik return success(); 1136ad7b97eSAart Bik } 1146ad7b97eSAart Bik 115*2f743ac5SIlya Enkovich Type amx::TileType::parse(AsmParser &parser) { 116*2f743ac5SIlya Enkovich if (parser.parseLess()) 117*2f743ac5SIlya Enkovich return nullptr; 118*2f743ac5SIlya Enkovich 119*2f743ac5SIlya Enkovich SmallVector<int64_t, 2> shape; 120*2f743ac5SIlya Enkovich if (parser.parseDimensionList(shape, false, true)) 121*2f743ac5SIlya Enkovich return nullptr; 122*2f743ac5SIlya Enkovich 123*2f743ac5SIlya Enkovich Type elementType; 124*2f743ac5SIlya Enkovich if (parser.parseType(elementType)) 125*2f743ac5SIlya Enkovich return nullptr; 126*2f743ac5SIlya Enkovich 127*2f743ac5SIlya Enkovich if (parser.parseGreater()) 128*2f743ac5SIlya Enkovich return nullptr; 129*2f743ac5SIlya Enkovich 130*2f743ac5SIlya Enkovich return TileType::get(shape, elementType); 131*2f743ac5SIlya Enkovich } 132*2f743ac5SIlya Enkovich 133*2f743ac5SIlya Enkovich void amx::TileType::print(AsmPrinter &os) const { 134*2f743ac5SIlya Enkovich os << "<"; 135*2f743ac5SIlya Enkovich os.printDimensionList(getShape()); 136*2f743ac5SIlya Enkovich os << 'x'; 137*2f743ac5SIlya Enkovich os.printType(getElementType()); 138*2f743ac5SIlya Enkovich os << '>'; 139*2f743ac5SIlya Enkovich } 140*2f743ac5SIlya Enkovich 1416ad7b97eSAart Bik #define GET_OP_CLASSES 1426ad7b97eSAart Bik #include "mlir/Dialect/AMX/AMX.cpp.inc" 143*2f743ac5SIlya Enkovich 144*2f743ac5SIlya Enkovich #define GET_TYPEDEF_CLASSES 145*2f743ac5SIlya Enkovich #include "mlir/Dialect/AMX/AMXTypes.cpp.inc" 146