Skip to content

Image loader

ImageLoader

Bases: BaseLoader

Loads images into image artifacts.

Attributes:

Name Type Description
format Optional[str]

If provided, attempts to ensure image artifacts are in this format when loaded. For example, when set to 'PNG', loading image.jpg will return an ImageArtifact containing the image bytes in PNG format.

Source code in griptape/loaders/image_loader.py
@define
class ImageLoader(BaseLoader):
    """Loads images into image artifacts.

    Attributes:
        format: If provided, attempts to ensure image artifacts are in this format when loaded.
                For example, when set to 'PNG', loading image.jpg will return an ImageArtifact containing the image
                    bytes in PNG format.
    """

    format: Optional[str] = field(default=None, kw_only=True)

    FORMAT_TO_MIME_TYPE = {
        "bmp": "image/bmp",
        "gif": "image/gif",
        "jpeg": "image/jpeg",
        "png": "image/png",
        "tiff": "image/tiff",
        "webp": "image/webp",
    }

    def load(self, source: bytes, *args, **kwargs) -> ImageArtifact:
        Image = import_optional_dependency("PIL.Image")
        image = Image.open(BytesIO(source))

        # Normalize format only if requested.
        if self.format is not None:
            byte_stream = BytesIO()
            image.save(byte_stream, format=self.format)
            image = Image.open(byte_stream)
            source = byte_stream.getvalue()

        image_artifact = ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height)

        return image_artifact

    def _get_mime_type(self, image_format: str | None) -> str:
        if image_format is None:
            raise ValueError("image_format is None")

        if image_format.lower() not in self.FORMAT_TO_MIME_TYPE:
            raise ValueError(f"Unsupported image format {image_format}")

        return self.FORMAT_TO_MIME_TYPE[image_format.lower()]

    def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]:
        return cast(dict[str, ImageArtifact], super().load_collection(sources, *args, **kwargs))

FORMAT_TO_MIME_TYPE = {'bmp': 'image/bmp', 'gif': 'image/gif', 'jpeg': 'image/jpeg', 'png': 'image/png', 'tiff': 'image/tiff', 'webp': 'image/webp'} class-attribute instance-attribute

format: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

load(source, *args, **kwargs)

Source code in griptape/loaders/image_loader.py
def load(self, source: bytes, *args, **kwargs) -> ImageArtifact:
    Image = import_optional_dependency("PIL.Image")
    image = Image.open(BytesIO(source))

    # Normalize format only if requested.
    if self.format is not None:
        byte_stream = BytesIO()
        image.save(byte_stream, format=self.format)
        image = Image.open(byte_stream)
        source = byte_stream.getvalue()

    image_artifact = ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height)

    return image_artifact

load_collection(sources, *args, **kwargs)

Source code in griptape/loaders/image_loader.py
def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ImageArtifact]:
    return cast(dict[str, ImageArtifact], super().load_collection(sources, *args, **kwargs))