Skip to content

google

__all__ = ['GooglePromptDriver'] module-attribute

GooglePromptDriver

Bases: BasePromptDriver

Google Prompt Driver.

Attributes:

Name Type Description
api_key str | None

Google API key.

model str

Google model name.

client Client

Custom google.genai.Client.

top_p float | None

Optional value for top_p.

top_k int | None

Optional value for top_k.

Source code in griptape/drivers/prompt/google_prompt_driver.py
@define
class GooglePromptDriver(BasePromptDriver):
    """Google Prompt Driver.

    Attributes:
        api_key: Google API key.
        model: Google model name.
        client: Custom `google.genai.Client`.
        top_p: Optional value for top_p.
        top_k: Optional value for top_k.
    """

    api_key: str | None = field(default=None, kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True),
        kw_only=True,
    )
    top_p: float | None = field(default=None, kw_only=True, metadata={"serializable": True})
    top_k: int | None = field(default=None, kw_only=True, metadata={"serializable": True})
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="tool", kw_only=True, metadata={"serializable": True}
    )
    tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
    _client: Client | None = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Client:
        genai = import_optional_dependency("google.genai")

        return genai.Client(api_key=self.api_key)

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        types = import_optional_dependency("google.genai.types")

        messages = self.__to_google_messages(prompt_stack)
        params = self._base_params(prompt_stack)
        config = types.GenerateContentConfig(**params)
        logger.debug((messages, params))
        response: GenerateContentResponse = self.client.models.generate_content(
            model=self.model,
            contents=cast("ContentListUnionDict", messages),
            config=config,
        )
        logger.debug(response.model_dump())

        usage_metadata = response.usage_metadata
        parts = response.candidates[0].content.parts if response.candidates and response.candidates[0].content else []

        return Message(
            content=[
                self.__to_prompt_stack_message_content(part)
                for part in (parts or [])
                if not self.__is_thought_part(part)
            ],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(
                input_tokens=usage_metadata.prompt_token_count if usage_metadata else None,
                output_tokens=usage_metadata.candidates_token_count if usage_metadata else None,
            ),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        types = import_optional_dependency("google.genai.types")

        messages = self.__to_google_messages(prompt_stack)
        params = self._base_params(prompt_stack)
        config = types.GenerateContentConfig(**params)
        logger.debug((messages, params))
        response = self.client.models.generate_content_stream(
            model=self.model,
            contents=cast("ContentListUnionDict", messages),
            config=config,
        )

        prompt_token_count = None
        for chunk in response:
            logger.debug(chunk.model_dump())
            usage_metadata = chunk.usage_metadata

            parts = chunk.candidates[0].content.parts if chunk.candidates and chunk.candidates[0].content else None
            # Gemini thinking models emit reasoning-only chunks (e.g. a bare `thought_signature`)
            # with no text or function_call; skip them since Griptape has no thought content type.
            non_thought_part = (
                next((part for part in parts if not self.__is_thought_part(part)), None) if parts else None
            )
            content = self.__to_prompt_stack_delta_message_content(non_thought_part) if non_thought_part else None

            # Only want to output the prompt token count once since it is static each chunk
            if prompt_token_count is None and usage_metadata is not None:
                prompt_token_count = usage_metadata.prompt_token_count
                yield DeltaMessage(
                    content=content,
                    usage=DeltaMessage.Usage(
                        input_tokens=usage_metadata.prompt_token_count,
                        output_tokens=usage_metadata.candidates_token_count,
                    ),
                )
            else:
                yield DeltaMessage(
                    content=content,
                    usage=DeltaMessage.Usage(
                        output_tokens=usage_metadata.candidates_token_count if usage_metadata else None,
                    ),
                )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        types = import_optional_dependency("google.genai.types")

        system_messages = prompt_stack.system_messages
        system_instruction = None
        if system_messages:
            system_instruction = types.Content(
                role="system",
                parts=[types.Part.from_text(text=system_message.to_text()) for system_message in system_messages],
            )

        params = {
            # For some reason, providing stop sequences when streaming breaks native functions
            # https://github.com/google-gemini/generative-ai-python/issues/446
            "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences,
            "max_output_tokens": self.max_tokens,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            **({"system_instruction": system_instruction} if system_instruction is not None else {}),
            **self.extra_params,
        }

        if prompt_stack.tools and self.use_native_tools:
            mode = self.tool_choice.upper()

            if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
                mode = "AUTO"

            params["tool_config"] = types.ToolConfig(
                function_calling_config=types.FunctionCallingConfig(mode=mode),
            )
            params["tools"] = self.__to_google_tools(prompt_stack.tools)

        if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
            params["response_mime_type"] = "application/json"
            params["response_json_schema"] = prompt_stack.to_output_json_schema()

        return params

    def __to_google_messages(self, prompt_stack: PromptStack) -> list[Content]:
        types = import_optional_dependency("google.genai.types")

        return [
            types.Content(
                role=self.__to_google_role(message),
                parts=[self.__to_google_message_content(content) for content in message.content],
            )
            for message in prompt_stack.messages
            if not message.is_system()
        ]

    def __to_google_role(self, message: Message) -> str:
        if message.is_assistant():
            return "model"
        return "user"

    def __to_google_tools(self, tools: list[BaseTool]) -> list:
        types = import_optional_dependency("google.genai.types")

        function_declarations = []
        for tool in tools:
            for activity in tool.activities():
                schema = tool.to_activity_json_schema(activity, "Parameters Schema")

                if "values" in schema["properties"]:
                    schema = schema["properties"]["values"]

                schema = remove_key_in_dict_recursively(schema, "additionalProperties")
                schema = remove_key_in_dict_recursively(schema, "title", preserve_under_key="properties")
                function_declaration = types.FunctionDeclaration(
                    name=tool.to_native_tool_name(activity),
                    description=tool.activity_description(activity),
                    **(
                        {
                            "parameters": {
                                "type": schema["type"],
                                "properties": schema["properties"],
                                "required": schema.get("required", []),
                            }
                        }
                        if schema.get("properties")
                        else {}
                    ),
                )

                function_declarations.append(function_declaration)

        return [types.Tool(function_declarations=function_declarations)]

    def __to_google_message_content(self, content: BaseMessageContent) -> Part:
        types = import_optional_dependency("google.genai.types")

        if isinstance(content, TextMessageContent):
            return types.Part.from_text(text=content.artifact.to_text())
        if isinstance(content, ImageMessageContent):
            if isinstance(content.artifact, ImageArtifact):
                return types.Part.from_bytes(data=content.artifact.value, mime_type=content.artifact.mime_type)
            # TODO: Google requires uploading to the files endpoint: https://ai.google.dev/gemini-api/docs/image-understanding#upload-image
            # Can be worked around by using GenericMessageContent, similar to videos.
            raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
        if isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return types.Part.from_function_call(name=action.tag, args=action.input)
        if isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            return types.Part.from_function_response(
                name=content.action.to_native_tool_name(),
                response=artifact.to_dict(),
            )
        if isinstance(content, GenericMessageContent):
            value = content.artifact.value
            if isinstance(value, str):
                return types.Part.from_text(text=value)
            return value
        raise ValueError(f"Unsupported prompt stack content type: {type(content)}")

    def __is_thought_part(self, content: Part) -> bool:
        # Gemini thinking models emit reasoning-only parts (e.g. a bare `thought_signature`)
        # with no text or function_call; Griptape has no thought content type, so callers skip them.
        return bool((content.thought or content.thought_signature) and not content.text and not content.function_call)

    def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent:
        if content.text:
            return TextMessageContent(TextArtifact(content.text))
        if content.function_call:
            function_call = content.function_call
            tag = function_call.name or ""

            name, path = ToolAction.from_native_tool_name(tag)

            args = function_call.args or {}
            return ActionCallMessageContent(
                artifact=ActionArtifact(value=ToolAction(tag=tag, name=name, path=path, input=args)),
            )
        raise ValueError(f"Unsupported message content type {content}")

    def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent:
        if content.text:
            return TextDeltaMessageContent(content.text)
        if content.function_call:
            function_call = content.function_call
            tag = function_call.name or ""

            name, path = ToolAction.from_native_tool_name(tag)

            args = function_call.args or {}
            return ActionCallDeltaMessageContent(
                tag=tag,
                name=name,
                path=path,
                partial_input=json.dumps(args),
            )
        raise ValueError(f"Unsupported message content type {content}")

_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer = field(default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

tool_choice = field(default='auto', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_k = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

use_native_tools = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__is_thought_part(content)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __is_thought_part(self, content: Part) -> bool:
    # Gemini thinking models emit reasoning-only parts (e.g. a bare `thought_signature`)
    # with no text or function_call; Griptape has no thought content type, so callers skip them.
    return bool((content.thought or content.thought_signature) and not content.text and not content.function_call)

__to_google_message_content(content)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_message_content(self, content: BaseMessageContent) -> Part:
    types = import_optional_dependency("google.genai.types")

    if isinstance(content, TextMessageContent):
        return types.Part.from_text(text=content.artifact.to_text())
    if isinstance(content, ImageMessageContent):
        if isinstance(content.artifact, ImageArtifact):
            return types.Part.from_bytes(data=content.artifact.value, mime_type=content.artifact.mime_type)
        # TODO: Google requires uploading to the files endpoint: https://ai.google.dev/gemini-api/docs/image-understanding#upload-image
        # Can be worked around by using GenericMessageContent, similar to videos.
        raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
    if isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return types.Part.from_function_call(name=action.tag, args=action.input)
    if isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        return types.Part.from_function_response(
            name=content.action.to_native_tool_name(),
            response=artifact.to_dict(),
        )
    if isinstance(content, GenericMessageContent):
        value = content.artifact.value
        if isinstance(value, str):
            return types.Part.from_text(text=value)
        return value
    raise ValueError(f"Unsupported prompt stack content type: {type(content)}")

__to_google_messages(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_messages(self, prompt_stack: PromptStack) -> list[Content]:
    types = import_optional_dependency("google.genai.types")

    return [
        types.Content(
            role=self.__to_google_role(message),
            parts=[self.__to_google_message_content(content) for content in message.content],
        )
        for message in prompt_stack.messages
        if not message.is_system()
    ]

__to_google_role(message)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_role(self, message: Message) -> str:
    if message.is_assistant():
        return "model"
    return "user"

__to_google_tools(tools)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_tools(self, tools: list[BaseTool]) -> list:
    types = import_optional_dependency("google.genai.types")

    function_declarations = []
    for tool in tools:
        for activity in tool.activities():
            schema = tool.to_activity_json_schema(activity, "Parameters Schema")

            if "values" in schema["properties"]:
                schema = schema["properties"]["values"]

            schema = remove_key_in_dict_recursively(schema, "additionalProperties")
            schema = remove_key_in_dict_recursively(schema, "title", preserve_under_key="properties")
            function_declaration = types.FunctionDeclaration(
                name=tool.to_native_tool_name(activity),
                description=tool.activity_description(activity),
                **(
                    {
                        "parameters": {
                            "type": schema["type"],
                            "properties": schema["properties"],
                            "required": schema.get("required", []),
                        }
                    }
                    if schema.get("properties")
                    else {}
                ),
            )

            function_declarations.append(function_declaration)

    return [types.Tool(function_declarations=function_declarations)]

__to_prompt_stack_delta_message_content(content)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent:
    if content.text:
        return TextDeltaMessageContent(content.text)
    if content.function_call:
        function_call = content.function_call
        tag = function_call.name or ""

        name, path = ToolAction.from_native_tool_name(tag)

        args = function_call.args or {}
        return ActionCallDeltaMessageContent(
            tag=tag,
            name=name,
            path=path,
            partial_input=json.dumps(args),
        )
    raise ValueError(f"Unsupported message content type {content}")

__to_prompt_stack_message_content(content)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent:
    if content.text:
        return TextMessageContent(TextArtifact(content.text))
    if content.function_call:
        function_call = content.function_call
        tag = function_call.name or ""

        name, path = ToolAction.from_native_tool_name(tag)

        args = function_call.args or {}
        return ActionCallMessageContent(
            artifact=ActionArtifact(value=ToolAction(tag=tag, name=name, path=path, input=args)),
        )
    raise ValueError(f"Unsupported message content type {content}")

_base_params(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    types = import_optional_dependency("google.genai.types")

    system_messages = prompt_stack.system_messages
    system_instruction = None
    if system_messages:
        system_instruction = types.Content(
            role="system",
            parts=[types.Part.from_text(text=system_message.to_text()) for system_message in system_messages],
        )

    params = {
        # For some reason, providing stop sequences when streaming breaks native functions
        # https://github.com/google-gemini/generative-ai-python/issues/446
        "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences,
        "max_output_tokens": self.max_tokens,
        "temperature": self.temperature,
        "top_p": self.top_p,
        "top_k": self.top_k,
        **({"system_instruction": system_instruction} if system_instruction is not None else {}),
        **self.extra_params,
    }

    if prompt_stack.tools and self.use_native_tools:
        mode = self.tool_choice.upper()

        if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
            mode = "AUTO"

        params["tool_config"] = types.ToolConfig(
            function_calling_config=types.FunctionCallingConfig(mode=mode),
        )
        params["tools"] = self.__to_google_tools(prompt_stack.tools)

    if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
        params["response_mime_type"] = "application/json"
        params["response_json_schema"] = prompt_stack.to_output_json_schema()

    return params

client()

Source code in griptape/drivers/prompt/google_prompt_driver.py
@lazy_property()
def client(self) -> Client:
    genai = import_optional_dependency("google.genai")

    return genai.Client(api_key=self.api_key)

try_run(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    types = import_optional_dependency("google.genai.types")

    messages = self.__to_google_messages(prompt_stack)
    params = self._base_params(prompt_stack)
    config = types.GenerateContentConfig(**params)
    logger.debug((messages, params))
    response: GenerateContentResponse = self.client.models.generate_content(
        model=self.model,
        contents=cast("ContentListUnionDict", messages),
        config=config,
    )
    logger.debug(response.model_dump())

    usage_metadata = response.usage_metadata
    parts = response.candidates[0].content.parts if response.candidates and response.candidates[0].content else []

    return Message(
        content=[
            self.__to_prompt_stack_message_content(part)
            for part in (parts or [])
            if not self.__is_thought_part(part)
        ],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(
            input_tokens=usage_metadata.prompt_token_count if usage_metadata else None,
            output_tokens=usage_metadata.candidates_token_count if usage_metadata else None,
        ),
    )

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    types = import_optional_dependency("google.genai.types")

    messages = self.__to_google_messages(prompt_stack)
    params = self._base_params(prompt_stack)
    config = types.GenerateContentConfig(**params)
    logger.debug((messages, params))
    response = self.client.models.generate_content_stream(
        model=self.model,
        contents=cast("ContentListUnionDict", messages),
        config=config,
    )

    prompt_token_count = None
    for chunk in response:
        logger.debug(chunk.model_dump())
        usage_metadata = chunk.usage_metadata

        parts = chunk.candidates[0].content.parts if chunk.candidates and chunk.candidates[0].content else None
        # Gemini thinking models emit reasoning-only chunks (e.g. a bare `thought_signature`)
        # with no text or function_call; skip them since Griptape has no thought content type.
        non_thought_part = (
            next((part for part in parts if not self.__is_thought_part(part)), None) if parts else None
        )
        content = self.__to_prompt_stack_delta_message_content(non_thought_part) if non_thought_part else None

        # Only want to output the prompt token count once since it is static each chunk
        if prompt_token_count is None and usage_metadata is not None:
            prompt_token_count = usage_metadata.prompt_token_count
            yield DeltaMessage(
                content=content,
                usage=DeltaMessage.Usage(
                    input_tokens=usage_metadata.prompt_token_count,
                    output_tokens=usage_metadata.candidates_token_count,
                ),
            )
        else:
            yield DeltaMessage(
                content=content,
                usage=DeltaMessage.Usage(
                    output_tokens=usage_metadata.candidates_token_count if usage_metadata else None,
                ),
            )