Skip to content

amazon_sagemaker_jumpstart

__all__ = ['AmazonSageMakerJumpstartEmbeddingDriver'] module-attribute

AmazonSageMakerJumpstartEmbeddingDriver

Bases: BaseEmbeddingDriver

Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    endpoint: str = field(kw_only=True, metadata={"serializable": True})
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    _client: SageMakerClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> SageMakerClient:
        return self.session.client("sagemaker-runtime")

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"text_inputs": chunk, "mode": "embedding"}

        endpoint_response = self.client.invoke_endpoint(
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=json.dumps(payload).encode("utf-8"),
            CustomAttributes=self.custom_attributes,
            **(
                {"InferenceComponentName": self.inference_component_name}
                if self.inference_component_name is not None
                else {}
            ),
        )

        response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))

        if "embedding" in response:
            embedding = response["embedding"]

            if embedding:
                if isinstance(embedding[0], list):
                    return embedding[0]
                else:
                    return embedding
            else:
                raise ValueError("model response is empty")
        else:
            raise ValueError("invalid response from model")

custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

client()

Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@lazy_property()
def client(self) -> SageMakerClient:
    return self.session.client("sagemaker-runtime")

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"text_inputs": chunk, "mode": "embedding"}

    endpoint_response = self.client.invoke_endpoint(
        EndpointName=self.endpoint,
        ContentType="application/json",
        Body=json.dumps(payload).encode("utf-8"),
        CustomAttributes=self.custom_attributes,
        **(
            {"InferenceComponentName": self.inference_component_name}
            if self.inference_component_name is not None
            else {}
        ),
    )

    response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))

    if "embedding" in response:
        embedding = response["embedding"]

        if embedding:
            if isinstance(embedding[0], list):
                return embedding[0]
            else:
                return embedding
        else:
            raise ValueError("model response is empty")
    else:
        raise ValueError("invalid response from model")