Skip to content

google_embedding_driver

GoogleEmbeddingDriver

Bases: BaseEmbeddingDriver

Google Embedding Driver.

Attributes:

Name Type Description
api_key str | None

Google API key.

model str

Google model name.

client Client

Custom google.genai.Client.

task_type str

Embedding model task type (https://ai.google.dev/gemini-api/docs/embeddings#task-types). Defaults to retrieval_document.

title str | None

Optional title for the content. Only works with retrieval_document task type.

Source code in griptape/drivers/embedding/google_embedding_driver.py
@define
class GoogleEmbeddingDriver(BaseEmbeddingDriver):
    """Google Embedding Driver.

    Attributes:
        api_key: Google API key.
        model: Google model name.
        client: Custom `google.genai.Client`.
        task_type: Embedding model task type (https://ai.google.dev/gemini-api/docs/embeddings#task-types). Defaults to `retrieval_document`.
        title: Optional title for the content. Only works with `retrieval_document` task type.
    """

    DEFAULT_MODEL = "models/embedding-001"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    api_key: str | None = field(default=None, kw_only=True, metadata={"serializable": False})
    task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True})
    title: str | None = field(default=None, kw_only=True, metadata={"serializable": True})
    _client: Client | None = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Client:
        genai = import_optional_dependency("google.genai")

        return genai.Client(api_key=self.api_key)

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        types = import_optional_dependency("google.genai.types")

        response = self.client.models.embed_content(
            model=self.model,
            contents=chunk,
            config=types.EmbedContentConfig(task_type=self.task_type, title=self.title),
        )

        return cast("list[float]", response.embeddings[0].values)  # pyright: ignore[reportOptionalSubscript]

    def _params(self, chunk: str) -> dict:
        return {"input": chunk, "model": self.model}

DEFAULT_MODEL = 'models/embedding-001' class-attribute instance-attribute

_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

model = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

task_type = field(default='retrieval_document', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

title = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_params(chunk)

Source code in griptape/drivers/embedding/google_embedding_driver.py
def _params(self, chunk: str) -> dict:
    return {"input": chunk, "model": self.model}

client()

Source code in griptape/drivers/embedding/google_embedding_driver.py
@lazy_property()
def client(self) -> Client:
    genai = import_optional_dependency("google.genai")

    return genai.Client(api_key=self.api_key)

try_embed_chunk(chunk, **kwargs)

Source code in griptape/drivers/embedding/google_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    types = import_optional_dependency("google.genai.types")

    response = self.client.models.embed_content(
        model=self.model,
        contents=chunk,
        config=types.EmbedContentConfig(task_type=self.task_type, title=self.title),
    )

    return cast("list[float]", response.embeddings[0].values)  # pyright: ignore[reportOptionalSubscript]