xref: /llvm-project/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (revision 2ce168baed02c7a6fdb039f4a2d9e48dee31e5c9)
1 //===-- AArch64SMEAttributes.h - Helper for interpreting SME attributes -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
10 #define LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
11 
12 #include "llvm/IR/Function.h"
13 
14 namespace llvm {
15 
16 class Function;
17 class CallBase;
18 class AttributeList;
19 
20 /// SMEAttrs is a utility class to parse the SME ACLE attributes on functions.
21 /// It helps determine a function's requirements for PSTATE.ZA and PSTATE.SM. It
22 /// has interfaces to query whether a streaming mode change or lazy-save
23 /// mechanism is required when going from one function to another (e.g. through
24 /// a call).
25 class SMEAttrs {
26   unsigned Bitmask;
27 
28 public:
29   enum class StateValue {
30     None = 0,
31     In = 1,        // aarch64_in_zt0
32     Out = 2,       // aarch64_out_zt0
33     InOut = 3,     // aarch64_inout_zt0
34     Preserved = 4, // aarch64_preserves_zt0
35     New = 5        // aarch64_new_zt0
36   };
37 
38   // Enum with bitmasks for each individual SME feature.
39   enum Mask {
40     Normal = 0,
41     SM_Enabled = 1 << 0,      // aarch64_pstate_sm_enabled
42     SM_Compatible = 1 << 1,   // aarch64_pstate_sm_compatible
43     SM_Body = 1 << 2,         // aarch64_pstate_sm_body
44     SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
45     ZA_State_Agnostic = 1 << 4,
46     ZA_Shift = 5,
47     ZA_Mask = 0b111 << ZA_Shift,
48     ZT0_Shift = 8,
49     ZT0_Mask = 0b111 << ZT0_Shift
50   };
51 
52   SMEAttrs(unsigned Mask = Normal) : Bitmask(0) { set(Mask); }
53   SMEAttrs(const Function &F) : SMEAttrs(F.getAttributes()) {}
54   SMEAttrs(const CallBase &CB);
55   SMEAttrs(const AttributeList &L);
56   SMEAttrs(StringRef FuncName);
57 
58   void set(unsigned M, bool Enable = true);
59 
60   // Interfaces to query PSTATE.SM
61   bool hasStreamingBody() const { return Bitmask & SM_Body; }
62   bool hasStreamingInterface() const { return Bitmask & SM_Enabled; }
63   bool hasStreamingInterfaceOrBody() const {
64     return hasStreamingBody() || hasStreamingInterface();
65   }
66   bool hasStreamingCompatibleInterface() const {
67     return Bitmask & SM_Compatible;
68   }
69   bool hasNonStreamingInterface() const {
70     return !hasStreamingInterface() && !hasStreamingCompatibleInterface();
71   }
72   bool hasNonStreamingInterfaceAndBody() const {
73     return hasNonStreamingInterface() && !hasStreamingBody();
74   }
75 
76   /// \return true if a call from Caller -> Callee requires a change in
77   /// streaming mode.
78   bool requiresSMChange(const SMEAttrs &Callee) const;
79 
80   // Interfaces to query ZA
81   static StateValue decodeZAState(unsigned Bitmask) {
82     return static_cast<StateValue>((Bitmask & ZA_Mask) >> ZA_Shift);
83   }
84   static unsigned encodeZAState(StateValue S) {
85     return static_cast<unsigned>(S) << ZA_Shift;
86   }
87 
88   bool isNewZA() const { return decodeZAState(Bitmask) == StateValue::New; }
89   bool isInZA() const { return decodeZAState(Bitmask) == StateValue::In; }
90   bool isOutZA() const { return decodeZAState(Bitmask) == StateValue::Out; }
91   bool isInOutZA() const { return decodeZAState(Bitmask) == StateValue::InOut; }
92   bool isPreservesZA() const {
93     return decodeZAState(Bitmask) == StateValue::Preserved;
94   }
95   bool sharesZA() const {
96     StateValue State = decodeZAState(Bitmask);
97     return State == StateValue::In || State == StateValue::Out ||
98            State == StateValue::InOut || State == StateValue::Preserved;
99   }
100   bool hasAgnosticZAInterface() const { return Bitmask & ZA_State_Agnostic; }
101   bool hasSharedZAInterface() const { return sharesZA() || sharesZT0(); }
102   bool hasPrivateZAInterface() const {
103     return !hasSharedZAInterface() && !hasAgnosticZAInterface();
104   }
105   bool hasZAState() const { return isNewZA() || sharesZA(); }
106   bool requiresLazySave(const SMEAttrs &Callee) const {
107     return hasZAState() && Callee.hasPrivateZAInterface() &&
108            !(Callee.Bitmask & SME_ABI_Routine);
109   }
110 
111   // Interfaces to query ZT0 State
112   static StateValue decodeZT0State(unsigned Bitmask) {
113     return static_cast<StateValue>((Bitmask & ZT0_Mask) >> ZT0_Shift);
114   }
115   static unsigned encodeZT0State(StateValue S) {
116     return static_cast<unsigned>(S) << ZT0_Shift;
117   }
118 
119   bool isNewZT0() const { return decodeZT0State(Bitmask) == StateValue::New; }
120   bool isInZT0() const { return decodeZT0State(Bitmask) == StateValue::In; }
121   bool isOutZT0() const { return decodeZT0State(Bitmask) == StateValue::Out; }
122   bool isInOutZT0() const {
123     return decodeZT0State(Bitmask) == StateValue::InOut;
124   }
125   bool isPreservesZT0() const {
126     return decodeZT0State(Bitmask) == StateValue::Preserved;
127   }
128   bool sharesZT0() const {
129     StateValue State = decodeZT0State(Bitmask);
130     return State == StateValue::In || State == StateValue::Out ||
131            State == StateValue::InOut || State == StateValue::Preserved;
132   }
133   bool hasZT0State() const { return isNewZT0() || sharesZT0(); }
134   bool requiresPreservingZT0(const SMEAttrs &Callee) const {
135     return hasZT0State() && !Callee.sharesZT0() &&
136            !Callee.hasAgnosticZAInterface();
137   }
138   bool requiresDisablingZABeforeCall(const SMEAttrs &Callee) const {
139     return hasZT0State() && !hasZAState() && Callee.hasPrivateZAInterface() &&
140            !(Callee.Bitmask & SME_ABI_Routine);
141   }
142   bool requiresEnablingZAAfterCall(const SMEAttrs &Callee) const {
143     return requiresLazySave(Callee) || requiresDisablingZABeforeCall(Callee);
144   }
145   bool requiresPreservingAllZAState(const SMEAttrs &Callee) const {
146     return hasAgnosticZAInterface() && !Callee.hasAgnosticZAInterface() &&
147            !(Callee.Bitmask & SME_ABI_Routine);
148   }
149 };
150 
151 } // namespace llvm
152 
153 #endif // LLVM_LIB_TARGET_AARCH64_UTILS_AARCH64SMEATTRIBUTES_H
154