RAG Engine
RAG Engines
Note
This section is a work in progress.
Rag Engine is an abstraction for implementing modular retrieval augmented generation (RAG) pipelines.
RAG Stages
RagEngine
s consist of three stages: QueryRagStage
, RetrievalRagStage
, and ResponseRagStage
. These stages are always executed sequentially. Each stage comprises multiple modules, which are executed in a customized manner. Due to this unique structure, RagEngines
are not intended to replace Workflows or Pipelines.
RAG Modules
RAG modules are used to implement actions in the different stages of the RAG pipeline. RagEngine
enables developers to easily add new modules to experiment with novel RAG strategies.
The three stages of the pipeline implemented in RAG Engines, together with their purposes and associated modules, are as follows:
Query Stage
This stage is used for modifying input queries before they are submitted.
Query Stage Modules
TranslateQueryRagModule
is for translating the query into another language.
Retrieval Stage
Results are retrieved in this stage, either from a vector store in the form of chunks, or with a text loader. You may optionally use a rerank module in this stage to rerank results in order of their relevance to the original query.
Retrieval Stage Modules
TextChunksRerankRagModule
is for re-ranking retrieved results.TextLoaderRetrievalRagModule
is for retrieving data with text loaders in real time.VectorStoreRetrievalRagModule
is for retrieving text chunks from a vector store.
Response Stage
Responses are generated in this final stage.
Response Stage Modules
PromptResponseRagModule
is for generating responses based on retrieved text chunks.TextChunksResponseRagModule
is for responding with retrieved text chunks.FootnotePromptResponseRagModule
is for responding with automatic footnotes from text chunk references.
RAG Context
RagContext
is a container object for passing around queries, text chunks, module configs, and other metadata. RagContext
is modified by modules when appropriate. Some modules support runtime config overrides through RagContext.module_configs
.
Example
The following example shows a simple RAG pipeline that translates incoming queries into English, retrieves data from a local vector store, reranks the results using the local rerank driver, and generates a response:
from griptape.chunkers import TextChunker
from griptape.drivers.embedding.openai import OpenAiEmbeddingDriver
from griptape.drivers.prompt.openai import OpenAiChatPromptDriver
from griptape.drivers.rerank.local import LocalRerankDriver
from griptape.drivers.vector.local import LocalVectorStoreDriver
from griptape.engines.rag import RagContext, RagEngine
from griptape.engines.rag.modules import (
PromptResponseRagModule,
TextChunksRerankRagModule,
TranslateQueryRagModule,
VectorStoreRetrievalRagModule,
)
from griptape.engines.rag.stages import (
QueryRagStage,
ResponseRagStage,
RetrievalRagStage,
)
from griptape.loaders import WebLoader
from griptape.rules import Rule, Ruleset
prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0)
vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver())
artifact = WebLoader().load("https://www.griptape.ai")
chunks = TextChunker(max_tokens=500).chunk(artifact)
vector_store.upsert_text_artifacts(
{
"griptape": chunks,
}
)
rag_engine = RagEngine(
query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="english")]),
retrieval_stage=RetrievalRagStage(
max_chunks=5,
retrieval_modules=[
VectorStoreRetrievalRagModule(
name="MyAwesomeRetriever",
vector_store_driver=vector_store,
query_params={"top_n": 20},
)
],
rerank_module=TextChunksRerankRagModule(rerank_driver=LocalRerankDriver()),
),
response_stage=ResponseRagStage(
response_modules=[
PromptResponseRagModule(
prompt_driver=prompt_driver,
rulesets=[Ruleset(name="persona", rules=[Rule("Talk like a pirate")])],
)
]
),
)
rag_context = RagContext(
query="¿Qué ofrecen los servicios en la nube de Griptape?",
module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}},
)
print(rag_engine.process(rag_context).outputs[0].to_text())