Bases: BaseEmbeddingDriver
Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
| @define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
endpoint: str = field(kw_only=True, metadata={"serializable": True})
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: SageMakerClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
@lazy_property()
def client(self) -> SageMakerClient:
return self.session.client("sagemaker-runtime")
def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"text_inputs": chunk, "mode": "embedding"}
endpoint_response = self.client.invoke_endpoint(
EndpointName=self.endpoint,
ContentType="application/json",
Body=json.dumps(payload).encode("utf-8"),
CustomAttributes=self.custom_attributes,
**(
{"InferenceComponentName": self.inference_component_name}
if self.inference_component_name is not None
else {}
),
)
response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))
if "embedding" in response:
embedding = response["embedding"]
if embedding:
if isinstance(embedding[0], list):
return embedding[0]
else:
return embedding
else:
raise ValueError("model response is empty")
else:
raise ValueError("invalid response from model")
|
custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
endpoint: str = field(kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute
instance-attribute
client()
Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
| @lazy_property()
def client(self) -> SageMakerClient:
return self.session.client("sagemaker-runtime")
|
try_embed_chunk(chunk)
Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
| def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"text_inputs": chunk, "mode": "embedding"}
endpoint_response = self.client.invoke_endpoint(
EndpointName=self.endpoint,
ContentType="application/json",
Body=json.dumps(payload).encode("utf-8"),
CustomAttributes=self.custom_attributes,
**(
{"InferenceComponentName": self.inference_component_name}
if self.inference_component_name is not None
else {}
),
)
response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))
if "embedding" in response:
embedding = response["embedding"]
if embedding:
if isinstance(embedding[0], list):
return embedding[0]
else:
return embedding
else:
raise ValueError("model response is empty")
else:
raise ValueError("invalid response from model")
|