Skip to content

google

__all__ = ['GooglePromptDriver'] module-attribute

GooglePromptDriver

Bases: BasePromptDriver

Google Prompt Driver.

Attributes:

Name Type Description
api_key Optional[str]

Google API key.

model str

Google model name.

client GenerativeModel

Custom GenerativeModel client.

top_p Optional[float]

Optional value for top_p.

top_k Optional[int]

Optional value for top_k.

Source code in griptape/drivers/prompt/google_prompt_driver.py
@define
class GooglePromptDriver(BasePromptDriver):
    """Google Prompt Driver.

    Attributes:
        api_key: Google API key.
        model: Google model name.
        client: Custom `GenerativeModel` client.
        top_p: Optional value for top_p.
        top_k: Optional value for top_k.
    """

    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True),
        kw_only=True,
    )
    top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
    top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="tool", kw_only=True, metadata={"serializable": True}
    )
    tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
    _client: Optional[GenerativeModel] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value == "native":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @lazy_property()
    def client(self) -> GenerativeModel:
        genai = import_optional_dependency("google.generativeai")
        genai.configure(api_key=self.api_key)

        return genai.GenerativeModel(self.model)

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        messages = self.__to_google_messages(prompt_stack)
        params = self._base_params(prompt_stack)
        logging.debug((messages, params))
        response: GenerateContentResponse = self.client.generate_content(messages, **params)
        logging.debug(response.to_dict())

        usage_metadata = response.usage_metadata

        return Message(
            content=[self.__to_prompt_stack_message_content(part) for part in response.parts],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(
                input_tokens=usage_metadata.prompt_token_count,
                output_tokens=usage_metadata.candidates_token_count,
            ),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        messages = self.__to_google_messages(prompt_stack)
        params = {**self._base_params(prompt_stack), "stream": True}
        logging.debug((messages, params))
        response: GenerateContentResponse = self.client.generate_content(
            messages,
            **params,
        )

        prompt_token_count = None
        for chunk in response:
            logger.debug(chunk.to_dict())
            usage_metadata = chunk.usage_metadata

            content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None
            # Only want to output the prompt token count once since it is static each chunk
            if prompt_token_count is None:
                prompt_token_count = usage_metadata.prompt_token_count
                yield DeltaMessage(
                    content=content,
                    usage=DeltaMessage.Usage(
                        input_tokens=usage_metadata.prompt_token_count,
                        output_tokens=usage_metadata.candidates_token_count,
                    ),
                )
            else:
                yield DeltaMessage(
                    content=content,
                    usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count),
                )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        types = import_optional_dependency("google.generativeai.types")
        protos = import_optional_dependency("google.generativeai.protos")

        system_messages = prompt_stack.system_messages
        if system_messages:
            self.client._system_instruction = types.ContentDict(
                role="system",
                parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages],
            )

        params = {
            "generation_config": types.GenerationConfig(
                **{
                    # For some reason, providing stop sequences when streaming breaks native functions
                    # https://github.com/google-gemini/generative-ai-python/issues/446
                    "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences,
                    "max_output_tokens": self.max_tokens,
                    "temperature": self.temperature,
                    "top_p": self.top_p,
                    "top_k": self.top_k,
                    **self.extra_params,
                },
            ),
        }

        if prompt_stack.tools and self.use_native_tools:
            params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

            if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
                params["tool_config"]["function_calling_config"]["mode"] = "auto"

            params["tools"] = self.__to_google_tools(prompt_stack.tools)

        return params

    def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
        types = import_optional_dependency("google.generativeai.types")

        return [
            types.ContentDict(
                {
                    "role": self.__to_google_role(message),
                    "parts": [self.__to_google_message_content(content) for content in message.content],
                },
            )
            for message in prompt_stack.messages
            if not message.is_system()
        ]

    def __to_google_role(self, message: Message) -> str:
        if message.is_assistant():
            return "model"
        else:
            return "user"

    def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
        types = import_optional_dependency("google.generativeai.types")

        tool_declarations = []
        for tool in tools:
            for activity in tool.activities():
                schema = tool.to_activity_json_schema(activity, "Parameters Schema")

                if "values" in schema["properties"]:
                    schema = schema["properties"]["values"]

                schema = remove_key_in_dict_recursively(schema, "additionalProperties")
                tool_declaration = types.FunctionDeclaration(
                    name=tool.to_native_tool_name(activity),
                    description=tool.activity_description(activity),
                    **(
                        {
                            "parameters": {
                                "type": schema["type"],
                                "properties": schema["properties"],
                                "required": schema.get("required", []),
                            }
                        }
                        if schema.get("properties")
                        else {}
                    ),
                )

                tool_declarations.append(tool_declaration)

        return tool_declarations

    def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str:
        types = import_optional_dependency("google.generativeai.types")
        protos = import_optional_dependency("google.generativeai.protos")

        if isinstance(content, TextMessageContent):
            return content.artifact.to_text()
        elif isinstance(content, ImageMessageContent):
            return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value)
        elif isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input))
        elif isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            return protos.Part(
                function_response=protos.FunctionResponse(
                    name=content.action.to_native_tool_name(),
                    response=artifact.to_dict(),
                ),
            )
        elif isinstance(content, GenericMessageContent):
            return content.artifact.value
        else:
            raise ValueError(f"Unsupported prompt stack content type: {type(content)}")

    def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent:
        json_format = import_optional_dependency("google.protobuf.json_format")

        if content.text:
            return TextMessageContent(TextArtifact(content.text))
        elif content.function_call:
            function_call = content.function_call

            name, path = ToolAction.from_native_tool_name(function_call.name)

            args = json_format.MessageToDict(function_call._pb).get("args", {})
            return ActionCallMessageContent(
                artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)),
            )
        else:
            raise ValueError(f"Unsupported message content type {content}")

    def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent:
        json_format = import_optional_dependency("google.protobuf.json_format")

        if content.text:
            return TextDeltaMessageContent(content.text)
        elif content.function_call:
            function_call = content.function_call

            name, path = ToolAction.from_native_tool_name(function_call.name)

            args = json_format.MessageToDict(function_call._pb).get("args", {})
            return ActionCallDeltaMessageContent(
                tag=function_call.name,
                name=name,
                path=path,
                partial_input=json.dumps(args),
            )
        else:
            raise ValueError(f"Unsupported message content type {content}")

_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer = field(default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

tool_choice = field(default='auto', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_k = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

use_native_tools = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_google_message_content(content)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str:
    types = import_optional_dependency("google.generativeai.types")
    protos = import_optional_dependency("google.generativeai.protos")

    if isinstance(content, TextMessageContent):
        return content.artifact.to_text()
    elif isinstance(content, ImageMessageContent):
        return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value)
    elif isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input))
    elif isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        return protos.Part(
            function_response=protos.FunctionResponse(
                name=content.action.to_native_tool_name(),
                response=artifact.to_dict(),
            ),
        )
    elif isinstance(content, GenericMessageContent):
        return content.artifact.value
    else:
        raise ValueError(f"Unsupported prompt stack content type: {type(content)}")

__to_google_messages(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
    types = import_optional_dependency("google.generativeai.types")

    return [
        types.ContentDict(
            {
                "role": self.__to_google_role(message),
                "parts": [self.__to_google_message_content(content) for content in message.content],
            },
        )
        for message in prompt_stack.messages
        if not message.is_system()
    ]

__to_google_role(message)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_role(self, message: Message) -> str:
    if message.is_assistant():
        return "model"
    else:
        return "user"

__to_google_tools(tools)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
    types = import_optional_dependency("google.generativeai.types")

    tool_declarations = []
    for tool in tools:
        for activity in tool.activities():
            schema = tool.to_activity_json_schema(activity, "Parameters Schema")

            if "values" in schema["properties"]:
                schema = schema["properties"]["values"]

            schema = remove_key_in_dict_recursively(schema, "additionalProperties")
            tool_declaration = types.FunctionDeclaration(
                name=tool.to_native_tool_name(activity),
                description=tool.activity_description(activity),
                **(
                    {
                        "parameters": {
                            "type": schema["type"],
                            "properties": schema["properties"],
                            "required": schema.get("required", []),
                        }
                    }
                    if schema.get("properties")
                    else {}
                ),
            )

            tool_declarations.append(tool_declaration)

    return tool_declarations

__to_prompt_stack_delta_message_content(content)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent:
    json_format = import_optional_dependency("google.protobuf.json_format")

    if content.text:
        return TextDeltaMessageContent(content.text)
    elif content.function_call:
        function_call = content.function_call

        name, path = ToolAction.from_native_tool_name(function_call.name)

        args = json_format.MessageToDict(function_call._pb).get("args", {})
        return ActionCallDeltaMessageContent(
            tag=function_call.name,
            name=name,
            path=path,
            partial_input=json.dumps(args),
        )
    else:
        raise ValueError(f"Unsupported message content type {content}")

__to_prompt_stack_message_content(content)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent:
    json_format = import_optional_dependency("google.protobuf.json_format")

    if content.text:
        return TextMessageContent(TextArtifact(content.text))
    elif content.function_call:
        function_call = content.function_call

        name, path = ToolAction.from_native_tool_name(function_call.name)

        args = json_format.MessageToDict(function_call._pb).get("args", {})
        return ActionCallMessageContent(
            artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)),
        )
    else:
        raise ValueError(f"Unsupported message content type {content}")

_base_params(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    types = import_optional_dependency("google.generativeai.types")
    protos = import_optional_dependency("google.generativeai.protos")

    system_messages = prompt_stack.system_messages
    if system_messages:
        self.client._system_instruction = types.ContentDict(
            role="system",
            parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages],
        )

    params = {
        "generation_config": types.GenerationConfig(
            **{
                # For some reason, providing stop sequences when streaming breaks native functions
                # https://github.com/google-gemini/generative-ai-python/issues/446
                "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences,
                "max_output_tokens": self.max_tokens,
                "temperature": self.temperature,
                "top_p": self.top_p,
                "top_k": self.top_k,
                **self.extra_params,
            },
        ),
    }

    if prompt_stack.tools and self.use_native_tools:
        params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

        if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
            params["tool_config"]["function_calling_config"]["mode"] = "auto"

        params["tools"] = self.__to_google_tools(prompt_stack.tools)

    return params

client()

Source code in griptape/drivers/prompt/google_prompt_driver.py
@lazy_property()
def client(self) -> GenerativeModel:
    genai = import_optional_dependency("google.generativeai")
    genai.configure(api_key=self.api_key)

    return genai.GenerativeModel(self.model)

try_run(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    messages = self.__to_google_messages(prompt_stack)
    params = self._base_params(prompt_stack)
    logging.debug((messages, params))
    response: GenerateContentResponse = self.client.generate_content(messages, **params)
    logging.debug(response.to_dict())

    usage_metadata = response.usage_metadata

    return Message(
        content=[self.__to_prompt_stack_message_content(part) for part in response.parts],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(
            input_tokens=usage_metadata.prompt_token_count,
            output_tokens=usage_metadata.candidates_token_count,
        ),
    )

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    messages = self.__to_google_messages(prompt_stack)
    params = {**self._base_params(prompt_stack), "stream": True}
    logging.debug((messages, params))
    response: GenerateContentResponse = self.client.generate_content(
        messages,
        **params,
    )

    prompt_token_count = None
    for chunk in response:
        logger.debug(chunk.to_dict())
        usage_metadata = chunk.usage_metadata

        content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None
        # Only want to output the prompt token count once since it is static each chunk
        if prompt_token_count is None:
            prompt_token_count = usage_metadata.prompt_token_count
            yield DeltaMessage(
                content=content,
                usage=DeltaMessage.Usage(
                    input_tokens=usage_metadata.prompt_token_count,
                    output_tokens=usage_metadata.candidates_token_count,
                ),
            )
        else:
            yield DeltaMessage(
                content=content,
                usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count),
            )

validate_structured_output_strategy(_, value)

Source code in griptape/drivers/prompt/google_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value == "native":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value