Skip to content

base_embedding_driver

BaseEmbeddingDriver

Bases: SerializableMixin, ExponentialBackoffMixin, ABC

Base Embedding Driver.

Attributes:

Name Type Description
model str

The name of the model to use.

tokenizer Optional[BaseTokenizer]

An instance of BaseTokenizer to use when calculating tokens.

Source code in griptape/drivers/embedding/base_embedding_driver.py
@define
class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    """Base Embedding Driver.

    Attributes:
        model: The name of the model to use.
        tokenizer: An instance of `BaseTokenizer` to use when calculating tokens.
    """

    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
    chunker: Optional[BaseChunker] = field(init=False)

    def __attrs_post_init__(self) -> None:
        self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None

    def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
        return self.embed_string(artifact.to_text())

    def embed_string(self, string: str) -> list[float]:
        for attempt in self.retrying():
            with attempt:
                if self.tokenizer is not None and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
                    return self._embed_long_string(string)
                else:
                    return self.try_embed_chunk(string)

        else:
            raise RuntimeError("Failed to embed string.")

    @abstractmethod
    def try_embed_chunk(self, chunk: str) -> list[float]: ...

    def _embed_long_string(self, string: str) -> list[float]:
        """Embeds a string that is too long to embed in one go.

        Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
        """
        chunks = self.chunker.chunk(string)  # pyright: ignore[reportOptionalMemberAccess] In practice this is never None

        embedding_chunks = []
        length_chunks = []
        for chunk in chunks:
            embedding_chunks.append(self.try_embed_chunk(chunk.value))
            length_chunks.append(len(chunk))

        # generate weighted averages
        embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

        # normalize length to 1
        embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

        return embedding_chunks.tolist()

chunker: Optional[BaseChunker] = field(init=False) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/embedding/base_embedding_driver.py
def __attrs_post_init__(self) -> None:
    self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None

embed_string(string)

Source code in griptape/drivers/embedding/base_embedding_driver.py
def embed_string(self, string: str) -> list[float]:
    for attempt in self.retrying():
        with attempt:
            if self.tokenizer is not None and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
                return self._embed_long_string(string)
            else:
                return self.try_embed_chunk(string)

    else:
        raise RuntimeError("Failed to embed string.")

embed_text_artifact(artifact)

Source code in griptape/drivers/embedding/base_embedding_driver.py
def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
    return self.embed_string(artifact.to_text())

try_embed_chunk(chunk) abstractmethod

Source code in griptape/drivers/embedding/base_embedding_driver.py
@abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]: ...