Skip to content

Amazon bedrock image query driver

AmazonBedrockImageQueryDriver

Bases: BaseMultiModelImageQueryDriver

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
@define
class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

        response = self.bedrock_client.invoke_model(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = json.loads(response.get("body").read())

        if response_body is None:
            raise ValueError("Model response is empty")

        try:
            return self.image_query_model_driver.process_output(response_body)
        except Exception as e:
            raise ValueError(f"Output is unable to be processed as returned {e}")

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

    response = self.bedrock_client.invoke_model(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = json.loads(response.get("body").read())

    if response_body is None:
        raise ValueError("Model response is empty")

    try:
        return self.image_query_model_driver.process_output(response_body)
    except Exception as e:
        raise ValueError(f"Output is unable to be processed as returned {e}")