Skip to content

Serializable mixin

T = TypeVar('T', bound='SerializableMixin') module-attribute

SerializableMixin

Bases: Generic[T]

Source code in griptape/mixins/serializable_mixin.py
@define(slots=False)
class SerializableMixin(Generic[T]):
    type: str = field(
        default=Factory(lambda self: self.__class__.__name__, takes_self=True),
        kw_only=True,
        metadata={"serializable": True},
    )

    @classmethod
    def get_schema(cls: type[T], subclass_name: Optional[str] = None) -> Schema:
        """Generates a Marshmallow schema for the class.

        Args:
            subclass_name: An optional subclass name. Required if the class is abstract.
        """
        if ABC in cls.__bases__:
            if subclass_name is None:
                raise ValueError(f"Type field is required for abstract class: {cls.__name__}")

            subclass_cls = cls._import_cls_rec(cls.__module__, subclass_name)

            schema_class = BaseSchema.from_attrs_cls(subclass_cls)
        else:
            schema_class = BaseSchema.from_attrs_cls(cls)

        return schema_class()

    @classmethod
    def from_dict(cls: type[T], data: dict) -> T:
        return cast(T, cls.get_schema(subclass_name=data["type"] if "type" in data else None).load(data))

    @classmethod
    def from_json(cls: type[T], data: str) -> T:
        return cls.from_dict(json.loads(data))

    def __str__(self) -> str:
        return json.dumps(self.to_dict())

    def to_json(self) -> str:
        return json.dumps(self.to_dict())

    def to_dict(self) -> dict:
        schema = BaseSchema.from_attrs_cls(self.__class__)

        return dict(schema().dump(self))

    @classmethod
    def _import_cls_rec(cls, module_name: str, class_name: str) -> type:
        """Imports a class given a module name and class name.
        Will recursively traverse up the module's path until it finds a
        package that it can import `class_name` from.

        Args:
            module_name: The module name.
            class_name: The class name.

        Returns:
            The imported class if found. Raises `ValueError` if not found.
        """
        try:
            module = import_module(module_name)
            test = getattr(module, class_name, None)
        except ModuleNotFoundError:
            test = None

        if test is None:
            module_dirs = module_name.split(".")[:-1]
            module_name = ".".join(module_dirs)

            if not len(module_dirs):
                raise ValueError(f"Unable to import class: {class_name}")
            return cls._import_cls_rec(module_name, class_name)
        else:
            return test

type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__str__()

Source code in griptape/mixins/serializable_mixin.py
def __str__(self) -> str:
    return json.dumps(self.to_dict())

from_dict(data) classmethod

Source code in griptape/mixins/serializable_mixin.py
@classmethod
def from_dict(cls: type[T], data: dict) -> T:
    return cast(T, cls.get_schema(subclass_name=data["type"] if "type" in data else None).load(data))

from_json(data) classmethod

Source code in griptape/mixins/serializable_mixin.py
@classmethod
def from_json(cls: type[T], data: str) -> T:
    return cls.from_dict(json.loads(data))

get_schema(subclass_name=None) classmethod

Generates a Marshmallow schema for the class.

Parameters:

Name Type Description Default
subclass_name Optional[str]

An optional subclass name. Required if the class is abstract.

None
Source code in griptape/mixins/serializable_mixin.py
@classmethod
def get_schema(cls: type[T], subclass_name: Optional[str] = None) -> Schema:
    """Generates a Marshmallow schema for the class.

    Args:
        subclass_name: An optional subclass name. Required if the class is abstract.
    """
    if ABC in cls.__bases__:
        if subclass_name is None:
            raise ValueError(f"Type field is required for abstract class: {cls.__name__}")

        subclass_cls = cls._import_cls_rec(cls.__module__, subclass_name)

        schema_class = BaseSchema.from_attrs_cls(subclass_cls)
    else:
        schema_class = BaseSchema.from_attrs_cls(cls)

    return schema_class()

to_dict()

Source code in griptape/mixins/serializable_mixin.py
def to_dict(self) -> dict:
    schema = BaseSchema.from_attrs_cls(self.__class__)

    return dict(schema().dump(self))

to_json()

Source code in griptape/mixins/serializable_mixin.py
def to_json(self) -> str:
    return json.dumps(self.to_dict())