Skip to content

base_schema

BaseSchema

Bases: Schema

Source code in griptape/schemas/base_schema.py
class BaseSchema(Schema):
    class Meta:
        unknown = INCLUDE

    DATACLASS_TYPE_MAPPING = {**Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw}

    @classmethod
    def from_attrs_cls(cls, attrs_cls: type) -> type:
        """Generate a Schema from an attrs class.

        Args:
            attrs_cls: An attrs class.
        """
        from marshmallow import post_load

        from griptape.mixins.serializable_mixin import SerializableMixin

        class SubSchema(cls):
            @post_load
            def make_obj(self, data: Any, **kwargs) -> Any:
                return attrs_cls(**data)

        if issubclass(attrs_cls, SerializableMixin):
            cls._resolve_types(attrs_cls)
            return SubSchema.from_dict(
                {
                    a.alias or a.name: cls._get_field_for_type(a.type)
                    for a in attrs.fields(attrs_cls)
                    if a.metadata.get("serializable")
                },
                name=f"{attrs_cls.__name__}Schema",
            )
        else:
            raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")

    @classmethod
    def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested:
        """Generate a marshmallow Field instance from a Python type.

        Args:
            field_type: A field type.
        """
        from griptape.schemas.polymorphic_schema import PolymorphicSchema

        field_class, args, optional = cls._get_field_type_info(field_type)

        # Resolve TypeVars to their bound type
        if isinstance(field_class, TypeVar):
            field_class = field_class.__bound__
        if field_class is None:
            return fields.Constant(None, allow_none=True)
        if cls._is_union(field_type):
            return cls._handle_union(field_type, optional=optional)
        elif attrs.has(field_class):
            schema = PolymorphicSchema if ABC in field_class.__bases__ else cls.from_attrs_cls
            return fields.Nested(schema(field_class), allow_none=optional)
        elif cls._is_enum(field_type):
            return fields.String(allow_none=optional)
        elif cls._is_list_sequence(field_class):
            if args:
                return cls._handle_list(args[0], optional=optional)
            else:
                raise ValueError(f"Missing type for list field: {field_type}")
        field_class = cls.DATACLASS_TYPE_MAPPING.get(field_class, fields.Raw)

        return field_class(allow_none=optional)

    @classmethod
    def _handle_list(cls, list_type: type, *, optional: bool) -> fields.Field:
        """Handle List Fields, including Union Types.

        Args:
            list_type: The List type to handle.
            optional: Whether the List can be none.

        Returns:
            A marshmallow List field.
        """
        if cls._is_union(list_type):
            union_field = cls._handle_union(list_type, optional=optional)
            return fields.List(cls_or_instance=union_field, allow_none=optional)
        list_field = cls._get_field_for_type(list_type)
        if isinstance(list_field, fields.Constant) and list_field.constant is None:
            raise ValueError(f"List elements cannot be None: {list_type}")
        return fields.List(cls_or_instance=list_field, allow_none=optional)

    @classmethod
    def _handle_union(cls, union_type: type, *, optional: bool) -> fields.Field:
        """Handle Union Fields, including Unions with List Types.

        Args:
            union_type: The Union Type to handle.
            optional: Whether the Union can be None.

        Returns:
            A marshmallow Union field.
        """
        candidate_fields = [cls._get_field_for_type(arg) for arg in get_args(union_type) if arg is not type(None)]
        optional_args = [arg is None for arg in get_args(union_type)]
        if optional_args:
            optional = True
        if not candidate_fields:
            raise ValueError(f"Unsupported UnionType field: {union_type}")

        return UnionField(fields=candidate_fields, allow_none=optional)

    @classmethod
    def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]:
        """Get information about a field type.

        Args:
            field_type: A field type.
        """
        origin = get_origin(field_type) or field_type
        args = get_args(field_type)
        optional = False

        if origin is Union:
            origin = args[0]
            if len(args) > 1 and args[1] is type(None):
                optional = True

            origin, args, _ = cls._get_field_type_info(origin)
        elif origin is Literal:
            origin = type(args[0])
            args = ()

        return origin, args, optional

    @classmethod
    def _resolve_types(cls, attrs_cls: type) -> None:
        """Resolve types in an attrs class.

        Args:
            attrs_cls: An attrs class.
        """
        from collections.abc import Sequence
        from typing import Any

        from griptape.artifacts import BaseArtifact
        from griptape.common import (
            BaseDeltaMessageContent,
            BaseMessageContent,
            Message,
            PromptStack,
            Reference,
            ToolAction,
        )
        from griptape.drivers import (
            BaseAudioTranscriptionDriver,
            BaseConversationMemoryDriver,
            BaseEmbeddingDriver,
            BaseImageGenerationDriver,
            BasePromptDriver,
            BaseRulesetDriver,
            BaseTextToSpeechDriver,
            BaseVectorStoreDriver,
        )
        from griptape.events import EventListener
        from griptape.memory import TaskMemory
        from griptape.memory.structure import BaseConversationMemory, Run
        from griptape.memory.task.storage import BaseArtifactStorage
        from griptape.structures import Structure
        from griptape.tasks import BaseTask
        from griptape.tokenizers import BaseTokenizer
        from griptape.tools import BaseTool
        from griptape.utils import import_optional_dependency, is_dependency_installed

        attrs.resolve_types(
            attrs_cls,
            localns={
                "Any": Any,
                "BasePromptDriver": BasePromptDriver,
                "BaseEmbeddingDriver": BaseEmbeddingDriver,
                "BaseVectorStoreDriver": BaseVectorStoreDriver,
                "BaseTextToSpeechDriver": BaseTextToSpeechDriver,
                "BaseAudioTranscriptionDriver": BaseAudioTranscriptionDriver,
                "BaseConversationMemoryDriver": BaseConversationMemoryDriver,
                "BaseRulesetDriver": BaseRulesetDriver,
                "BaseImageGenerationDriver": BaseImageGenerationDriver,
                "BaseArtifact": BaseArtifact,
                "PromptStack": PromptStack,
                "EventListener": EventListener,
                "BaseMessageContent": BaseMessageContent,
                "BaseDeltaMessageContent": BaseDeltaMessageContent,
                "BaseTool": BaseTool,
                "BaseTask": BaseTask,
                "Usage": Message.Usage,
                "Structure": Structure,
                "BaseTokenizer": BaseTokenizer,
                "ToolAction": ToolAction,
                "Reference": Reference,
                "Run": Run,
                "Sequence": Sequence,
                "TaskMemory": TaskMemory,
                "State": BaseTask.State,
                "BaseConversationMemory": BaseConversationMemory,
                "BaseArtifactStorage": BaseArtifactStorage,
                # Third party modules
                "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
                "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
                if is_dependency_installed("google.generativeai")
                else Any,
                "boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any,
                "Anthropic": import_optional_dependency("anthropic").Anthropic
                if is_dependency_installed("anthropic")
                else Any,
                "BedrockClient": import_optional_dependency("mypy_boto3_bedrock").BedrockClient
                if is_dependency_installed("mypy_boto3_bedrock")
                else Any,
                "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any,
            },
        )

    @classmethod
    def _is_list_sequence(cls, field_type: type | _SpecialForm) -> bool:
        if isinstance(field_type, type):
            if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple):
                return False
            else:
                return issubclass(field_type, Sequence)
        else:
            return False

    @classmethod
    def _is_union(cls, field_type: type) -> bool:
        return field_type is Union or get_origin(field_type) is Union

    @classmethod
    def _is_enum(cls, field_type: type) -> bool:
        return isinstance(field_type, type) and issubclass(field_type, Enum)

DATACLASS_TYPE_MAPPING = {None: Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw} class-attribute instance-attribute

Meta

Source code in griptape/schemas/base_schema.py
class Meta:
    unknown = INCLUDE
unknown = INCLUDE class-attribute instance-attribute

from_attrs_cls(attrs_cls) classmethod

Generate a Schema from an attrs class.

Parameters:

Name Type Description Default
attrs_cls type

An attrs class.

required
Source code in griptape/schemas/base_schema.py
@classmethod
def from_attrs_cls(cls, attrs_cls: type) -> type:
    """Generate a Schema from an attrs class.

    Args:
        attrs_cls: An attrs class.
    """
    from marshmallow import post_load

    from griptape.mixins.serializable_mixin import SerializableMixin

    class SubSchema(cls):
        @post_load
        def make_obj(self, data: Any, **kwargs) -> Any:
            return attrs_cls(**data)

    if issubclass(attrs_cls, SerializableMixin):
        cls._resolve_types(attrs_cls)
        return SubSchema.from_dict(
            {
                a.alias or a.name: cls._get_field_for_type(a.type)
                for a in attrs.fields(attrs_cls)
                if a.metadata.get("serializable")
            },
            name=f"{attrs_cls.__name__}Schema",
        )
    else:
        raise ValueError(f"Class must implement SerializableMixin: {attrs_cls}")