Skip to content

huggingface_pipeline

__all__ = ['HuggingFacePipelinePromptDriver'] module-attribute

HuggingFacePipelinePromptDriver

Bases: BasePromptDriver

Hugging Face Pipeline Prompt Driver.

Attributes:

Name Type Description
model str

Hugging Face Hub model name.

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@define
class HuggingFacePipelinePromptDriver(BasePromptDriver):
    """Hugging Face Pipeline Prompt Driver.

    Attributes:
        model: Hugging Face Hub model name.
    """

    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )
    structured_output_strategy: StructuredOutputStrategy = field(
        default="rule", kw_only=True, metadata={"serializable": True}
    )
    _pipeline: TextGenerationPipeline = field(
        default=None, kw_only=True, alias="pipeline", metadata={"serializable": False}
    )

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value in ("native", "tool"):
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @lazy_property()
    def pipeline(self) -> TextGenerationPipeline:
        return import_optional_dependency("transformers").pipeline(
            task="text-generation",
            model=self.model,
            max_new_tokens=self.max_tokens,
            tokenizer=self.tokenizer.tokenizer,
        )

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        messages = self._prompt_stack_to_messages(prompt_stack)
        full_params = self._base_params(prompt_stack)
        logger.debug(
            (
                messages,
                full_params,
            )
        )

        result = self.pipeline(messages, **full_params)
        logger.debug(result)

        if isinstance(result, list):
            if len(result) == 1:
                generated_text = result[0]["generated_text"][-1]["content"]

                input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
                output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))

                return Message(
                    content=[TextMessageContent(TextArtifact(generated_text))],
                    role=Message.ASSISTANT_ROLE,
                    usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
                )
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception("invalid output format")

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        raise NotImplementedError("streaming is not supported")

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "max_new_tokens": self.max_tokens,
            "temperature": self.temperature,
            "do_sample": True,
            **self.extra_params,
        }

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
        messages = []

        for message in prompt_stack.messages:
            messages.append({"role": message.role, "content": message.to_text()})

        return messages

    def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
        messages = self._prompt_stack_to_messages(prompt_stack)
        tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

        if isinstance(tokens, list):
            return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
        else:
            raise ValueError("Invalid output type.")

max_tokens: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structured_output_strategy: StructuredOutputStrategy = field(default='rule', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

__prompt_stack_to_tokens(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
    messages = self._prompt_stack_to_messages(prompt_stack)
    tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

    if isinstance(tokens, list):
        return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
    else:
        raise ValueError("Invalid output type.")

pipeline()

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@lazy_property()
def pipeline(self) -> TextGenerationPipeline:
    return import_optional_dependency("transformers").pipeline(
        task="text-generation",
        model=self.model,
        max_new_tokens=self.max_tokens,
        tokenizer=self.tokenizer.tokenizer,
    )

prompt_stack_to_string(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
    return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

try_run(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    messages = self._prompt_stack_to_messages(prompt_stack)
    full_params = self._base_params(prompt_stack)
    logger.debug(
        (
            messages,
            full_params,
        )
    )

    result = self.pipeline(messages, **full_params)
    logger.debug(result)

    if isinstance(result, list):
        if len(result) == 1:
            generated_text = result[0]["generated_text"][-1]["content"]

            input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
            output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))

            return Message(
                content=[TextMessageContent(TextArtifact(generated_text))],
                role=Message.ASSISTANT_ROLE,
                usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
            )
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception("invalid output format")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    raise NotImplementedError("streaming is not supported")

validate_structured_output_strategy(_, value)

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value in ("native", "tool"):
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value