Skip to content

Pgvector vector store driver

PgVectorVectorStoreDriver

Bases: BaseVectorStoreDriver

A vector store driver to Postgres using the PGVector extension.

Attributes:

Name Type Description
connection_string Optional[str]

An optional string describing the target Postgres database instance.

create_engine_params dict

Additional configuration params passed when creating the database connection.

engine Optional[Engine]

An optional sqlalchemy Postgres engine to use.

table_name str

Optionally specify the name of the table to used to store vectors.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
@define
class PgVectorVectorStoreDriver(BaseVectorStoreDriver):
    """A vector store driver to Postgres using the PGVector extension.

    Attributes:
        connection_string: An optional string describing the target Postgres database instance.
        create_engine_params: Additional configuration params passed when creating the database connection.
        engine: An optional sqlalchemy Postgres engine to use.
        table_name: Optionally specify the name of the table to used to store vectors.
    """

    connection_string: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
    engine: Optional[Engine] = field(default=None, kw_only=True)
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    _model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True))

    @connection_string.validator  # pyright: ignore
    def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
        # If an engine is provided, the connection string is not used.
        if self.engine is not None:
            return

        # If an engine is not provided, a connection string is required.
        if connection_string is None:
            raise ValueError("An engine or connection string is required")

        if not connection_string.startswith("postgresql://"):
            raise ValueError("The connection string must describe a Postgres database connection")

    @engine.validator  # pyright: ignore
    def validate_engine(self, _, engine: Optional[Engine]) -> None:
        # If a connection string is provided, an engine does not need to be provided.
        if self.connection_string is not None:
            return

        # If a connection string is not provided, an engine is required.
        if engine is None:
            raise ValueError("An engine or connection string is required")

    def __attrs_post_init__(self) -> None:
        """If an engine is provided, it will be used to connect to the database.
        If not, a connection string is used to create a new database connection here.
        """
        if self.engine is None:
            self.engine = cast(Engine, create_engine(self.connection_string, **self.create_engine_params))

    def setup(
        self, create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True
    ) -> None:
        """Provides a mechanism to initialize the database schema and extensions."""
        if install_uuid_extension:
            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')

        if install_vector_extension:
            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";')

        if create_schema:
            self._model.metadata.create_all(self.engine)

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in the collection."""
        with Session(self.engine) as session:
            obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

            obj = session.merge(obj)
            session.commit()

            return str(getattr(obj, "id"))

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
        """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
        with Session(self.engine) as session:
            result = session.get(self._model, vector_id)

            return BaseVectorStoreDriver.Entry(
                id=getattr(result, "id"),
                vector=getattr(result, "vector"),
                namespace=getattr(result, "namespace"),
                meta=getattr(result, "meta"),
            )

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Retrieves all vector entries from the collection, optionally filtering to only
        those that match the provided namespace.
        """
        with Session(self.engine) as session:
            query = session.query(self._model)
            if namespace:
                query = query.filter_by(namespace=namespace)

            results = query.all()

            return [
                BaseVectorStoreDriver.Entry(
                    id=str(result.id), vector=result.vector, namespace=result.namespace, meta=result.meta
                )
                for result in results
            ]

    def query(
        self,
        query: str,
        count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        distance_metric: str = "cosine_distance",
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Performs a search on the collection to find vectors similar to the provided input vector,
        optionally filtering to only those that match the provided namespace.
        """
        distance_metrics = {
            "cosine_distance": self._model.vector.cosine_distance,
            "l2_distance": self._model.vector.l2_distance,
            "inner_product": self._model.vector.max_inner_product,
        }

        if distance_metric not in distance_metrics:
            raise ValueError("Invalid distance metric provided")

        op = distance_metrics[distance_metric]

        with Session(self.engine) as session:
            vector = self.embedding_driver.embed_string(query)

            # The query should return both the vector and the distance metric score.
            query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector))  # pyright: ignore

            filter_kwargs: Optional[OrderedDict] = None

            if namespace is not None:
                filter_kwargs = OrderedDict(namespace=namespace)

            if "filter" in kwargs and isinstance(kwargs["filter"], dict):
                filter_kwargs = filter_kwargs or OrderedDict()
                filter_kwargs.update(kwargs["filter"])

            if filter_kwargs is not None:
                query_result = query_result.filter_by(**filter_kwargs)

            results = query_result.limit(count).all()

            return [
                BaseVectorStoreDriver.QueryResult(
                    id=str(result[0].id),
                    vector=result[0].vector if include_vectors else None,
                    score=result[1],
                    meta=result[0].meta,
                    namespace=result[0].namespace,
                )
                for result in results
            ]

    def default_vector_model(self) -> Any:
        Vector = import_optional_dependency("pgvector.sqlalchemy").Vector
        Base = declarative_base()

        @dataclass
        class VectorModel(Base):
            __tablename__ = self.table_name

            id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
            vector = Column(Vector())
            namespace = Column(String)
            meta = Column(JSON)

        return VectorModel

    def delete_vector(self, vector_id: str):
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

connection_string: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

create_engine_params: dict = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

engine: Optional[Engine] = field(default=None, kw_only=True) class-attribute instance-attribute

table_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__attrs_post_init__()

If an engine is provided, it will be used to connect to the database. If not, a connection string is used to create a new database connection here.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def __attrs_post_init__(self) -> None:
    """If an engine is provided, it will be used to connect to the database.
    If not, a connection string is used to create a new database connection here.
    """
    if self.engine is None:
        self.engine = cast(Engine, create_engine(self.connection_string, **self.create_engine_params))

default_vector_model()

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def default_vector_model(self) -> Any:
    Vector = import_optional_dependency("pgvector.sqlalchemy").Vector
    Base = declarative_base()

    @dataclass
    class VectorModel(Base):
        __tablename__ = self.table_name

        id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
        vector = Column(Vector())
        namespace = Column(String)
        meta = Column(JSON)

    return VectorModel

delete_vector(vector_id)

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def delete_vector(self, vector_id: str):
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(namespace=None)

Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Retrieves all vector entries from the collection, optionally filtering to only
    those that match the provided namespace.
    """
    with Session(self.engine) as session:
        query = session.query(self._model)
        if namespace:
            query = query.filter_by(namespace=namespace)

        results = query.all()

        return [
            BaseVectorStoreDriver.Entry(
                id=str(result.id), vector=result.vector, namespace=result.namespace, meta=result.meta
            )
            for result in results
        ]

load_entry(vector_id, namespace=None)

Retrieves a specific vector entry from the collection based on its identifier and optional namespace.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
    """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
    with Session(self.engine) as session:
        result = session.get(self._model, vector_id)

        return BaseVectorStoreDriver.Entry(
            id=getattr(result, "id"),
            vector=getattr(result, "vector"),
            namespace=getattr(result, "namespace"),
            meta=getattr(result, "meta"),
        )

query(query, count=BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, namespace=None, include_vectors=False, distance_metric='cosine_distance', **kwargs)

Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    distance_metric: str = "cosine_distance",
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Performs a search on the collection to find vectors similar to the provided input vector,
    optionally filtering to only those that match the provided namespace.
    """
    distance_metrics = {
        "cosine_distance": self._model.vector.cosine_distance,
        "l2_distance": self._model.vector.l2_distance,
        "inner_product": self._model.vector.max_inner_product,
    }

    if distance_metric not in distance_metrics:
        raise ValueError("Invalid distance metric provided")

    op = distance_metrics[distance_metric]

    with Session(self.engine) as session:
        vector = self.embedding_driver.embed_string(query)

        # The query should return both the vector and the distance metric score.
        query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector))  # pyright: ignore

        filter_kwargs: Optional[OrderedDict] = None

        if namespace is not None:
            filter_kwargs = OrderedDict(namespace=namespace)

        if "filter" in kwargs and isinstance(kwargs["filter"], dict):
            filter_kwargs = filter_kwargs or OrderedDict()
            filter_kwargs.update(kwargs["filter"])

        if filter_kwargs is not None:
            query_result = query_result.filter_by(**filter_kwargs)

        results = query_result.limit(count).all()

        return [
            BaseVectorStoreDriver.QueryResult(
                id=str(result[0].id),
                vector=result[0].vector if include_vectors else None,
                score=result[1],
                meta=result[0].meta,
                namespace=result[0].namespace,
            )
            for result in results
        ]

setup(create_schema=True, install_uuid_extension=True, install_vector_extension=True)

Provides a mechanism to initialize the database schema and extensions.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def setup(
    self, create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True
) -> None:
    """Provides a mechanism to initialize the database schema and extensions."""
    if install_uuid_extension:
        self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')

    if install_vector_extension:
        self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";')

    if create_schema:
        self._model.metadata.create_all(self.engine)

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in the collection.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in the collection."""
    with Session(self.engine) as session:
        obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

        obj = session.merge(obj)
        session.commit()

        return str(getattr(obj, "id"))

validate_connection_string(_, connection_string)

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
@connection_string.validator  # pyright: ignore
def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
    # If an engine is provided, the connection string is not used.
    if self.engine is not None:
        return

    # If an engine is not provided, a connection string is required.
    if connection_string is None:
        raise ValueError("An engine or connection string is required")

    if not connection_string.startswith("postgresql://"):
        raise ValueError("The connection string must describe a Postgres database connection")

validate_engine(_, engine)

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
@engine.validator  # pyright: ignore
def validate_engine(self, _, engine: Optional[Engine]) -> None:
    # If a connection string is provided, an engine does not need to be provided.
    if self.connection_string is not None:
        return

    # If a connection string is not provided, an engine is required.
    if engine is None:
        raise ValueError("An engine or connection string is required")