Skip to content

cohere_rerank_driver

CohereRerankDriver

Bases: BaseRerankDriver

Source code in griptape/drivers/rerank/cohere_rerank_driver.py
@define(kw_only=True)
class CohereRerankDriver(BaseRerankDriver):
    model: str = field(default="rerank-english-v3.0", metadata={"serializable": True})
    top_n: Optional[int] = field(default=None)

    api_key: str = field(metadata={"serializable": True})
    client: Client = field(
        default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
    )

    def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
        # Cohere errors out if passed "empty" documents or no documents at all
        artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a}

        if artifacts_dict:
            response = self.client.rerank(
                model=self.model,
                query=query,
                documents=[a.to_text() for a in artifacts_dict.values()],
                return_documents=True,
                top_n=self.top_n,
            )
            return [artifacts_dict[str(hash(r.document.text))] for r in response.results]
        else:
            return []

api_key: str = field(metadata={'serializable': True}) class-attribute instance-attribute

client: Client = field(default=Factory(lambda self: import_optional_dependency('cohere').Client(self.api_key), takes_self=True)) class-attribute instance-attribute

model: str = field(default='rerank-english-v3.0', metadata={'serializable': True}) class-attribute instance-attribute

top_n: Optional[int] = field(default=None) class-attribute instance-attribute

run(query, artifacts)

Source code in griptape/drivers/rerank/cohere_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
    # Cohere errors out if passed "empty" documents or no documents at all
    artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a}

    if artifacts_dict:
        response = self.client.rerank(
            model=self.model,
            query=query,
            documents=[a.to_text() for a in artifacts_dict.values()],
            return_documents=True,
            top_n=self.top_n,
        )
        return [artifacts_dict[str(hash(r.document.text))] for r in response.results]
    else:
        return []