Skip to content

Huggingface pipeline prompt driver

HuggingFacePipelinePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
params dict

Custom model run parameters.

model str

Hugging Face Hub model name.

tokenizer HuggingFaceTokenizer

Custom HuggingFaceTokenizer.

Source code in griptape/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@define
class HuggingFacePipelinePromptDriver(BasePromptDriver):
    """
    Attributes:
        params: Custom model run parameters.
        model: Hugging Face Hub model name.
        tokenizer: Custom `HuggingFaceTokenizer`.

    """

    SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
    DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1}

    model: str = field(kw_only=True)
    params: dict = field(factory=dict, kw_only=True)
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model)
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        prompt = self.prompt_stack_to_string(prompt_stack)
        pipeline = import_optional_dependency("transformers").pipeline

        generator = pipeline(
            tokenizer=self.tokenizer.tokenizer,
            model=self.model,
            max_new_tokens=self.tokenizer.count_tokens_left(prompt),
        )

        if generator.task in self.SUPPORTED_TASKS:
            extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id}

            response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params))

            if len(response) == 1:
                return TextArtifact(value=response[0]["generated_text"].strip())
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

    def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

DEFAULT_PARAMS = {'return_full_text': False, 'num_return_sequences': 1} class-attribute instance-attribute

SUPPORTED_TASKS = ['text2text-generation', 'text-generation'] class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

params: dict = field(factory=dict, kw_only=True) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda : HuggingFaceTokenizer(tokenizer=import_optional_dependency('transformers').AutoTokenizer.from_pretrained(self.model)), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    prompt = self.prompt_stack_to_string(prompt_stack)
    pipeline = import_optional_dependency("transformers").pipeline

    generator = pipeline(
        tokenizer=self.tokenizer.tokenizer,
        model=self.model,
        max_new_tokens=self.tokenizer.count_tokens_left(prompt),
    )

    if generator.task in self.SUPPORTED_TASKS:
        extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id}

        response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params))

        if len(response) == 1:
            return TextArtifact(value=response[0]["generated_text"].strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

try_stream(_)

Source code in griptape/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")