Skip to content

Pgvector vector store driver


Bases: BaseVectorStoreDriver

A vector store driver to Postgres using the PGVector extension.


Name Type Description
connection_string Optional[str]

An optional string describing the target Postgres database instance.

create_engine_params dict

Additional configuration params passed when creating the database connection.

engine Optional[Engine]

An optional sqlalchemy Postgres engine to use.

table_name str

Optionally specify the name of the table to used to store vectors.

Source code in griptape/drivers/vector/
class PgVectorVectorStoreDriver(BaseVectorStoreDriver):
    """A vector store driver to Postgres using the PGVector extension.

        connection_string: An optional string describing the target Postgres database instance.
        create_engine_params: Additional configuration params passed when creating the database connection.
        engine: An optional sqlalchemy Postgres engine to use.
        table_name: Optionally specify the name of the table to used to store vectors.

    connection_string: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
    engine: Optional[Engine] = field(default=None, kw_only=True)
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    _model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True))

    @connection_string.validator  # pyright: ignore
    def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
        # If an engine is provided, the connection string is not used.
        if self.engine is not None:

        # If an engine is not provided, a connection string is required.
        if connection_string is None:
            raise ValueError("An engine or connection string is required")

        if not connection_string.startswith("postgresql://"):
            raise ValueError("The connection string must describe a Postgres database connection")

    @engine.validator  # pyright: ignore
    def validate_engine(self, _, engine: Optional[Engine]) -> None:
        # If a connection string is provided, an engine does not need to be provided.
        if self.connection_string is not None:

        # If a connection string is not provided, an engine is required.
        if engine is None:
            raise ValueError("An engine or connection string is required")

    def __attrs_post_init__(self) -> None:
        """If an engine is provided, it will be used to connect to the database.
        If not, a connection string is used to create a new database connection here.
        if self.engine is None:
            self.engine = cast(Engine, create_engine(self.connection_string, **self.create_engine_params))

    def setup(
        self, create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True
    ) -> None:
        """Provides a mechanism to initialize the database schema and extensions."""
        if install_uuid_extension:
            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')

        if install_vector_extension:
            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";')

        if create_schema:

    def upsert_vector(
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
    ) -> str:
        """Inserts or updates a vector in the collection."""
        with Session(self.engine) as session:
            obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

            obj = session.merge(obj)

            return str(getattr(obj, "id"))

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
        """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
        with Session(self.engine) as session:
            result = session.get(self._model, vector_id)

            return BaseVectorStoreDriver.Entry(
                id=getattr(result, "id"),
                vector=getattr(result, "vector"),
                namespace=getattr(result, "namespace"),
                meta=getattr(result, "meta"),

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Retrieves all vector entries from the collection, optionally filtering to only
        those that match the provided namespace.
        with Session(self.engine) as session:
            query = session.query(self._model)
            if namespace:
                query = query.filter_by(namespace=namespace)

            results = query.all()

            return [
                    id=str(, vector=result.vector, namespace=result.namespace, meta=result.meta
                for result in results

    def query(
        query: str,
        count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        distance_metric: str = "cosine_distance",
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Performs a search on the collection to find vectors similar to the provided input vector,
        optionally filtering to only those that match the provided namespace.
        distance_metrics = {
            "cosine_distance": self._model.vector.cosine_distance,
            "l2_distance": self._model.vector.l2_distance,
            "inner_product": self._model.vector.max_inner_product,

        if distance_metric not in distance_metrics:
            raise ValueError("Invalid distance metric provided")

        op = distance_metrics[distance_metric]

        with Session(self.engine) as session:
            vector = self.embedding_driver.embed_string(query)

            # The query should return both the vector and the distance metric score.
            query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector))  # pyright: ignore

            filter_kwargs: Optional[OrderedDict] = None

            if namespace is not None:
                filter_kwargs = OrderedDict(namespace=namespace)

            if "filter" in kwargs and isinstance(kwargs["filter"], dict):
                filter_kwargs = filter_kwargs or OrderedDict()

            if filter_kwargs is not None:
                query_result = query_result.filter_by(**filter_kwargs)

            results = query_result.limit(count).all()

            return [
                    vector=result[0].vector if include_vectors else None,
                for result in results

    def default_vector_model(self) -> Any:
        Vector = import_optional_dependency("pgvector.sqlalchemy").Vector
        Base = declarative_base()

        class VectorModel(Base):
            __tablename__ = self.table_name

            id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
            vector = Column(Vector())
            namespace = Column(String)
            meta = Column(JSON)

        return VectorModel

    def delete_vector(self, vector_id: str):
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

connection_string: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

create_engine_params: dict = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

engine: Optional[Engine] = field(default=None, kw_only=True) class-attribute instance-attribute

table_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute


If an engine is provided, it will be used to connect to the database. If not, a connection string is used to create a new database connection here.

Source code in griptape/drivers/vector/
def __attrs_post_init__(self) -> None:
    """If an engine is provided, it will be used to connect to the database.
    If not, a connection string is used to create a new database connection here.
    if self.engine is None:
        self.engine = cast(Engine, create_engine(self.connection_string, **self.create_engine_params))


Source code in griptape/drivers/vector/
def default_vector_model(self) -> Any:
    Vector = import_optional_dependency("pgvector.sqlalchemy").Vector
    Base = declarative_base()

    class VectorModel(Base):
        __tablename__ = self.table_name

        id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
        vector = Column(Vector())
        namespace = Column(String)
        meta = Column(JSON)

    return VectorModel


Source code in griptape/drivers/vector/
def delete_vector(self, vector_id: str):
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")


Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace.

Source code in griptape/drivers/vector/
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Retrieves all vector entries from the collection, optionally filtering to only
    those that match the provided namespace.
    with Session(self.engine) as session:
        query = session.query(self._model)
        if namespace:
            query = query.filter_by(namespace=namespace)

        results = query.all()

        return [
                id=str(, vector=result.vector, namespace=result.namespace, meta=result.meta
            for result in results

load_entry(vector_id, namespace=None)

Retrieves a specific vector entry from the collection based on its identifier and optional namespace.

Source code in griptape/drivers/vector/
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
    """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
    with Session(self.engine) as session:
        result = session.get(self._model, vector_id)

        return BaseVectorStoreDriver.Entry(
            id=getattr(result, "id"),
            vector=getattr(result, "vector"),
            namespace=getattr(result, "namespace"),
            meta=getattr(result, "meta"),

query(query, count=BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, namespace=None, include_vectors=False, distance_metric='cosine_distance', **kwargs)

Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace.

Source code in griptape/drivers/vector/
def query(
    query: str,
    count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    distance_metric: str = "cosine_distance",
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Performs a search on the collection to find vectors similar to the provided input vector,
    optionally filtering to only those that match the provided namespace.
    distance_metrics = {
        "cosine_distance": self._model.vector.cosine_distance,
        "l2_distance": self._model.vector.l2_distance,
        "inner_product": self._model.vector.max_inner_product,

    if distance_metric not in distance_metrics:
        raise ValueError("Invalid distance metric provided")

    op = distance_metrics[distance_metric]

    with Session(self.engine) as session:
        vector = self.embedding_driver.embed_string(query)

        # The query should return both the vector and the distance metric score.
        query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector))  # pyright: ignore

        filter_kwargs: Optional[OrderedDict] = None

        if namespace is not None:
            filter_kwargs = OrderedDict(namespace=namespace)

        if "filter" in kwargs and isinstance(kwargs["filter"], dict):
            filter_kwargs = filter_kwargs or OrderedDict()

        if filter_kwargs is not None:
            query_result = query_result.filter_by(**filter_kwargs)

        results = query_result.limit(count).all()

        return [
                vector=result[0].vector if include_vectors else None,
            for result in results

setup(create_schema=True, install_uuid_extension=True, install_vector_extension=True)

Provides a mechanism to initialize the database schema and extensions.

Source code in griptape/drivers/vector/
def setup(
    self, create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True
) -> None:
    """Provides a mechanism to initialize the database schema and extensions."""
    if install_uuid_extension:
        self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')

    if install_vector_extension:
        self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";')

    if create_schema:

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in the collection.

Source code in griptape/drivers/vector/
def upsert_vector(
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
) -> str:
    """Inserts or updates a vector in the collection."""
    with Session(self.engine) as session:
        obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

        obj = session.merge(obj)

        return str(getattr(obj, "id"))

validate_connection_string(_, connection_string)

Source code in griptape/drivers/vector/
@connection_string.validator  # pyright: ignore
def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
    # If an engine is provided, the connection string is not used.
    if self.engine is not None:

    # If an engine is not provided, a connection string is required.
    if connection_string is None:
        raise ValueError("An engine or connection string is required")

    if not connection_string.startswith("postgresql://"):
        raise ValueError("The connection string must describe a Postgres database connection")

validate_engine(_, engine)

Source code in griptape/drivers/vector/
@engine.validator  # pyright: ignore
def validate_engine(self, _, engine: Optional[Engine]) -> None:
    # If a connection string is provided, an engine does not need to be provided.
    if self.connection_string is not None:

    # If a connection string is not provided, an engine is required.
    if engine is None:
        raise ValueError("An engine or connection string is required")