Skip to content

drivers

__all__ = ['BasePromptDriver', 'OpenAiChatPromptDriver', 'AzureOpenAiChatPromptDriver', 'CoherePromptDriver', 'HuggingFacePipelinePromptDriver', 'HuggingFaceHubPromptDriver', 'AnthropicPromptDriver', 'AmazonSageMakerJumpstartPromptDriver', 'AmazonBedrockPromptDriver', 'GooglePromptDriver', 'DummyPromptDriver', 'OllamaPromptDriver', 'BaseConversationMemoryDriver', 'LocalConversationMemoryDriver', 'AmazonDynamoDbConversationMemoryDriver', 'RedisConversationMemoryDriver', 'GriptapeCloudConversationMemoryDriver', 'BaseEmbeddingDriver', 'OpenAiEmbeddingDriver', 'AzureOpenAiEmbeddingDriver', 'AmazonSageMakerJumpstartEmbeddingDriver', 'AmazonBedrockTitanEmbeddingDriver', 'AmazonBedrockCohereEmbeddingDriver', 'VoyageAiEmbeddingDriver', 'HuggingFaceHubEmbeddingDriver', 'GoogleEmbeddingDriver', 'DummyEmbeddingDriver', 'CohereEmbeddingDriver', 'OllamaEmbeddingDriver', 'BaseVectorStoreDriver', 'LocalVectorStoreDriver', 'PineconeVectorStoreDriver', 'MarqoVectorStoreDriver', 'MongoDbAtlasVectorStoreDriver', 'AzureMongoDbVectorStoreDriver', 'RedisVectorStoreDriver', 'OpenSearchVectorStoreDriver', 'AmazonOpenSearchVectorStoreDriver', 'PgVectorVectorStoreDriver', 'QdrantVectorStoreDriver', 'AstraDbVectorStoreDriver', 'DummyVectorStoreDriver', 'GriptapeCloudKnowledgeBaseVectorStoreDriver', 'BaseSqlDriver', 'AmazonRedshiftSqlDriver', 'SnowflakeSqlDriver', 'SqlDriver', 'BaseImageGenerationModelDriver', 'BedrockStableDiffusionImageGenerationModelDriver', 'BedrockTitanImageGenerationModelDriver', 'BaseDiffusionImageGenerationPipelineDriver', 'StableDiffusion3ImageGenerationPipelineDriver', 'StableDiffusion3Img2ImgImageGenerationPipelineDriver', 'StableDiffusion3ControlNetImageGenerationPipelineDriver', 'BaseImageGenerationDriver', 'BaseMultiModelImageGenerationDriver', 'OpenAiImageGenerationDriver', 'LeonardoImageGenerationDriver', 'AmazonBedrockImageGenerationDriver', 'AzureOpenAiImageGenerationDriver', 'DummyImageGenerationDriver', 'HuggingFacePipelineImageGenerationDriver', 'BaseImageQueryModelDriver', 'BedrockClaudeImageQueryModelDriver', 'BaseImageQueryDriver', 'OpenAiImageQueryDriver', 'AzureOpenAiImageQueryDriver', 'DummyImageQueryDriver', 'AnthropicImageQueryDriver', 'BaseMultiModelImageQueryDriver', 'AmazonBedrockImageQueryDriver', 'BaseWebScraperDriver', 'TrafilaturaWebScraperDriver', 'MarkdownifyWebScraperDriver', 'ProxyWebScraperDriver', 'BaseWebSearchDriver', 'GoogleWebSearchDriver', 'DuckDuckGoWebSearchDriver', 'BaseEventListenerDriver', 'AmazonSqsEventListenerDriver', 'WebhookEventListenerDriver', 'AwsIotCoreEventListenerDriver', 'GriptapeCloudEventListenerDriver', 'PusherEventListenerDriver', 'BaseFileManagerDriver', 'LocalFileManagerDriver', 'AmazonS3FileManagerDriver', 'BaseRerankDriver', 'CohereRerankDriver', 'BaseTextToSpeechDriver', 'DummyTextToSpeechDriver', 'ElevenLabsTextToSpeechDriver', 'OpenAiTextToSpeechDriver', 'AzureOpenAiTextToSpeechDriver', 'BaseStructureRunDriver', 'GriptapeCloudStructureRunDriver', 'LocalStructureRunDriver', 'BaseAudioTranscriptionDriver', 'DummyAudioTranscriptionDriver', 'OpenAiAudioTranscriptionDriver', 'BaseObservabilityDriver', 'NoOpObservabilityDriver', 'OpenTelemetryObservabilityDriver', 'GriptapeCloudObservabilityDriver', 'DatadogObservabilityDriver'] module-attribute

AmazonBedrockCohereEmbeddingDriver

Bases: BaseEmbeddingDriver

Amazon Bedrock Cohere Embedding Driver.

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

input_type str

Defaults to search_query. Prepends special tokens to differentiate each type from one another: search_document when you encode documents for embeddings that you store in a vector database. search_query when querying your vector DB to find relevant documents.

session Session

Optionally provide custom boto3.Session.

tokenizer BaseTokenizer

Optionally provide custom BedrockCohereTokenizer.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@define
class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
    """Amazon Bedrock Cohere Embedding Driver.

    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        input_type: Defaults to `search_query`. Prepends special tokens to differentiate each type from one another:
            `search_document` when you encode documents for embeddings that you store in a vector database.
            `search_query` when querying your vector DB to find relevant documents.
        session: Optionally provide custom `boto3.Session`.
        tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "cohere.embed-english-v3"

    model: str = field(default=DEFAULT_MODEL, kw_only=True)
    input_type: str = field(default="search_query", kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
        kw_only=True,
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"input_type": self.input_type, "texts": [chunk]}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload),
            modelId=self.model,
            accept="*/*",
            contentType="application/json",
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embeddings")[0]

DEFAULT_MODEL = 'cohere.embed-english-v3' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

input_type: str = field(default='search_query', kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"input_type": self.input_type, "texts": [chunk]}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload),
        modelId=self.model,
        accept="*/*",
        contentType="application/json",
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embeddings")[0]

AmazonBedrockImageGenerationDriver

Bases: BaseMultiModelImageGenerationDriver

Driver for image generation models provided by Amazon Bedrock.

Attributes:

Name Type Description
model

Bedrock model ID.

session Session

boto3 session.

bedrock_client Any

Bedrock runtime client.

image_width int

Width of output images. Defaults to 512 and must be a multiple of 64.

image_height int

Height of output images. Defaults to 512 and must be a multiple of 64.

seed Optional[int]

Optionally provide a consistent seed to generation requests, increasing consistency in output.

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@define
class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
    """Driver for image generation models provided by Amazon Bedrock.

    Attributes:
        model: Bedrock model ID.
        session: boto3 session.
        bedrock_client: Bedrock runtime client.
        image_width: Width of output images. Defaults to 512 and must be a multiple of 64.
        image_height: Height of output images. Defaults to 512 and must be a multiple of 64.
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client(service_name="bedrock-runtime"), takes_self=True),
    )
    image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
    image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        request = self.image_generation_model_driver.text_to_image_request_parameters(
            prompts,
            self.image_width,
            self.image_height,
            negative_prompts=negative_prompts,
            seed=self.seed,
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            value=image_bytes,
            format="png",
            width=self.image_width,
            height=self.image_height,
            meta={"prompt": ", ".join(prompts), "model": self.model},
        )

    def try_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_variation_request_parameters(
            prompts,
            image=image,
            negative_prompts=negative_prompts,
            seed=self.seed,
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            meta={"prompt": ", ".join(prompts), "model": self.model},
        )

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_inpainting_request_parameters(
            prompts,
            image=image,
            mask=mask,
            negative_prompts=negative_prompts,
            seed=self.seed,
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            meta={"prompt": ", ".join(prompts), "model": self.model},
        )

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_outpainting_request_parameters(
            prompts,
            image=image,
            mask=mask,
            negative_prompts=negative_prompts,
            seed=self.seed,
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            meta={"prompt": ", ".join(prompts), "model": self.model},
        )

    def _make_request(self, request: dict) -> bytes:
        response = self.bedrock_client.invoke_model(
            body=json.dumps(request),
            modelId=self.model,
            accept="application/json",
            contentType="application/json",
        )

        response_body = json.loads(response.get("body").read())

        try:
            image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
        except Exception as e:
            raise ValueError(f"Inpainting generation failed: {e}") from e

        return image_bytes

bedrock_client: Any = field(default=Factory(lambda self: self.session.client(service_name='bedrock-runtime'), takes_self=True)) class-attribute instance-attribute

image_height: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

image_width: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

seed: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_inpainting_request_parameters(
        prompts,
        image=image,
        mask=mask,
        negative_prompts=negative_prompts,
        seed=self.seed,
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        meta={"prompt": ", ".join(prompts), "model": self.model},
    )

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_outpainting_request_parameters(
        prompts,
        image=image,
        mask=mask,
        negative_prompts=negative_prompts,
        seed=self.seed,
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        meta={"prompt": ", ".join(prompts), "model": self.model},
    )

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_variation_request_parameters(
        prompts,
        image=image,
        negative_prompts=negative_prompts,
        seed=self.seed,
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        meta={"prompt": ", ".join(prompts), "model": self.model},
    )

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    request = self.image_generation_model_driver.text_to_image_request_parameters(
        prompts,
        self.image_width,
        self.image_height,
        negative_prompts=negative_prompts,
        seed=self.seed,
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        value=image_bytes,
        format="png",
        width=self.image_width,
        height=self.image_height,
        meta={"prompt": ", ".join(prompts), "model": self.model},
    )

AmazonBedrockImageQueryDriver

Bases: BaseMultiModelImageQueryDriver

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
@define
class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
        kw_only=True,
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

        response = self.bedrock_client.invoke_model(
            modelId=self.model,
            contentType="application/json",
            accept="application/json",
            body=json.dumps(payload),
        )

        response_body = json.loads(response.get("body").read())

        if response_body is None:
            raise ValueError("Model response is empty")

        try:
            return self.image_query_model_driver.process_output(response_body)
        except Exception as e:
            raise ValueError(f"Output is unable to be processed as returned {e}") from e

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

    response = self.bedrock_client.invoke_model(
        modelId=self.model,
        contentType="application/json",
        accept="application/json",
        body=json.dumps(payload),
    )

    response_body = json.loads(response.get("body").read())

    if response_body is None:
        raise ValueError("Model response is empty")

    try:
        return self.image_query_model_driver.process_output(response_body)
    except Exception as e:
        raise ValueError(f"Output is unable to be processed as returned {e}") from e

AmazonBedrockPromptDriver

Bases: BasePromptDriver

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@define
class AmazonBedrockPromptDriver(BasePromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
        kw_only=True,
    )
    additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        response = self.bedrock_client.converse(**self._base_params(prompt_stack))

        usage = response["usage"]
        output_message = response["output"]["message"]

        return Message(
            content=[self.__to_prompt_stack_message_content(content) for content in output_message["content"]],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack))

        stream = response.get("stream")
        if stream is not None:
            for event in stream:
                if "contentBlockDelta" in event or "contentBlockStart" in event:
                    yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
                elif "metadata" in event:
                    usage = event["metadata"]["usage"]
                    yield DeltaMessage(
                        usage=DeltaMessage.Usage(
                            input_tokens=usage["inputTokens"],
                            output_tokens=usage["outputTokens"],
                        ),
                    )
        else:
            raise Exception("model response is empty")

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]

        messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

        return {
            "modelId": self.model,
            "messages": messages,
            "system": system_messages,
            "inferenceConfig": {"temperature": self.temperature, "maxTokens": self.max_tokens},
            "additionalModelRequestFields": self.additional_model_request_fields,
            **(
                {"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
                if prompt_stack.tools and self.use_native_tools
                else {}
            ),
        }

    def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
        return [
            {
                "role": self.__to_bedrock_role(message),
                "content": [self.__to_bedrock_message_content(content) for content in message.content],
            }
            for message in messages
        ]

    def __to_bedrock_role(self, message: Message) -> str:
        if message.is_assistant():
            return "assistant"
        else:
            return "user"

    def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]:
        return [
            {
                "toolSpec": {
                    "name": tool.to_native_tool_name(activity),
                    "description": tool.activity_description(activity),
                    "inputSchema": {
                        "json": (tool.activity_schema(activity) or Schema({})).json_schema(
                            "http://json-schema.org/draft-07/schema#",
                        ),
                    },
                },
            }
            for tool in tools
            for activity in tool.activities()
        ]

    def __to_bedrock_message_content(self, content: BaseMessageContent) -> dict:
        if isinstance(content, TextMessageContent):
            return {"text": content.artifact.to_text()}
        elif isinstance(content, ImageMessageContent):
            artifact = content.artifact

            return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
        elif isinstance(content, ActionCallMessageContent):
            action_call = content.artifact.value

            return {
                "toolUse": {
                    "toolUseId": action_call.tag,
                    "name": f"{action_call.name}_{action_call.path}",
                    "input": action_call.input,
                },
            }
        elif isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            if isinstance(artifact, ListArtifact):
                message_content = [self.__to_bedrock_tool_use_content(artifact) for artifact in artifact.value]
            else:
                message_content = [self.__to_bedrock_tool_use_content(artifact)]

            return {
                "toolResult": {
                    "toolUseId": content.action.tag,
                    "content": message_content,
                    "status": "error" if isinstance(artifact, ErrorArtifact) else "success",
                },
            }
        else:
            raise ValueError(f"Unsupported content type: {type(content)}")

    def __to_bedrock_tool_use_content(self, artifact: BaseArtifact) -> dict:
        if isinstance(artifact, ImageArtifact):
            return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
        elif isinstance(artifact, (TextArtifact, ErrorArtifact, InfoArtifact)):
            return {"text": artifact.to_text()}
        else:
            raise ValueError(f"Unsupported artifact type: {type(artifact)}")

    def __to_prompt_stack_message_content(self, content: dict) -> BaseMessageContent:
        if "text" in content:
            return TextMessageContent(TextArtifact(content["text"]))
        elif "toolUse" in content:
            name, path = ToolAction.from_native_tool_name(content["toolUse"]["name"])
            return ActionCallMessageContent(
                artifact=ActionArtifact(
                    value=ToolAction(
                        tag=content["toolUse"]["toolUseId"],
                        name=name,
                        path=path,
                        input=content["toolUse"]["input"],
                    ),
                ),
            )
        else:
            raise ValueError(f"Unsupported message content type: {content}")

    def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessageContent:
        if "contentBlockStart" in event:
            content_block = event["contentBlockStart"]["start"]

            if "toolUse" in content_block:
                name, path = ToolAction.from_native_tool_name(content_block["toolUse"]["name"])

                return ActionCallDeltaMessageContent(
                    index=event["contentBlockStart"]["contentBlockIndex"],
                    tag=content_block["toolUse"]["toolUseId"],
                    name=name,
                    path=path,
                )
            elif "text" in content_block:
                return TextDeltaMessageContent(
                    content_block["text"],
                    index=event["contentBlockStart"]["contentBlockIndex"],
                )
            else:
                raise ValueError(f"Unsupported message content type: {event}")
        elif "contentBlockDelta" in event:
            content_block_delta = event["contentBlockDelta"]

            if "text" in content_block_delta["delta"]:
                return TextDeltaMessageContent(
                    content_block_delta["delta"]["text"],
                    index=content_block_delta["contentBlockIndex"],
                )
            elif "toolUse" in content_block_delta["delta"]:
                return ActionCallDeltaMessageContent(
                    index=content_block_delta["contentBlockIndex"],
                    partial_input=content_block_delta["delta"]["toolUse"]["input"],
                )
            else:
                raise ValueError(f"Unsupported message content type: {event}")
        else:
            raise ValueError(f"Unsupported message content type: {event}")

additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True) class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

tool_choice: dict = field(default=Factory(lambda: {'auto': {}}), kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

use_native_tools: bool = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_bedrock_message_content(content)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_message_content(self, content: BaseMessageContent) -> dict:
    if isinstance(content, TextMessageContent):
        return {"text": content.artifact.to_text()}
    elif isinstance(content, ImageMessageContent):
        artifact = content.artifact

        return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
    elif isinstance(content, ActionCallMessageContent):
        action_call = content.artifact.value

        return {
            "toolUse": {
                "toolUseId": action_call.tag,
                "name": f"{action_call.name}_{action_call.path}",
                "input": action_call.input,
            },
        }
    elif isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        if isinstance(artifact, ListArtifact):
            message_content = [self.__to_bedrock_tool_use_content(artifact) for artifact in artifact.value]
        else:
            message_content = [self.__to_bedrock_tool_use_content(artifact)]

        return {
            "toolResult": {
                "toolUseId": content.action.tag,
                "content": message_content,
                "status": "error" if isinstance(artifact, ErrorArtifact) else "success",
            },
        }
    else:
        raise ValueError(f"Unsupported content type: {type(content)}")

__to_bedrock_messages(messages)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
    return [
        {
            "role": self.__to_bedrock_role(message),
            "content": [self.__to_bedrock_message_content(content) for content in message.content],
        }
        for message in messages
    ]

__to_bedrock_role(message)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_role(self, message: Message) -> str:
    if message.is_assistant():
        return "assistant"
    else:
        return "user"

__to_bedrock_tool_use_content(artifact)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_tool_use_content(self, artifact: BaseArtifact) -> dict:
    if isinstance(artifact, ImageArtifact):
        return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
    elif isinstance(artifact, (TextArtifact, ErrorArtifact, InfoArtifact)):
        return {"text": artifact.to_text()}
    else:
        raise ValueError(f"Unsupported artifact type: {type(artifact)}")

__to_bedrock_tools(tools)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]:
    return [
        {
            "toolSpec": {
                "name": tool.to_native_tool_name(activity),
                "description": tool.activity_description(activity),
                "inputSchema": {
                    "json": (tool.activity_schema(activity) or Schema({})).json_schema(
                        "http://json-schema.org/draft-07/schema#",
                    ),
                },
            },
        }
        for tool in tools
        for activity in tool.activities()
    ]

__to_prompt_stack_delta_message_content(event)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessageContent:
    if "contentBlockStart" in event:
        content_block = event["contentBlockStart"]["start"]

        if "toolUse" in content_block:
            name, path = ToolAction.from_native_tool_name(content_block["toolUse"]["name"])

            return ActionCallDeltaMessageContent(
                index=event["contentBlockStart"]["contentBlockIndex"],
                tag=content_block["toolUse"]["toolUseId"],
                name=name,
                path=path,
            )
        elif "text" in content_block:
            return TextDeltaMessageContent(
                content_block["text"],
                index=event["contentBlockStart"]["contentBlockIndex"],
            )
        else:
            raise ValueError(f"Unsupported message content type: {event}")
    elif "contentBlockDelta" in event:
        content_block_delta = event["contentBlockDelta"]

        if "text" in content_block_delta["delta"]:
            return TextDeltaMessageContent(
                content_block_delta["delta"]["text"],
                index=content_block_delta["contentBlockIndex"],
            )
        elif "toolUse" in content_block_delta["delta"]:
            return ActionCallDeltaMessageContent(
                index=content_block_delta["contentBlockIndex"],
                partial_input=content_block_delta["delta"]["toolUse"]["input"],
            )
        else:
            raise ValueError(f"Unsupported message content type: {event}")
    else:
        raise ValueError(f"Unsupported message content type: {event}")

__to_prompt_stack_message_content(content)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_prompt_stack_message_content(self, content: dict) -> BaseMessageContent:
    if "text" in content:
        return TextMessageContent(TextArtifact(content["text"]))
    elif "toolUse" in content:
        name, path = ToolAction.from_native_tool_name(content["toolUse"]["name"])
        return ActionCallMessageContent(
            artifact=ActionArtifact(
                value=ToolAction(
                    tag=content["toolUse"]["toolUseId"],
                    name=name,
                    path=path,
                    input=content["toolUse"]["input"],
                ),
            ),
        )
    else:
        raise ValueError(f"Unsupported message content type: {content}")

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    response = self.bedrock_client.converse(**self._base_params(prompt_stack))

    usage = response["usage"]
    output_message = response["output"]["message"]

    return Message(
        content=[self.__to_prompt_stack_message_content(content) for content in output_message["content"]],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]),
    )

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack))

    stream = response.get("stream")
    if stream is not None:
        for event in stream:
            if "contentBlockDelta" in event or "contentBlockStart" in event:
                yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
            elif "metadata" in event:
                usage = event["metadata"]["usage"]
                yield DeltaMessage(
                    usage=DeltaMessage.Usage(
                        input_tokens=usage["inputTokens"],
                        output_tokens=usage["outputTokens"],
                    ),
                )
    else:
        raise Exception("model response is empty")

AmazonBedrockTitanEmbeddingDriver

Bases: BaseEmbeddingDriver

Amazon Bedrock Titan Embedding Driver.

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

tokenizer BaseTokenizer

Optionally provide custom BedrockTitanTokenizer.

session Session

Optionally provide custom boto3.Session.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@define
class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
    """Amazon Bedrock Titan Embedding Driver.

    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
        session: Optionally provide custom `boto3.Session`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "amazon.titan-embed-text-v1"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True),
        kw_only=True,
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"inputText": chunk}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload),
            modelId=self.model,
            accept="application/json",
            contentType="application/json",
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embedding")

DEFAULT_MODEL = 'amazon.titan-embed-text-v1' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"inputText": chunk}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload),
        modelId=self.model,
        accept="application/json",
        contentType="application/json",
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embedding")

AmazonDynamoDbConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@define
class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    partition_key: str = field(kw_only=True, metadata={"serializable": True})
    value_attribute_key: str = field(kw_only=True, metadata={"serializable": True})
    partition_key_value: str = field(kw_only=True, metadata={"serializable": True})
    sort_key: Optional[str] = field(default=None, metadata={"serializable": True})
    sort_key_value: Optional[str | int] = field(default=None, metadata={"serializable": True})

    table: Any = field(init=False)

    def __attrs_post_init__(self) -> None:
        self.table = self.session.resource("dynamodb").Table(self.table_name)

    def store(self, runs: list[Run], metadata: dict) -> None:
        self.table.update_item(
            Key=self._get_key(),
            UpdateExpression="set #attr = :value",
            ExpressionAttributeNames={"#attr": self.value_attribute_key},
            ExpressionAttributeValues={
                ":value": json.dumps(self._to_params_dict(runs, metadata)),
            },
        )

    def load(self) -> tuple[list[Run], dict[str, Any]]:
        response = self.table.get_item(Key=self._get_key())

        if "Item" in response and self.value_attribute_key in response["Item"]:
            memory_dict = json.loads(response["Item"][self.value_attribute_key])
            return self._from_params_dict(memory_dict)
        else:
            return [], {}

    def _get_key(self) -> dict[str, str | int]:
        key: dict[str, str | int] = {self.partition_key: self.partition_key_value}

        if self.sort_key is not None and self.sort_key_value is not None:
            key[self.sort_key] = self.sort_key_value

        return key

partition_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

partition_key_value: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

sort_key: Optional[str] = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

sort_key_value: Optional[str | int] = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

table: Any = field(init=False) class-attribute instance-attribute

table_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

value_attribute_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def __attrs_post_init__(self) -> None:
    self.table = self.session.resource("dynamodb").Table(self.table_name)

load()

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]:
    response = self.table.get_item(Key=self._get_key())

    if "Item" in response and self.value_attribute_key in response["Item"]:
        memory_dict = json.loads(response["Item"][self.value_attribute_key])
        return self._from_params_dict(memory_dict)
    else:
        return [], {}

store(runs, metadata)

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict) -> None:
    self.table.update_item(
        Key=self._get_key(),
        UpdateExpression="set #attr = :value",
        ExpressionAttributeNames={"#attr": self.value_attribute_key},
        ExpressionAttributeValues={
            ":value": json.dumps(self._to_params_dict(runs, metadata)),
        },
    )

AmazonOpenSearchVectorStoreDriver

Bases: OpenSearchVectorStoreDriver

A Vector Store Driver for Amazon OpenSearch.

Attributes:

Name Type Description
session Session

The boto3 session to use.

service str

Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.

http_auth str | tuple[str, str]

The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.

client OpenSearch

An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.

Source code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
@define
class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver):
    """A Vector Store Driver for Amazon OpenSearch.

    Attributes:
        session: The boto3 session to use.
        service: Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.
        http_auth: The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.
        client: An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.
    """

    session: Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    service: str = field(default="es", kw_only=True)
    http_auth: str | tuple[str, str] = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").AWSV4SignerAuth(
                self.session.get_credentials(),
                self.session.region_name,
                self.service,
            ),
            takes_self=True,
        ),
    )

    client: OpenSearch = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").OpenSearch(
                hosts=[{"host": self.host, "port": self.port}],
                http_auth=self.http_auth,
                use_ssl=self.use_ssl,
                verify_certs=self.verify_certs,
                connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection,
            ),
            takes_self=True,
        ),
    )

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in OpenSearch.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """
        vector_id = vector_id or str_to_hash(str(vector))
        doc = {"vector": vector, "namespace": namespace, "metadata": meta}
        doc.update(kwargs)
        if self.service == "aoss":
            response = self.client.index(index=self.index_name, body=doc)
        else:
            response = self.client.index(index=self.index_name, id=vector_id, body=doc)

        return response["_id"]

client: OpenSearch = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').OpenSearch(hosts=[{'host': self.host, 'port': self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=import_optional_dependency('opensearchpy').RequestsHttpConnection), takes_self=True)) class-attribute instance-attribute

http_auth: str | tuple[str, str] = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').AWSV4SignerAuth(self.session.get_credentials(), self.session.region_name, self.service), takes_self=True)) class-attribute instance-attribute

service: str = field(default='es', kw_only=True) class-attribute instance-attribute

session: Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in OpenSearch.

If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided.

Source code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in OpenSearch.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """
    vector_id = vector_id or str_to_hash(str(vector))
    doc = {"vector": vector, "namespace": namespace, "metadata": meta}
    doc.update(kwargs)
    if self.service == "aoss":
        response = self.client.index(index=self.index_name, body=doc)
    else:
        response = self.client.index(index=self.index_name, id=vector_id, body=doc)

    return response["_id"]

AmazonRedshiftSqlDriver

Bases: BaseSqlDriver

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@define
class AmazonRedshiftSqlDriver(BaseSqlDriver):
    database: str = field(kw_only=True)
    session: boto3.Session = field(kw_only=True)
    cluster_identifier: Optional[str] = field(default=None, kw_only=True)
    workgroup_name: Optional[str] = field(default=None, kw_only=True)
    db_user: Optional[str] = field(default=None, kw_only=True)
    database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True)
    wait_for_query_completion_sec: float = field(default=0.3, kw_only=True)
    client: Any = field(
        default=Factory(lambda self: self.session.client("redshift-data"), takes_self=True),
        kw_only=True,
    )

    @workgroup_name.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None:
        if not self.cluster_identifier and not self.workgroup_name:
            raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
        if self.cluster_identifier and self.workgroup_name:
            raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

    @classmethod
    def _process_rows_from_records(cls, records: list) -> list[list]:
        return [[c[list(c.keys())[0]] for c in r] for r in records]

    @classmethod
    def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]:
        return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]

    @classmethod
    def _process_columns_from_column_metadata(cls, meta: dict) -> list:
        return [k["name"] for k in meta]

    @classmethod
    def _post_process(cls, meta: dict, records: list) -> list[dict[str, Any]]:
        columns = cls._process_columns_from_column_metadata(meta)
        rows = cls._process_rows_from_records(records)
        return cls._process_cells_from_rows_and_columns(columns, rows)

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)
        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
        function_kwargs = {"Sql": query, "Database": self.database}
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn

        response = self.client.execute_statement(**function_kwargs)
        response_id = response["Id"]

        statement = self.client.describe_statement(Id=response_id)

        while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
            time.sleep(self.wait_for_query_completion_sec)
            statement = self.client.describe_statement(Id=response_id)

        if statement["Status"] == "FINISHED":
            statement_result = self.client.get_statement_result(Id=response_id)
            results = statement_result.get("Records", [])

            while "NextToken" in statement_result:
                statement_result = self.client.get_statement_result(
                    Id=response_id,
                    NextToken=statement_result["NextToken"],
                )
                results = results + response.get("Records", [])

            return self._post_process(statement_result["ColumnMetadata"], results)

        elif statement["Status"] in ["FAILED", "ABORTED"]:
            return None
        return None

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        function_kwargs = {"Database": self.database, "Table": table_name}
        if schema:
            function_kwargs["Schema"] = schema
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn
        response = self.client.describe_table(**function_kwargs)
        return str([col["name"] for col in response["ColumnList"]])

client: Any = field(default=Factory(lambda self: self.session.client('redshift-data'), takes_self=True), kw_only=True) class-attribute instance-attribute

cluster_identifier: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

database: str = field(kw_only=True) class-attribute instance-attribute

database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

db_user: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(kw_only=True) class-attribute instance-attribute

wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) class-attribute instance-attribute

workgroup_name: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

execute_query(query)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)
    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
    function_kwargs = {"Sql": query, "Database": self.database}
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn

    response = self.client.execute_statement(**function_kwargs)
    response_id = response["Id"]

    statement = self.client.describe_statement(Id=response_id)

    while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
        time.sleep(self.wait_for_query_completion_sec)
        statement = self.client.describe_statement(Id=response_id)

    if statement["Status"] == "FINISHED":
        statement_result = self.client.get_statement_result(Id=response_id)
        results = statement_result.get("Records", [])

        while "NextToken" in statement_result:
            statement_result = self.client.get_statement_result(
                Id=response_id,
                NextToken=statement_result["NextToken"],
            )
            results = results + response.get("Records", [])

        return self._post_process(statement_result["ColumnMetadata"], results)

    elif statement["Status"] in ["FAILED", "ABORTED"]:
        return None
    return None

get_table_schema(table_name, schema=None)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    function_kwargs = {"Database": self.database, "Table": table_name}
    if schema:
        function_kwargs["Schema"] = schema
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn
    response = self.client.describe_table(**function_kwargs)
    return str([col["name"] for col in response["ColumnList"]])

validate_params(_, workgroup_name)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None:
    if not self.cluster_identifier and not self.workgroup_name:
        raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
    if self.cluster_identifier and self.workgroup_name:
        raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

AmazonS3FileManagerDriver

Bases: BaseFileManagerDriver

AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

Attributes:

Name Type Description
session Session

The boto3 session to use for S3 operations.

bucket str

The name of the S3 bucket.

workdir str

The absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory.

s3_client Any

The S3 client to use for S3 operations.

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@define
class AmazonS3FileManagerDriver(BaseFileManagerDriver):
    """AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

    Attributes:
        session: The boto3 session to use for S3 operations.
        bucket: The name of the S3 bucket.
        workdir: The absolute working directory (must start with "/"). List, load, and save
            operations will be performed relative to this directory.
        s3_client: The S3 client to use for S3 operations.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bucket: str = field(kw_only=True)
    workdir: str = field(default="/", kw_only=True)
    s3_client: Any = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True)

    @workdir.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_workdir(self, _: Attribute, workdir: str) -> None:
        if not workdir.startswith("/"):
            raise ValueError("Workdir must be an absolute path")

    def try_list_files(self, path: str) -> list[str]:
        full_key = self._to_dir_full_key(path)
        files_and_dirs = self._list_files_and_dirs(full_key)
        if len(files_and_dirs) == 0:
            if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
                raise NotADirectoryError
            raise FileNotFoundError
        return files_and_dirs

    def try_load_file(self, path: str) -> bytes:
        botocore = import_optional_dependency("botocore")
        full_key = self._to_full_key(path)

        if self._is_a_directory(full_key):
            raise IsADirectoryError

        try:
            response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key)
            return response["Body"].read()
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                raise FileNotFoundError from e
            raise e

    def try_save_file(self, path: str, value: bytes) -> None:
        full_key = self._to_full_key(path)
        if self._is_a_directory(full_key):
            raise IsADirectoryError
        self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value)

    def _to_full_key(self, path: str) -> str:
        path = path.lstrip("/")
        full_key = f"{self.workdir}/{path}"
        # Need to keep the trailing slash if it was there,
        # because it means the path is a directory.
        ended_with_slash = path.endswith("/")

        full_key = self._normpath(full_key)

        if ended_with_slash:
            full_key += "/"
        return full_key.lstrip("/")

    def _to_dir_full_key(self, path: str) -> str:
        full_key = self._to_full_key(path)
        # S3 "directories" always end with a slash, except for the root.
        if full_key != "" and not full_key.endswith("/"):
            full_key += "/"
        return full_key

    def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
        max_items = kwargs.get("max_items")
        pagination_config = {}
        if max_items is not None:
            pagination_config["MaxItems"] = max_items

        paginator = self.s3_client.get_paginator("list_objects_v2")
        pages = paginator.paginate(
            Bucket=self.bucket,
            Prefix=full_key,
            Delimiter="/",
            PaginationConfig=pagination_config,
        )
        files_and_dirs = []
        for page in pages:
            for obj in page.get("CommonPrefixes", []):
                prefix = obj.get("Prefix")
                directory = prefix[len(full_key) :].rstrip("/")
                files_and_dirs.append(directory)

            for obj in page.get("Contents", []):
                key = obj.get("Key")
                file = key[len(full_key) :]
                files_and_dirs.append(file)
        return files_and_dirs

    def _is_a_directory(self, full_key: str) -> bool:
        botocore = import_optional_dependency("botocore")
        if full_key == "" or full_key.endswith("/"):
            return True

        try:
            self.s3_client.head_object(Bucket=self.bucket, Key=full_key)
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                return len(self._list_files_and_dirs(full_key, max_items=1)) > 0
            else:
                raise e

        return False

    def _normpath(self, path: str) -> str:
        unix_path = path.replace("\\", "/")
        parts = unix_path.split("/")
        stack = []

        for part in parts:
            if part == "" or part == ".":
                continue
            if part == "..":
                if stack:
                    stack.pop()
            else:
                stack.append(part)

        return "/".join(stack)

bucket: str = field(kw_only=True) class-attribute instance-attribute

s3_client: Any = field(default=Factory(lambda self: self.session.client('s3'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

workdir: str = field(default='/', kw_only=True) class-attribute instance-attribute

try_list_files(path)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]:
    full_key = self._to_dir_full_key(path)
    files_and_dirs = self._list_files_and_dirs(full_key)
    if len(files_and_dirs) == 0:
        if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
            raise NotADirectoryError
        raise FileNotFoundError
    return files_and_dirs

try_load_file(path)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_load_file(self, path: str) -> bytes:
    botocore = import_optional_dependency("botocore")
    full_key = self._to_full_key(path)

    if self._is_a_directory(full_key):
        raise IsADirectoryError

    try:
        response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key)
        return response["Body"].read()
    except botocore.exceptions.ClientError as e:
        if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
            raise FileNotFoundError from e
        raise e

try_save_file(path, value)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_save_file(self, path: str, value: bytes) -> None:
    full_key = self._to_full_key(path)
    if self._is_a_directory(full_key):
        raise IsADirectoryError
    self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value)

validate_workdir(_, workdir)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@workdir.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_workdir(self, _: Attribute, workdir: str) -> None:
    if not workdir.startswith("/"):
        raise ValueError("Workdir must be an absolute path")

AmazonSageMakerJumpstartEmbeddingDriver

Bases: BaseEmbeddingDriver

Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True),
        kw_only=True,
    )
    endpoint: str = field(kw_only=True, metadata={"serializable": True})
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"text_inputs": chunk, "mode": "embedding"}

        endpoint_response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=json.dumps(payload).encode("utf-8"),
            CustomAttributes=self.custom_attributes,
            **(
                {"InferenceComponentName": self.inference_component_name}
                if self.inference_component_name is not None
                else {}
            ),
        )

        response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))

        if "embedding" in response:
            embedding = response["embedding"]

            if embedding:
                if isinstance(embedding[0], list):
                    return embedding[0]
                else:
                    return embedding
            else:
                raise ValueError("model response is empty")
        else:
            raise ValueError("invalid response from model")

custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda self: self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"text_inputs": chunk, "mode": "embedding"}

    endpoint_response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.endpoint,
        ContentType="application/json",
        Body=json.dumps(payload).encode("utf-8"),
        CustomAttributes=self.custom_attributes,
        **(
            {"InferenceComponentName": self.inference_component_name}
            if self.inference_component_name is not None
            else {}
        ),
    )

    response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))

    if "embedding" in response:
        embedding = response["embedding"]

        if embedding:
            if isinstance(embedding[0], list):
                return embedding[0]
            else:
                return embedding
        else:
            raise ValueError("model response is empty")
    else:
        raise ValueError("invalid response from model")

AmazonSageMakerJumpstartPromptDriver

Bases: BasePromptDriver

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@define
class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True),
        kw_only=True,
    )
    endpoint: str = field(kw_only=True, metadata={"serializable": True})
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )

    @stream.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_stream(self, _: Attribute, stream: bool) -> None:  # noqa: FBT001
        if stream:
            raise ValueError("streaming is not supported")

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        payload = {
            "inputs": self.prompt_stack_to_string(prompt_stack),
            "parameters": {**self._base_params(prompt_stack)},
        }

        response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
            **(
                {"InferenceComponentName": self.inference_component_name}
                if self.inference_component_name is not None
                else {}
            ),
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))

        if isinstance(decoded_body, list):
            if decoded_body:
                generated_text = decoded_body[0]["generated_text"]
            else:
                raise ValueError("model response is empty")
        else:
            generated_text = decoded_body["generated_text"]

        input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
        output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))

        return Message(
            content=[TextMessageContent(TextArtifact(generated_text))],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        raise NotImplementedError("streaming is not supported")

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "temperature": self.temperature,
            "max_new_tokens": self.max_tokens,
            "do_sample": True,
            "eos_token_id": self.tokenizer.tokenizer.eos_token_id,
            "stop_strings": self.tokenizer.stop_sequences,
            "return_full_text": False,
        }

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
        messages = []

        for message in prompt_stack.messages:
            messages.append({"role": message.role, "content": message.to_text()})

        return messages

    def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
        messages = self._prompt_stack_to_messages(prompt_stack)

        tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

        if isinstance(tokens, list):
            return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
        else:
            raise ValueError("Invalid o