Skip to content

Sagemaker falcon prompt model driver

SageMakerFalconPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
@define
class SageMakerFalconPromptModelDriver(BasePromptModelDriver):
    DEFAULT_MAX_TOKENS = 600

    _tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True)

    @property
    def tokenizer(self) -> HuggingFaceTokenizer:
        if self._tokenizer is None:
            self._tokenizer = HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained("tiiuae/falcon-40b"),
                max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS,
            )
        return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
        return self.prompt_driver.prompt_stack_to_string(prompt_stack)

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)
        stop_sequences = self.prompt_driver.tokenizer.stop_sequences

        return {
            "max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
            "temperature": self.prompt_driver.temperature,
            "do_sample": True,
            "stop": stop_sequences,
        }

    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        if isinstance(output, list):
            return TextArtifact(output[0]["generated_text"].strip())
        else:
            raise ValueError("output must be an instance of 'list'")

DEFAULT_MAX_TOKENS = 600 class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer property

process_output(output)

Source code in griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    if isinstance(output, list):
        return TextArtifact(output[0]["generated_text"].strip())
    else:
        raise ValueError("output must be an instance of 'list'")

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
    return self.prompt_driver.prompt_stack_to_string(prompt_stack)

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)
    stop_sequences = self.prompt_driver.tokenizer.stop_sequences

    return {
        "max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
        "temperature": self.prompt_driver.temperature,
        "do_sample": True,
        "stop": stop_sequences,
    }