Skip to content

Amazon sagemaker prompt driver

AmazonSageMakerPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@define
class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

    @stream.validator  # pyright: ignore
    def validate_stream(self, _, stream):
        if stream:
            raise ValueError("streaming is not supported")

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        payload = {
            "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
            "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
        }
        response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.model,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))

        if decoded_body:
            return self.prompt_model_driver.process_output(decoded_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda self: self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    payload = {
        "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
        "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
    }
    response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.model,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes=self.custom_attributes,
    )

    decoded_body = json.loads(response["Body"].read().decode("utf8"))

    if decoded_body:
        return self.prompt_model_driver.process_output(decoded_body)
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@stream.validator  # pyright: ignore
def validate_stream(self, _, stream):
    if stream:
        raise ValueError("streaming is not supported")