Skip to content

amazon_sagemaker_jumpstart

__all__ = ['AmazonSageMakerJumpstartPromptDriver'] module-attribute

AmazonSageMakerJumpstartPromptDriver

Bases: BasePromptDriver

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@define
class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    endpoint: str = field(kw_only=True, metadata={"serializable": True})
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )
    structured_output_strategy: StructuredOutputStrategy = field(
        default="rule", kw_only=True, metadata={"serializable": True}
    )
    _client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value != "rule":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @lazy_property()
    def client(self) -> Any:
        return self.session.client("sagemaker-runtime")

    @stream.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_stream(self, _: Attribute, stream: bool) -> None:  # noqa: FBT001
        if stream:
            raise ValueError("streaming is not supported")

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        payload = {
            "inputs": self.prompt_stack_to_string(prompt_stack),
            "parameters": {**self._base_params(prompt_stack)},
        }
        logger.debug(payload)

        response = self.client.invoke_endpoint(
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
            **(
                {"InferenceComponentName": self.inference_component_name}
                if self.inference_component_name is not None
                else {}
            ),
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))
        logger.debug(decoded_body)

        if isinstance(decoded_body, list):
            if decoded_body:
                generated_text = decoded_body[0]["generated_text"]
            else:
                raise ValueError("model response is empty")
        else:
            generated_text = decoded_body["generated_text"]

        input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
        output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))

        return Message(
            content=[TextMessageContent(TextArtifact(generated_text))],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        raise NotImplementedError("streaming is not supported")

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "temperature": self.temperature,
            "max_new_tokens": self.max_tokens,
            "do_sample": True,
            "eos_token_id": self.tokenizer.tokenizer.eos_token_id,
            "stop_strings": self.tokenizer.stop_sequences,
            "return_full_text": False,
            **self.extra_params,
        }

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
        messages = []

        for message in prompt_stack.messages:
            messages.append({"role": message.role, "content": message.to_text()})

        return messages

    def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
        messages = self._prompt_stack_to_messages(prompt_stack)

        tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

        if isinstance(tokens, list):
            return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
        else:
            raise ValueError("Invalid output type.")

custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

max_tokens: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structured_output_strategy: StructuredOutputStrategy = field(default='rule', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

__prompt_stack_to_tokens(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
    messages = self._prompt_stack_to_messages(prompt_stack)

    tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

    if isinstance(tokens, list):
        return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
    else:
        raise ValueError("Invalid output type.")

client()

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@lazy_property()
def client(self) -> Any:
    return self.session.client("sagemaker-runtime")

prompt_stack_to_string(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
    return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    payload = {
        "inputs": self.prompt_stack_to_string(prompt_stack),
        "parameters": {**self._base_params(prompt_stack)},
    }
    logger.debug(payload)

    response = self.client.invoke_endpoint(
        EndpointName=self.endpoint,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes=self.custom_attributes,
        **(
            {"InferenceComponentName": self.inference_component_name}
            if self.inference_component_name is not None
            else {}
        ),
    )

    decoded_body = json.loads(response["Body"].read().decode("utf8"))
    logger.debug(decoded_body)

    if isinstance(decoded_body, list):
        if decoded_body:
            generated_text = decoded_body[0]["generated_text"]
        else:
            raise ValueError("model response is empty")
    else:
        generated_text = decoded_body["generated_text"]

    input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
    output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))

    return Message(
        content=[TextMessageContent(TextArtifact(generated_text))],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
    )

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@stream.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_stream(self, _: Attribute, stream: bool) -> None:  # noqa: FBT001
    if stream:
        raise ValueError("streaming is not supported")

validate_structured_output_strategy(_, value)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value != "rule":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value