Skip to content

Amazon bedrock image generation driver

AmazonBedrockImageGenerationDriver

Bases: BaseMultiModelImageGenerationDriver

Driver for image generation models provided by Amazon Bedrock.

Attributes:

Name Type Description
model

Bedrock model ID.

session Session

boto3 session.

bedrock_client Any

Bedrock runtime client.

image_width int

Width of output images. Defaults to 512 and must be a multiple of 64.

image_height int

Height of output images. Defaults to 512 and must be a multiple of 64.

seed int | None

Optionally provide a consistent seed to generation requests, increasing consistency in output.

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@define
class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
    """Driver for image generation models provided by Amazon Bedrock.

    Attributes:
        model: Bedrock model ID.
        session: boto3 session.
        bedrock_client: Bedrock runtime client.
        image_width: Width of output images. Defaults to 512 and must be a multiple of 64.
        image_height: Height of output images. Defaults to 512 and must be a multiple of 64.
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client(service_name="bedrock-runtime"), takes_self=True)
    )
    image_width: int = field(default=512, kw_only=True)
    image_height: int = field(default=512, kw_only=True)
    seed: int | None = field(default=None, kw_only=True)

    def try_text_to_image(self, prompts: list[str], negative_prompts: list[str] | None = None) -> ImageArtifact:
        request = self.image_generation_model_driver.text_to_image_request_parameters(
            prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            mime_type="image/png",
            width=self.image_width,
            height=self.image_height,
            model=self.model,
        )

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_variation_request_parameters(
            prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            mime_type="image/png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_inpainting(
        self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_inpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            mime_type="image/png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_outpainting(
        self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_outpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            mime_type="image/png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def _make_request(self, request: dict) -> bytes:
        response = self.bedrock_client.invoke_model(
            body=json.dumps(request), modelId=self.model, accept="application/json", contentType="application/json"
        )

        response_body = json.loads(response.get("body").read())

        try:
            image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
        except Exception as e:
            raise ValueError(f"Inpainting generation failed: {e}")

        return image_bytes

bedrock_client: Any = field(default=Factory(lambda : self.session.client(service_name='bedrock-runtime'), takes_self=True)) class-attribute instance-attribute

image_height: int = field(default=512, kw_only=True) class-attribute instance-attribute

image_width: int = field(default=512, kw_only=True) class-attribute instance-attribute

seed: int | None = field(default=None, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_inpainting(
    self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_inpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        mime_type="image/png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_outpainting(
    self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_outpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        mime_type="image/png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_variation_request_parameters(
        prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        mime_type="image/png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: list[str] | None = None) -> ImageArtifact:
    request = self.image_generation_model_driver.text_to_image_request_parameters(
        prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        mime_type="image/png",
        width=self.image_width,
        height=self.image_height,
        model=self.model,
    )