Skip to content

prompt_response_rag_module

PromptResponseRagModule

Bases: BaseResponseRagModule, RuleMixin

Source code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
@define(kw_only=True)
class PromptResponseRagModule(BaseResponseRagModule, RuleMixin):
    prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver))
    answer_token_offset: int = field(default=400)
    metadata: Optional[str] = field(default=None)
    generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(
        default=Factory(lambda self: self.default_generate_system_template, takes_self=True),
    )

    def run(self, context: RagContext) -> BaseArtifact:
        query = context.query
        tokenizer = self.prompt_driver.tokenizer
        included_chunks = []
        system_prompt = self.generate_system_template(context, included_chunks)

        for artifact in context.text_chunks:
            included_chunks.append(artifact)

            system_prompt = self.generate_system_template(context, included_chunks)
            message_token_count = self.prompt_driver.tokenizer.count_tokens(
                self.prompt_driver.prompt_stack_to_string(self.generate_prompt_stack(system_prompt, query)),
            )

            if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
                included_chunks.pop()

                system_prompt = self.generate_system_template(context, included_chunks)

                break

        output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact()

        if isinstance(output, TextArtifact):
            return output
        else:
            raise ValueError("Prompt driver did not return a TextArtifact")

    def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
        params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

        if len(self.rulesets) > 0:
            params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets)

        if self.metadata is not None:
            params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)

        return J2("engines/rag/modules/response/prompt/system.j2").render(**params)

answer_token_offset: int = field(default=400) class-attribute instance-attribute

generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(default=Factory(lambda self: self.default_generate_system_template, takes_self=True)) class-attribute instance-attribute

metadata: Optional[str] = field(default=None) class-attribute instance-attribute

prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver)) class-attribute instance-attribute

default_generate_system_template(context, artifacts)

Source code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
    params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

    if len(self.rulesets) > 0:
        params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets)

    if self.metadata is not None:
        params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)

    return J2("engines/rag/modules/response/prompt/system.j2").render(**params)

run(context)

Source code in griptape/engines/rag/modules/response/prompt_response_rag_module.py
def run(self, context: RagContext) -> BaseArtifact:
    query = context.query
    tokenizer = self.prompt_driver.tokenizer
    included_chunks = []
    system_prompt = self.generate_system_template(context, included_chunks)

    for artifact in context.text_chunks:
        included_chunks.append(artifact)

        system_prompt = self.generate_system_template(context, included_chunks)
        message_token_count = self.prompt_driver.tokenizer.count_tokens(
            self.prompt_driver.prompt_stack_to_string(self.generate_prompt_stack(system_prompt, query)),
        )

        if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
            included_chunks.pop()

            system_prompt = self.generate_system_template(context, included_chunks)

            break

    output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact()

    if isinstance(output, TextArtifact):
        return output
    else:
        raise ValueError("Prompt driver did not return a TextArtifact")