1"""Helper library to traverse data emitted for Rust enums """ 2from lldbsuite.test.lldbtest import * 3 4DISCRIMINANT_MEMBER_NAME = "$discr$" 5VALUE_MEMBER_NAME = "value" 6 7class RustEnumValue: 8 9 def __init__(self, value: lldb.SBValue): 10 self.value = value 11 12 def getAllVariantTypes(self): 13 result = [] 14 for i in range(self._inner().GetNumChildren()): 15 result.append(self.getVariantByIndex(i).GetDisplayTypeName()) 16 return result 17 18 def _inner(self) -> lldb.SBValue: 19 return self.value.GetChildAtIndex(0) 20 21 def getVariantByIndex(self, index): 22 return self._inner().GetChildAtIndex(index).GetChildMemberWithName(VALUE_MEMBER_NAME) 23 24 @staticmethod 25 def _getDiscriminantValueAsUnsigned(discr_sbvalue: lldb.SBValue): 26 byte_size = discr_sbvalue.GetType().GetByteSize() 27 error = lldb.SBError() 28 29 # when discriminant is u16 Clang emits 'unsigned char' 30 # and LLDB seems to treat it as character type disalowing to call GetValueAsUnsigned 31 if byte_size == 1: 32 return discr_sbvalue.GetData().GetUnsignedInt8(error, 0) 33 elif byte_size == 2: 34 return discr_sbvalue.GetData().GetUnsignedInt16(error, 0) 35 elif byte_size == 4: 36 return discr_sbvalue.GetData().GetUnsignedInt32(error, 0) 37 elif byte_size == 8: 38 return discr_sbvalue.GetData().GetUnsignedInt64(error, 0) 39 else: 40 return discr_sbvalue.GetValueAsUnsigned() 41 42 def getCurrentVariantIndex(self): 43 default_index = 0 44 for i in range(self._inner().GetNumChildren()): 45 variant: lldb.SBValue = self._inner().GetChildAtIndex(i); 46 discr = variant.GetChildMemberWithName(DISCRIMINANT_MEMBER_NAME) 47 if discr.IsValid(): 48 discr_unsigned_value = RustEnumValue._getDiscriminantValueAsUnsigned(discr) 49 if variant.GetName() == f"$variant${discr_unsigned_value}": 50 return discr_unsigned_value 51 else: 52 default_index = i 53 return default_index 54 55 def getFields(self): 56 result = [] 57 for i in range(self._inner().GetNumChildren()): 58 type: lldb.SBType = self._inner().GetType() 59 result.append(type.GetFieldAtIndex(i).GetName()) 60 return result 61 62 def getCurrentValue(self) -> lldb.SBValue: 63 return self.getVariantByIndex(self.getCurrentVariantIndex()) 64