Skip to content

Hugging face hub prompt driver

HuggingFaceHubPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_token str

Hugging Face Hub API token.

use_gpu bool

Use GPU during model run.

params dict

Custom model run parameters.

model str

Hugging Face Hub model name.

client InferenceApi

Custom InferenceApi.

tokenizer HuggingFaceTokenizer

Custom HuggingFaceTokenizer.

Source code in griptape/griptape/drivers/prompt/hugging_face_hub_prompt_driver.py
@define
class HuggingFaceHubPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_token: Hugging Face Hub API token.
        use_gpu: Use GPU during model run.
        params: Custom model run parameters.
        model: Hugging Face Hub model name.
        client: Custom `InferenceApi`.
        tokenizer: Custom `HuggingFaceTokenizer`.

    """

    SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
    MAX_NEW_TOKENS = 250
    DEFAULT_PARAMS = {"return_full_text": False, "max_new_tokens": MAX_NEW_TOKENS}

    api_token: str = field(kw_only=True)
    use_gpu: bool = field(default=False, kw_only=True)
    params: dict = field(factory=dict, kw_only=True)
    model: str = field(kw_only=True)
    client: InferenceApi = field(
        default=Factory(
            lambda self: import_optional_dependency("huggingface_hub").InferenceApi(
                repo_id=self.model, token=self.api_token, gpu=self.use_gpu
            ),
            takes_self=True,
        ),
        kw_only=True,
    )
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
                max_tokens=self.MAX_NEW_TOKENS,
            ),
            takes_self=True,
        ),
        kw_only=True,
    )
    stream: bool = field(default=False, kw_only=True)

    @stream.validator
    def validate_stream(self, _, stream):
        if stream:
            raise ValueError("streaming is not supported")

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        prompt = self.prompt_stack_to_string(prompt_stack)

        if self.client.task in self.SUPPORTED_TASKS:
            response = self.client(inputs=prompt, params=self.DEFAULT_PARAMS | self.params)

            if len(response) == 1:
                return TextArtifact(value=response[0]["generated_text"].strip())
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

    def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

DEFAULT_PARAMS = {'return_full_text': False, 'max_new_tokens': MAX_NEW_TOKENS} class-attribute instance-attribute

MAX_NEW_TOKENS = 250 class-attribute instance-attribute

SUPPORTED_TASKS = ['text2text-generation', 'text-generation'] class-attribute instance-attribute

api_token: str = field(kw_only=True) class-attribute instance-attribute

client: InferenceApi = field(default=Factory(lambda : import_optional_dependency('huggingface_hub').InferenceApi(repo_id=self.model, token=self.api_token, gpu=self.use_gpu), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

params: dict = field(factory=dict, kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda : HuggingFaceTokenizer(tokenizer=import_optional_dependency('transformers').AutoTokenizer.from_pretrained(self.model), max_tokens=self.MAX_NEW_TOKENS), takes_self=True), kw_only=True) class-attribute instance-attribute

use_gpu: bool = field(default=False, kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/hugging_face_hub_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    prompt = self.prompt_stack_to_string(prompt_stack)

    if self.client.task in self.SUPPORTED_TASKS:
        response = self.client(inputs=prompt, params=self.DEFAULT_PARAMS | self.params)

        if len(response) == 1:
            return TextArtifact(value=response[0]["generated_text"].strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

try_stream(_)

Source code in griptape/griptape/drivers/prompt/hugging_face_hub_prompt_driver.py
def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/griptape/drivers/prompt/hugging_face_hub_prompt_driver.py
@stream.validator
def validate_stream(self, _, stream):
    if stream:
        raise ValueError("streaming is not supported")