Bases: CodeExecutionTask[Union[InfoArtifact, ListArtifact[InfoArtifact]]]
Source code in griptape/tasks/branch_task.py
| @define
class BranchTask(CodeExecutionTask[Union[InfoArtifact, ListArtifact[InfoArtifact]]]):
on_run: Callable[[BranchTask], Union[InfoArtifact, ListArtifact[InfoArtifact]]] = field(kw_only=True)
def try_run(self) -> InfoArtifact | ListArtifact[InfoArtifact]:
result = self.on_run(self)
if isinstance(result, ListArtifact):
branch_task_ids = {artifact.value for artifact in result}
else:
branch_task_ids = {result.value}
if not all(branch_task_id in self.child_ids for branch_task_id in branch_task_ids):
raise ValueError(f"Branch task returned invalid child task id {branch_task_ids}")
if self.structure is not None:
children_to_skip = [child for child in self.children if child.id not in branch_task_ids]
for child in children_to_skip:
child.state = BaseTask.State.SKIPPED
return result
|
on_run = field(kw_only=True)
class-attribute
instance-attribute
try_run()
Source code in griptape/tasks/branch_task.py
| def try_run(self) -> InfoArtifact | ListArtifact[InfoArtifact]:
result = self.on_run(self)
if isinstance(result, ListArtifact):
branch_task_ids = {artifact.value for artifact in result}
else:
branch_task_ids = {result.value}
if not all(branch_task_id in self.child_ids for branch_task_id in branch_task_ids):
raise ValueError(f"Branch task returned invalid child task id {branch_task_ids}")
if self.structure is not None:
children_to_skip = [child for child in self.children if child.id not in branch_task_ids]
for child in children_to_skip:
child.state = BaseTask.State.SKIPPED
return result
|