Skip to content

huggingface_pipeline_image_generation_driver

HuggingFacePipelineImageGenerationDriver

Bases: BaseImageGenerationDriver, ABC

Image generation driver for models hosted by Hugging Face's Diffusion Pipeline.

For more information, see the HuggingFace documentation for Diffusers: https://huggingface.co/docs/diffusers/en/index

Attributes:

Name Type Description
pipeline_driver BaseDiffusionImageGenerationPipelineDriver

A pipeline image generation model driver typed for the specific pipeline required by the model.

device Optional[str]

The hardware device used for inference. For example, "cpu", "cuda", or "mps".

output_format str

The format the generated image is returned in. Defaults to "png".

Source code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
@define
class HuggingFacePipelineImageGenerationDriver(BaseImageGenerationDriver, ABC):
    """Image generation driver for models hosted by Hugging Face's Diffusion Pipeline.

    For more information, see the HuggingFace documentation for Diffusers:
            https://huggingface.co/docs/diffusers/en/index

    Attributes:
        pipeline_driver: A pipeline image generation model driver typed for the specific pipeline required by the model.
        device: The hardware device used for inference. For example, "cpu", "cuda", or "mps".
        output_format: The format the generated image is returned in. Defaults to "png".
    """

    pipeline_driver: BaseDiffusionImageGenerationPipelineDriver = field(kw_only=True, metadata={"serializable": True})
    device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    output_format: str = field(default="png", kw_only=True, metadata={"serializable": True})

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

        prompt = ", ".join(prompts)
        output_image = pipeline(
            prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device)
        ).images[0]

        buffer = io.BytesIO()
        output_image.save(buffer, format=self.output_format.upper())

        return ImageArtifact(
            value=buffer.getvalue(),
            format=self.output_format.lower(),
            height=output_image.height,
            width=output_image.width,
            meta={"prompt": prompt},
        )

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        pil_image = import_optional_dependency("PIL.Image")

        pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

        prompt = ", ".join(prompts)
        input_image = pil_image.open(io.BytesIO(image.value))
        # The size of the input image drives the size of the output image.
        # Resize the input image to the configured dimensions.
        output_width, output_height = self.pipeline_driver.output_image_dimensions
        if input_image.height != output_height or input_image.width != output_width:
            input_image = input_image.resize((output_width, output_height))

        output_image = pipeline(
            prompt,
            **self.pipeline_driver.make_image_param(input_image),
            **self.pipeline_driver.make_additional_params(negative_prompts, self.device),
        ).images[0]

        buffer = io.BytesIO()
        output_image.save(buffer, format=self.output_format.upper())

        return ImageArtifact(
            value=buffer.getvalue(),
            format=self.output_format.lower(),
            height=output_image.height,
            width=output_image.width,
            meta={"prompt": prompt},
        )

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError("Inpainting is not supported by this driver.")

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError("Outpainting is not supported by this driver.")

device: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

output_format: str = field(default='png', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

pipeline_driver: BaseDiffusionImageGenerationPipelineDriver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError("Inpainting is not supported by this driver.")

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError("Outpainting is not supported by this driver.")

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    pil_image = import_optional_dependency("PIL.Image")

    pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

    prompt = ", ".join(prompts)
    input_image = pil_image.open(io.BytesIO(image.value))
    # The size of the input image drives the size of the output image.
    # Resize the input image to the configured dimensions.
    output_width, output_height = self.pipeline_driver.output_image_dimensions
    if input_image.height != output_height or input_image.width != output_width:
        input_image = input_image.resize((output_width, output_height))

    output_image = pipeline(
        prompt,
        **self.pipeline_driver.make_image_param(input_image),
        **self.pipeline_driver.make_additional_params(negative_prompts, self.device),
    ).images[0]

    buffer = io.BytesIO()
    output_image.save(buffer, format=self.output_format.upper())

    return ImageArtifact(
        value=buffer.getvalue(),
        format=self.output_format.lower(),
        height=output_image.height,
        width=output_image.width,
        meta={"prompt": prompt},
    )

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

    prompt = ", ".join(prompts)
    output_image = pipeline(
        prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device)
    ).images[0]

    buffer = io.BytesIO()
    output_image.save(buffer, format=self.output_format.upper())

    return ImageArtifact(
        value=buffer.getvalue(),
        format=self.output_format.lower(),
        height=output_image.height,
        width=output_image.width,
        meta={"prompt": prompt},
    )