Skip to content

Huggingface hub embedding driver

HuggingFaceHubEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
api_token str

Hugging Face Hub API token.

model str

Hugging Face Hub model name.

client InferenceClient

Custom InferenceApi.

Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
@define
class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        api_token: Hugging Face Hub API token.
        model: Hugging Face Hub model name.
        client: Custom `InferenceApi`.
    """

    api_token: str = field(kw_only=True, metadata={"serializable": True})
    client: InferenceClient = field(
        default=Factory(
            lambda self: import_optional_dependency("huggingface_hub").InferenceClient(
                model=self.model, token=self.api_token
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        response = self.client.feature_extraction(chunk)

        return response.flatten().tolist()

api_token: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: InferenceClient = field(default=Factory(lambda self: import_optional_dependency('huggingface_hub').InferenceClient(model=self.model, token=self.api_token), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    response = self.client.feature_extraction(chunk)

    return response.flatten().tolist()