@define
class Agent(Structure):
input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field(
default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
)
stream: bool | None = field(default=None, kw_only=True)
prompt_driver: BasePromptDriver | None = field(default=None, kw_only=True)
output_schema: Schema | type[BaseModel] | None = field(default=None, kw_only=True)
tools: list[BaseTool] = field(factory=list, kw_only=True)
max_meta_memory_entries: int | None = field(default=20, kw_only=True)
max_subtasks: int | None = field(default=None, kw_only=True)
fail_fast: bool = field(default=False, kw_only=True)
_tasks: list[BaseTask | list[BaseTask]] = field(
factory=list, kw_only=True, alias="tasks", metadata={"serializable": True}
)
@fail_fast.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None: # noqa: FBT001
if fail_fast:
raise ValueError("Agents cannot fail fast, as they can only have 1 task.")
@prompt_driver.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_prompt_driver(self, _: Attribute, prompt_driver: BasePromptDriver | None) -> None:
if prompt_driver is not None and self.stream is not None:
warnings.warn(
"`Agent.prompt_driver` is set, but `Agent.stream` was provided. `Agent.stream` will be ignored. This will be an error in the future.",
UserWarning,
stacklevel=2,
)
@_tasks.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_tasks(self, _: Attribute, tasks: list) -> None:
if tasks and self.prompt_driver is not None:
warnings.warn(
"`Agent.tasks` is set, but `Agent.prompt_driver` was provided. `Agent.prompt_driver` will be ignored. This will be an error in the future.",
UserWarning,
stacklevel=2,
)
def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()
if len(self.tasks) == 0:
self._init_task()
@property
def task(self) -> BaseTask:
return self.tasks[0]
def add_task(self, task: BaseTask) -> BaseTask:
self._tasks.clear()
task.preprocess(self)
self._tasks.append(task)
return task
def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]:
if len(tasks) > 1:
raise ValueError("Agents can only have one task.")
return super().add_tasks(*tasks)
@observable
def try_run(self, *args) -> Agent:
self.task.run()
return self
def _init_task(self) -> None:
if self.stream is None:
with validators.disabled():
self.stream = Defaults.drivers_config.prompt_driver.stream
if self.prompt_driver is None:
with validators.disabled():
prompt_driver = evolve(Defaults.drivers_config.prompt_driver, stream=self.stream)
self.prompt_driver = prompt_driver
else:
prompt_driver = self.prompt_driver
task_kwargs = {
"prompt_driver": prompt_driver,
"tools": self.tools,
"output_schema": self.output_schema,
"max_meta_memory_entries": self.max_meta_memory_entries,
}
if self.max_subtasks is not None:
task_kwargs["max_subtasks"] = self.max_subtasks
task = PromptTask(self.input, **task_kwargs)
self.add_task(task)