Skip to content

amazon_bedrock_titan_embedding_driver

AmazonBedrockTitanEmbeddingDriver

Bases: BaseEmbeddingDriver

Amazon Bedrock Titan Embedding Driver.

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

tokenizer BaseTokenizer

Optionally provide custom BedrockTitanTokenizer.

session Session

Optionally provide custom boto3.Session.

client BedrockRuntimeClient

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@define
class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
    """Amazon Bedrock Titan Embedding Driver.

    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
        session: Optionally provide custom `boto3.Session`.
        client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "amazon.titan-embed-text-v1"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    _client: Optional[BedrockRuntimeClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> BedrockRuntimeClient:
        return self.session.client("bedrock-runtime")

    def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]:
        if isinstance(artifact, TextArtifact):
            return self.try_embed_chunk(artifact.value)
        return self._invoke_model({"inputImage": base64.b64encode(artifact.value).decode()})["embedding"]

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        return self._invoke_model(
            {
                "inputText": chunk,
            }
        )["embedding"]

    def _invoke_model(self, payload: dict) -> dict[str, Any]:
        response = self.client.invoke_model(
            body=json.dumps(payload),
            modelId=self.model,
            accept="application/json",
            contentType="application/json",
        )
        return json.loads(response.get("body").read())

DEFAULT_MODEL = 'amazon.titan-embed-text-v1' class-attribute instance-attribute

_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

model = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

_invoke_model(payload)

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def _invoke_model(self, payload: dict) -> dict[str, Any]:
    response = self.client.invoke_model(
        body=json.dumps(payload),
        modelId=self.model,
        accept="application/json",
        contentType="application/json",
    )
    return json.loads(response.get("body").read())

client()

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@lazy_property()
def client(self) -> BedrockRuntimeClient:
    return self.session.client("bedrock-runtime")

try_embed_artifact(artifact, **kwargs)

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]:
    if isinstance(artifact, TextArtifact):
        return self.try_embed_chunk(artifact.value)
    return self._invoke_model({"inputImage": base64.b64encode(artifact.value).decode()})["embedding"]

try_embed_chunk(chunk, **kwargs)

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    return self._invoke_model(
        {
            "inputText": chunk,
        }
    )["embedding"]