Skip to content

Schemas

__all__ = ['BaseSchema', 'PolymorphicSchema', 'Bytes'] module-attribute

BaseSchema

Bases: Schema

Source code in griptape/schemas/base_schema.py
class BaseSchema(Schema):
    class Meta:
        unknown = INCLUDE

    DATACLASS_TYPE_MAPPING = {**Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes}

    @classmethod
    def from_attrs_cls(cls, attrs_cls: type) -> type:
        """Generate a Schema from an attrs class.

        Args:
            attrs_cls: An attrs class.
        """
        from marshmallow import post_load
        from griptape.mixins import SerializableMixin

        class SubSchema(cls):
            @post_load
            def make_obj(self, data, **kwargs):
                return attrs_cls(**data)

        if issubclass(attrs_cls, SerializableMixin):
            cls._resolve_types(attrs_cls)
            return SubSchema.from_dict(
                {
                    a.name: cls._get_field_for_type(a.type)
                    for a in attrs.fields(attrs_cls)
                    if a.metadata.get("serializable")
                },
                name=f"{attrs_cls.__name__}Schema",
            )
        else:
            raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")

    @classmethod
    def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested:
        """Generate a marshmallow Field instance from a Python type.

        Args:
            field_type: A field type.
        """
        from griptape.schemas.polymorphic_schema import PolymorphicSchema

        field_class, args, optional = cls._get_field_type_info(field_type)

        if attrs.has(field_class):
            if ABC in field_class.__bases__:
                return fields.Nested(PolymorphicSchema(inner_class=field_class), allow_none=optional)
            else:
                return fields.Nested(cls.from_attrs_cls(field_type), allow_none=optional)
        elif cls.is_list_sequence(field_class):
            if args:
                return fields.List(cls_or_instance=cls._get_field_for_type(args[0]), allow_none=optional)
            else:
                raise ValueError(f"Missing type for list field: {field_type}")
        else:
            FieldClass = cls.DATACLASS_TYPE_MAPPING[field_class]

            return FieldClass(allow_none=optional)

    @classmethod
    def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]:
        """Get information about a field type.

        Args:
            field_type: A field type.
        """
        origin = get_origin(field_type) or field_type
        args = get_args(field_type)
        optional = False

        if origin is Union:
            origin = args[0]
            if len(args) > 1 and args[1] is type(None):
                optional = True

            origin, args, _ = cls._get_field_type_info(origin)
        elif origin is Literal:
            origin = type(args[0])
            args = ()

        return origin, args, optional

    @classmethod
    def _resolve_types(cls, attrs_cls: type) -> None:
        """Resolve types in an attrs class.

        Args:
            attrs_cls: An attrs class.
        """
        from griptape.utils.import_utils import import_optional_dependency, is_dependency_installed

        # These modules are required to avoid `NameError`s when resolving types.
        from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver
        from griptape.structures import Structure
        from griptape.utils import PromptStack
        from griptape.tokenizers.base_tokenizer import BaseTokenizer
        from typing import Any

        boto3 = import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any
        Client = import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any

        attrs.resolve_types(
            attrs_cls,
            localns={
                "PromptStack": PromptStack,
                "Input": PromptStack.Input,
                "Structure": Structure,
                "BaseConversationMemoryDriver": BaseConversationMemoryDriver,
                "BasePromptDriver": BasePromptDriver,
                "BaseTokenizer": BaseTokenizer,
                "boto3": boto3,
                "Client": Client,
            },
        )

    @classmethod
    def is_list_sequence(cls, field_type: type) -> bool:
        if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple):
            return False
        else:
            return issubclass(field_type, Sequence)

DATACLASS_TYPE_MAPPING = {None: Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes} class-attribute instance-attribute

Meta

Source code in griptape/schemas/base_schema.py
class Meta:
    unknown = INCLUDE
unknown = INCLUDE class-attribute instance-attribute

from_attrs_cls(attrs_cls) classmethod

Generate a Schema from an attrs class.

Parameters:

Name Type Description Default
attrs_cls type

An attrs class.

required
Source code in griptape/schemas/base_schema.py
@classmethod
def from_attrs_cls(cls, attrs_cls: type) -> type:
    """Generate a Schema from an attrs class.

    Args:
        attrs_cls: An attrs class.
    """
    from marshmallow import post_load
    from griptape.mixins import SerializableMixin

    class SubSchema(cls):
        @post_load
        def make_obj(self, data, **kwargs):
            return attrs_cls(**data)

    if issubclass(attrs_cls, SerializableMixin):
        cls._resolve_types(attrs_cls)
        return SubSchema.from_dict(
            {
                a.name: cls._get_field_for_type(a.type)
                for a in attrs.fields(attrs_cls)
                if a.metadata.get("serializable")
            },
            name=f"{attrs_cls.__name__}Schema",
        )
    else:
        raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")

is_list_sequence(field_type) classmethod

Source code in griptape/schemas/base_schema.py
@classmethod
def is_list_sequence(cls, field_type: type) -> bool:
    if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple):
        return False
    else:
        return issubclass(field_type, Sequence)

Bytes

Bases: Field

Source code in griptape/schemas/bytes_field.py
class Bytes(fields.Field):
    def _serialize(self, value, attr, obj, **kwargs):
        return base64.b64encode(value).decode()

    def _deserialize(self, value, attr, data, **kwargs):
        return base64.b64decode(value)

    def _validate(self, value):
        if not isinstance(value, bytes):
            raise ValidationError("Invalid input type.")

PolymorphicSchema

Bases: BaseSchema

PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema

Source code in griptape/schemas/polymorphic_schema.py
class PolymorphicSchema(BaseSchema):
    """
    PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema
    """

    def __init__(self, inner_class: Any, **kwargs):
        super().__init__(**kwargs)

        self.inner_class = inner_class

    type_field = "type"
    type_field_remove = True

    def get_obj_type(self, obj):
        """Returns name of the schema during dump() calls, given the object
        being dumped."""
        return obj.__class__.__name__

    def get_data_type(self, data):
        """Returns name of the schema during load() calls, given the data being
        loaded. Defaults to looking up `type_field` in the data."""
        data_type = data.get(self.type_field)
        if self.type_field in data and self.type_field_remove:
            data.pop(self.type_field)
        return data_type

    def dump(self, obj, *, many=None, **kwargs):
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if not many:
            result = result_data = self._dump(obj, **kwargs)
        else:
            for idx, o in enumerate(obj):
                try:
                    result = self._dump(o, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()
                    result_data.append(error.valid_data)

        result = result_data
        errors = result_errors

        if not errors:
            return result
        else:
            exc = ValidationError(errors, data=obj, valid_data=result)  # pyright: ignore
            raise exc

    def _dump(self, obj, *, update_fields=True, **kwargs):
        obj_type = self.get_obj_type(obj)

        if not obj_type:
            return (None, {"_schema": "Unknown object class: %s" % obj.__class__.__name__})

        type_schema = BaseSchema.from_attrs_cls(obj.__class__)

        if not type_schema:
            return None, {"_schema": "Unsupported object type: %s" % obj_type}

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        result = schema.dump(obj, many=False, **kwargs)

        if result is not None:
            result[self.type_field] = obj_type  # pyright: ignore

        return result

    def load(self, data, *, many=None, partial=None, unknown=None, **kwargs):
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if partial is None:
            partial = self.partial
        if not many:
            try:
                result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs)
                #  result_data.append(result)
            except ValidationError as error:
                result_errors = error.normalized_messages()
                result_data.append(error.valid_data)
        else:
            for idx, item in enumerate(data):
                try:
                    result = self._load(item, partial=partial, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()
                    result_data.append(error.valid_data)

        result = result_data
        errors = result_errors

        if not errors:
            return result
        else:
            exc = ValidationError(errors, data=data, valid_data=result)
            raise exc

    def _load(self, data, *, partial=None, unknown=None, **kwargs):
        if not isinstance(data, dict):
            raise ValidationError({"_schema": "Invalid data type: %s" % data})

        data = dict(data)
        unknown = unknown or self.unknown
        data_type = self.get_data_type(data)

        if data_type is None:
            raise ValidationError({self.type_field: ["Missing data for required field."]})

        type_schema = self.inner_class.get_schema(data_type)
        if not type_schema:
            raise ValidationError({self.type_field: ["Unsupported value: %s" % data_type]})

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)

    def validate(self, data, *, many=None, partial=None):  # pyright: ignore
        try:
            self.load(data, many=many, partial=partial)
        except ValidationError as ve:
            return ve.messages
        return {}

inner_class = inner_class instance-attribute

type_field = 'type' class-attribute instance-attribute

type_field_remove = True class-attribute instance-attribute

__init__(inner_class, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def __init__(self, inner_class: Any, **kwargs):
    super().__init__(**kwargs)

    self.inner_class = inner_class

dump(obj, *, many=None, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def dump(self, obj, *, many=None, **kwargs):
    errors = {}
    result_data = []
    result_errors = {}
    many = self.many if many is None else bool(many)
    if not many:
        result = result_data = self._dump(obj, **kwargs)
    else:
        for idx, o in enumerate(obj):
            try:
                result = self._dump(o, **kwargs)
                result_data.append(result)
            except ValidationError as error:
                result_errors[idx] = error.normalized_messages()
                result_data.append(error.valid_data)

    result = result_data
    errors = result_errors

    if not errors:
        return result
    else:
        exc = ValidationError(errors, data=obj, valid_data=result)  # pyright: ignore
        raise exc

get_data_type(data)

Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up type_field in the data.

Source code in griptape/schemas/polymorphic_schema.py
def get_data_type(self, data):
    """Returns name of the schema during load() calls, given the data being
    loaded. Defaults to looking up `type_field` in the data."""
    data_type = data.get(self.type_field)
    if self.type_field in data and self.type_field_remove:
        data.pop(self.type_field)
    return data_type

get_obj_type(obj)

Returns name of the schema during dump() calls, given the object being dumped.

Source code in griptape/schemas/polymorphic_schema.py
def get_obj_type(self, obj):
    """Returns name of the schema during dump() calls, given the object
    being dumped."""
    return obj.__class__.__name__

load(data, *, many=None, partial=None, unknown=None, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def load(self, data, *, many=None, partial=None, unknown=None, **kwargs):
    errors = {}
    result_data = []
    result_errors = {}
    many = self.many if many is None else bool(many)
    if partial is None:
        partial = self.partial
    if not many:
        try:
            result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs)
            #  result_data.append(result)
        except ValidationError as error:
            result_errors = error.normalized_messages()
            result_data.append(error.valid_data)
    else:
        for idx, item in enumerate(data):
            try:
                result = self._load(item, partial=partial, **kwargs)
                result_data.append(result)
            except ValidationError as error:
                result_errors[idx] = error.normalized_messages()
                result_data.append(error.valid_data)

    result = result_data
    errors = result_errors

    if not errors:
        return result
    else:
        exc = ValidationError(errors, data=data, valid_data=result)
        raise exc

validate(data, *, many=None, partial=None)

Source code in griptape/schemas/polymorphic_schema.py
def validate(self, data, *, many=None, partial=None):  # pyright: ignore
    try:
        self.load(data, many=many, partial=partial)
    except ValidationError as ve:
        return ve.messages
    return {}