Skip to content

cohere

__all__ = ['CohereEmbeddingDriver'] module-attribute

CohereEmbeddingDriver

Bases: BaseEmbeddingDriver

Cohere Embedding Driver.

Attributes:

Name Type Description
api_key str

Cohere API key.

model str

Cohere model name.

client Client

Custom cohere.Client.

tokenizer CohereTokenizer

Custom CohereTokenizer.

input_type str

Cohere embedding input type.

Source code in griptape/drivers/embedding/cohere_embedding_driver.py
@define
class CohereEmbeddingDriver(BaseEmbeddingDriver):
    """Cohere Embedding Driver.

    Attributes:
        api_key: Cohere API key.
        model: 	Cohere model name.
        client: Custom `cohere.Client`.
        tokenizer: Custom `CohereTokenizer`.
        input_type: Cohere embedding input type.
    """

    DEFAULT_MODEL = "models/embedding-001"

    api_key: str = field(kw_only=True, metadata={"serializable": False})
    input_type: str = field(kw_only=True, metadata={"serializable": True})
    _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
    tokenizer: CohereTokenizer = field(
        default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
        kw_only=True,
    )

    @lazy_property()
    def client(self) -> Client:
        return import_optional_dependency("cohere").Client(self.api_key)

    def try_embed_chunk(self, chunk: str) -> list[float]:
        result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)

        if isinstance(result.embeddings, list):
            return result.embeddings[0]
        else:
            raise ValueError("Non-float embeddings are not supported.")

DEFAULT_MODEL = 'models/embedding-001' class-attribute instance-attribute

api_key: str = field(kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

input_type: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: CohereTokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True) class-attribute instance-attribute

client()

Source code in griptape/drivers/embedding/cohere_embedding_driver.py
@lazy_property()
def client(self) -> Client:
    return import_optional_dependency("cohere").Client(self.api_key)

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)

    if isinstance(result.embeddings, list):
        return result.embeddings[0]
    else:
        raise ValueError("Non-float embeddings are not supported.")