Skip to content

rag

__all__ = ['RagContext', 'RagEngine'] module-attribute

RagContext

Bases: SerializableMixin

Used by RagEngine stages and module to pass context that individual modules are expected to update in the run method.

Attributes:

Name Type Description
query str

Query provided by the user.

module_configs dict[str, dict]

Dictionary of module configs. First key should be a module name and the second a dictionary of configs parameters.

before_query list[str]

An optional list of strings to add before the query in response modules.

after_query list[str]

An optional list of strings to add after the query in response modules.

text_chunks list[TextArtifact]

A list of text chunks to pass around from the retrieval stage to the response stage.

outputs list[BaseArtifact]

List of outputs from the response stage.

Source code in griptape/engines/rag/rag_context.py
@define(kw_only=True)
class RagContext(SerializableMixin):
    """Used by RagEngine stages and module to pass context that individual modules are expected to update in the `run` method.

    Attributes:
        query: Query provided by the user.
        module_configs: Dictionary of module configs. First key should be a module name and the second a dictionary of configs parameters.
        before_query: An optional list of strings to add before the query in response modules.
        after_query: An optional list of strings to add after the query in response modules.
        text_chunks: A list of text chunks to pass around from the retrieval stage to the response stage.
        outputs: List of outputs from the response stage.
    """

    query: str = field(metadata={"serializable": True})
    module_configs: dict[str, dict] = field(factory=dict, metadata={"serializable": True})
    before_query: list[str] = field(factory=list, metadata={"serializable": True})
    after_query: list[str] = field(factory=list, metadata={"serializable": True})
    text_chunks: list[TextArtifact] = field(factory=list, metadata={"serializable": True})
    outputs: list[BaseArtifact] = field(factory=list, metadata={"serializable": True})

    def get_references(self) -> list[Reference]:
        return utils.references_from_artifacts(self.text_chunks)

after_query: list[str] = field(factory=list, metadata={'serializable': True}) class-attribute instance-attribute

before_query: list[str] = field(factory=list, metadata={'serializable': True}) class-attribute instance-attribute

module_configs: dict[str, dict] = field(factory=dict, metadata={'serializable': True}) class-attribute instance-attribute

outputs: list[BaseArtifact] = field(factory=list, metadata={'serializable': True}) class-attribute instance-attribute

query: str = field(metadata={'serializable': True}) class-attribute instance-attribute

text_chunks: list[TextArtifact] = field(factory=list, metadata={'serializable': True}) class-attribute instance-attribute

get_references()

Source code in griptape/engines/rag/rag_context.py
def get_references(self) -> list[Reference]:
    return utils.references_from_artifacts(self.text_chunks)

RagEngine

Source code in griptape/engines/rag/rag_engine.py
@define(kw_only=True)
class RagEngine:
    query_stage: Optional[QueryRagStage] = field(default=None)
    retrieval_stage: Optional[RetrievalRagStage] = field(default=None)
    response_stage: Optional[ResponseRagStage] = field(default=None)

    def __attrs_post_init__(self) -> None:
        modules = []

        if self.query_stage is not None:
            modules.extend(self.query_stage.modules)

        if self.retrieval_stage is not None:
            modules.extend(self.retrieval_stage.modules)

        if self.response_stage is not None:
            modules.extend(self.response_stage.modules)

        module_names = [m.name for m in modules]

        if len(module_names) > len(set(module_names)):
            raise ValueError("module names have to be unique")

    def process_query(self, query: str) -> RagContext:
        return self.process(RagContext(query=query))

    def process(self, context: RagContext) -> RagContext:
        if self.query_stage:
            context = self.query_stage.run(context)

        if self.retrieval_stage:
            context = self.retrieval_stage.run(context)

        if self.response_stage:
            context = self.response_stage.run(context)

        return context

query_stage: Optional[QueryRagStage] = field(default=None) class-attribute instance-attribute

response_stage: Optional[ResponseRagStage] = field(default=None) class-attribute instance-attribute

retrieval_stage: Optional[RetrievalRagStage] = field(default=None) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/engines/rag/rag_engine.py
def __attrs_post_init__(self) -> None:
    modules = []

    if self.query_stage is not None:
        modules.extend(self.query_stage.modules)

    if self.retrieval_stage is not None:
        modules.extend(self.retrieval_stage.modules)

    if self.response_stage is not None:
        modules.extend(self.response_stage.modules)

    module_names = [m.name for m in modules]

    if len(module_names) > len(set(module_names)):
        raise ValueError("module names have to be unique")

process(context)

Source code in griptape/engines/rag/rag_engine.py
def process(self, context: RagContext) -> RagContext:
    if self.query_stage:
        context = self.query_stage.run(context)

    if self.retrieval_stage:
        context = self.retrieval_stage.run(context)

    if self.response_stage:
        context = self.response_stage.run(context)

    return context

process_query(query)

Source code in griptape/engines/rag/rag_engine.py
def process_query(self, query: str) -> RagContext:
    return self.process(RagContext(query=query))