Bases: BasePromptModelDriver
Source code in griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py
| @define
class SageMakerLlamaPromptModelDriver(BasePromptModelDriver):
DEFAULT_MAX_TOKENS = 600
_tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True)
@property
def tokenizer(self) -> HuggingFaceTokenizer:
if self._tokenizer is None:
self._tokenizer = HuggingFaceTokenizer(
tokenizer=import_optional_dependency("transformers").LlamaTokenizerFast.from_pretrained(
"hf-internal-testing/llama-tokenizer"
),
max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS,
)
return self._tokenizer
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list:
return [[{"role": i.role, "content": i.content} for i in prompt_stack.inputs]]
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack)
return {
"max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
"temperature": self.prompt_driver.temperature,
}
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
if isinstance(output, list):
return TextArtifact(output[0]["generation"]["content"].strip())
else:
raise ValueError("output must be an instance of 'list'")
|
DEFAULT_MAX_TOKENS = 600
class-attribute
instance-attribute
tokenizer: HuggingFaceTokenizer
property
process_output(output)
Source code in griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py
| def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
if isinstance(output, list):
return TextArtifact(output[0]["generation"]["content"].strip())
else:
raise ValueError("output must be an instance of 'list'")
|
Source code in griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py
| def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list:
return [[{"role": i.role, "content": i.content} for i in prompt_stack.inputs]]
|
prompt_stack_to_model_params(prompt_stack)
Source code in griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py
| def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack)
return {
"max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
"temperature": self.prompt_driver.temperature,
}
|