Bases: BaseEmbeddingDriver
Cohere Embedding Driver.
Attributes:
Source code in griptape/drivers/embedding/cohere_embedding_driver.py
| @define
class CohereEmbeddingDriver(BaseEmbeddingDriver):
"""Cohere Embedding Driver.
Attributes:
api_key: Cohere API key.
model: Cohere model name.
client: Custom `cohere.Client`.
tokenizer: Custom `CohereTokenizer`.
input_type: Cohere embedding input type.
"""
DEFAULT_MODEL = "models/embedding-001"
api_key: str = field(kw_only=True, metadata={"serializable": False})
input_type: str = field(kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
tokenizer: CohereTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
kw_only=True,
)
@lazy_property()
def client(self) -> Client:
return import_optional_dependency("cohere").Client(self.api_key)
def try_embed_chunk(self, chunk: str) -> list[float]:
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)
if isinstance(result.embeddings, list):
return result.embeddings[0]
else:
raise ValueError("Non-float embeddings are not supported.")
|
DEFAULT_MODEL = 'models/embedding-001'
class-attribute
instance-attribute
api_key: str = field(kw_only=True, metadata={'serializable': False})
class-attribute
instance-attribute
tokenizer: CohereTokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True)
class-attribute
instance-attribute
client()
Source code in griptape/drivers/embedding/cohere_embedding_driver.py
| @lazy_property()
def client(self) -> Client:
return import_optional_dependency("cohere").Client(self.api_key)
|
try_embed_chunk(chunk)
Source code in griptape/drivers/embedding/cohere_embedding_driver.py
| def try_embed_chunk(self, chunk: str) -> list[float]:
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)
if isinstance(result.embeddings, list):
return result.embeddings[0]
else:
raise ValueError("Non-float embeddings are not supported.")
|