Skip to content

text_loader_retrieval_rag_module

TextLoaderRetrievalRagModule

Bases: BaseRetrievalRagModule

Source code in griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py
@define(kw_only=True)
class TextLoaderRetrievalRagModule(BaseRetrievalRagModule):
    loader: BaseTextLoader = field()
    vector_store_driver: BaseVectorStoreDriver = field()
    source: Any = field()
    query_params: dict[str, Any] = field(factory=dict)
    process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
        default=Factory(lambda: lambda es: [e.to_artifact() for e in es]),
    )

    def run(self, context: RagContext) -> Sequence[TextArtifact]:
        namespace = uuid.uuid4().hex
        context_source = self.get_context_param(context, "source")
        source = self.source if context_source is None else context_source

        query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

        query_params["namespace"] = namespace

        loader_output = self.loader.load(source)

        self.vector_store_driver.upsert_text_artifacts({namespace: loader_output})

        return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))

loader: BaseTextLoader = field() class-attribute instance-attribute

process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(default=Factory(lambda: lambda es: [e.to_artifact() for e in es])) class-attribute instance-attribute

query_params: dict[str, Any] = field(factory=dict) class-attribute instance-attribute

source: Any = field() class-attribute instance-attribute

vector_store_driver: BaseVectorStoreDriver = field() class-attribute instance-attribute

run(context)

Source code in griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py
def run(self, context: RagContext) -> Sequence[TextArtifact]:
    namespace = uuid.uuid4().hex
    context_source = self.get_context_param(context, "source")
    source = self.source if context_source is None else context_source

    query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

    query_params["namespace"] = namespace

    loader_output = self.loader.load(source)

    self.vector_store_driver.upsert_text_artifacts({namespace: loader_output})

    return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))