Bases: BaseEmbeddingDriver
Attributes:
Name |
Type |
Description |
api_token |
str
|
Hugging Face Hub API token.
|
model |
str
|
Hugging Face Hub model name.
|
client |
InferenceClient
|
|
Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
| @define
class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver):
"""
Attributes:
api_token: Hugging Face Hub API token.
model: Hugging Face Hub model name.
client: Custom `InferenceApi`.
"""
api_token: str = field(kw_only=True, metadata={"serializable": True})
client: InferenceClient = field(
default=Factory(
lambda self: import_optional_dependency("huggingface_hub").InferenceClient(
model=self.model, token=self.api_token
),
takes_self=True,
),
kw_only=True,
)
def try_embed_chunk(self, chunk: str) -> list[float]:
response = self.client.feature_extraction(chunk)
return response.flatten().tolist()
|
api_token: str = field(kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
client: InferenceClient = field(default=Factory(lambda self: import_optional_dependency('huggingface_hub').InferenceClient(model=self.model, token=self.api_token), takes_self=True), kw_only=True)
class-attribute
instance-attribute
try_embed_chunk(chunk)
Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
| def try_embed_chunk(self, chunk: str) -> list[float]:
response = self.client.feature_extraction(chunk)
return response.flatten().tolist()
|