Bases: ArtifactFileOutputMixin
, RuleMixin
, BaseTask
, ABC
Provides a base class for image generation-related tasks.
Attributes:
Name |
Type |
Description |
negative_rulesets |
list[Ruleset]
|
List of negatively-weighted rulesets applied to the text prompt, if supported by the driver.
|
negative_rules |
list[Rule]
|
List of negatively-weighted rules applied to the text prompt, if supported by the driver.
|
output_dir |
list[Rule]
|
If provided, the generated image will be written to disk in output_dir.
|
output_file |
list[Rule]
|
If provided, the generated image will be written to disk as output_file.
|
Source code in griptape/tasks/base_image_generation_task.py
| @define
class BaseImageGenerationTask(ArtifactFileOutputMixin, RuleMixin, BaseTask, ABC):
"""Provides a base class for image generation-related tasks.
Attributes:
negative_rulesets: List of negatively-weighted rulesets applied to the text prompt, if supported by the driver.
negative_rules: List of negatively-weighted rules applied to the text prompt, if supported by the driver.
output_dir: If provided, the generated image will be written to disk in output_dir.
output_file: If provided, the generated image will be written to disk as output_file.
"""
DEFAULT_NEGATIVE_RULESET_NAME = "Negative Ruleset"
image_generation_driver: BaseImageGenerationDriver = field(
default=Factory(lambda: Defaults.drivers_config.image_generation_driver),
kw_only=True,
)
_negative_rulesets: list[Ruleset] = field(factory=list, kw_only=True, alias="negative_rulesets")
negative_rules: list[Rule] = field(factory=list, kw_only=True)
@property
def negative_rulesets(self) -> list[Ruleset]:
negative_rulesets = self._negative_rulesets
if self.negative_rules:
negative_rulesets.append(Ruleset(name=self.DEFAULT_NEGATIVE_RULESET_NAME, rules=self.negative_rules))
return negative_rulesets
def _read_from_file(self, path: str) -> ImageArtifact:
logger.info("Reading image from %s", os.path.abspath(path))
return ImageLoader().load(Path(path))
def _get_prompts(self, prompt: str) -> list[str]:
return [prompt, *[rule.value for ruleset in self.rulesets for rule in ruleset.rules]]
def _get_negative_prompts(self) -> list[str]:
return [rule.value for ruleset in self.negative_rulesets for rule in ruleset.rules]
|
DEFAULT_NEGATIVE_RULESET_NAME = 'Negative Ruleset'
class-attribute
instance-attribute
image_generation_driver: BaseImageGenerationDriver = field(default=Factory(lambda: Defaults.drivers_config.image_generation_driver), kw_only=True)
class-attribute
instance-attribute
negative_rules: list[Rule] = field(factory=list, kw_only=True)
class-attribute
instance-attribute
negative_rulesets: list[Ruleset]
property