xref: /llvm-project/llvm/include/llvm/CodeGen/TileShapeInfo.h (revision c72a751dabff4260dcc309e48008941d51b31d21)
1 //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Shape utility for AMX.
10 /// AMX hardware requires to config the shape of tile data register before use.
11 /// The 2D shape includes row and column. In AMX intrinsics interface the shape
12 /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
13 /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
14 /// tile config and register allocator. The row and column are machine operand
15 /// of AMX pseudo instructions.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef LLVM_CODEGEN_TILESHAPEINFO_H
20 #define LLVM_CODEGEN_TILESHAPEINFO_H
21 
22 #include "llvm/CodeGen/MachineInstr.h"
23 #include "llvm/CodeGen/MachineOperand.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/Register.h"
26 
27 namespace llvm {
28 
29 class ShapeT {
30 public:
31   ShapeT(MachineOperand *Row, MachineOperand *Col,
32          const MachineRegisterInfo *MRI = nullptr)
33       : Row(Row), Col(Col) {
34     if (MRI)
35       deduceImm(MRI);
36   }
37   // When ShapeT has multiple shapes, we only use Shapes (never use Row and Col)
38   // and ImmShapes. Due to the most case is only one shape (just simply use
39   // Shape.Row or Shape.Col), so here we don't merge Row and Col into vector
40   // Shapes to keep the speed and code simplicity.
41   // TODO: The upper solution is a temporary way to minimize current tile
42   // register allocation code changes. It can not handle both Reg shape and
43   // Imm shape for different shapes (e.g. shape 1 is reg shape while shape 2
44   // is imm shape). Refine me when we have more multi-tile shape instructions!
45   ShapeT(ArrayRef<MachineOperand *> ShapesOperands,
46          const MachineRegisterInfo *MRI = nullptr)
47       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
48         ColImm(InvalidImmShape) {
49     assert(ShapesOperands.size() % 2 == 0 && "Miss row or col!");
50 
51     for (auto *Shape : ShapesOperands)
52       Shapes.push_back(Shape);
53 
54     if (MRI)
55       deduceImm(MRI);
56   }
57   ShapeT()
58       : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
59         ColImm(InvalidImmShape) {}
60   // TODO: We need to extern cmp operator for multi-shapes if
61   // we have requirement in the future.
62   bool operator==(const ShapeT &Shape) const {
63     MachineOperand *R = Shape.Row;
64     MachineOperand *C = Shape.Col;
65     if (!R || !C)
66       return false;
67     if (!Row || !Col)
68       return false;
69     if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
70       return true;
71     if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
72       return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
73     return false;
74   }
75 
76   bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); }
77 
78   MachineOperand *getRow(unsigned I = 0) const {
79     if (Shapes.empty())
80       return Row;
81     assert(Shapes.size() / 2 >= I && "Get invalid row from id!");
82     return Shapes[I * 2];
83   }
84 
85   MachineOperand *getCol(unsigned I = 0) const {
86     if (Shapes.empty())
87       return Col;
88     assert(Shapes.size() / 2 >= I && "Get invalid col from id!");
89     return Shapes[I * 2 + 1];
90   }
91 
92   int64_t getRowImm(unsigned I = 0) const {
93     if (ImmShapes.empty())
94       return RowImm;
95     assert(ImmShapes.size() / 2 >= I && "Get invalid imm row from id!");
96     return ImmShapes[I * 2];
97   }
98 
99   int64_t getColImm(unsigned I = 0) const {
100     if (ImmShapes.empty())
101       return ColImm;
102     assert(ImmShapes.size() / 2 >= I && "Get invalid imm col from id!");
103     return ImmShapes[I * 2 + 1];
104   }
105 
106   unsigned getShapeNum() {
107     if (Shapes.empty())
108       return isValid() ? 1 : 0;
109     else
110       return Shapes.size() / 2;
111   }
112 
113   bool isValid() { return (Row != nullptr) && (Col != nullptr); }
114 
115   void deduceImm(const MachineRegisterInfo *MRI) {
116     // All def must be the same value, otherwise it is invalid MIs.
117     // Find the immediate.
118     // TODO copy propagation.
119     auto GetImm = [&](Register Reg) {
120       int64_t Imm = InvalidImmShape;
121       for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
122         const auto *MI = DefMO.getParent();
123         if (MI->isMoveImmediate()) {
124           assert(MI->getNumOperands() == 2 &&
125                  "Unsupported number of operands in instruction for setting "
126                  "row/column.");
127           if (MI->getOperand(1).isImm()) {
128             Imm = MI->getOperand(1).getImm();
129           } else {
130             assert(MI->getOperand(1).isImplicit() &&
131                    "Operand 1 is assumed to be implicit.");
132             Imm = 0;
133           }
134           break;
135         }
136       }
137       return Imm;
138     };
139     if (Shapes.empty()) { // Single Shape
140       RowImm = GetImm(Row->getReg());
141       ColImm = GetImm(Col->getReg());
142       // The number of rows of 2nd destination buffer is assigned by the one of
143       // 1st destination buffer. If the column size is equal to zero, the row
144       // size should be reset to zero too.
145       if (ColImm == 0)
146         Row = Col;
147     } else { // Multiple Shapes
148       for (auto *Shape : Shapes) {
149         int64_t ImmShape = GetImm(Shape->getReg());
150         ImmShapes.push_back(ImmShape);
151       }
152     }
153   }
154 
155 private:
156   static constexpr int64_t InvalidImmShape = -1;
157   MachineOperand *Row;
158   MachineOperand *Col;
159   int64_t RowImm = -1;
160   int64_t ColImm = -1;
161   // Multiple Shapes
162   SmallVector<MachineOperand *, 0> Shapes;
163   SmallVector<int64_t, 0> ImmShapes;
164 };
165 
166 } // namespace llvm
167 
168 #endif
169