@define
class GooglePromptDriver(BasePromptDriver):
"""
Attributes:
api_key: Google API key.
model: Google model name.
model_client: Custom `GenerativeModel` client.
tokenizer: Custom `GoogleTokenizer`.
top_p: Optional value for top_p.
top_k: Optional value for top_k.
"""
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
model: str = field(kw_only=True, metadata={"serializable": True})
model_client: Any = field(default=Factory(lambda self: self._default_model_client(), takes_self=True), kw_only=True)
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True),
kw_only=True,
)
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig
inputs = self._prompt_stack_to_model_input(prompt_stack)
response = self.model_client.generate_content(
inputs,
generation_config=GenerationConfig(
stop_sequences=self.tokenizer.stop_sequences,
max_output_tokens=self.max_output_tokens(inputs),
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
),
)
return TextArtifact(value=response.text)
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig
inputs = self._prompt_stack_to_model_input(prompt_stack)
response = self.model_client.generate_content(
inputs,
stream=True,
generation_config=GenerationConfig(
stop_sequences=self.tokenizer.stop_sequences,
max_output_tokens=self.max_output_tokens(inputs),
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
),
)
for chunk in response:
yield TextArtifact(value=chunk.text)
def _default_model_client(self) -> GenerativeModel:
genai = import_optional_dependency("google.generativeai")
genai.configure(api_key=self.api_key)
return genai.GenerativeModel(self.model)
def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list[ContentDict]:
inputs = [
self.__to_content_dict(prompt_input) for prompt_input in prompt_stack.inputs if not prompt_input.is_system()
]
# Gemini does not have the notion of a system message, so we insert it as part of the first message in the history.
system = next((i for i in prompt_stack.inputs if i.is_system()), None)
if system is not None:
inputs[0]["parts"].insert(0, system.content)
return inputs
def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict:
ContentDict = import_optional_dependency("google.generativeai.types").ContentDict
return ContentDict({"role": self.__to_google_role(prompt_input), "parts": [prompt_input.content]})
def __to_google_role(self, prompt_input: PromptStack.Input) -> str:
if prompt_input.is_system():
return "user"
elif prompt_input.is_assistant():
return "model"
else:
return "user"