Skip to content

schemas

__all__ = ['BaseSchema', 'PolymorphicSchema', 'Bytes', 'Union'] module-attribute

BaseSchema

Bases: Schema

Source code in griptape/schemas/base_schema.py
class BaseSchema(Schema):
    class Meta:
        unknown = INCLUDE

    DATACLASS_TYPE_MAPPING = {**Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw}

    @classmethod
    def from_attrs_cls(cls, attrs_cls: type) -> type:
        """Generate a Schema from an attrs class.

        Args:
            attrs_cls: An attrs class.
        """
        from marshmallow import post_load

        from griptape.mixins.serializable_mixin import SerializableMixin

        class SubSchema(cls):
            @post_load
            def make_obj(self, data: Any, **kwargs) -> Any:
                # Map the serialized keys to their correct deserialization keys
                fields = attrs.fields_dict(attrs_cls)
                for key in list(data):
                    if key in fields:
                        field = fields[key]
                        if field.metadata.get("deserialization_key"):
                            data[field.metadata["deserialization_key"]] = data.pop(key)

                return attrs_cls(**data)

        if issubclass(attrs_cls, SerializableMixin):
            cls._resolve_types(attrs_cls)
            return SubSchema.from_dict(
                {
                    a.alias or a.name: cls._get_field_for_type(
                        a.type, serialization_key=a.metadata.get("serialization_key")
                    )
                    for a in attrs.fields(attrs_cls)
                    if a.metadata.get("serializable")
                },
                name=f"{attrs_cls.__name__}Schema",
            )
        else:
            raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")

    @classmethod
    def _get_field_for_type(
        cls, field_type: type, serialization_key: Optional[str] = None
    ) -> fields.Field | fields.Nested:
        """Generate a marshmallow Field instance from a Python type.

        Args:
            field_type: A field type.
            serialization_key: The key to pull the data from before serializing.
        """
        from griptape.schemas.polymorphic_schema import PolymorphicSchema

        field_class, args, optional = cls._get_field_type_info(field_type)

        if field_class is None:
            return fields.Constant(None, allow_none=True)

        # Resolve TypeVars to their bound type
        if isinstance(field_class, TypeVar):
            field_class = field_class.__bound__
            if field_class is None:
                return fields.Raw(allow_none=optional, attribute=serialization_key)

        if cls._is_union(field_type):
            return cls._handle_union(
                field_type,
                optional=optional,
                serialization_key=serialization_key,
            )
        elif attrs.has(field_class):
            schema = PolymorphicSchema if ABC in field_class.__bases__ else cls.from_attrs_cls
            return fields.Nested(schema(field_class), allow_none=optional, attribute=serialization_key)
        elif cls._is_enum(field_type):
            return fields.String(allow_none=optional, attribute=serialization_key)
        elif cls._is_list_sequence(field_class):
            if args:
                return cls._handle_list(
                    args[0],
                    optional=optional,
                    serialization_key=serialization_key,
                )
            else:
                raise ValueError(f"Missing type for list field: {field_type}")
        field_class = cls.DATACLASS_TYPE_MAPPING.get(field_class, fields.Raw)

        return field_class(allow_none=optional, attribute=serialization_key)

    @classmethod
    def _handle_list(
        cls,
        list_type: type,
        *,
        optional: bool,
        serialization_key: Optional[str] = None,
    ) -> fields.Field:
        """Handle List Fields, including Union Types.

        Args:
            list_type: The List type to handle.
            optional: Whether the List can be none.
            serialization_key: The key to pull the data from before serializing.

        Returns:
            A marshmallow List field.
        """
        if cls._is_union(list_type):
            instance = cls._handle_union(
                list_type,
                optional=optional,
                serialization_key=serialization_key,
            )
        else:
            instance = cls._get_field_for_type(list_type, serialization_key=serialization_key)
        return fields.List(cls_or_instance=instance, allow_none=optional, attribute=serialization_key)

    @classmethod
    def _handle_union(
        cls,
        union_type: type,
        *,
        optional: bool,
        serialization_key: Optional[str] = None,
    ) -> fields.Field:
        """Handle Union Fields, including Unions with List Types.

        Args:
            union_type: The Union Type to handle.
            optional: Whether the Union can be None.
            serialization_key: The key to pull the data from before serializing.

        Returns:
            A marshmallow Union field.
        """
        candidate_fields = [cls._get_field_for_type(arg) for arg in get_args(union_type) if arg is not type(None)]
        optional_args = [arg is None for arg in get_args(union_type)]
        if optional_args:
            optional = True
        if not candidate_fields:
            raise ValueError(f"Unsupported UnionType field: {union_type}")

        return UnionField(fields=candidate_fields, allow_none=optional, attribute=serialization_key)

    @classmethod
    def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]:
        """Get information about a field type.

        Args:
            field_type: A field type.
        """
        origin = get_origin(field_type) or field_type
        args = get_args(field_type)
        optional = False

        if origin is Union:
            origin = args[0]
            if len(args) > 1 and args[1] is type(None):
                optional = True

            origin, args, _ = cls._get_field_type_info(origin)
        elif origin is Literal:
            origin = type(args[0])
            args = ()

        return origin, args, optional

    @classmethod
    def _resolve_types(cls, attrs_cls: type) -> None:
        """Resolve types in an attrs class.

        Args:
            attrs_cls: An attrs class.
        """
        from collections.abc import Sequence
        from typing import Any

        from pydantic import BaseModel
        from schema import Schema

        from griptape.artifacts import (
            ActionArtifact,
            AudioArtifact,
            BaseArtifact,
            BlobArtifact,
            BooleanArtifact,
            ErrorArtifact,
            GenericArtifact,
            ImageArtifact,
            InfoArtifact,
            JsonArtifact,
            ListArtifact,
            TextArtifact,
        )
        from griptape.common import (
            BaseDeltaMessageContent,
            BaseMessageContent,
            Message,
            PromptStack,
            Reference,
            ToolAction,
        )
        from griptape.drivers.audio_transcription import BaseAudioTranscriptionDriver
        from griptape.drivers.embedding import BaseEmbeddingDriver
        from griptape.drivers.image_generation import BaseImageGenerationDriver, BaseMultiModelImageGenerationDriver
        from griptape.drivers.image_generation_model import BaseImageGenerationModelDriver
        from griptape.drivers.memory.conversation import BaseConversationMemoryDriver
        from griptape.drivers.prompt import BasePromptDriver
        from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
        from griptape.drivers.ruleset import BaseRulesetDriver
        from griptape.drivers.text_to_speech import BaseTextToSpeechDriver
        from griptape.drivers.vector import BaseVectorStoreDriver
        from griptape.engines.rag import RagContext
        from griptape.events import EventListener
        from griptape.memory import TaskMemory
        from griptape.memory.structure import BaseConversationMemory, Run
        from griptape.memory.task.storage import BaseArtifactStorage
        from griptape.rules.base_rule import BaseRule
        from griptape.rules.ruleset import Ruleset
        from griptape.structures import Structure
        from griptape.tasks import BaseTask
        from griptape.tokenizers import BaseTokenizer
        from griptape.tools import BaseTool
        from griptape.utils import import_optional_dependency, is_dependency_installed

        attrs.resolve_types(
            attrs_cls,
            localns={
                "Any": Any,
                "BasePromptDriver": BasePromptDriver,
                "BaseEmbeddingDriver": BaseEmbeddingDriver,
                "BaseVectorStoreDriver": BaseVectorStoreDriver,
                "BaseTextToSpeechDriver": BaseTextToSpeechDriver,
                "BaseAudioTranscriptionDriver": BaseAudioTranscriptionDriver,
                "BaseConversationMemoryDriver": BaseConversationMemoryDriver,
                "BaseRulesetDriver": BaseRulesetDriver,
                "BaseImageGenerationDriver": BaseImageGenerationDriver,
                "BaseMultiModelImageGenerationDriver": BaseMultiModelImageGenerationDriver,
                "BaseImageGenerationModelDriver": BaseImageGenerationModelDriver,
                "BaseArtifact": BaseArtifact,
                "PromptStack": PromptStack,
                "EventListener": EventListener,
                "BaseMessageContent": BaseMessageContent,
                "BaseDeltaMessageContent": BaseDeltaMessageContent,
                "BaseTool": BaseTool,
                "BaseTask": BaseTask,
                "TextArtifact": TextArtifact,
                "ImageArtifact": ImageArtifact,
                "ErrorArtifact": ErrorArtifact,
                "InfoArtifact": InfoArtifact,
                "JsonArtifact": JsonArtifact,
                "BlobArtifact": BlobArtifact,
                "BooleanArtifact": BooleanArtifact,
                "ListArtifact": ListArtifact,
                "AudioArtifact": AudioArtifact,
                "ActionArtifact": ActionArtifact,
                "GenericArtifact": GenericArtifact,
                "Usage": Message.Usage,
                "Structure": Structure,
                "BaseTokenizer": BaseTokenizer,
                "ToolAction": ToolAction,
                "Reference": Reference,
                "Run": Run,
                "Sequence": Sequence,
                "TaskMemory": TaskMemory,
                "State": BaseTask.State,
                "BaseConversationMemory": BaseConversationMemory,
                "BaseArtifactStorage": BaseArtifactStorage,
                "BaseRule": BaseRule,
                "Ruleset": Ruleset,
                "StructuredOutputStrategy": StructuredOutputStrategy,
                "RagContext": RagContext,
                # Third party modules
                "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
                "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any,
                "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
                if is_dependency_installed("google.generativeai")
                else Any,
                "boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any,
                "Anthropic": import_optional_dependency("anthropic").Anthropic
                if is_dependency_installed("anthropic")
                else Any,
                "BedrockClient": import_optional_dependency("mypy_boto3_bedrock").BedrockClient
                if is_dependency_installed("mypy_boto3_bedrock")
                else Any,
                "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any,
                "Schema": Schema,
                "BaseModel": BaseModel,
            },
        )

    @classmethod
    def _is_list_sequence(cls, field_type: type | _SpecialForm) -> bool:
        if isinstance(field_type, type):
            if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple):
                return False
            else:
                return issubclass(field_type, Sequence)
        else:
            return False

    @classmethod
    def _is_union(cls, field_type: type) -> bool:
        return field_type is Union or get_origin(field_type) is Union

    @classmethod
    def _is_enum(cls, field_type: type) -> bool:
        return isinstance(field_type, type) and issubclass(field_type, Enum)

DATACLASS_TYPE_MAPPING = {None: Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw} class-attribute instance-attribute

Meta

Source code in griptape/schemas/base_schema.py
class Meta:
    unknown = INCLUDE
unknown = INCLUDE class-attribute instance-attribute

from_attrs_cls(attrs_cls) classmethod

Generate a Schema from an attrs class.

Parameters:

Name Type Description Default
attrs_cls type

An attrs class.

required
Source code in griptape/schemas/base_schema.py
@classmethod
def from_attrs_cls(cls, attrs_cls: type) -> type:
    """Generate a Schema from an attrs class.

    Args:
        attrs_cls: An attrs class.
    """
    from marshmallow import post_load

    from griptape.mixins.serializable_mixin import SerializableMixin

    class SubSchema(cls):
        @post_load
        def make_obj(self, data: Any, **kwargs) -> Any:
            # Map the serialized keys to their correct deserialization keys
            fields = attrs.fields_dict(attrs_cls)
            for key in list(data):
                if key in fields:
                    field = fields[key]
                    if field.metadata.get("deserialization_key"):
                        data[field.metadata["deserialization_key"]] = data.pop(key)

            return attrs_cls(**data)

    if issubclass(attrs_cls, SerializableMixin):
        cls._resolve_types(attrs_cls)
        return SubSchema.from_dict(
            {
                a.alias or a.name: cls._get_field_for_type(
                    a.type, serialization_key=a.metadata.get("serialization_key")
                )
                for a in attrs.fields(attrs_cls)
                if a.metadata.get("serializable")
            },
            name=f"{attrs_cls.__name__}Schema",
        )
    else:
        raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")

Bytes

Bases: Field

Source code in griptape/schemas/bytes_field.py
class Bytes(fields.Field):
    def _serialize(self, value: Any, attr: Any, obj: Any, **kwargs) -> str:
        return base64.b64encode(value).decode()

    def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs) -> bytes:
        return base64.b64decode(value)

    def _validate(self, value: Any) -> None:
        if not isinstance(value, bytes):
            raise ValidationError("Invalid input type.")

PolymorphicSchema

Bases: BaseSchema

PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema.

Source code in griptape/schemas/polymorphic_schema.py
class PolymorphicSchema(BaseSchema):
    """PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema."""

    def __init__(self, inner_class: Any, **kwargs) -> None:
        super().__init__(**kwargs)

        self.inner_class = inner_class

    type_field = "type"
    type_field_remove = True

    def get_obj_type(self, obj: Any) -> Any:
        """Returns name of the schema during dump() calls, given the object being dumped."""
        return obj.__class__.__name__

    def get_data_type(self, data: Any) -> Any:
        """Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up `type_field` in the data."""
        data_type = data.get(self.type_field)
        if self.type_field in data and self.type_field_remove:
            data.pop(self.type_field)
        return data_type

    def dump(self, obj: Any, *, many: Any = None, **kwargs) -> Any:
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if not many:
            result = result_data = self._dump(obj, **kwargs)
        else:
            for idx, o in enumerate(obj):
                try:
                    result = self._dump(o, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()
                    result_data.append(error.valid_data)

        result = result_data
        errors = result_errors

        if not errors:
            return result
        else:
            exc = ValidationError(errors, data=obj, valid_data=result)  # pyright: ignore[reportArgumentType]
            raise exc

    def _dump(self, obj: Any, *, update_fields: bool = True, **kwargs) -> Any:
        obj_type = self.get_obj_type(obj)

        if not obj_type:
            return (None, {"_schema": f"Unknown object class: {obj.__class__.__name__}"})

        type_schema = BaseSchema.from_attrs_cls(obj.__class__)

        if not type_schema:
            return None, {"_schema": f"Unsupported object type: {obj_type}"}

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        result = schema.dump(obj, many=False, **kwargs)

        if result is not None:
            result[self.type_field] = obj_type  # pyright: ignore[reportArgumentType,reportCallIssue]

        return result

    def load(self, data: Any, *, many: Any = None, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if partial is None:
            partial = self.partial
        if not many:
            try:
                result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs)
            except ValidationError as error:
                result_errors = error.normalized_messages()
                result_data.append(error.valid_data)
        else:
            for idx, item in enumerate(data):
                try:
                    result = self._load(item, partial=partial, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()
                    result_data.append(error.valid_data)

        result = result_data
        errors = result_errors

        if not errors:
            return result
        else:
            exc = ValidationError(errors, data=data, valid_data=result)
            raise exc

    def _load(self, data: Any, *, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
        if not isinstance(data, dict):
            raise ValidationError({"_schema": f"Invalid data type: {data}"})

        data = dict(data)
        unknown = unknown or self.unknown
        data_type = self.get_data_type(data)

        if data_type is None:
            raise ValidationError({self.type_field: ["Missing data for required field."]})

        type_schema = self.inner_class.get_schema(data_type, module_name=data.get("module_name"))
        if not type_schema:
            raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]})

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)

    def validate(self, data: Any, *, many: Any = None, partial: Any = None) -> Any:  # pyright: ignore[reportIncompatibleMethodOverride]
        try:
            self.load(data, many=many, partial=partial)
        except ValidationError as ve:
            return ve.messages
        return {}

inner_class = inner_class instance-attribute

type_field = 'type' class-attribute instance-attribute

type_field_remove = True class-attribute instance-attribute

__init__(inner_class, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def __init__(self, inner_class: Any, **kwargs) -> None:
    super().__init__(**kwargs)

    self.inner_class = inner_class

dump(obj, *, many=None, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def dump(self, obj: Any, *, many: Any = None, **kwargs) -> Any:
    errors = {}
    result_data = []
    result_errors = {}
    many = self.many if many is None else bool(many)
    if not many:
        result = result_data = self._dump(obj, **kwargs)
    else:
        for idx, o in enumerate(obj):
            try:
                result = self._dump(o, **kwargs)
                result_data.append(result)
            except ValidationError as error:
                result_errors[idx] = error.normalized_messages()
                result_data.append(error.valid_data)

    result = result_data
    errors = result_errors

    if not errors:
        return result
    else:
        exc = ValidationError(errors, data=obj, valid_data=result)  # pyright: ignore[reportArgumentType]
        raise exc

get_data_type(data)

Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up type_field in the data.

Source code in griptape/schemas/polymorphic_schema.py
def get_data_type(self, data: Any) -> Any:
    """Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up `type_field` in the data."""
    data_type = data.get(self.type_field)
    if self.type_field in data and self.type_field_remove:
        data.pop(self.type_field)
    return data_type

get_obj_type(obj)

Returns name of the schema during dump() calls, given the object being dumped.

Source code in griptape/schemas/polymorphic_schema.py
def get_obj_type(self, obj: Any) -> Any:
    """Returns name of the schema during dump() calls, given the object being dumped."""
    return obj.__class__.__name__

load(data, *, many=None, partial=None, unknown=None, **kwargs)

Source code in griptape/schemas/polymorphic_schema.py
def load(self, data: Any, *, many: Any = None, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
    errors = {}
    result_data = []
    result_errors = {}
    many = self.many if many is None else bool(many)
    if partial is None:
        partial = self.partial
    if not many:
        try:
            result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs)
        except ValidationError as error:
            result_errors = error.normalized_messages()
            result_data.append(error.valid_data)
    else:
        for idx, item in enumerate(data):
            try:
                result = self._load(item, partial=partial, **kwargs)
                result_data.append(result)
            except ValidationError as error:
                result_errors[idx] = error.normalized_messages()
                result_data.append(error.valid_data)

    result = result_data
    errors = result_errors

    if not errors:
        return result
    else:
        exc = ValidationError(errors, data=data, valid_data=result)
        raise exc

validate(data, *, many=None, partial=None)

Source code in griptape/schemas/polymorphic_schema.py
def validate(self, data: Any, *, many: Any = None, partial: Any = None) -> Any:  # pyright: ignore[reportIncompatibleMethodOverride]
    try:
        self.load(data, many=many, partial=partial)
    except ValidationError as ve:
        return ve.messages
    return {}

Union

Bases: Field

Field that accepts any one of multiple fields.

Source: https://github.com/adamboche/python-marshmallow-union Each argument will be tried until one succeeds.

Parameters:

Name Type Description Default
fields list[Field]

The list of candidate fields to try.

required
reverse_serialize_candidates bool

Whether to try the candidates in reverse order when serializing.

False
Source code in griptape/schemas/union_field.py
class Union(marshmallow.fields.Field):
    """Field that accepts any one of multiple fields.

    Source: https://github.com/adamboche/python-marshmallow-union
    Each argument will be tried until one succeeds.

    Args:
        fields: The list of candidate fields to try.
        reverse_serialize_candidates: Whether to try the candidates in reverse order when serializing.
    """

    def __init__(
        self,
        fields: list[marshmallow.fields.Field],
        *,
        reverse_serialize_candidates: bool = False,
        **kwargs: Any,
    ) -> None:
        self._candidate_fields = fields
        self._reverse_serialize_candidates = reverse_serialize_candidates
        super().__init__(**kwargs)

    def _serialize(self, value: Any, attr: str | None, obj: str, **kwargs: Any) -> Any:
        """Pulls the value for the given key from the object, applies the field's formatting and returns the result.

        Args:
            value: The value to be serialized.
            attr: The attribute or key to get from the object.
            obj: The object to pull the key from.
            kwargs: Field-specific keyword arguments.

        Raises:
            marshmallow.exceptions.ValidationError: In case of formatting problem
        """
        error_store = kwargs.pop("error_store", marshmallow.error_store.ErrorStore())
        fields = (
            list(reversed(self._candidate_fields)) if self._reverse_serialize_candidates else self._candidate_fields
        )

        for candidate_field in fields:
            try:
                # pylint: disable=protected-access
                return candidate_field._serialize(value, attr, obj, error_store=error_store, **kwargs)
            except (TypeError, ValueError) as e:
                error_store.store_error({attr: str(e)})

        raise ExceptionGroupError("All serializers raised exceptions.", error_store.errors)

    def _deserialize(self, value: Any, attr: str | None = None, data: Any = None, **kwargs: Any) -> Any:
        """Deserialize ``value``.

        Args:
            value: The value to be deserialized.
            attr: The attribute/key in `data` to be deserialized.
            data: The raw input data passed to the `Schema.load`.
            kwargs: Field-specific keyword arguments.

        Raises:
            ValidationError: If an invalid value is passed or if a required value is missing.
        """
        errors = []
        for candidate_field in self._candidate_fields:
            try:
                return candidate_field.deserialize(value, attr, data, **kwargs)
            except marshmallow.exceptions.ValidationError as exc:
                errors.append(exc.messages)

        raise marshmallow.exceptions.ValidationError(message=errors, field_name=attr or "")

__init__(fields, *, reverse_serialize_candidates=False, **kwargs)

Source code in griptape/schemas/union_field.py
def __init__(
    self,
    fields: list[marshmallow.fields.Field],
    *,
    reverse_serialize_candidates: bool = False,
    **kwargs: Any,
) -> None:
    self._candidate_fields = fields
    self._reverse_serialize_candidates = reverse_serialize_candidates
    super().__init__(**kwargs)