1 //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the AMX dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/AMX/AMXDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 20 #include "llvm/ADT/TypeSwitch.h" 21 22 using namespace mlir; 23 24 #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" 25 26 void amx::AMXDialect::initialize() { 27 addTypes< 28 #define GET_TYPEDEF_LIST 29 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc" 30 >(); 31 32 addOperations< 33 #define GET_OP_LIST 34 #include "mlir/Dialect/AMX/AMX.cpp.inc" 35 >(); 36 } 37 38 /// Verify that AMX supports the implied tile shape. 39 static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) { 40 const unsigned kMaxRows = 16; 41 const unsigned kBitsPerRow = 64 * 8; 42 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); 43 if (tp.getDimSize(0) > kMaxRows) 44 return op->emitOpError("bad row height: ") << tp.getDimSize(0); 45 if (col > kBitsPerRow || col & 0x1f) 46 return op->emitOpError("bad column width: ") << (col >> 3); 47 return success(); 48 } 49 50 /// Verify that AMX supports the multiplication. 51 static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, 52 amx::TileType btp, amx::TileType ctp, 53 unsigned scale) { 54 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; 55 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; 56 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); 57 if (cm != am || cn != bn || ak != bk) 58 return op->emitOpError("bad mult shape: ") 59 << cm << " x " << cn << " x " << ak; 60 return success(); 61 } 62 63 LogicalResult amx::TileZeroOp::verify() { 64 return verifyTileSize(*this, getTileType()); 65 } 66 67 LogicalResult amx::TileLoadOp::verify() { 68 unsigned rank = getMemRefType().getRank(); 69 if (getIndices().size() != rank) 70 return emitOpError("requires ") << rank << " indices"; 71 return verifyTileSize(*this, getTileType()); 72 } 73 74 LogicalResult amx::TileStoreOp::verify() { 75 unsigned rank = getMemRefType().getRank(); 76 if (getIndices().size() != rank) 77 return emitOpError("requires ") << rank << " indices"; 78 return verifyTileSize(*this, getTileType()); 79 } 80 81 LogicalResult amx::TileMulFOp::verify() { 82 amx::TileType aType = getLhsTileType(); 83 amx::TileType bType = getRhsTileType(); 84 amx::TileType cType = getTileType(); 85 if (failed(verifyTileSize(*this, aType)) || 86 failed(verifyTileSize(*this, bType)) || 87 failed(verifyTileSize(*this, cType)) || 88 failed(verifyMultShape(*this, aType, bType, cType, 1))) 89 return failure(); 90 Type ta = aType.getElementType(); 91 Type tb = bType.getElementType(); 92 Type tc = cType.getElementType(); 93 if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32()) 94 return emitOpError("unsupported type combination"); 95 return success(); 96 } 97 98 LogicalResult amx::TileMulIOp::verify() { 99 amx::TileType aType = getLhsTileType(); 100 amx::TileType bType = getRhsTileType(); 101 amx::TileType cType = getTileType(); 102 if (failed(verifyTileSize(*this, aType)) || 103 failed(verifyTileSize(*this, bType)) || 104 failed(verifyTileSize(*this, cType)) || 105 failed(verifyMultShape(*this, aType, bType, cType, 2))) 106 return failure(); 107 Type ta = aType.getElementType(); 108 Type tb = bType.getElementType(); 109 Type tc = cType.getElementType(); 110 if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) 111 return emitOpError("unsupported type combination"); 112 return success(); 113 } 114 115 Type amx::TileType::parse(AsmParser &parser) { 116 if (parser.parseLess()) 117 return nullptr; 118 119 SmallVector<int64_t, 2> shape; 120 if (parser.parseDimensionList(shape, false, true)) 121 return nullptr; 122 123 Type elementType; 124 if (parser.parseType(elementType)) 125 return nullptr; 126 127 if (parser.parseGreater()) 128 return nullptr; 129 130 return TileType::get(shape, elementType); 131 } 132 133 void amx::TileType::print(AsmPrinter &os) const { 134 os << "<"; 135 os.printDimensionList(getShape()); 136 os << 'x'; 137 os.printType(getElementType()); 138 os << '>'; 139 } 140 141 #define GET_OP_CLASSES 142 #include "mlir/Dialect/AMX/AMX.cpp.inc" 143 144 #define GET_TYPEDEF_CLASSES 145 #include "mlir/Dialect/AMX/AMXTypes.cpp.inc" 146