Skip to content

modules

__all__ = ['BaseRagModule', 'BaseQueryRagModule', 'TranslateQueryRagModule', 'BaseRetrievalRagModule', 'BaseRerankRagModule', 'TextChunksRerankRagModule', 'VectorStoreRetrievalRagModule', 'TextLoaderRetrievalRagModule', 'BaseBeforeResponseRagModule', 'BaseAfterResponseRagModule', 'BaseResponseRagModule', 'PromptResponseRagModule', 'TextChunksResponseRagModule', 'FootnotePromptResponseRagModule'] module-attribute

BaseAfterResponseRagModule

Bases: BaseRagModule, ABC

Source code in griptape/engines/rag/modules/response/base_after_response_rag_module.py
@define(kw_only=True)
class BaseAfterResponseRagModule(BaseRagModule, ABC):
    @abstractmethod
    def run(self, context: RagContext) -> RagContext: ...

run(context) abstractmethod

Source code in griptape/engines/rag/modules/response/base_after_response_rag_module.py
@abstractmethod
def run(self, context: RagContext) -> RagContext: ...

BaseBeforeResponseRagModule

Bases: BaseRagModule, ABC

Source code in griptape/engines/rag/modules/response/base_before_response_rag_module.py
@define(kw_only=True)
class BaseBeforeResponseRagModule(BaseRagModule, ABC):
    @abstractmethod
    def run(self, context: RagContext) -> RagContext: ...

run(context) abstractmethod

Source code in griptape/engines/rag/modules/response/base_before_response_rag_module.py
@abstractmethod
def run(self, context: RagContext) -> RagContext: ...

BaseQueryRagModule

Bases: BaseRagModule, ABC

Source code in griptape/engines/rag/modules/query/base_query_rag_module.py
@define(kw_only=True)
class BaseQueryRagModule(BaseRagModule, ABC):
    @abstractmethod
    def run(self, context: RagContext) -> RagContext: ...

run(context) abstractmethod

Source code in griptape/engines/rag/modules/query/base_query_rag_module.py
@abstractmethod
def run(self, context: RagContext) -> RagContext: ...

BaseRagModule

Bases: FuturesExecutorMixin, ABC

Source code in griptape/engines/rag/modules/base_rag_module.py
@define(kw_only=True)
class BaseRagModule(FuturesExecutorMixin, ABC):
    name: str = field(
        default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True
    )

    def generate_prompt_stack(self, system_prompt: Optional[str], query: str) -> PromptStack:
        messages = []

        if system_prompt is not None:
            messages.append(Message(system_prompt, role=Message.SYSTEM_ROLE))

        messages.append(Message(query, role=Message.USER_ROLE))

        return PromptStack(messages=messages)

    def get_context_param(self, context: RagContext, key: str) -> Optional[Any]:
        return context.module_configs.get(self.name, {}).get(key)

    def set_context_param(self, context: RagContext, key: str, value: Any) -> None:
        if not isinstance(context.module_configs.get(self.name), dict):
            context.module_configs[self.name] = {}

        context.module_configs[self.name][key] = value

name: str = field(default=Factory(lambda self: f'{self.__class__.__name__}-{uuid.uuid4().hex}', takes_self=True), kw_only=True) class-attribute instance-attribute

generate_prompt_stack(system_prompt, query)

Source code in griptape/engines/rag/modules/base_rag_module.py
def generate_prompt_stack(self, system_prompt: Optional[str], query: str) -> PromptStack:
    messages = []

    if system_prompt is not None:
        messages.append(Message(system_prompt, role=Message.SYSTEM_ROLE))

    messages.append(Message(query, role=Message.USER_ROLE))

    return PromptStack(messages=messages)

get_context_param(context, key)

Source code in griptape/engines/rag/modules/base_rag_module.py
def get_context_param(self, context: RagContext, key: str) -> Optional[Any]:
    return context.module_configs.get(self.name, {}).get(key)

set_context_param(context, key, value)

Source code in griptape/engines/rag/modules/base_rag_module.py
def set_context_param(self, context: RagContext, key: str, value: Any) -> None:
    if not isinstance(context.module_configs.get(self.name), dict):
        context.module_configs[self.name] = {}

    context.module_configs[self.name][key] = value

BaseRerankRagModule

Bases: BaseRagModule, ABC

Source code in griptape/engines/rag/modules/retrieval/base_rerank_rag_module.py
@define(kw_only=True)
class BaseRerankRagModule(BaseRagModule, ABC):
    rerank_driver: BaseRerankDriver = field()

    @abstractmethod
    def run(self, context: RagContext) -> Sequence[BaseArtifact]: ...

rerank_driver: BaseRerankDriver = field() class-attribute instance-attribute

run(context) abstractmethod

Source code in griptape/engines/rag/modules/retrieval/base_rerank_rag_module.py
@abstractmethod
def run(self, context: RagContext) -> Sequence[BaseArtifact]: ...

BaseResponseRagModule

Bases: BaseRagModule, ABC

Source code in griptape/engines/rag/modules/response/base_response_rag_module.py
@define(kw_only=True)
class BaseResponseRagModule(BaseRagModule, ABC):
    @abstractmethod
    def run(self, context: RagContext) -> BaseArtifact: ...

run(context) abstractmethod

Source code in griptape/engines/rag/modules/response/base_response_rag_module.py
@abstractmethod
def run(self, context: RagContext) -> BaseArtifact: ...

BaseRetrievalRagModule

Bases: BaseRagModule, ABC

Source code in griptape/engines/rag/modules/retrieval/base_retrieval_rag_module.py
@define(kw_only=True)
class BaseRetrievalRagModule(BaseRagModule, ABC):
    @abstractmethod
    def run(self, context: RagContext) -> Sequence[BaseArtifact]: ...

run(context) abstractmethod

Source code in griptape/engines/rag/modules/retrieval/base_retrieval_rag_module.py
@abstractmethod
def run(self, context: RagContext) -> Sequence[BaseArtifact]: ...

FootnotePromptResponseRagModule

Bases: PromptResponseRagModule

Source code in griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py
@define(kw_only=True)
class FootnotePromptResponseRagModule(PromptResponseRagModule):
    def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
        return J2("engines/rag/modules/response/footnote_prompt/system.j2").render(
            text_chunk_artifacts=artifacts,
            references=utils.references_from_artifacts(artifacts),
            before_system_prompt="\n\n".join(context.before_query),
            after_system_prompt="\n\n".join(context.after_query),
        )

default_system_template_generator(context, artifacts)

Source code in griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py
def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
    return J2("engines/rag/modules/response/footnote_prompt/system.j2").render(
        text_chunk_artifacts=artifacts,
        references=utils.references_from_artifacts(artifacts),
        before_system_prompt="\n\n".join(context.before_query),
        after_system_prompt="\n\n".join(context.after_query),
    )

PromptResponseRagModule

Bases: BaseResponseRagModule, RuleMixin

Source code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
@define(kw_only=True)
class PromptResponseRagModule(BaseResponseRagModule, RuleMixin):
    prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver))
    answer_token_offset: int = field(default=400)
    metadata: Optional[str] = field(default=None)
    generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(
        default=Factory(lambda self: self.default_system_template_generator, takes_self=True),
    )

    def run(self, context: RagContext) -> BaseArtifact:
        query = context.query
        tokenizer = self.prompt_driver.tokenizer
        included_chunks = []
        system_prompt = self.generate_system_template(context, included_chunks)

        for artifact in context.text_chunks:
            included_chunks.append(artifact)

            system_prompt = self.generate_system_template(context, included_chunks)
            message_token_count = self.prompt_driver.tokenizer.count_tokens(
                self.prompt_driver.prompt_stack_to_string(self.generate_prompt_stack(system_prompt, query)),
            )

            if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
                included_chunks.pop()

                system_prompt = self.generate_system_template(context, included_chunks)

                break

        output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact()

        if isinstance(output, TextArtifact):
            return output
        else:
            raise ValueError("Prompt driver did not return a TextArtifact")

    def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
        params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

        if len(self.all_rulesets) > 0:
            params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets)

        if self.metadata is not None:
            params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)

        return J2("engines/rag/modules/response/prompt/system.j2").render(**params)

answer_token_offset: int = field(default=400) class-attribute instance-attribute

generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(default=Factory(lambda self: self.default_system_template_generator, takes_self=True)) class-attribute instance-attribute

metadata: Optional[str] = field(default=None) class-attribute instance-attribute

prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver)) class-attribute instance-attribute

default_system_template_generator(context, artifacts)

Source code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
    params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

    if len(self.all_rulesets) > 0:
        params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets)

    if self.metadata is not None:
        params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)

    return J2("engines/rag/modules/response/prompt/system.j2").render(**params)

run(context)

Source code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
def run(self, context: RagContext) -> BaseArtifact:
    query = context.query
    tokenizer = self.prompt_driver.tokenizer
    included_chunks = []
    system_prompt = self.generate_system_template(context, included_chunks)

    for artifact in context.text_chunks:
        included_chunks.append(artifact)

        system_prompt = self.generate_system_template(context, included_chunks)
        message_token_count = self.prompt_driver.tokenizer.count_tokens(
            self.prompt_driver.prompt_stack_to_string(self.generate_prompt_stack(system_prompt, query)),
        )

        if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
            included_chunks.pop()

            system_prompt = self.generate_system_template(context, included_chunks)

            break

    output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact()

    if isinstance(output, TextArtifact):
        return output
    else:
        raise ValueError("Prompt driver did not return a TextArtifact")

TextChunksRerankRagModule

Bases: BaseRerankRagModule

Source code in griptape/engines/rag/modules/retrieval/text_chunks_rerank_rag_module.py
@define(kw_only=True)
class TextChunksRerankRagModule(BaseRerankRagModule):
    def run(self, context: RagContext) -> Sequence[BaseArtifact]:
        return self.rerank_driver.run(context.query, context.text_chunks)

run(context)

Source code in griptape/engines/rag/modules/retrieval/text_chunks_rerank_rag_module.py
def run(self, context: RagContext) -> Sequence[BaseArtifact]:
    return self.rerank_driver.run(context.query, context.text_chunks)

TextChunksResponseRagModule

Bases: BaseResponseRagModule

Source code in griptape/engines/rag/modules/response/text_chunks_response_rag_module.py
@define(kw_only=True)
class TextChunksResponseRagModule(BaseResponseRagModule):
    def run(self, context: RagContext) -> BaseArtifact:
        return ListArtifact(context.text_chunks)

run(context)

Source code in griptape/engines/rag/modules/response/text_chunks_response_rag_module.py
def run(self, context: RagContext) -> BaseArtifact:
    return ListArtifact(context.text_chunks)

TextLoaderRetrievalRagModule

Bases: BaseRetrievalRagModule

Source code in griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py
@define(kw_only=True)
class TextLoaderRetrievalRagModule(BaseRetrievalRagModule):
    loader: BaseTextLoader = field()
    vector_store_driver: BaseVectorStoreDriver = field()
    source: Any = field()
    query_params: dict[str, Any] = field(factory=dict)
    process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
        default=Factory(lambda: lambda es: [e.to_artifact() for e in es]),
    )

    def run(self, context: RagContext) -> Sequence[TextArtifact]:
        namespace = uuid.uuid4().hex
        context_source = self.get_context_param(context, "source")
        source = self.source if context_source is None else context_source

        query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

        query_params["namespace"] = namespace

        loader_output = self.loader.load(source)

        self.vector_store_driver.upsert_text_artifacts({namespace: loader_output})

        return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))

loader: BaseTextLoader = field() class-attribute instance-attribute

process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(default=Factory(lambda: lambda es: [e.to_artifact() for e in es])) class-attribute instance-attribute

query_params: dict[str, Any] = field(factory=dict) class-attribute instance-attribute

source: Any = field() class-attribute instance-attribute

vector_store_driver: BaseVectorStoreDriver = field() class-attribute instance-attribute

run(context)

Source code in griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py
def run(self, context: RagContext) -> Sequence[TextArtifact]:
    namespace = uuid.uuid4().hex
    context_source = self.get_context_param(context, "source")
    source = self.source if context_source is None else context_source

    query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

    query_params["namespace"] = namespace

    loader_output = self.loader.load(source)

    self.vector_store_driver.upsert_text_artifacts({namespace: loader_output})

    return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))

TranslateQueryRagModule

Bases: BaseQueryRagModule

Source code in griptape/engines/rag/modules/query/translate_query_rag_module.py
@define(kw_only=True)
class TranslateQueryRagModule(BaseQueryRagModule):
    prompt_driver: BasePromptDriver = field()
    language: str = field()
    generate_user_template: Callable[[str, str], str] = field(
        default=Factory(lambda self: self.default_user_template_generator, takes_self=True),
    )

    def run(self, context: RagContext) -> RagContext:
        user_prompt = self.generate_user_template(context.query, self.language)
        output = self.prompt_driver.run(self.generate_prompt_stack(None, user_prompt)).to_artifact()

        context.query = output.to_text()

        return context

    def default_user_template_generator(self, query: str, language: str) -> str:
        return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language)

generate_user_template: Callable[[str, str], str] = field(default=Factory(lambda self: self.default_user_template_generator, takes_self=True)) class-attribute instance-attribute

language: str = field() class-attribute instance-attribute

prompt_driver: BasePromptDriver = field() class-attribute instance-attribute

default_user_template_generator(query, language)

Source code in griptape/engines/rag/modules/query/translate_query_rag_module.py
def default_user_template_generator(self, query: str, language: str) -> str:
    return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language)

run(context)

Source code in griptape/engines/rag/modules/query/translate_query_rag_module.py
def run(self, context: RagContext) -> RagContext:
    user_prompt = self.generate_user_template(context.query, self.language)
    output = self.prompt_driver.run(self.generate_prompt_stack(None, user_prompt)).to_artifact()

    context.query = output.to_text()

    return context

VectorStoreRetrievalRagModule

Bases: BaseRetrievalRagModule

Source code in griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py
@define(kw_only=True)
class VectorStoreRetrievalRagModule(BaseRetrievalRagModule):
    vector_store_driver: BaseVectorStoreDriver = field(
        default=Factory(lambda: Defaults.drivers_config.vector_store_driver)
    )
    query_params: dict[str, Any] = field(factory=dict)
    process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
        default=Factory(lambda: lambda es: [e.to_artifact() for e in es]),
    )

    def run(self, context: RagContext) -> Sequence[TextArtifact]:
        query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

        return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))

process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(default=Factory(lambda: lambda es: [e.to_artifact() for e in es])) class-attribute instance-attribute

query_params: dict[str, Any] = field(factory=dict) class-attribute instance-attribute

vector_store_driver: BaseVectorStoreDriver = field(default=Factory(lambda: Defaults.drivers_config.vector_store_driver)) class-attribute instance-attribute

run(context)

Source code in griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py
def run(self, context: RagContext) -> Sequence[TextArtifact]:
    query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

    return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))