Embedding Drivers
Overview
Embeddings in Griptape are multidimensional representations of text data. Embeddings carry semantic information, which makes them useful for extracting relevant chunks from large bodies of text for search and querying.
Griptape provides a way to build Embedding Drivers that are reused in downstream framework components. Every Embedding Driver has two basic methods that can be used to generate embeddings:
- embed_text_artifact() for TextArtifacts.
- embed_string() for any string.
You can optionally provide a Tokenizer via the tokenizer field to have the Driver automatically chunk the input text to fit into the token limit.
Embedding Drivers
OpenAI Embeddings
The OpenAiEmbeddingDriver uses the OpenAI Embeddings API.
from griptape.drivers import OpenAiEmbeddingDriver
embeddings = OpenAiEmbeddingDriver().embed_string("Hello Griptape!")
# display the first 3 embeddings
print(embeddings[:3])
Azure OpenAI Embeddings
The AzureOpenAiEmbeddingDriver uses the same parameters as OpenAiEmbeddingDriver with updated defaults.
Bedrock Titan Embeddings
Info
This driver requires the drivers-embedding-amazon-bedrock
extra.
The AmazonBedrockTitanEmbeddingDriver uses the Amazon Bedrock Embeddings API.
from griptape.drivers import AmazonBedrockTitanEmbeddingDriver
embeddings = AmazonBedrockTitanEmbeddingDriver().embed_string("Hello world!")
# display the first 3 embeddings
print(embeddings[:3])
Google Embeddings
Info
This driver requires the drivers-embedding-google
extra.
The GoogleEmbeddingDriver uses the Google Embeddings API.
from griptape.drivers import GoogleEmbeddingDriver
embeddings = GoogleEmbeddingDriver().embed_string("Hello world!")
# display the first 3 embeddings
print(embeddings[:3])
Hugging Face Hub Embeddings
Info
This driver requires the drivers-embedding-huggingface
extra.
The HuggingFaceHubEmbeddingDriver connects to the Hugging Face Hub API. It supports models with the following tasks:
- feature-extraction
import os
from griptape.drivers import HuggingFaceHubEmbeddingDriver
from griptape.tokenizers import HuggingFaceTokenizer
from transformers import AutoTokenizer
driver = HuggingFaceHubEmbeddingDriver(
api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"],
model="sentence-transformers/all-MiniLM-L6-v2",
tokenizer=HuggingFaceTokenizer(
max_output_tokens=512,
tokenizer=AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2"
)
),
)
embeddings = driver.embed_string("Hello world!")
# display the first 3 embeddings
print(embeddings[:3])
Multi Model Embedding Drivers
Certain embeddings providers such as Amazon SageMaker support many types of models, each with their own slight differences in parameters and response formats. To support this variation across models, these Embedding Drivers takes a Embedding Model Driver through the embedding_model_driver parameter. Embedding Model Drivers allows for model-specific customization for Embedding Drivers.
SageMaker Embeddings
The AmazonSageMakerEmbeddingDriver uses the Amazon SageMaker Endpoints to generate embeddings on AWS.
Info
This driver requires the drivers-embedding-amazon-sagemaker
extra.
TensorFlow Hub Models
import os
from griptape.drivers import AmazonSageMakerEmbeddingDriver, SageMakerTensorFlowHubEmbeddingModelDriver
driver = AmazonSageMakerEmbeddingDriver(
model=os.environ["SAGEMAKER_TENSORFLOW_HUB_MODEL"],
embedding_model_driver=SageMakerTensorFlowHubEmbeddingModelDriver(),
)
embeddings = driver.embed_string("Hello world!")
# display the first 3 embeddings
print(embeddings[:3])
HuggingFace Models
import os
from griptape.drivers import AmazonSageMakerEmbeddingDriver, SageMakerHuggingFaceEmbeddingModelDriver
driver = AmazonSageMakerEmbeddingDriver(
model=os.environ["SAGEMAKER_HUGGINGFACE_MODEL"],
embedding_model_driver=SageMakerHuggingFaceEmbeddingModelDriver(),
)
embeddings = driver.embed_string("Hello world!")
# display the first 3 embeddings
print(embeddings[:3])
VoyageAI Embeddings
The VoyageAiEmbeddingDriver uses the VoyageAI Embeddings API.
Info
This driver requires the drivers-embedding-voyageai
extra.
import os
from griptape.drivers import VoyageAiEmbeddingDriver
driver = VoyageAiEmbeddingDriver(
api_key=os.environ["VOYAGE_API_KEY"]
)
embeddings = driver.embed_string("Hello world!")
# display the first 3 embeddings
print(embeddings[:3])
Override Default Structure Embedding Driver
Here is how you can override the Embedding Driver that is used by default in Structures.
from griptape.structures import Agent
from griptape.tools import WebScraper, TaskMemoryClient
from griptape.drivers import (
OpenAiChatPromptDriver,
VoyageAiEmbeddingDriver,
)
from griptape.config import (
StructureGlobalDriversConfig,
StructureConfig,
)
agent = Agent(
tools=[WebScraper(), TaskMemoryClient(off_prompt=False)],
config=StructureConfig(
global_drivers=StructureGlobalDriversConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
embedding_driver=VoyageAiEmbeddingDriver(),
)
),
)
agent.run("based on https://www.griptape.ai/, tell me what Griptape is")