Skip to content

huggingface_hub_prompt_driver

logger = logging.getLogger(Defaults.logging_config.logger_name) module-attribute

HuggingFaceHubPromptDriver

Bases: BasePromptDriver

Hugging Face Hub Prompt Driver.

Attributes:

Name Type Description
api_token str

Hugging Face Hub API token.

use_gpu str

Use GPU during model run.

model str

Hugging Face Hub model name.

client InferenceClient

Custom InferenceApi.

tokenizer HuggingFaceTokenizer

Custom HuggingFaceTokenizer.

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@define
class HuggingFaceHubPromptDriver(BasePromptDriver):
    """Hugging Face Hub Prompt Driver.

    Attributes:
        api_token: Hugging Face Hub API token.
        use_gpu: Use GPU during model run.
        model: Hugging Face Hub model name.
        client: Custom `InferenceApi`.
        tokenizer: Custom `HuggingFaceTokenizer`.
    """

    api_token: str = field(kw_only=True, metadata={"serializable": True})
    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    model: str = field(kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="native", kw_only=True, metadata={"serializable": True}
    )
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )
    _client: InferenceClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> InferenceClient:
        return import_optional_dependency("huggingface_hub").InferenceClient(
            model=self.model,
            token=self.api_token,
        )

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value == "tool":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        prompt = self.prompt_stack_to_string(prompt_stack)
        full_params = self._base_params(prompt_stack)
        logger.debug(
            {
                "prompt": prompt,
                **full_params,
            }
        )

        response = self.client.text_generation(
            prompt,
            **full_params,
        )
        logger.debug(response)
        input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
        output_tokens = len(self.tokenizer.tokenizer.encode(response))

        return Message(
            content=response,
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        prompt = self.prompt_stack_to_string(prompt_stack)
        full_params = {**self._base_params(prompt_stack), "stream": True}
        logger.debug(
            {
                "prompt": prompt,
                **full_params,
            }
        )

        response = self.client.text_generation(prompt, **full_params)

        input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))

        full_text = ""
        for token in response:
            logger.debug(token)
            full_text += token
            yield DeltaMessage(content=TextDeltaMessageContent(token, index=0))

        output_tokens = len(self.tokenizer.tokenizer.encode(full_text))
        yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens))

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = {
            "return_full_text": False,
            "max_new_tokens": self.max_tokens,
            **self.extra_params,
        }

        if prompt_stack.output_schema and self.structured_output_strategy == "native":
            # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
            output_schema = prompt_stack.output_schema.json_schema("Output Schema")
            # Grammar does not support $schema and $id
            del output_schema["$schema"]
            del output_schema["$id"]
            params["grammar"] = {"type": "json", "value": output_schema}

        return params

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
        messages = []
        for message in prompt_stack.messages:
            if len(message.content) == 1:
                messages.append({"role": message.role, "content": message.to_text()})
            else:
                raise ValueError("Invalid input content length.")

        return messages

    def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
        messages = self._prompt_stack_to_messages(prompt_stack)
        tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

        if isinstance(tokens, list):
            return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
        else:
            raise ValueError("Invalid output type.")

api_token: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

max_tokens: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structured_output_strategy: StructuredOutputStrategy = field(default='native', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

__prompt_stack_to_tokens(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
    messages = self._prompt_stack_to_messages(prompt_stack)
    tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

    if isinstance(tokens, list):
        return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
    else:
        raise ValueError("Invalid output type.")

client()

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@lazy_property()
def client(self) -> InferenceClient:
    return import_optional_dependency("huggingface_hub").InferenceClient(
        model=self.model,
        token=self.api_token,
    )

prompt_stack_to_string(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
    return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

try_run(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    prompt = self.prompt_stack_to_string(prompt_stack)
    full_params = self._base_params(prompt_stack)
    logger.debug(
        {
            "prompt": prompt,
            **full_params,
        }
    )

    response = self.client.text_generation(
        prompt,
        **full_params,
    )
    logger.debug(response)
    input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
    output_tokens = len(self.tokenizer.tokenizer.encode(response))

    return Message(
        content=response,
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
    )

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    prompt = self.prompt_stack_to_string(prompt_stack)
    full_params = {**self._base_params(prompt_stack), "stream": True}
    logger.debug(
        {
            "prompt": prompt,
            **full_params,
        }
    )

    response = self.client.text_generation(prompt, **full_params)

    input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))

    full_text = ""
    for token in response:
        logger.debug(token)
        full_text += token
        yield DeltaMessage(content=TextDeltaMessageContent(token, index=0))

    output_tokens = len(self.tokenizer.tokenizer.encode(full_text))
    yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens))

validate_structured_output_strategy(_, value)

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value == "tool":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value