Skip to content

Prompt task

PromptTask

Bases: BaseTextInputTask

Source code in griptape/tasks/prompt_task.py
@define
class PromptTask(BaseTextInputTask):
    _prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True, alias="prompt_driver")
    generate_system_template: Callable[[PromptTask], str] = field(
        default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True
    )

    output: Optional[BaseArtifact] = field(default=None, init=False)

    @property
    def prompt_stack(self) -> PromptStack:
        stack = PromptStack()
        memory = self.structure.conversation_memory

        stack.add_system_input(self.generate_system_template(self))

        stack.add_user_input(self.input.to_text())

        if self.output:
            stack.add_assistant_input(self.output.to_text())

        if memory:
            # inserting at index 1 to place memory right after system prompt
            memory.add_to_prompt_stack(stack, 1)

        return stack

    @property
    def prompt_driver(self) -> BasePromptDriver:
        if self._prompt_driver is None:
            if self.structure is not None:
                self._prompt_driver = self.structure.config.prompt_driver
            else:
                raise ValueError("Prompt Driver is not set")
        return self._prompt_driver

    def preprocess(self, structure: Structure) -> PromptTask:
        super().preprocess(structure)
        if self.prompt_driver is not None:
            self.prompt_driver.structure = structure

        return self

    def default_system_template_generator(self, _: PromptTask) -> str:
        return J2("tasks/prompt_task/system.j2").render(
            rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets)
        )

    def run(self) -> BaseArtifact:
        self.output = self.prompt_driver.run(self.prompt_stack)

        return self.output

generate_system_template: Callable[[PromptTask], str] = field(default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True) class-attribute instance-attribute

output: Optional[BaseArtifact] = field(default=None, init=False) class-attribute instance-attribute

prompt_driver: BasePromptDriver property

prompt_stack: PromptStack property

default_system_template_generator(_)

Source code in griptape/tasks/prompt_task.py
def default_system_template_generator(self, _: PromptTask) -> str:
    return J2("tasks/prompt_task/system.j2").render(
        rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets)
    )

preprocess(structure)

Source code in griptape/tasks/prompt_task.py
def preprocess(self, structure: Structure) -> PromptTask:
    super().preprocess(structure)
    if self.prompt_driver is not None:
        self.prompt_driver.structure = structure

    return self

run()

Source code in griptape/tasks/prompt_task.py
def run(self) -> BaseArtifact:
    self.output = self.prompt_driver.run(self.prompt_stack)

    return self.output