Bases: SerializableMixin
, ExponentialBackoffMixin
, ABC
Base Embedding Driver.
Attributes:
Name |
Type |
Description |
model |
str
|
The name of the model to use.
|
tokenizer |
Optional[BaseTokenizer]
|
An instance of BaseTokenizer to use when calculating tokens.
|
Source code in griptape/drivers/embedding/base_embedding_driver.py
| @define
class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
"""Base Embedding Driver.
Attributes:
model: The name of the model to use.
tokenizer: An instance of `BaseTokenizer` to use when calculating tokens.
"""
model: str = field(kw_only=True, metadata={"serializable": True})
tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
chunker: Optional[BaseChunker] = field(init=False)
def __attrs_post_init__(self) -> None:
self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None
def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
return self.embed_string(artifact.to_text())
def embed_string(self, string: str) -> list[float]:
for attempt in self.retrying():
with attempt:
if self.tokenizer is not None and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
return self._embed_long_string(string)
else:
return self.try_embed_chunk(string)
else:
raise RuntimeError("Failed to embed string.")
@abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]: ...
def _embed_long_string(self, string: str) -> list[float]:
"""Embeds a string that is too long to embed in one go.
Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
"""
chunks = self.chunker.chunk(string) # pyright: ignore[reportOptionalMemberAccess] In practice this is never None
embedding_chunks = []
length_chunks = []
for chunk in chunks:
embedding_chunks.append(self.try_embed_chunk(chunk.value))
length_chunks.append(len(chunk))
# generate weighted averages
embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)
# normalize length to 1
embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)
return embedding_chunks.tolist()
|
chunker: Optional[BaseChunker] = field(init=False)
class-attribute
instance-attribute
model: str = field(kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
class-attribute
instance-attribute
__attrs_post_init__()
Source code in griptape/drivers/embedding/base_embedding_driver.py
| def __attrs_post_init__(self) -> None:
self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None
|
embed_string(string)
Source code in griptape/drivers/embedding/base_embedding_driver.py
| def embed_string(self, string: str) -> list[float]:
for attempt in self.retrying():
with attempt:
if self.tokenizer is not None and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
return self._embed_long_string(string)
else:
return self.try_embed_chunk(string)
else:
raise RuntimeError("Failed to embed string.")
|
embed_text_artifact(artifact)
Source code in griptape/drivers/embedding/base_embedding_driver.py
| def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
return self.embed_string(artifact.to_text())
|
try_embed_chunk(chunk)
abstractmethod
Source code in griptape/drivers/embedding/base_embedding_driver.py
| @abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]: ...
|