1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Shape utility for AMX. 10 /// AMX hardware requires to config the shape of tile data register before use. 11 /// The 2D shape includes row and column. In AMX intrinsics interface the shape 12 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd 13 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate 14 /// tile config and register allocator. The row and column are machine operand 15 /// of AMX pseudo instructions. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H 20 #define LLVM_CODEGEN_TILESHAPEINFO_H 21 22 #include "llvm/CodeGen/MachineInstr.h" 23 #include "llvm/CodeGen/MachineOperand.h" 24 #include "llvm/CodeGen/MachineRegisterInfo.h" 25 #include "llvm/CodeGen/Register.h" 26 27 namespace llvm { 28 29 class ShapeT { 30 public: 31 ShapeT(MachineOperand *Row, MachineOperand *Col, 32 const MachineRegisterInfo *MRI = nullptr) 33 : Row(Row), Col(Col) { 34 if (MRI) 35 deduceImm(MRI); 36 } 37 // When ShapeT has multiple shapes, we only use Shapes (never use Row and Col) 38 // and ImmShapes. Due to the most case is only one shape (just simply use 39 // Shape.Row or Shape.Col), so here we don't merge Row and Col into vector 40 // Shapes to keep the speed and code simplicity. 41 // TODO: The upper solution is a temporary way to minimize current tile 42 // register allocation code changes. It can not handle both Reg shape and 43 // Imm shape for different shapes (e.g. shape 1 is reg shape while shape 2 44 // is imm shape). Refine me when we have more multi-tile shape instructions! 45 ShapeT(ArrayRef<MachineOperand *> ShapesOperands, 46 const MachineRegisterInfo *MRI = nullptr) 47 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 48 ColImm(InvalidImmShape) { 49 assert(ShapesOperands.size() % 2 == 0 && "Miss row or col!"); 50 51 for (auto *Shape : ShapesOperands) 52 Shapes.push_back(Shape); 53 54 if (MRI) 55 deduceImm(MRI); 56 } 57 ShapeT() 58 : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 59 ColImm(InvalidImmShape) {} 60 // TODO: We need to extern cmp operator for multi-shapes if 61 // we have requirement in the future. 62 bool operator==(const ShapeT &Shape) const { 63 MachineOperand *R = Shape.Row; 64 MachineOperand *C = Shape.Col; 65 if (!R || !C) 66 return false; 67 if (!Row || !Col) 68 return false; 69 if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) 70 return true; 71 if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape)) 72 return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); 73 return false; 74 } 75 76 bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); } 77 78 MachineOperand *getRow(unsigned I = 0) const { 79 if (Shapes.empty()) 80 return Row; 81 assert(Shapes.size() / 2 >= I && "Get invalid row from id!"); 82 return Shapes[I * 2]; 83 } 84 85 MachineOperand *getCol(unsigned I = 0) const { 86 if (Shapes.empty()) 87 return Col; 88 assert(Shapes.size() / 2 >= I && "Get invalid col from id!"); 89 return Shapes[I * 2 + 1]; 90 } 91 92 int64_t getRowImm(unsigned I = 0) const { 93 if (ImmShapes.empty()) 94 return RowImm; 95 assert(ImmShapes.size() / 2 >= I && "Get invalid imm row from id!"); 96 return ImmShapes[I * 2]; 97 } 98 99 int64_t getColImm(unsigned I = 0) const { 100 if (ImmShapes.empty()) 101 return ColImm; 102 assert(ImmShapes.size() / 2 >= I && "Get invalid imm col from id!"); 103 return ImmShapes[I * 2 + 1]; 104 } 105 106 unsigned getShapeNum() { 107 if (Shapes.empty()) 108 return isValid() ? 1 : 0; 109 else 110 return Shapes.size() / 2; 111 } 112 113 bool isValid() { return (Row != nullptr) && (Col != nullptr); } 114 115 void deduceImm(const MachineRegisterInfo *MRI) { 116 // All def must be the same value, otherwise it is invalid MIs. 117 // Find the immediate. 118 // TODO copy propagation. 119 auto GetImm = [&](Register Reg) { 120 int64_t Imm = InvalidImmShape; 121 for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { 122 const auto *MI = DefMO.getParent(); 123 if (MI->isMoveImmediate()) { 124 assert(MI->getNumOperands() == 2 && 125 "Unsupported number of operands in instruction for setting " 126 "row/column."); 127 if (MI->getOperand(1).isImm()) { 128 Imm = MI->getOperand(1).getImm(); 129 } else { 130 assert(MI->getOperand(1).isImplicit() && 131 "Operand 1 is assumed to be implicit."); 132 Imm = 0; 133 } 134 break; 135 } 136 } 137 return Imm; 138 }; 139 if (Shapes.empty()) { // Single Shape 140 RowImm = GetImm(Row->getReg()); 141 ColImm = GetImm(Col->getReg()); 142 // The number of rows of 2nd destination buffer is assigned by the one of 143 // 1st destination buffer. If the column size is equal to zero, the row 144 // size should be reset to zero too. 145 if (ColImm == 0) 146 Row = Col; 147 } else { // Multiple Shapes 148 for (auto *Shape : Shapes) { 149 int64_t ImmShape = GetImm(Shape->getReg()); 150 ImmShapes.push_back(ImmShape); 151 } 152 } 153 } 154 155 private: 156 static constexpr int64_t InvalidImmShape = -1; 157 MachineOperand *Row; 158 MachineOperand *Col; 159 int64_t RowImm = -1; 160 int64_t ColImm = -1; 161 // Multiple Shapes 162 SmallVector<MachineOperand *, 0> Shapes; 163 SmallVector<int64_t, 0> ImmShapes; 164 }; 165 166 } // namespace llvm 167 168 #endif 169