Skip to content

local

__all__ = ['LocalRerankDriver'] module-attribute

LocalRerankDriver

Bases: BaseRerankDriver, FuturesExecutorMixin

Source code in griptape/drivers/rerank/local_rerank_driver.py
@define(kw_only=True)
class LocalRerankDriver(BaseRerankDriver, FuturesExecutorMixin):
    calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
    embedding_driver: BaseEmbeddingDriver = field(
        kw_only=True, default=Factory(lambda: Defaults.drivers_config.embedding_driver), metadata={"serializable": True}
    )

    def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
        query_embedding = self.embedding_driver.embed(query)

        with self.create_futures_executor() as futures_executor:
            artifact_embeddings = execute_futures_list(
                [
                    futures_executor.submit(with_contextvars(self.embedding_driver.embed_text_artifact), a)
                    for a in artifacts
                ],
            )

        artifacts_and_relatednesses = [
            (artifact, self.calculate_relatedness(query_embedding, artifact_embedding))
            for artifact, artifact_embedding in zip(artifacts, artifact_embeddings)
        ]

        artifacts_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

        return [artifact for artifact, _ in artifacts_and_relatednesses]

calculate_relatedness = field(default=lambda x, y: dot(x, y) / norm(x) * norm(y)) class-attribute instance-attribute

embedding_driver = field(kw_only=True, default=Factory(lambda: Defaults.drivers_config.embedding_driver), metadata={'serializable': True}) class-attribute instance-attribute

run(query, artifacts)

Source code in griptape/drivers/rerank/local_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
    query_embedding = self.embedding_driver.embed(query)

    with self.create_futures_executor() as futures_executor:
        artifact_embeddings = execute_futures_list(
            [
                futures_executor.submit(with_contextvars(self.embedding_driver.embed_text_artifact), a)
                for a in artifacts
            ],
        )

    artifacts_and_relatednesses = [
        (artifact, self.calculate_relatedness(query_embedding, artifact_embedding))
        for artifact, artifact_embedding in zip(artifacts, artifact_embeddings)
    ]

    artifacts_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

    return [artifact for artifact, _ in artifacts_and_relatednesses]