Skip to content

prompt_task

logger = logging.getLogger(Defaults.logging_config.logger_name) module-attribute

PromptTask

Bases: RuleMixin, BaseTask

Source code in griptape/tasks/prompt_task.py
@define
class PromptTask(RuleMixin, BaseTask):
    prompt_driver: BasePromptDriver = field(
        default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True
    )
    generate_system_template: Callable[[PromptTask], str] = field(
        default=Factory(lambda self: self.default_generate_system_template, takes_self=True),
        kw_only=True,
    )
    _input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field(
        default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
        alias="input",
    )

    @property
    def rulesets(self) -> list:
        default_rules = self.rules
        rulesets = self._rulesets

        if self.structure is not None:
            if self.structure._rulesets:
                rulesets = self.structure._rulesets + self._rulesets
            if self.structure.rules:
                default_rules = self.structure.rules + self.rules

        if default_rules:
            rulesets.append(Ruleset(name=self.DEFAULT_RULESET_NAME, rules=default_rules))

        return rulesets

    @property
    def input(self) -> BaseArtifact:
        return self._process_task_input(self._input)

    @input.setter
    def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact]) -> None:
        self._input = value

    output: Optional[BaseArtifact] = field(default=None, init=False)

    @property
    def prompt_stack(self) -> PromptStack:
        stack = PromptStack()
        memory = self.structure.conversation_memory if self.structure is not None else None

        system_template = self.generate_system_template(self)
        if system_template:
            stack.add_system_message(system_template)

        stack.add_user_message(self.input)

        if self.output:
            stack.add_assistant_message(self.output)

        if memory is not None:
            # insert memory into the stack right before the user messages
            memory.add_to_prompt_stack(self.prompt_driver, stack, 1 if system_template else 0)

        return stack

    def default_generate_system_template(self, _: PromptTask) -> str:
        return J2("tasks/prompt_task/system.j2").render(
            rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
        )

    def before_run(self) -> None:
        super().before_run()

        logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text())

    def after_run(self) -> None:
        super().after_run()

        logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text())

    def try_run(self) -> BaseArtifact:
        message = self.prompt_driver.run(self.prompt_stack)

        return message.to_artifact()

    def _process_task_input(
        self,
        task_input: str | tuple | list | BaseArtifact | Callable[[BaseTask], BaseArtifact],
    ) -> BaseArtifact:
        if isinstance(task_input, TextArtifact):
            task_input.value = J2().render_from_string(task_input.value, **self.full_context)

            return task_input
        elif isinstance(task_input, Callable):
            return self._process_task_input(task_input(self))
        elif isinstance(task_input, ListArtifact):
            return ListArtifact([self._process_task_input(elem) for elem in task_input.value])
        elif isinstance(task_input, BaseArtifact):
            return task_input
        elif isinstance(task_input, (list, tuple)):
            return ListArtifact([self._process_task_input(elem) for elem in task_input])
        else:
            return self._process_task_input(TextArtifact(task_input))

generate_system_template: Callable[[PromptTask], str] = field(default=Factory(lambda self: self.default_generate_system_template, takes_self=True), kw_only=True) class-attribute instance-attribute

input: BaseArtifact property writable

output: Optional[BaseArtifact] = field(default=None, init=False) class-attribute instance-attribute

prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True) class-attribute instance-attribute

prompt_stack: PromptStack property

rulesets: list property

after_run()

Source code in griptape/tasks/prompt_task.py
def after_run(self) -> None:
    super().after_run()

    logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text())

before_run()

Source code in griptape/tasks/prompt_task.py
def before_run(self) -> None:
    super().before_run()

    logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, self.input.to_text())

default_generate_system_template(_)

Source code in griptape/tasks/prompt_task.py
def default_generate_system_template(self, _: PromptTask) -> str:
    return J2("tasks/prompt_task/system.j2").render(
        rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets),
    )

try_run()

Source code in griptape/tasks/prompt_task.py
def try_run(self) -> BaseArtifact:
    message = self.prompt_driver.run(self.prompt_stack)

    return message.to_artifact()