Bases: BaseEmbeddingDriver
Attributes:
Name |
Type |
Description |
model |
str
|
Embedding model name. Defaults to DEFAULT_MODEL.
|
input_type |
str
|
Defaults to search_query . Prepends special tokens to differentiate each type from one another:
search_document when you encode documents for embeddings that you store in a vector database.
search_query when querying your vector DB to find relevant documents.
|
session |
Session
|
Optionally provide custom boto3.Session .
|
tokenizer |
BedrockCohereTokenizer
|
Optionally provide custom BedrockCohereTokenizer .
|
bedrock_client |
Any
|
Optionally provide custom bedrock-runtime client.
|
Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
| @define
class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
"""
Attributes:
model: Embedding model name. Defaults to DEFAULT_MODEL.
input_type: Defaults to `search_query`. Prepends special tokens to differentiate each type from one another:
`search_document` when you encode documents for embeddings that you store in a vector database.
`search_query` when querying your vector DB to find relevant documents.
session: Optionally provide custom `boto3.Session`.
tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
bedrock_client: Optionally provide custom `bedrock-runtime` client.
"""
DEFAULT_MODEL = "cohere.embed-english-v3"
model: str = field(default=DEFAULT_MODEL, kw_only=True)
input_type: str = field(default="search_query", kw_only=True)
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
tokenizer: BedrockCohereTokenizer = field(
default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True
)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
)
def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"input_type": self.input_type, "texts": [chunk]}
response = self.bedrock_client.invoke_model(
body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json"
)
response_body = json.loads(response.get("body").read())
return response_body.get("embeddings")[0]
|
DEFAULT_MODEL = 'cohere.embed-english-v3'
class-attribute
instance-attribute
bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True)
class-attribute
instance-attribute
model: str = field(default=DEFAULT_MODEL, kw_only=True)
class-attribute
instance-attribute
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute
instance-attribute
tokenizer: BedrockCohereTokenizer = field(default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute
instance-attribute
try_embed_chunk(chunk)
Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
| def try_embed_chunk(self, chunk: str) -> list[float]:
payload = {"input_type": self.input_type, "texts": [chunk]}
response = self.bedrock_client.invoke_model(
body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json"
)
response_body = json.loads(response.get("body").read())
return response_body.get("embeddings")[0]
|