"""
StructDataclass: Base class for auto-decoding/encoding struct-like dataclasses.
"""
import inspect
import re
import struct
from copy import deepcopy
from dataclasses import dataclass, field, is_dataclass
from pystructtype.structtypes import iterate_types
@dataclass
[docs]
class StructState:
"""
Contains necessary struct information to correctly
decode and encode the data in a StructDataclass
"""
[docs]
class StructDataclass:
"""
Class that will auto-magically decode and encode data for the defined
subclass.
"""
[docs]
def __init_subclass__(cls: type[StructDataclass], **kwargs: object) -> None:
"""
Automatically configure the subclass as a dataclass and set up default values for fields.
Handles special logic for list and non-list fields, default factories, and class variables.
"""
super().__init_subclass__(**kwargs)
# If the class is already a dataclass, skip
if is_dataclass(cls):
return
# Make sure any fields without a default have one
for type_iterator in iterate_types(cls):
if not type_iterator.is_pystructtype and not inspect.isclass(type_iterator.base_type):
continue
if not type_iterator.type_meta or type_iterator.type_meta.size == 1:
if type_iterator.is_list:
raise ValueError(f"Attribute {type_iterator.key} is defined as a list type but has size set to 1")
if not getattr(cls, type_iterator.key, None):
default = type_iterator.base_type
if type_iterator.type_meta:
if type_iterator.type_meta.default is not None:
default = type_iterator.type_meta.default
if isinstance(default, list):
raise TypeError(f"default value for {type_iterator.key} attribute can not be a list")
if inspect.isclass(default):
default = field(default_factory=default)
setattr(cls, type_iterator.key, default)
continue
if inspect.isclass(default):
default = field(default_factory=default)
else:
default = field(default_factory=lambda d=default: deepcopy(d)) # type: ignore
setattr(cls, type_iterator.key, default)
else:
if not type_iterator.is_list:
raise ValueError(f"Attribute {type_iterator.key} is not a list type but has a size > 1")
if type_iterator.type_meta and type_iterator.type_meta.default:
default = type_iterator.type_meta.default
if isinstance(default, list):
default_tuple = tuple(deepcopy(default))
default_list = field(default_factory=lambda d=default_tuple: list(d)) # type: ignore
elif inspect.isclass(default):
default_list = field(
default_factory=lambda d=default, s=type_iterator.type_meta.size: [ # type: ignore
d() for _ in range(s)
]
)
else:
default_list = field(
default_factory=lambda d=default, s=type_iterator.type_meta.size: [ # type: ignore
deepcopy(d) for _ in range(s)
]
)
else:
default = type_iterator.base_type
default_list = field(
default_factory=lambda d=default, s=type_iterator.type_meta.size: [ # type: ignore
d() for _ in range(s)
]
)
setattr(cls, type_iterator.key, default_list)
dataclass(cls)
[docs]
def __post_init__(self) -> None:
"""
Initialize instance state and struct format after dataclass construction.
Computes struct format string and byte length for encoding/decoding.
"""
self._state: list[StructState] = []
# Grab Struct Format
self.struct_fmt = ""
for type_iterator in iterate_types(self.__class__):
if type_iterator.type_info:
self._state.append(
StructState(
type_iterator.key,
type_iterator.type_info.format,
type_iterator.size,
type_iterator.chunk_size,
)
)
_fmt_prefix = type_iterator.chunk_size if type_iterator.chunk_size > 1 else ""
self.struct_fmt += f"{_fmt_prefix}{type_iterator.type_info.format}" * type_iterator.size
elif inspect.isclass(type_iterator.base_type) and issubclass(type_iterator.base_type, StructDataclass):
attr = getattr(self, type_iterator.key)
if type_iterator.is_list:
fmt = attr[0].struct_fmt
else:
fmt = attr.struct_fmt
self._state.append(StructState(type_iterator.key, fmt, type_iterator.size, type_iterator.chunk_size))
self.struct_fmt += fmt * type_iterator.size
else:
# We have no TypeInfo object, and we're not a StructDataclass
# This means we're a regularly defined class variable, and we
# Don't have to do anything about this.
pass
self._simplify_format()
self._byte_length = struct.calcsize("=" + self.struct_fmt)
[docs]
def size(self) -> int:
"""
The size of this struct is defined as the sum of the sizes of all attributes
:return: Combined size of the struct
"""
return sum(state.size for state in self._state)
@staticmethod
[docs]
def _endian(little_endian: bool) -> str:
"""
Return "<" or ">" depending on endianness, to pass to struct decode/encode
:param little_endian: True if we expect little_endian, else False
:return: "<" if little_endian else ">"
"""
return "<" if little_endian else ">"
@staticmethod
[docs]
def _to_bytes(data: list[int] | bytes) -> bytes:
"""
Convert a list of ints into bytes
:param data: a list of ints or a bytes object
:return: a bytes object
"""
if isinstance(data, bytes):
return data
return bytes(data)
@staticmethod
[docs]
def _to_list(data: list[int] | bytes) -> list[int]:
"""
Convert a bytes object into a list of ints
:param data: a list of ints or a bytes object
:return: a list of ints
"""
if isinstance(data, bytes):
return list(data)
return data
[docs]
def _decode(self, data: list[int]) -> None:
"""
Internal decoding function for the StructDataclass.
Extend this function if you wish to add extra processing to your StructDataclass decoding processing
:param data: A list of ints to decode into the StructDataclass
"""
idx = 0
for state in self._state:
attr = getattr(self, state.name)
if isinstance(attr, list) and isinstance(attr[0], StructDataclass):
# If the current attribute is a list, and contains subclasses of StructDataclass
# Call _decode on the required subset of bytes for each list item
list_idx = 0
sub_struct_byte_length = attr[0].size()
while list_idx < state.size:
instance: StructDataclass = attr[list_idx]
instance._decode(data[idx : idx + sub_struct_byte_length])
list_idx += 1
idx += sub_struct_byte_length
elif isinstance(attr, StructDataclass):
# If the current attribute is not a list, and is a subclass of StructDataclass
# Call _decode on the required subset of bytes for the item
sub_struct_byte_length = attr.size()
attr._decode(data[idx : idx + sub_struct_byte_length])
idx += sub_struct_byte_length
elif state.size == 1:
# The current attribute is a base type of size 1
setattr(self, state.name, data[idx])
idx += 1
else:
# The current attribute is a list of base types
list_idx = 0
while list_idx < state.size:
getattr(self, state.name)[list_idx] = data[idx]
list_idx += 1
idx += 1
[docs]
def decode(self, data: list[int] | bytes, little_endian: bool = False) -> None:
"""
Decode the given data into this subclass of StructDataclass
:param data: list of ints or a bytes object
:param little_endian: True if decoding little_endian formatted data, else False
:raises ValueError: If the input data is not the correct length for the struct
"""
data = self._to_bytes(data)
expected_len = struct.calcsize(self._endian(little_endian) + self.struct_fmt)
if len(data) != expected_len:
raise ValueError(f"Input data length {len(data)} does not match expected struct size {expected_len}")
# Decode
self._decode(list(struct.unpack(self._endian(little_endian) + self.struct_fmt, data)))
[docs]
def _encode(self) -> list[int]:
"""
Internal encoding function for the StructDataclass.
Extend this function if you wish to add extra processing to your StructDataclass encoding processing
:return: list of encoded int data
"""
result: list[int] = []
for state in self._state:
attr = getattr(self, state.name)
if isinstance(attr, list) and isinstance(attr[0], StructDataclass):
# Attribute is a list of StructDataclass subclasses.
# Simply call _encode on each item in the list
item: StructDataclass
for item in attr:
result.extend(item._encode())
elif isinstance(attr, StructDataclass):
# Attribute is a StructDataclass subclass
# Call _encode on it
result.extend(attr._encode())
elif state.size == 1:
# Attribute is a single base type
# Append it to the result
result.append(getattr(self, state.name))
else:
# Attribute is a list of base types
# Extend it to the result
result.extend(getattr(self, state.name))
return result
[docs]
def encode(self, little_endian: bool = False) -> bytes:
"""
Encode the data from this subclass of StructDataclass into bytes
:param little_endian: True if encoding little_endian formatted data, else False
:return: encoded bytes
"""
result = self._encode()
return struct.pack(self._endian(little_endian) + self.struct_fmt, *result)