Skip to content

astradb_vector_store_driver

AstraDbVectorStoreDriver

Bases: BaseVectorStoreDriver

A Vector Store Driver for Astra DB.

Attributes:

Name Type Description
embedding_driver

a griptape.drivers.BaseEmbeddingDriver for embedding computations within the store

api_endpoint str

the "API Endpoint" for the Astra DB instance.

token Optional[str | TokenProvider]

a Database Token ("AstraCS:...") secret to access Astra DB. An instance of astrapy.authentication.TokenProvider is also accepted.

collection_name str

the name of the collection on Astra DB. The collection must have been created beforehand, and support vectors with a vector dimension matching the embeddings being used by this driver.

environment Optional[str]

the environment ("prod", "hcd", ...) hosting the target Data API. It can be omitted for production Astra DB targets. See astrapy.constants.Environment for allowed values.

astra_db_namespace Optional[str]

optional specification of the namespace (in the Astra database) for the data. Note: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store.

caller_name str

the name of the caller for the Astra DB client. Defaults to "griptape".

client DataAPIClient

an instance of astrapy.DataAPIClient for the Astra DB.

collection Collection

an instance of astrapy.Collection for the Astra DB.

Source code in griptape/drivers/vector/astradb_vector_store_driver.py
@define
class AstraDbVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for Astra DB.

    Attributes:
        embedding_driver: a `griptape.drivers.BaseEmbeddingDriver` for embedding computations within the store
        api_endpoint: the "API Endpoint" for the Astra DB instance.
        token: a Database Token ("AstraCS:...") secret to access Astra DB. An instance of `astrapy.authentication.TokenProvider` is also accepted.
        collection_name: the name of the collection on Astra DB. The collection must have been created beforehand,
            and support vectors with a vector dimension matching the embeddings being used by this driver.
        environment: the environment ("prod", "hcd", ...) hosting the target Data API.
            It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values.
        astra_db_namespace: optional specification of the namespace (in the Astra database) for the data.
            *Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store.
        caller_name: the name of the caller for the Astra DB client. Defaults to "griptape".
        client: an instance of `astrapy.DataAPIClient` for the Astra DB.
        collection: an instance of `astrapy.Collection` for the Astra DB.
    """

    api_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    token: Optional[str | astrapy.authentication.TokenProvider] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    collection_name: str = field(kw_only=True, metadata={"serializable": True})
    environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
    astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    caller_name: str = field(default="griptape", kw_only=True, metadata={"serializable": False})
    _client: astrapy.DataAPIClient = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
    _collection: astrapy.Collection = field(
        default=None, kw_only=True, alias="collection", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> astrapy.DataAPIClient:
        return import_optional_dependency("astrapy").DataAPIClient(
            caller_name=self.caller_name,
            environment=self.environment,
        )

    @lazy_property()
    def collection(self) -> astrapy.Collection:
        return self.client.get_database(
            self.api_endpoint, token=self.token, namespace=self.astra_db_namespace
        ).get_collection(self.collection_name)

    def delete_vector(self, vector_id: str) -> None:
        """Delete a vector from Astra DB store.

        The method succeeds regardless of whether a vector with the provided ID
        was actually stored or not in the first place.

        Args:
            vector_id: ID of the vector to delete.
        """
        self.collection.delete_one({"_id": vector_id})

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs: Any,
    ) -> str:
        """Write a vector to the Astra DB store.

        In case the provided ID exists already, an overwrite will take place.

        Args:
            vector: the vector to be upserted.
            vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed.
            namespace: a namespace (a grouping within the vector store) to assign the vector to.
            meta: a metadata dictionary associated to the vector.
            kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.

        Returns:
            the ID of the written vector (str).
        """
        document = {
            k: v
            for k, v in {"$vector": vector, "_id": vector_id, "namespace": namespace, "meta": meta}.items()
            if v is not None
        }
        if vector_id is not None:
            self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True)
            return vector_id
        else:
            insert_result = self.collection.insert_one(document)
            return insert_result.inserted_id

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Load a single vector entry from the Astra DB store given its ID.

        Args:
            vector_id: the ID of the required vector.
            namespace: a namespace, within the vector store, to constrain the search.

        Returns:
            The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None.
        """
        find_filter = {k: v for k, v in {"_id": vector_id, "namespace": namespace}.items() if v is not None}
        match = self.collection.find_one(filter=find_filter, projection={"*": 1})
        if match is not None:
            return BaseVectorStoreDriver.Entry(
                id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
            )
        else:
            return None

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Load entries from the Astra DB store.

        Args:
            namespace: a namespace, within the vector store, to constrain the search.

        Returns:
            A list of vector (`BaseVectorStoreDriver.Entry`) entries.
        """
        find_filter: dict[str, str] = {} if namespace is None else {"namespace": namespace}
        return [
            BaseVectorStoreDriver.Entry(
                id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
            )
            for match in self.collection.find(filter=find_filter, projection={"*": 1})
        ]

    def query(
        self,
        query: str,
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs: Any,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Run a similarity search on the Astra DB store, based on a query string.

        Args:
            query: the query string.
            count: the maximum number of results to return. If omitted, defaults will apply.
            namespace: the namespace to filter results by.
            include_vectors: whether to include vector data in the results.
            kwargs: additional keyword arguments. Currently only the free-form dict `filter`
                is recognized (and goes straight to the Data API query);
                others will generate a warning and be ignored.

        Returns:
            A list of vector (`BaseVectorStoreDriver.Entry`) entries,
            with their `score` attribute set to the vector similarity to the query.
        """
        query_filter: Optional[dict[str, Any]] = kwargs.get("filter")
        find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace}
        find_filter = {**(query_filter or {}), **find_filter_ns}
        find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
        vector = self.embedding_driver.embed_string(query)
        ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        matches = self.collection.find(
            filter=find_filter,
            sort={"$vector": vector},
            limit=ann_limit,
            projection=find_projection,
            include_similarity=True,
        )
        return [
            BaseVectorStoreDriver.Entry(
                id=match["_id"],
                vector=match.get("$vector"),
                score=match["$similarity"],
                meta=match.get("meta"),
                namespace=match.get("namespace"),
            )
            for match in matches
        ]

api_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

caller_name: str = field(default='griptape', kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

collection_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

environment: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': True}) class-attribute instance-attribute

token: Optional[str | astrapy.authentication.TokenProvider] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

client()

Source code in griptape/drivers/vector/astradb_vector_store_driver.py
@lazy_property()
def client(self) -> astrapy.DataAPIClient:
    return import_optional_dependency("astrapy").DataAPIClient(
        caller_name=self.caller_name,
        environment=self.environment,
    )

collection()

Source code in griptape/drivers/vector/astradb_vector_store_driver.py
@lazy_property()
def collection(self) -> astrapy.Collection:
    return self.client.get_database(
        self.api_endpoint, token=self.token, namespace=self.astra_db_namespace
    ).get_collection(self.collection_name)

delete_vector(vector_id)

Delete a vector from Astra DB store.

The method succeeds regardless of whether a vector with the provided ID was actually stored or not in the first place.

Parameters:

Name Type Description Default
vector_id str

ID of the vector to delete.

required
Source code in griptape/drivers/vector/astradb_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None:
    """Delete a vector from Astra DB store.

    The method succeeds regardless of whether a vector with the provided ID
    was actually stored or not in the first place.

    Args:
        vector_id: ID of the vector to delete.
    """
    self.collection.delete_one({"_id": vector_id})

load_entries(*, namespace=None)

Load entries from the Astra DB store.

Parameters:

Name Type Description Default
namespace Optional[str]

a namespace, within the vector store, to constrain the search.

None

Returns:

Type Description
list[Entry]

A list of vector (BaseVectorStoreDriver.Entry) entries.

Source code in griptape/drivers/vector/astradb_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Load entries from the Astra DB store.

    Args:
        namespace: a namespace, within the vector store, to constrain the search.

    Returns:
        A list of vector (`BaseVectorStoreDriver.Entry`) entries.
    """
    find_filter: dict[str, str] = {} if namespace is None else {"namespace": namespace}
    return [
        BaseVectorStoreDriver.Entry(
            id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
        )
        for match in self.collection.find(filter=find_filter, projection={"*": 1})
    ]

load_entry(vector_id, *, namespace=None)

Load a single vector entry from the Astra DB store given its ID.

Parameters:

Name Type Description Default
vector_id str

the ID of the required vector.

required
namespace Optional[str]

a namespace, within the vector store, to constrain the search.

None

Returns:

Type Description
Optional[Entry]

The vector entry (a BaseVectorStoreDriver.Entry) if found, otherwise None.

Source code in griptape/drivers/vector/astradb_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Load a single vector entry from the Astra DB store given its ID.

    Args:
        vector_id: the ID of the required vector.
        namespace: a namespace, within the vector store, to constrain the search.

    Returns:
        The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None.
    """
    find_filter = {k: v for k, v in {"_id": vector_id, "namespace": namespace}.items() if v is not None}
    match = self.collection.find_one(filter=find_filter, projection={"*": 1})
    if match is not None:
        return BaseVectorStoreDriver.Entry(
            id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
        )
    else:
        return None

query(query, *, count=None, namespace=None, include_vectors=False, **kwargs)

Run a similarity search on the Astra DB store, based on a query string.

Parameters:

Name Type Description Default
query str

the query string.

required
count Optional[int]

the maximum number of results to return. If omitted, defaults will apply.

None
namespace Optional[str]

the namespace to filter results by.

None
include_vectors bool

whether to include vector data in the results.

False
kwargs Any

additional keyword arguments. Currently only the free-form dict filter is recognized (and goes straight to the Data API query); others will generate a warning and be ignored.

{}

Returns:

Type Description
list[Entry]

A list of vector (BaseVectorStoreDriver.Entry) entries,

list[Entry]

with their score attribute set to the vector similarity to the query.

Source code in griptape/drivers/vector/astradb_vector_store_driver.py
def query(
    self,
    query: str,
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
    """Run a similarity search on the Astra DB store, based on a query string.

    Args:
        query: the query string.
        count: the maximum number of results to return. If omitted, defaults will apply.
        namespace: the namespace to filter results by.
        include_vectors: whether to include vector data in the results.
        kwargs: additional keyword arguments. Currently only the free-form dict `filter`
            is recognized (and goes straight to the Data API query);
            others will generate a warning and be ignored.

    Returns:
        A list of vector (`BaseVectorStoreDriver.Entry`) entries,
        with their `score` attribute set to the vector similarity to the query.
    """
    query_filter: Optional[dict[str, Any]] = kwargs.get("filter")
    find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace}
    find_filter = {**(query_filter or {}), **find_filter_ns}
    find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
    vector = self.embedding_driver.embed_string(query)
    ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    matches = self.collection.find(
        filter=find_filter,
        sort={"$vector": vector},
        limit=ann_limit,
        projection=find_projection,
        include_similarity=True,
    )
    return [
        BaseVectorStoreDriver.Entry(
            id=match["_id"],
            vector=match.get("$vector"),
            score=match["$similarity"],
            meta=match.get("meta"),
            namespace=match.get("namespace"),
        )
        for match in matches
    ]

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Write a vector to the Astra DB store.

In case the provided ID exists already, an overwrite will take place.

Parameters:

Name Type Description Default
vector list[float]

the vector to be upserted.

required
vector_id Optional[str]

the ID for the vector to store. If omitted, a server-provided new ID will be employed.

None
namespace Optional[str]

a namespace (a grouping within the vector store) to assign the vector to.

None
meta Optional[dict]

a metadata dictionary associated to the vector.

None
kwargs Any

additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.

{}

Returns:

Type Description
str

the ID of the written vector (str).

Source code in griptape/drivers/vector/astradb_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs: Any,
) -> str:
    """Write a vector to the Astra DB store.

    In case the provided ID exists already, an overwrite will take place.

    Args:
        vector: the vector to be upserted.
        vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed.
        namespace: a namespace (a grouping within the vector store) to assign the vector to.
        meta: a metadata dictionary associated to the vector.
        kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.

    Returns:
        the ID of the written vector (str).
    """
    document = {
        k: v
        for k, v in {"$vector": vector, "_id": vector_id, "namespace": namespace, "meta": meta}.items()
        if v is not None
    }
    if vector_id is not None:
        self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True)
        return vector_id
    else:
        insert_result = self.collection.insert_one(document)
        return insert_result.inserted_id