Skip to content

Google prompt driver

GooglePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key Optional[str]

Google API key.

model str

Google model name.

model_client Any

Custom GenerativeModel client.

tokenizer BaseTokenizer

Custom GoogleTokenizer.

top_p Optional[float]

Optional value for top_p.

top_k Optional[int]

Optional value for top_k.

Source code in griptape/drivers/prompt/google_prompt_driver.py
@define
class GooglePromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Google API key.
        model: Google model name.
        model_client: Custom `GenerativeModel` client.
        tokenizer: Custom `GoogleTokenizer`.
        top_p: Optional value for top_p.
        top_k: Optional value for top_k.
    """

    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    model_client: Any = field(default=Factory(lambda self: self._default_model_client(), takes_self=True), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True),
        kw_only=True,
    )
    top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
    top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

        inputs = self._prompt_stack_to_model_input(prompt_stack)
        response = self.model_client.generate_content(
            inputs,
            generation_config=GenerationConfig(
                stop_sequences=self.tokenizer.stop_sequences,
                max_output_tokens=self.max_output_tokens(inputs),
                temperature=self.temperature,
                top_p=self.top_p,
                top_k=self.top_k,
            ),
        )

        return TextArtifact(value=response.text)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

        inputs = self._prompt_stack_to_model_input(prompt_stack)
        response = self.model_client.generate_content(
            inputs,
            stream=True,
            generation_config=GenerationConfig(
                stop_sequences=self.tokenizer.stop_sequences,
                max_output_tokens=self.max_output_tokens(inputs),
                temperature=self.temperature,
                top_p=self.top_p,
                top_k=self.top_k,
            ),
        )

        for chunk in response:
            yield TextArtifact(value=chunk.text)

    def _default_model_client(self) -> GenerativeModel:
        genai = import_optional_dependency("google.generativeai")
        genai.configure(api_key=self.api_key)

        return genai.GenerativeModel(self.model)

    def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list[ContentDict]:
        inputs = [
            self.__to_content_dict(prompt_input) for prompt_input in prompt_stack.inputs if not prompt_input.is_system()
        ]

        # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history.
        system = next((i for i in prompt_stack.inputs if i.is_system()), None)
        if system is not None:
            inputs[0]["parts"].insert(0, system.content)

        return inputs

    def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict:
        ContentDict = import_optional_dependency("google.generativeai.types").ContentDict

        return ContentDict({"role": self.__to_google_role(prompt_input), "parts": [prompt_input.content]})

    def __to_google_role(self, prompt_input: PromptStack.Input) -> str:
        if prompt_input.is_system():
            return "user"
        elif prompt_input.is_assistant():
            return "model"
        else:
            return "user"

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model_client: Any = field(default=Factory(lambda self: self._default_model_client(), takes_self=True), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

top_k: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p: Optional[float] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_content_dict(prompt_input)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict:
    ContentDict = import_optional_dependency("google.generativeai.types").ContentDict

    return ContentDict({"role": self.__to_google_role(prompt_input), "parts": [prompt_input.content]})

__to_google_role(prompt_input)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_role(self, prompt_input: PromptStack.Input) -> str:
    if prompt_input.is_system():
        return "user"
    elif prompt_input.is_assistant():
        return "model"
    else:
        return "user"

try_run(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

    inputs = self._prompt_stack_to_model_input(prompt_stack)
    response = self.model_client.generate_content(
        inputs,
        generation_config=GenerationConfig(
            stop_sequences=self.tokenizer.stop_sequences,
            max_output_tokens=self.max_output_tokens(inputs),
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
        ),
    )

    return TextArtifact(value=response.text)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

    inputs = self._prompt_stack_to_model_input(prompt_stack)
    response = self.model_client.generate_content(
        inputs,
        stream=True,
        generation_config=GenerationConfig(
            stop_sequences=self.tokenizer.stop_sequences,
            max_output_tokens=self.max_output_tokens(inputs),
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
        ),
    )

    for chunk in response:
        yield TextArtifact(value=chunk.text)