Skip to content

Sagemaker falcon prompt model driver

SageMakerFalconPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
@define
class SageMakerFalconPromptModelDriver(BasePromptModelDriver):
    tokenizer: BaseTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(
                    "tiiuae/falcon-40b", model_max_length=self.max_tokens
                )
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
        return self.prompt_driver.prompt_stack_to_string(prompt_stack)

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)
        stop_sequences = self.prompt_driver.tokenizer.stop_sequences

        return {
            "max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
            "temperature": self.prompt_driver.temperature,
            "do_sample": True,
            "stop": stop_sequences,
        }

    def process_output(self, output: list[dict]) -> TextArtifact:
        return TextArtifact(output[0]["generated_text"].strip())

tokenizer: BaseTokenizer = field(default=Factory(lambda : HuggingFaceTokenizer(tokenizer=import_optional_dependency('transformers').AutoTokenizer.from_pretrained('tiiuae/falcon-40b', model_max_length=self.max_tokens)), takes_self=True), kw_only=True) class-attribute instance-attribute

process_output(output)

Source code in griptape/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def process_output(self, output: list[dict]) -> TextArtifact:
    return TextArtifact(output[0]["generated_text"].strip())

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
    return self.prompt_driver.prompt_stack_to_string(prompt_stack)

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)
    stop_sequences = self.prompt_driver.tokenizer.stop_sequences

    return {
        "max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
        "temperature": self.prompt_driver.temperature,
        "do_sample": True,
        "stop": stop_sequences,
    }