Skip to content

output_schema_validation_subtask

logger = logging.getLogger(Defaults.logging_config.logger_name) module-attribute

OutputSchemaValidationSubtask

Bases: BaseSubtask

Source code in griptape/tasks/output_schema_validation_subtask.py
@define
class OutputSchemaValidationSubtask(BaseSubtask):
    _input: BaseArtifact = field(alias="input")
    output_schema: Union[Schema, type[BaseModel]] = field(kw_only=True)
    structured_output_strategy: StructuredOutputStrategy = field(
        default="rule", kw_only=True, metadata={"serializable": True}
    )
    generate_assistant_subtask_template: Callable[[OutputSchemaValidationSubtask], str] = field(
        default=Factory(lambda self: self.default_generate_assistant_subtask_template, takes_self=True),
        kw_only=True,
    )
    generate_user_subtask_template: Callable[[OutputSchemaValidationSubtask], str] = field(
        default=Factory(lambda self: self.default_generate_user_subtask_template, takes_self=True),
        kw_only=True,
    )
    _validation_errors: str | None = field(default=None, init=False)

    @property
    def input(self) -> BaseArtifact:
        return self._input

    @input.setter
    def input(self, value: BaseArtifact) -> None:
        self._input = value

    @property
    def validation_errors(self) -> str | None:
        return self._validation_errors

    def attach_to(self, parent_task: BaseTask) -> None:
        super().attach_to(parent_task)
        try:
            # With `native` or `rule` strategies, the output will be a json string that can be parsed.
            # With the `tool` strategy, the output will already be a `JsonArtifact`.
            if self.structured_output_strategy in ("native", "rule"):
                if isinstance(self.output_schema, Schema):
                    self.output_schema.validate(json.loads(self.input.value))
                    self.output = JsonArtifact(self.input.value)
                else:
                    model = TypeAdapter(self.output_schema).validate_json(self.input.value)
                    self.output = ModelArtifact(model)
            else:
                self.output = self.input
        except SchemaError as e:
            self._validation_errors = str(e)
        except ValidationError as e:
            self._validation_errors = str(e.errors())

    def before_run(self) -> None:
        logger.info("%s Validating: %s", self.__class__.__name__, self.input.value)

    def try_run(self) -> BaseArtifact:
        if self._validation_errors is None:
            return self._input
        return ErrorArtifact(
            value=f"Validation error: {self._validation_errors}",
        )

    def after_run(self) -> None:
        if self._validation_errors is None:
            logger.info("%s Validation successful", self.__class__.__name__)
        else:
            logger.error("%s Validation error: %s", self.__class__.__name__, self._validation_errors)

    def add_to_prompt_stack(self, stack: PromptStack) -> None:
        if self.output is None:
            return
        stack.add_assistant_message(self.generate_assistant_subtask_template(self))
        stack.add_user_message(self.generate_user_subtask_template(self))

    def default_generate_assistant_subtask_template(self, subtask: OutputSchemaValidationSubtask) -> str:
        return J2("tasks/prompt_task/assistant_output_schema_validation_subtask.j2").render(
            subtask=subtask,
        )

    def default_generate_user_subtask_template(self, subtask: OutputSchemaValidationSubtask) -> str:
        return J2("tasks/prompt_task/user_output_schema_validation_subtask.j2").render(
            subtask=subtask,
        )

_input = field(alias='input') class-attribute instance-attribute

_validation_errors = field(default=None, init=False) class-attribute instance-attribute

generate_assistant_subtask_template = field(default=Factory(lambda self: self.default_generate_assistant_subtask_template, takes_self=True), kw_only=True) class-attribute instance-attribute

generate_user_subtask_template = field(default=Factory(lambda self: self.default_generate_user_subtask_template, takes_self=True), kw_only=True) class-attribute instance-attribute

input property writable

output_schema = field(kw_only=True) class-attribute instance-attribute

structured_output_strategy = field(default='rule', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

validation_errors property

add_to_prompt_stack(stack)

Source code in griptape/tasks/output_schema_validation_subtask.py
def add_to_prompt_stack(self, stack: PromptStack) -> None:
    if self.output is None:
        return
    stack.add_assistant_message(self.generate_assistant_subtask_template(self))
    stack.add_user_message(self.generate_user_subtask_template(self))

after_run()

Source code in griptape/tasks/output_schema_validation_subtask.py
def after_run(self) -> None:
    if self._validation_errors is None:
        logger.info("%s Validation successful", self.__class__.__name__)
    else:
        logger.error("%s Validation error: %s", self.__class__.__name__, self._validation_errors)

attach_to(parent_task)

Source code in griptape/tasks/output_schema_validation_subtask.py
def attach_to(self, parent_task: BaseTask) -> None:
    super().attach_to(parent_task)
    try:
        # With `native` or `rule` strategies, the output will be a json string that can be parsed.
        # With the `tool` strategy, the output will already be a `JsonArtifact`.
        if self.structured_output_strategy in ("native", "rule"):
            if isinstance(self.output_schema, Schema):
                self.output_schema.validate(json.loads(self.input.value))
                self.output = JsonArtifact(self.input.value)
            else:
                model = TypeAdapter(self.output_schema).validate_json(self.input.value)
                self.output = ModelArtifact(model)
        else:
            self.output = self.input
    except SchemaError as e:
        self._validation_errors = str(e)
    except ValidationError as e:
        self._validation_errors = str(e.errors())

before_run()

Source code in griptape/tasks/output_schema_validation_subtask.py
def before_run(self) -> None:
    logger.info("%s Validating: %s", self.__class__.__name__, self.input.value)

default_generate_assistant_subtask_template(subtask)

Source code in griptape/tasks/output_schema_validation_subtask.py
def default_generate_assistant_subtask_template(self, subtask: OutputSchemaValidationSubtask) -> str:
    return J2("tasks/prompt_task/assistant_output_schema_validation_subtask.j2").render(
        subtask=subtask,
    )

default_generate_user_subtask_template(subtask)

Source code in griptape/tasks/output_schema_validation_subtask.py
def default_generate_user_subtask_template(self, subtask: OutputSchemaValidationSubtask) -> str:
    return J2("tasks/prompt_task/user_output_schema_validation_subtask.j2").render(
        subtask=subtask,
    )

try_run()

Source code in griptape/tasks/output_schema_validation_subtask.py
def try_run(self) -> BaseArtifact:
    if self._validation_errors is None:
        return self._input
    return ErrorArtifact(
        value=f"Validation error: {self._validation_errors}",
    )