Skip to content

Google embedding driver

GoogleEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
api_key Optional[str]

Google API key.

model str

Google model name.

task_type str

Embedding model task type (https://ai.google.dev/tutorials/python_quickstart#use_embeddings). Defaults to retrieval_document.

title Optional[str]

Optional title for the content. Only works with retrieval_document task type.

Source code in griptape/drivers/embedding/google_embedding_driver.py
@define
class GoogleEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        api_key: Google API key.
        model: Google model name.
        task_type: Embedding model task type (https://ai.google.dev/tutorials/python_quickstart#use_embeddings). Defaults to `retrieval_document`.
        title: Optional title for the content. Only works with `retrieval_document` task type.
    """

    DEFAULT_MODEL = "models/embedding-001"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True})
    title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_embed_chunk(self, chunk: str) -> list[float]:
        genai = import_optional_dependency("google.generativeai")
        genai.configure(api_key=self.api_key)

        result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title)

        return result["embedding"]

    def _params(self, chunk: str) -> dict:
        return {"input": chunk, "model": self.model}

DEFAULT_MODEL = 'models/embedding-001' class-attribute instance-attribute

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

task_type: str = field(default='retrieval_document', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

title: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/google_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    genai = import_optional_dependency("google.generativeai")
    genai.configure(api_key=self.api_key)

    result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title)

    return result["embedding"]