Skip to content

schemas

__all__ = ['BaseSchema', 'PolymorphicSchema', 'Bytes', 'Union'] module-attribute

BaseSchema

Bases: Schema

Source code in griptape/schemas/base_schema.py
class BaseSchema(Schema):
    class Meta:
        unknown = INCLUDE

    DATACLASS_TYPE_MAPPING = {**Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw}

    @classmethod
    def from_attrs_cls(cls, attrs_cls: type) -> type:
        """Generate a Schema from an attrs class.

        Args:
            attrs_cls: An attrs class.
        """
        from marshmallow import post_load

        from griptape.mixins.serializable_mixin import SerializableMixin

        class SubSchema(cls):
            @post_load
            def make_obj(self, data: Any, **kwargs) -> Any:
                return attrs_cls(**data)

        if issubclass(attrs_cls, SerializableMixin):
            cls._resolve_types(attrs_cls)
            return SubSchema.from_dict(
                {
                    a.alias or a.name: cls._get_field_for_type(a.type)
                    for a in attrs.fields(attrs_cls)
                    if a.metadata.get("serializable")
                },
                name=f"{attrs_cls.__name__}Schema",
            )
        else:
            raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")

    @classmethod
    def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested:
        """Generate a marshmallow Field instance from a Python type.

        Args:
            field_type: A field type.
        """
        from griptape.schemas.polymorphic_schema import PolymorphicSchema

        field_class, args, optional = cls._get_field_type_info(field_type)

        # Resolve TypeVars to their bound type
        if isinstance(field_class, TypeVar):
            field_class = field_class.__bound__
        if field_class is None:
            return fields.Constant(None, allow_none=True)
        if cls._is_union(field_type):
            return cls._handle_union(field_type, optional=optional)
        elif attrs.has(field_class):
            schema = PolymorphicSchema if ABC in field_class.__bases__ else cls.from_attrs_cls
            return fields.Nested(schema(field_class), allow_none=optional)
        elif cls._is_enum(field_type):
            return fields.String(allow_none=optional)
        elif cls._is_list_sequence(field_class):
            if args:
                return cls._handle_list(args[0], optional=optional)
            else:
                raise ValueError(f"Missing type for list field: {field_type}")
        field_class = cls.DATACLASS_TYPE_MAPPING.get(field_class)
        if field_class is None:
            raise ValueError(f"Unsupported field type: {field_type}")
        return field_class(allow_none=optional)

    @classmethod
    def _handle_list(cls, list_type: type, *, optional: bool) -> fields.Field:
        """Handle List Fields, including Union Types.

        Args:
            list_type: The List type to handle.
            optional: Whether the List can be none.

        Returns:
            A marshmallow List field.
        """
        if cls._is_union(list_type):
            union_field = cls._handle_union(list_type, optional=optional)
            return fields.List(cls_or_instance=union_field, allow_none=optional)
        list_field = cls._get_field_for_type(list_type)
        if isinstance(list_field, fields.Constant) and list_field.constant is None:
            raise ValueError(f"List elements cannot be None: {list_type}")
        return fields.List(cls_or_instance=list_field, allow_none=optional)

    @classmethod
    def _handle_union(cls, union_type: type, *, optional: bool) -> fields.Field:
        """Handle Union Fields, including Unions with List Types.

        Args:
            union_type: The Union Type to handle.
            optional: Whether the Union can be None.

        Returns:
            A marshmallow Union field.
        """
        candidate_fields = [cls._get_field_for_type(arg) for arg in get_args(union_type) if arg is not type(None)]
        optional_args = [arg is None for arg in get_args(union_type)]
        if optional_args:
            optional = True
        if not candidate_fields:
            raise ValueError(f"Unsupported UnionType field: {union_type}")

        return UnionField(fields=candidate_fields, allow_none=optional)

    @classmethod
    def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]:
        """Get information about a field type.

        Args:
            field_type: A field type.
        """
        origin = get_origin(field_type) or field_type
        args = get_args(field_type)
        optional = False

        if origin is Union:
            origin = args[0]
            if len(args) > 1 and args[1] is type(None):
                optional = True

            origin, args, _ = cls._get_field_type_info(origin)
        elif origin is Literal:
            origin = type(args[0])
            args = ()

        return origin, args, optional

    @classmethod
    def _resolve_types(cls, attrs_cls: type) -> None:
        """Resolve types in an attrs class.

        Args:
            attrs_cls: An attrs class.
        """
        from collections.abc import Sequence
        from typing import Any

        from griptape.artifacts import BaseArtifact
        from griptape.common import (
            BaseDeltaMessageContent,
            BaseMessageContent,
            Message,
            PromptStack,
            Reference,
            ToolAction,
        )
        from griptape.drivers import (
            BaseAudioTranscriptionDriver,
            BaseConversationMemoryDriver,
            BaseEmbeddingDriver,
            BaseImageGenerationDriver,
            BaseImageQueryDriver,
            BasePromptDriver,
            BaseRulesetDriver,
            BaseTextToSpeechDriver,
            BaseVectorStoreDriver,
        )
        from griptape.events import EventListener
        from griptape.memory import TaskMemory
        from griptape.memory.structure import BaseConversationMemory, Run
        from griptape.memory.task.storage import BaseArtifactStorage
        from griptape.structures import Structure
        from griptape.tasks import BaseTask
        from griptape.tokenizers import BaseTokenizer
        from griptape.tools import BaseTool
        from griptape.utils import import_optional_dependency, is_dependency_installed

        attrs.resolve_types(
            attrs_cls,
            localns={
                "Any": Any,
                "BasePromptDriver": BasePromptDriver,
                "BaseImageQueryDriver": BaseImageQueryDriver,
                "BaseEmbeddingDriver": BaseEmbeddingDriver,
                "BaseVectorStoreDriver": BaseVectorStoreDriver,
                "BaseTextToSpeechDriver": BaseTextToSpeechDriver,
                "BaseAudioTranscriptionDriver": BaseAudioTranscriptionDriver,
                "BaseConversationMemoryDriver": BaseConversationMemoryDriver,
                "BaseRulesetDriver": BaseRulesetDriver,
                "BaseImageGenerationDriver": BaseImageGenerationDriver,
                "BaseArtifact": BaseArtifact,
                "PromptStack": PromptStack,
                "EventListener": EventListener,
                "BaseMessageContent": BaseMessageContent,
                "BaseDeltaMessageContent": BaseDeltaMessageContent,
                "BaseTool": BaseTool,
                "BaseTask": BaseTask,
                "Usage": Message.Usage,
                "Structure": Structure,
                "BaseTokenizer": BaseTokenizer,
                "ToolAction": ToolAction,
                "Reference": Reference,
                "Run": Run,
                "Sequence": Sequence,
                "TaskMemory": TaskMemory,
                "State": BaseTask.State,
                "BaseConversationMemory": BaseConversationMemory,
                "BaseArtifactStorage": BaseArtifactStorage,
                # Third party modules
                "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
                "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
                if is_dependency_installed("google.generativeai")
                else Any,
                "boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any,
                "Anthropic": import_optional_dependency("anthropic").Anthropic
                if is_dependency_installed("anthropic")
                else Any,
                "BedrockClient": import_optional_dependency("mypy_boto3_bedrock").BedrockClient
                if is_dependency_installed("mypy_boto3_bedrock")
                else Any,
                "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any,
            },
        )

    @classmethod
    def _is_list_sequence(cls, field_type: type | _SpecialForm) -> bool:
        if isinstance(field_type, type):
            if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple):
                return False
            else:
                return issubclass(field_type, Sequence)
        else:
            return False

    @classmethod
    def _is_union(cls, field_type: type) -> bool:
        return field_type is Union or get_origin(field_type) is Union

    @classmethod
    def _is_enum(cls, field_type: type) -> bool:
        return isinstance(field_type, type) and issubclass(field_type, Enum)

DATACLASS_TYPE_MAPPING = {None: Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw} class-attribute instance-attribute

Meta

Source code in griptape/schemas/base_schema.py
class Meta:
    unknown = INCLUDE
unknown = INCLUDE class-attribute instance-attribute

from_attrs_cls(attrs_cls) classmethod

Generate a Schema from an attrs class.

Parameters:

Name Type Description Default
attrs_cls type

An attrs class.

required
Source code in griptape/schemas/base_schema.py
@classmethod
def from_attrs_cls(cls, attrs_cls: type) -> type:
    """Generate a Schema from an attrs class.

    Args:
        attrs_cls: An attrs class.
    """
    from marshmallow import post_load

    from griptape.mixins.serializable_mixin import SerializableMixin

    class SubSchema(cls):
        @post_load
        def make_obj(self, data: Any, **kwargs) -> Any:
            return attrs_cls(**data)

    if issubclass(attrs_cls, SerializableMixin):
        cls._resolve_types(attrs_cls)
        return SubSchema.from_dict(
            {
                a.alias or a.name: cls._get_field_for_type(a.type)
                for a in attrs.fields(attrs_cls)
                if a.metadata.get("serializable")
            },
            name=f"{attrs_cls.__name__}Schema",
        )
    else:
        raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")

Bytes

Bases: Field

Source code in griptape/schemas/bytes_field.py
class Bytes(fields.Field):
    def _serialize(self, value: Any, attr: Any, obj: Any, **kwargs) -> str:
        return base64.b64encode(value).decode()

    def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs) -> bytes:
        return base64.b64decode(value)

    def _validate(self, value: Any) -> None:
        if not isinstance(value, bytes):
            raise ValidationError("Invalid input type.")

PolymorphicSchema

Bases: BaseSchema

PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema.

Source code in griptape/schemas/polymorphic_schema.py
class PolymorphicSchema(BaseSchema):
    """PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema."""

    def __init__(self, inner_class: Any, **kwargs) -> None:
        super().__init__(**kwargs)

        self.inner_class = inner_class

    type_field = "type"
    type_field_remove = True

    def get_obj_type(self, obj: Any) -> Any:
        """Returns name of the schema during dump() calls, given the object being dumped."""
        return obj.__class__.__name__

    def get_data_type(self, data: Any) -> Any:
        """Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up `type_field` in the data."""
        data_type = data.get(self.type_field)
        if self.type_field in data and self.type_field_remove:
            data.pop(self.type_field)
        return data_type

    def dump(self, obj: Any, *, many: Any = None, **kwargs) -> Any:
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if not many:
            result = result_data = self._dump(obj, **kwargs)
        else:
            for idx, o in enumerate(obj):
                try:
                    result = self._dump(o, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()
                    result_data.append(error.valid_data)

        result = result_data
        errors = result_errors

        if not errors:
            return result
        else:
            exc = ValidationError(errors, data=obj, valid_data=result)  # pyright: ignore[reportArgumentType]
            raise exc

    def _dump(self, obj: Any, *, update_fields: bool = True, **kwargs) -> Any:
        obj_type = self.get_obj_type(obj)

        if not obj_type:
            return (None, {"_schema": f"Unknown object class: {obj.__class__.__name__}"})

        type_schema = BaseSchema.from_attrs_cls(obj.__class__)

        if not type_schema:
            return None, {"_schema": f"Unsupported object type: {obj_type}"}

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        result = schema.dump(obj, many=False, **kwargs)

        if result is not None:
            result[self.type_field] = obj_type  # pyright: ignore[reportArgumentType,reportCallIssue]

        return result

    def load(self, data: Any, *, many: Any = None, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if partial is None:
            partial = self.partial
        if not many:
            try:
                result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs)
            except ValidationError as error:
                result_errors = error.normalized_messages()
                result_data.append(error.valid_data)
        else:
            for idx, item in enumerate(data):
                try:
                    result = self._load(item, partial=partial, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()
                    result_data.append(error.valid_data)

        result = result_data
        errors = result_errors

        if not errors:
            return result
        else:
            exc = ValidationError(errors, data=data, valid_data=result)
            raise exc

    def _load(self, data: Any, *, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
        if not isinstance(data, dict):
            raise ValidationError({"_schema": f"Invalid data type: {data}"})

        data = dict(data)
        unknown = unknown or self.unknown
        data_type = self.get_data_type(data)

        if data_type is None:
            raise ValidationError({self.type_field: ["Missing data for required field."]})

        type_schema = self.inner_class.get_schema(data_type, module_name=data.get("module_name"))
        if not type_schema:
            raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]})

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)

    def validate(self, data: Any, *, many: Any = None, partial: Any = None) -> Any:  # pyright: ignore[reportIncompatibleMethodOverride]
        try:
            self.load(data, many=many, partial=partial)
        except ValidationError as ve:
            return ve.messages
        return {}

inner_class = inner_class instance-attribute

type_field = 'type' class-attribute instance-attribute

type_field_remove = True class-attribute instance-attribute

__init__(inner_class, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def __init__(self, inner_class: Any, **kwargs) -> None:
    super().__init__(**kwargs)

    self.inner_class = inner_class

dump(obj, *, many=None, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def dump(self, obj: Any, *, many: Any = None, **kwargs) -> Any:
    errors = {}
    result_data = []
    result_errors = {}
    many = self.many if many is None else bool(many)
    if not many:
        result = result_data = self._dump(obj, **kwargs)
    else:
        for idx, o in enumerate(obj):
            try:
                result = self._dump(o, **kwargs)
                result_data.append(result)
            except ValidationError as error:
                result_errors[idx] = error.normalized_messages()
                result_data.append(error.valid_data)

    result = result_data
    errors = result_errors

    if not errors:
        return result
    else:
        exc = ValidationError(errors, data=obj, valid_data=result)  # pyright: ignore[reportArgumentType]
        raise exc

get_data_type(data)

Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up type_field in the data.

Source code in griptape/schemas/polymorphic_schema.py
def get_data_type(self, data: Any) -> Any:
    """Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up `type_field` in the data."""
    data_type = data.get(self.type_field)
    if self.type_field in data and self.type_field_remove:
        data.pop(self.type_field)
    return data_type

get_obj_type(obj)

Returns name of the schema during dump() calls, given the object being dumped.

Source code in griptape/schemas/polymorphic_schema.py
def get_obj_type(self, obj: Any) -> Any:
    """Returns name of the schema during dump() calls, given the object being dumped."""
    return obj.__class__.__name__

load(data, *, many=None, partial=None, unknown=None, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def load(self, data: Any, *, many: Any = None, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
    errors = {}
    result_data = []
    result_errors = {}
    many = self.many if many is None else bool(many)
    if partial is None:
        partial = self.partial
    if not many:
        try:
            result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs)
        except ValidationError as error:
            result_errors = error.normalized_messages()
            result_data.append(error.valid_data)
    else:
        for idx, item in enumerate(data):
            try:
                result = self._load(item, partial=partial, **kwargs)
                result_data.append(result)
            except ValidationError as error:
                result_errors[idx] = error.normalized_messages()
                result_data.append(error.valid_data)

    result = result_data
    errors = result_errors

    if not errors:
        return result
    else:
        exc = ValidationError(errors, data=data, valid_data=result)
        raise exc

validate(data, *, many=None, partial=None)

Source code in griptape/schemas/polymorphic_schema.py
def validate(self, data: Any, *, many: Any = None, partial: Any = None) -> Any:  # pyright: ignore[reportIncompatibleMethodOverride]
    try:
        self.load(data, many=many, partial=partial)
    except ValidationError as ve:
        return ve.messages
    return {}

Union

Bases: Field

Field that accepts any one of multiple fields.

Source: https://github.com/adamboche/python-marshmallow-union Each argument will be tried until one succeeds.

Parameters:

Name Type Description Default
fields list[Field]

The list of candidate fields to try.

required
reverse_serialize_candidates bool

Whether to try the candidates in reverse order when serializing.

False
Source code in griptape/schemas/union_field.py
class Union(marshmallow.fields.Field):
    """Field that accepts any one of multiple fields.

    Source: https://github.com/adamboche/python-marshmallow-union
    Each argument will be tried until one succeeds.

    Args:
        fields: The list of candidate fields to try.
        reverse_serialize_candidates: Whether to try the candidates in reverse order when serializing.
    """

    def __init__(
        self,
        fields: list[marshmallow.fields.Field],
        *,
        reverse_serialize_candidates: bool = False,
        **kwargs: Any,
    ) -> None:
        self._candidate_fields = fields
        self._reverse_serialize_candidates = reverse_serialize_candidates
        super().__init__(**kwargs)

    def _serialize(self, value: Any, attr: str | None, obj: str, **kwargs: Any) -> Any:
        """Pulls the value for the given key from the object, applies the field's formatting and returns the result.

        Args:
            value: The value to be serialized.
            attr: The attribute or key to get from the object.
            obj: The object to pull the key from.
            kwargs: Field-specific keyword arguments.

        Raises:
            marshmallow.exceptions.ValidationError: In case of formatting problem
        """
        error_store = kwargs.pop("error_store", marshmallow.error_store.ErrorStore())
        fields = (
            list(reversed(self._candidate_fields)) if self._reverse_serialize_candidates else self._candidate_fields
        )

        for candidate_field in fields:
            try:
                # pylint: disable=protected-access
                return candidate_field._serialize(value, attr, obj, error_store=error_store, **kwargs)
            except (TypeError, ValueError) as e:
                error_store.store_error({attr: str(e)})

        raise ExceptionGroupError("All serializers raised exceptions.", error_store.errors)

    def _deserialize(self, value: Any, attr: str | None = None, data: Any = None, **kwargs: Any) -> Any:
        """Deserialize ``value``.

        Args:
            value: The value to be deserialized.
            attr: The attribute/key in `data` to be deserialized.
            data: The raw input data passed to the `Schema.load`.
            kwargs: Field-specific keyword arguments.

        Raises:
            ValidationError: If an invalid value is passed or if a required value is missing.
        """
        errors = []
        for candidate_field in self._candidate_fields:
            try:
                return candidate_field.deserialize(value, attr, data, **kwargs)
            except marshmallow.exceptions.ValidationError as exc:
                errors.append(exc.messages)

        raise marshmallow.exceptions.ValidationError(message=errors, field_name=attr or "")

__init__(fields, *, reverse_serialize_candidates=False, **kwargs)

Source code in griptape/schemas/union_field.py
def __init__(
    self,
    fields: list[marshmallow.fields.Field],
    *,
    reverse_serialize_candidates: bool = False,
    **kwargs: Any,
) -> None:
    self._candidate_fields = fields
    self._reverse_serialize_candidates = reverse_serialize_candidates
    super().__init__(**kwargs)