Skip to content

RAG Engine

RAG Engines

Note

This section is a work in progress.

Rag Engine is an abstraction for implementing modular retrieval augmented generation (RAG) pipelines.

RAG Stages

RagEngines consist of three stages: QueryRagStage, RetrievalRagStage, and ResponseRagStage. These stages are always executed sequentially. Each stage comprises multiple modules, which are executed in a customized manner. Due to this unique structure, RagEngines are not intended to replace Workflows or Pipelines.

RAG Modules

RAG modules are used to implement actions in the different stages of the RAG pipeline. RagEngine enables developers to easily add new modules to experiment with novel RAG strategies.

The three stages of the pipeline implemented in RAG Engines, together with their purposes and associated modules, are as follows:

Query Stage

This stage is used for modifying input queries before they are submitted.

Query Stage Modules

  • TranslateQueryRagModule is for translating the query into another language.

Retrieval Stage

Results are retrieved in this stage, either from a vector store in the form of chunks, or with a text loader. You may optionally use a rerank module in this stage to rerank results in order of their relevance to the original query.

Retrieval Stage Modules

  • TextChunksRerankRagModule is for re-ranking retrieved results.
  • TextLoaderRetrievalRagModule is for retrieving data with text loaders in real time.
  • VectorStoreRetrievalRagModule is for retrieving text chunks from a vector store.

Response Stage

Responses are generated in this final stage.

Response Stage Modules

  • PromptResponseRagModule is for generating responses based on retrieved text chunks.
  • TextChunksResponseRagModule is for responding with retrieved text chunks.
  • FootnotePromptResponseRagModule is for responding with automatic footnotes from text chunk references.

RAG Context

RagContext is a container object for passing around queries, text chunks, module configs, and other metadata. RagContext is modified by modules when appropriate. Some modules support runtime config overrides through RagContext.module_configs.

Example

The following example shows a simple RAG pipeline that translates incoming queries into English, retrieves data from a local vector store, reranks the results using the local rerank driver, and generates a response:

from griptape.chunkers import TextChunker
from griptape.drivers.embedding.openai import OpenAiEmbeddingDriver
from griptape.drivers.prompt.openai import OpenAiChatPromptDriver
from griptape.drivers.rerank.local import LocalRerankDriver
from griptape.drivers.vector.local import LocalVectorStoreDriver
from griptape.engines.rag import RagContext, RagEngine
from griptape.engines.rag.modules import (
    PromptResponseRagModule,
    TextChunksRerankRagModule,
    TranslateQueryRagModule,
    VectorStoreRetrievalRagModule,
)
from griptape.engines.rag.stages import (
    QueryRagStage,
    ResponseRagStage,
    RetrievalRagStage,
)
from griptape.loaders import WebLoader
from griptape.rules import Rule, Ruleset

prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0)

vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver())
artifact = WebLoader().load("https://www.griptape.ai")
chunks = TextChunker(max_tokens=500).chunk(artifact)

vector_store.upsert_text_artifacts(
    {
        "griptape": chunks,
    }
)

rag_engine = RagEngine(
    query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="english")]),
    retrieval_stage=RetrievalRagStage(
        max_chunks=5,
        retrieval_modules=[
            VectorStoreRetrievalRagModule(
                name="MyAwesomeRetriever",
                vector_store_driver=vector_store,
                query_params={"top_n": 20},
            )
        ],
        rerank_module=TextChunksRerankRagModule(rerank_driver=LocalRerankDriver()),
    ),
    response_stage=ResponseRagStage(
        response_modules=[
            PromptResponseRagModule(
                prompt_driver=prompt_driver,
                rulesets=[Ruleset(name="persona", rules=[Rule("Talk like a pirate")])],
            )
        ]
    ),
)

rag_context = RagContext(
    query="¿Qué ofrecen los servicios en la nube de Griptape?",
    module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}},
)

print(rag_engine.process(rag_context).outputs[0].to_text())