Skip to content

griptape_cloud

__all__ = ['GriptapeCloudImageGenerationDriver'] module-attribute

GriptapeCloudImageGenerationDriver

Bases: BaseImageGenerationDriver

Source code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
@define
class GriptapeCloudImageGenerationDriver(BaseImageGenerationDriver):
    model: Optional[str] = field(default=None, kw_only=True)
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
    )
    style: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    quality: Literal["standard", "hd"] = field(default="standard", kw_only=True, metadata={"serializable": True})
    image_size: Literal["1024x1024", "1024x1792", "1792x1024"] = field(
        default="1024x1024", kw_only=True, metadata={"serializable": True}
    )

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        url = urljoin(self.base_url.strip("/"), "/api/images/generations")

        response = requests.post(
            url,
            headers=self.headers,
            json={
                "prompts": prompts,
                "driver_configuration": {
                    "model": self.model,
                    "image_size": self.image_size,
                    "quality": self.quality,
                    "style": self.style,
                },
            },
        )
        response.raise_for_status()
        response = response.json()

        return ImageArtifact.from_dict(response["artifact"])

    def try_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support image variation")

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY'])) class-attribute instance-attribute

base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

image_size = field(default='1024x1024', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model = field(default=None, kw_only=True) class-attribute instance-attribute

quality = field(default='standard', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

style = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support image variation")

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    url = urljoin(self.base_url.strip("/"), "/api/images/generations")

    response = requests.post(
        url,
        headers=self.headers,
        json={
            "prompts": prompts,
            "driver_configuration": {
                "model": self.model,
                "image_size": self.image_size,
                "quality": self.quality,
                "style": self.style,
            },
        },
    )
    response.raise_for_status()
    response = response.json()

    return ImageArtifact.from_dict(response["artifact"])