Skip to content

Anthropic prompt driver

AnthropicPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key Optional[str]

Anthropic API key.

model str

Anthropic model name.

client Any

Custom Anthropic client.

tokenizer AnthropicTokenizer

Custom AnthropicTokenizer.

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@define
class AnthropicPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
        tokenizer: Custom `AnthropicTokenizer`.
    """

    api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Any = field(
        default=Factory(
            lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )
    tokenizer: AnthropicTokenizer = field(
        default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
    top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        response = self.client.messages.create(**self._base_params(prompt_stack))

        return TextArtifact(value=response.content[0].text)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        response = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

        for chunk in response:
            if chunk.type == "content_block_delta":
                yield TextArtifact(value=chunk.delta.text)

    def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        messages = [
            {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content}
            for prompt_input in prompt_stack.inputs
            if not prompt_input.is_system()
        ]
        system = next((i for i in prompt_stack.inputs if i.is_system()), None)

        if system is None:
            return {"messages": messages}
        else:
            return {"messages": messages, "system": system.content}

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "model": self.model,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_output_tokens(self.prompt_stack_to_string(prompt_stack)),
            "top_p": self.top_p,
            "top_k": self.top_k,
            **self._prompt_stack_to_model_input(prompt_stack),
        }

    def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
        if prompt_input.is_system():
            return "system"
        elif prompt_input.is_assistant():
            return "assistant"
        else:
            return "user"

api_key: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

client: Any = field(default=Factory(lambda self: import_optional_dependency('anthropic').Anthropic(api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: AnthropicTokenizer = field(default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

top_k: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p: float = field(default=0.999, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_anthropic_role(prompt_input)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
    if prompt_input.is_system():
        return "system"
    elif prompt_input.is_assistant():
        return "assistant"
    else:
        return "user"

try_run(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    response = self.client.messages.create(**self._base_params(prompt_stack))

    return TextArtifact(value=response.content[0].text)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    response = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

    for chunk in response:
        if chunk.type == "content_block_delta":
            yield TextArtifact(value=chunk.delta.text)