Skip to content

vector_store_retrieval_rag_module

VectorStoreRetrievalRagModule

Bases: BaseRetrievalRagModule

Source code in griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py
@define(kw_only=True)
class VectorStoreRetrievalRagModule(BaseRetrievalRagModule):
    vector_store_driver: BaseVectorStoreDriver = field(
        default=Factory(lambda: Defaults.drivers_config.vector_store_driver)
    )
    query_params: dict[str, Any] = field(factory=dict)
    process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
        default=Factory(lambda: lambda es: [e.to_artifact() for e in es]),
    )

    def run(self, context: RagContext) -> Sequence[TextArtifact]:
        query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

        return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))

process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(default=Factory(lambda: lambda es: [e.to_artifact() for e in es])) class-attribute instance-attribute

query_params: dict[str, Any] = field(factory=dict) class-attribute instance-attribute

vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Defaults.drivers_config.vector_store_driver)) class-attribute instance-attribute

run(context)

Source code in griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py
def run(self, context: RagContext) -> Sequence[TextArtifact]:
    query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

    return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))