Bases: BaseRetrievalRagModule
Source code in griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py
| @define(kw_only=True)
class TextLoaderRetrievalRagModule(BaseRetrievalRagModule):
loader: BaseTextLoader = field()
vector_store_driver: BaseVectorStoreDriver = field()
source: Any = field()
query_params: dict[str, Any] = field(factory=dict)
process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
default=Factory(lambda: lambda es: [e.to_artifact() for e in es]),
)
def run(self, context: RagContext) -> Sequence[TextArtifact]:
namespace = uuid.uuid4().hex
context_source = self.get_context_param(context, "source")
source = self.source if context_source is None else context_source
query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))
query_params["namespace"] = namespace
loader_output = self.loader.load(source)
self.vector_store_driver.upsert_text_artifacts({namespace: loader_output})
return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))
|
loader: BaseTextLoader = field()
class-attribute
instance-attribute
process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(default=Factory(lambda: lambda es: [e.to_artifact() for e in es]))
class-attribute
instance-attribute
query_params: dict[str, Any] = field(factory=dict)
class-attribute
instance-attribute
source: Any = field()
class-attribute
instance-attribute
vector_store_driver: BaseVectorStoreDriver = field()
class-attribute
instance-attribute
run(context)
Source code in griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py
| def run(self, context: RagContext) -> Sequence[TextArtifact]:
namespace = uuid.uuid4().hex
context_source = self.get_context_param(context, "source")
source = self.source if context_source is None else context_source
query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))
query_params["namespace"] = namespace
loader_output = self.loader.load(source)
self.vector_store_driver.upsert_text_artifacts({namespace: loader_output})
return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))
|