Skip to content

anthropic

__all__ = ['AnthropicPromptDriver'] module-attribute

AnthropicPromptDriver

Bases: BasePromptDriver

Anthropic Prompt Driver.

Attributes:

Name Type Description
api_key Optional[str]

Anthropic API key.

model str

Anthropic model name.

client Client

Custom Anthropic client.

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@define
class AnthropicPromptDriver(BasePromptDriver):
    """Anthropic Prompt Driver.

    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
    """

    api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
    top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
    tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="tool", kw_only=True, metadata={"serializable": True}
    )
    max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
    _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Client:
        return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value == "native":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        params = self._base_params(prompt_stack)
        logger.debug(params)
        response = self.client.messages.create(**params)

        logger.debug(response.model_dump())

        return Message(
            content=[self.__to_prompt_stack_message_content(content) for content in response.content],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        params = {**self._base_params(prompt_stack), "stream": True}
        logger.debug(params)
        events = self.client.messages.create(**params)

        for event in events:
            logger.debug(event)
            if event.type == "content_block_delta" or event.type == "content_block_start":
                yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
            elif event.type == "message_start":
                yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens))
            elif event.type == "message_delta":
                yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens))

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        messages = self.__to_anthropic_messages([i for i in prompt_stack.messages if not i.is_system()])

        system_messages = prompt_stack.system_messages
        system_message = system_messages[0].to_text() if system_messages else None

        params = {
            "model": self.model,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "max_tokens": self.max_tokens,
            "messages": messages,
            **({"system": system_message} if system_message else {}),
            **self.extra_params,
        }

        if prompt_stack.tools and self.use_native_tools:
            params["tool_choice"] = self.tool_choice

            if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
                params["tool_choice"] = {"type": "any"}

            params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)

        return params

    def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
        return [
            {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
            for message in messages
        ]

    def __to_anthropic_role(self, message: Message) -> str:
        if message.is_assistant():
            return "assistant"
        else:
            return "user"

    def __to_anthropic_tools(self, tools: list[BaseTool]) -> list[dict]:
        tool_schemas = [
            {
                "name": tool.to_native_tool_name(activity),
                "description": tool.activity_description(activity),
                "input_schema": tool.to_activity_json_schema(activity, "Input Schema"),
            }
            for tool in tools
            for activity in tool.activities()
        ]

        # Anthropic doesn't support $schema and $id
        for tool_schema in tool_schemas:
            del tool_schema["input_schema"]["$schema"]
            del tool_schema["input_schema"]["$id"]

        return tool_schemas

    def __to_anthropic_content(self, message: Message) -> str | list[dict]:
        if message.has_all_content_type(TextMessageContent):
            return message.to_text()
        else:
            return [self.__to_anthropic_message_content(content) for content in message.content]

    def __to_anthropic_message_content(self, content: BaseMessageContent) -> dict:
        if isinstance(content, TextMessageContent):
            return {"type": "text", "text": content.artifact.value}
        elif isinstance(content, ImageMessageContent):
            artifact = content.artifact

            return {
                "type": "image",
                "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64},
            }
        elif isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return {"type": "tool_use", "id": action.tag, "name": action.to_native_tool_name(), "input": action.input}
        elif isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            if isinstance(artifact, ListArtifact):
                message_content = [self.__to_anthropic_tool_result_content(artifact) for artifact in artifact.value]
            else:
                message_content = [self.__to_anthropic_tool_result_content(artifact)]

            return {
                "type": "tool_result",
                "tool_use_id": content.action.tag,
                "content": message_content,
                "is_error": isinstance(artifact, ErrorArtifact),
            }
        else:
            return content.artifact.value

    def __to_anthropic_tool_result_content(self, artifact: BaseArtifact) -> dict:
        if isinstance(artifact, ImageArtifact):
            return {
                "type": "image",
                "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64},
            }
        elif isinstance(artifact, (TextArtifact, ErrorArtifact, InfoArtifact)):
            return {"type": "text", "text": artifact.to_text()}
        else:
            raise ValueError(f"Unsupported tool result artifact type: {type(artifact)}")

    def __to_prompt_stack_message_content(self, content: ContentBlock) -> BaseMessageContent:
        if content.type == "text":
            return TextMessageContent(TextArtifact(content.text))
        elif content.type == "tool_use":
            name, path = ToolAction.from_native_tool_name(content.name)

            return ActionCallMessageContent(
                artifact=ActionArtifact(
                    value=ToolAction(tag=content.id, name=name, path=path, input=content.input),  # pyright: ignore[reportArgumentType]
                ),
            )
        else:
            raise ValueError(f"Unsupported message content type: {content.type}")

    def __to_prompt_stack_delta_message_content(
        self,
        event: ContentBlockDeltaEvent | ContentBlockStartEvent,
    ) -> BaseDeltaMessageContent:
        if event.type == "content_block_start":
            content_block = event.content_block

            if content_block.type == "tool_use":
                name, path = ToolAction.from_native_tool_name(content_block.name)

                return ActionCallDeltaMessageContent(index=event.index, tag=content_block.id, name=name, path=path)
            elif content_block.type == "text":
                return TextDeltaMessageContent(content_block.text, index=event.index)
            else:
                raise ValueError(f"Unsupported content block type: {content_block.type}")
        elif event.type == "content_block_delta":
            content_block_delta = event.delta

            if content_block_delta.type == "text_delta":
                return TextDeltaMessageContent(content_block_delta.text, index=event.index)
            elif content_block_delta.type == "input_json_delta":
                return ActionCallDeltaMessageContent(index=event.index, partial_input=content_block_delta.partial_json)
            else:
                raise ValueError(f"Unsupported message content type: {event}")
        else:
            raise ValueError(f"Unsupported message content type: {event}")

api_key: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

max_tokens: int = field(default=1000, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structured_output_strategy: StructuredOutputStrategy = field(default='tool', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

tool_choice: dict = field(default=Factory(lambda: {'type': 'auto'}), kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

top_k: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p: float = field(default=0.999, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

use_native_tools: bool = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_anthropic_content(message)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_content(self, message: Message) -> str | list[dict]:
    if message.has_all_content_type(TextMessageContent):
        return message.to_text()
    else:
        return [self.__to_anthropic_message_content(content) for content in message.content]

__to_anthropic_message_content(content)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_message_content(self, content: BaseMessageContent) -> dict:
    if isinstance(content, TextMessageContent):
        return {"type": "text", "text": content.artifact.value}
    elif isinstance(content, ImageMessageContent):
        artifact = content.artifact

        return {
            "type": "image",
            "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64},
        }
    elif isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return {"type": "tool_use", "id": action.tag, "name": action.to_native_tool_name(), "input": action.input}
    elif isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        if isinstance(artifact, ListArtifact):
            message_content = [self.__to_anthropic_tool_result_content(artifact) for artifact in artifact.value]
        else:
            message_content = [self.__to_anthropic_tool_result_content(artifact)]

        return {
            "type": "tool_result",
            "tool_use_id": content.action.tag,
            "content": message_content,
            "is_error": isinstance(artifact, ErrorArtifact),
        }
    else:
        return content.artifact.value

__to_anthropic_messages(messages)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
    return [
        {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
        for message in messages
    ]

__to_anthropic_role(message)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_role(self, message: Message) -> str:
    if message.is_assistant():
        return "assistant"
    else:
        return "user"

__to_anthropic_tool_result_content(artifact)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_tool_result_content(self, artifact: BaseArtifact) -> dict:
    if isinstance(artifact, ImageArtifact):
        return {
            "type": "image",
            "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64},
        }
    elif isinstance(artifact, (TextArtifact, ErrorArtifact, InfoArtifact)):
        return {"type": "text", "text": artifact.to_text()}
    else:
        raise ValueError(f"Unsupported tool result artifact type: {type(artifact)}")

__to_anthropic_tools(tools)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_tools(self, tools: list[BaseTool]) -> list[dict]:
    tool_schemas = [
        {
            "name": tool.to_native_tool_name(activity),
            "description": tool.activity_description(activity),
            "input_schema": tool.to_activity_json_schema(activity, "Input Schema"),
        }
        for tool in tools
        for activity in tool.activities()
    ]

    # Anthropic doesn't support $schema and $id
    for tool_schema in tool_schemas:
        del tool_schema["input_schema"]["$schema"]
        del tool_schema["input_schema"]["$id"]

    return tool_schemas

__to_prompt_stack_delta_message_content(event)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_prompt_stack_delta_message_content(
    self,
    event: ContentBlockDeltaEvent | ContentBlockStartEvent,
) -> BaseDeltaMessageContent:
    if event.type == "content_block_start":
        content_block = event.content_block

        if content_block.type == "tool_use":
            name, path = ToolAction.from_native_tool_name(content_block.name)

            return ActionCallDeltaMessageContent(index=event.index, tag=content_block.id, name=name, path=path)
        elif content_block.type == "text":
            return TextDeltaMessageContent(content_block.text, index=event.index)
        else:
            raise ValueError(f"Unsupported content block type: {content_block.type}")
    elif event.type == "content_block_delta":
        content_block_delta = event.delta

        if content_block_delta.type == "text_delta":
            return TextDeltaMessageContent(content_block_delta.text, index=event.index)
        elif content_block_delta.type == "input_json_delta":
            return ActionCallDeltaMessageContent(index=event.index, partial_input=content_block_delta.partial_json)
        else:
            raise ValueError(f"Unsupported message content type: {event}")
    else:
        raise ValueError(f"Unsupported message content type: {event}")

__to_prompt_stack_message_content(content)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_prompt_stack_message_content(self, content: ContentBlock) -> BaseMessageContent:
    if content.type == "text":
        return TextMessageContent(TextArtifact(content.text))
    elif content.type == "tool_use":
        name, path = ToolAction.from_native_tool_name(content.name)

        return ActionCallMessageContent(
            artifact=ActionArtifact(
                value=ToolAction(tag=content.id, name=name, path=path, input=content.input),  # pyright: ignore[reportArgumentType]
            ),
        )
    else:
        raise ValueError(f"Unsupported message content type: {content.type}")

client()

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@lazy_property()
def client(self) -> Client:
    return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

try_run(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    params = self._base_params(prompt_stack)
    logger.debug(params)
    response = self.client.messages.create(**params)

    logger.debug(response.model_dump())

    return Message(
        content=[self.__to_prompt_stack_message_content(content) for content in response.content],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
    )

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    params = {**self._base_params(prompt_stack), "stream": True}
    logger.debug(params)
    events = self.client.messages.create(**params)

    for event in events:
        logger.debug(event)
        if event.type == "content_block_delta" or event.type == "content_block_start":
            yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
        elif event.type == "message_start":
            yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens))
        elif event.type == "message_delta":
            yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens))

validate_structured_output_strategy(_, value)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value == "native":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value