Bases: BaseEmbeddingDriver
Hugging Face Hub Embedding Driver.
Attributes:
Name |
Type |
Description |
api_token |
str
|
Hugging Face Hub API token.
|
model |
str
|
Hugging Face Hub model name.
|
client |
InferenceClient
|
|
Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
| @define
class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver):
"""Hugging Face Hub Embedding Driver.
Attributes:
api_token: Hugging Face Hub API token.
model: Hugging Face Hub model name.
client: Custom `InferenceApi`.
"""
api_token: str = field(kw_only=True, metadata={"serializable": True})
_client: InferenceClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
@lazy_property()
def client(self) -> InferenceClient:
return import_optional_dependency("huggingface_hub").InferenceClient(
model=self.model,
token=self.api_token,
)
def try_embed_chunk(self, chunk: str) -> list[float]:
response = self.client.feature_extraction(chunk)
return response.flatten().tolist()
|
api_token: str = field(kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
client()
Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
| @lazy_property()
def client(self) -> InferenceClient:
return import_optional_dependency("huggingface_hub").InferenceClient(
model=self.model,
token=self.api_token,
)
|
try_embed_chunk(chunk)
Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
| def try_embed_chunk(self, chunk: str) -> list[float]:
response = self.client.feature_extraction(chunk)
return response.flatten().tolist()
|