Skip to content

Drivers

__all__ = ['BasePromptDriver', 'OpenAiChatPromptDriver', 'AzureOpenAiChatPromptDriver', 'CoherePromptDriver', 'HuggingFacePipelinePromptDriver', 'HuggingFaceHubPromptDriver', 'AnthropicPromptDriver', 'AmazonSageMakerJumpstartPromptDriver', 'AmazonBedrockPromptDriver', 'GooglePromptDriver', 'DummyPromptDriver', 'BaseConversationMemoryDriver', 'LocalConversationMemoryDriver', 'AmazonDynamoDbConversationMemoryDriver', 'RedisConversationMemoryDriver', 'BaseEmbeddingDriver', 'OpenAiEmbeddingDriver', 'AzureOpenAiEmbeddingDriver', 'AmazonSageMakerJumpstartEmbeddingDriver', 'AmazonBedrockTitanEmbeddingDriver', 'AmazonBedrockCohereEmbeddingDriver', 'VoyageAiEmbeddingDriver', 'HuggingFaceHubEmbeddingDriver', 'GoogleEmbeddingDriver', 'DummyEmbeddingDriver', 'CohereEmbeddingDriver', 'BaseVectorStoreDriver', 'LocalVectorStoreDriver', 'PineconeVectorStoreDriver', 'MarqoVectorStoreDriver', 'MongoDbAtlasVectorStoreDriver', 'AzureMongoDbVectorStoreDriver', 'RedisVectorStoreDriver', 'OpenSearchVectorStoreDriver', 'AmazonOpenSearchVectorStoreDriver', 'PgVectorVectorStoreDriver', 'DummyVectorStoreDriver', 'BaseSqlDriver', 'AmazonRedshiftSqlDriver', 'SnowflakeSqlDriver', 'SqlDriver', 'BaseImageGenerationModelDriver', 'BedrockStableDiffusionImageGenerationModelDriver', 'BedrockTitanImageGenerationModelDriver', 'BaseImageGenerationDriver', 'BaseMultiModelImageGenerationDriver', 'OpenAiImageGenerationDriver', 'LeonardoImageGenerationDriver', 'AmazonBedrockImageGenerationDriver', 'AzureOpenAiImageGenerationDriver', 'DummyImageGenerationDriver', 'BaseImageQueryModelDriver', 'BedrockClaudeImageQueryModelDriver', 'BaseImageQueryDriver', 'OpenAiImageQueryDriver', 'AzureOpenAiImageQueryDriver', 'DummyImageQueryDriver', 'AnthropicImageQueryDriver', 'BaseMultiModelImageQueryDriver', 'AmazonBedrockImageQueryDriver', 'BaseWebScraperDriver', 'TrafilaturaWebScraperDriver', 'MarkdownifyWebScraperDriver', 'BaseEventListenerDriver', 'AmazonSqsEventListenerDriver', 'WebhookEventListenerDriver', 'AwsIotCoreEventListenerDriver', 'GriptapeCloudEventListenerDriver', 'PusherEventListenerDriver', 'BaseFileManagerDriver', 'LocalFileManagerDriver', 'AmazonS3FileManagerDriver', 'BaseTextToSpeechDriver', 'DummyTextToSpeechDriver', 'ElevenLabsTextToSpeechDriver', 'OpenAiTextToSpeechDriver', 'BaseStructureRunDriver', 'GriptapeCloudStructureRunDriver', 'LocalStructureRunDriver', 'BaseAudioTranscriptionDriver', 'DummyAudioTranscriptionDriver', 'OpenAiAudioTranscriptionDriver'] module-attribute

AmazonBedrockCohereEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

input_type str

Defaults to search_query. Prepends special tokens to differentiate each type from one another: search_document when you encode documents for embeddings that you store in a vector database. search_query when querying your vector DB to find relevant documents.

session Session

Optionally provide custom boto3.Session.

tokenizer BaseTokenizer

Optionally provide custom BedrockCohereTokenizer.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@define
class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        input_type: Defaults to `search_query`. Prepends special tokens to differentiate each type from one another:
            `search_document` when you encode documents for embeddings that you store in a vector database.
            `search_query` when querying your vector DB to find relevant documents.
        session: Optionally provide custom `boto3.Session`.
        tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "cohere.embed-english-v3"

    model: str = field(default=DEFAULT_MODEL, kw_only=True)
    input_type: str = field(default="search_query", kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"input_type": self.input_type, "texts": [chunk]}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json"
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embeddings")[0]

DEFAULT_MODEL = 'cohere.embed-english-v3' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

input_type: str = field(default='search_query', kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"input_type": self.input_type, "texts": [chunk]}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json"
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embeddings")[0]

AmazonBedrockImageGenerationDriver

Bases: BaseMultiModelImageGenerationDriver

Driver for image generation models provided by Amazon Bedrock.

Attributes:

Name Type Description
model

Bedrock model ID.

session Session

boto3 session.

bedrock_client Any

Bedrock runtime client.

image_width int

Width of output images. Defaults to 512 and must be a multiple of 64.

image_height int

Height of output images. Defaults to 512 and must be a multiple of 64.

seed Optional[int]

Optionally provide a consistent seed to generation requests, increasing consistency in output.

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@define
class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
    """Driver for image generation models provided by Amazon Bedrock.

    Attributes:
        model: Bedrock model ID.
        session: boto3 session.
        bedrock_client: Bedrock runtime client.
        image_width: Width of output images. Defaults to 512 and must be a multiple of 64.
        image_height: Height of output images. Defaults to 512 and must be a multiple of 64.
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client(service_name="bedrock-runtime"), takes_self=True)
    )
    image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
    image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        request = self.image_generation_model_driver.text_to_image_request_parameters(
            prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=self.image_width,
            height=self.image_height,
            model=self.model,
        )

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_variation_request_parameters(
            prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_inpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_outpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def _make_request(self, request: dict) -> bytes:
        response = self.bedrock_client.invoke_model(
            body=json.dumps(request), modelId=self.model, accept="application/json", contentType="application/json"
        )

        response_body = json.loads(response.get("body").read())

        try:
            image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
        except Exception as e:
            raise ValueError(f"Inpainting generation failed: {e}")

        return image_bytes

bedrock_client: Any = field(default=Factory(lambda self: self.session.client(service_name='bedrock-runtime'), takes_self=True)) class-attribute instance-attribute

image_height: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

image_width: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

seed: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_inpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_outpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_variation_request_parameters(
        prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    request = self.image_generation_model_driver.text_to_image_request_parameters(
        prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=self.image_width,
        height=self.image_height,
        model=self.model,
    )

AmazonBedrockImageQueryDriver

Bases: BaseMultiModelImageQueryDriver

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
@define
class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

        response = self.bedrock_client.invoke_model(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = json.loads(response.get("body").read())

        if response_body is None:
            raise ValueError("Model response is empty")

        try:
            return self.image_query_model_driver.process_output(response_body)
        except Exception as e:
            raise ValueError(f"Output is unable to be processed as returned {e}")

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

    response = self.bedrock_client.invoke_model(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = json.loads(response.get("body").read())

    if response_body is None:
        raise ValueError("Model response is empty")

    try:
        return self.image_query_model_driver.process_output(response_body)
    except Exception as e:
        raise ValueError(f"Output is unable to be processed as returned {e}")

AmazonBedrockPromptDriver

Bases: BasePromptDriver

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@define
class AmazonBedrockPromptDriver(BasePromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )
    additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        response = self.bedrock_client.converse(**self._base_params(prompt_stack))

        output_message = response["output"]["message"]
        output_content = output_message["content"][0]["text"]

        return TextArtifact(output_content)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack))

        stream = response.get("stream")
        if stream is not None:
            for event in stream:
                if "contentBlockDelta" in event:
                    yield TextArtifact(event["contentBlockDelta"]["delta"]["text"])
        else:
            raise Exception("model response is empty")

    def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict:
        content = [{"text": prompt_input.content}]

        if prompt_input.is_system():
            return {"text": prompt_input.content}
        elif prompt_input.is_assistant():
            return {"role": "assistant", "content": content}
        else:
            return {"role": "user", "content": content}

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        system_messages = [
            self._prompt_stack_input_to_message(input)
            for input in prompt_stack.inputs
            if input.is_system() and input.content
        ]
        messages = [
            self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs if not input.is_system()
        ]

        return {
            "modelId": self.model,
            "messages": messages,
            "system": system_messages,
            "inferenceConfig": {"temperature": self.temperature},
            "additionalModelRequestFields": self.additional_model_request_fields,
        }

additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True) class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    response = self.bedrock_client.converse(**self._base_params(prompt_stack))

    output_message = response["output"]["message"]
    output_content = output_message["content"][0]["text"]

    return TextArtifact(output_content)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack))

    stream = response.get("stream")
    if stream is not None:
        for event in stream:
            if "contentBlockDelta" in event:
                yield TextArtifact(event["contentBlockDelta"]["delta"]["text"])
    else:
        raise Exception("model response is empty")

AmazonBedrockTitanEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

tokenizer BaseTokenizer

Optionally provide custom BedrockTitanTokenizer.

session Session

Optionally provide custom boto3.Session.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@define
class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
        session: Optionally provide custom `boto3.Session`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "amazon.titan-embed-text-v1"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"inputText": chunk}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embedding")

DEFAULT_MODEL = 'amazon.titan-embed-text-v1' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"inputText": chunk}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embedding")

AmazonDynamoDbConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@define
class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    partition_key: str = field(kw_only=True, metadata={"serializable": True})
    value_attribute_key: str = field(kw_only=True, metadata={"serializable": True})
    partition_key_value: str = field(kw_only=True, metadata={"serializable": True})

    table: Any = field(init=False)

    def __attrs_post_init__(self) -> None:
        dynamodb = self.session.resource("dynamodb")

        self.table = dynamodb.Table(self.table_name)

    def store(self, memory: BaseConversationMemory) -> None:
        self.table.update_item(
            Key={self.partition_key: self.partition_key_value},
            UpdateExpression="set #attr = :value",
            ExpressionAttributeNames={"#attr": self.value_attribute_key},
            ExpressionAttributeValues={":value": memory.to_json()},
        )

    def load(self) -> Optional[BaseConversationMemory]:
        response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

        if "Item" in response and self.value_attribute_key in response["Item"]:
            memory_value = response["Item"][self.value_attribute_key]

            memory = BaseConversationMemory.from_json(memory_value)

            memory.driver = self

            return memory
        else:
            return None

partition_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

partition_key_value: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

table: Any = field(init=False) class-attribute instance-attribute

table_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

value_attribute_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def __attrs_post_init__(self) -> None:
    dynamodb = self.session.resource("dynamodb")

    self.table = dynamodb.Table(self.table_name)

load()

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def load(self) -> Optional[BaseConversationMemory]:
    response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

    if "Item" in response and self.value_attribute_key in response["Item"]:
        memory_value = response["Item"][self.value_attribute_key]

        memory = BaseConversationMemory.from_json(memory_value)

        memory.driver = self

        return memory
    else:
        return None

store(memory)

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def store(self, memory: BaseConversationMemory) -> None:
    self.table.update_item(
        Key={self.partition_key: self.partition_key_value},
        UpdateExpression="set #attr = :value",
        ExpressionAttributeNames={"#attr": self.value_attribute_key},
        ExpressionAttributeValues={":value": memory.to_json()},
    )

AmazonOpenSearchVectorStoreDriver

Bases: OpenSearchVectorStoreDriver

A Vector Store Driver for Amazon OpenSearch.

Attributes:

Name Type Description
session Session

The boto3 session to use.

service str

Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.

http_auth str | tuple[str, str]

The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.

client OpenSearch

An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.

Source code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
@define
class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver):
    """A Vector Store Driver for Amazon OpenSearch.

    Attributes:
        session: The boto3 session to use.
        service: Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.
        http_auth: The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.
        client: An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.
    """

    session: Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    service: str = field(default="es", kw_only=True)
    http_auth: str | tuple[str, str] = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").AWSV4SignerAuth(
                self.session.get_credentials(), self.session.region_name, self.service
            ),
            takes_self=True,
        )
    )

    client: OpenSearch = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").OpenSearch(
                hosts=[{"host": self.host, "port": self.port}],
                http_auth=self.http_auth,
                use_ssl=self.use_ssl,
                verify_certs=self.verify_certs,
                connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection,
            ),
            takes_self=True,
        )
    )

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in OpenSearch.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """

        vector_id = vector_id if vector_id else str_to_hash(str(vector))
        doc = {"vector": vector, "namespace": namespace, "metadata": meta}
        doc.update(kwargs)
        if self.service == "aoss":
            response = self.client.index(index=self.index_name, body=doc)
        else:
            response = self.client.index(index=self.index_name, id=vector_id, body=doc)

        return response["_id"]

client: OpenSearch = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').OpenSearch(hosts=[{'host': self.host, 'port': self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=import_optional_dependency('opensearchpy').RequestsHttpConnection), takes_self=True)) class-attribute instance-attribute

http_auth: str | tuple[str, str] = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').AWSV4SignerAuth(self.session.get_credentials(), self.session.region_name, self.service), takes_self=True)) class-attribute instance-attribute

service: str = field(default='es', kw_only=True) class-attribute instance-attribute

session: Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in OpenSearch.

If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided.

Source code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in OpenSearch.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """

    vector_id = vector_id if vector_id else str_to_hash(str(vector))
    doc = {"vector": vector, "namespace": namespace, "metadata": meta}
    doc.update(kwargs)
    if self.service == "aoss":
        response = self.client.index(index=self.index_name, body=doc)
    else:
        response = self.client.index(index=self.index_name, id=vector_id, body=doc)

    return response["_id"]

AmazonRedshiftSqlDriver

Bases: BaseSqlDriver

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@define
class AmazonRedshiftSqlDriver(BaseSqlDriver):
    database: str = field(kw_only=True)
    session: boto3.Session = field(kw_only=True)
    cluster_identifier: Optional[str] = field(default=None, kw_only=True)
    workgroup_name: Optional[str] = field(default=None, kw_only=True)
    db_user: Optional[str] = field(default=None, kw_only=True)
    database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True)
    wait_for_query_completion_sec: float = field(default=0.3, kw_only=True)
    client: Any = field(
        default=Factory(lambda self: self.session.client("redshift-data"), takes_self=True), kw_only=True
    )

    @workgroup_name.validator  # pyright: ignore
    def validate_params(self, _, workgroup_name: Optional[str]) -> None:
        if not self.cluster_identifier and not self.workgroup_name:
            raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
        elif self.cluster_identifier and self.workgroup_name:
            raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

    @classmethod
    def _process_rows_from_records(cls, records) -> list[list]:
        return [[c[list(c.keys())[0]] for c in r] for r in records]

    @classmethod
    def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]:
        return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]

    @classmethod
    def _process_columns_from_column_metadata(cls, meta) -> list:
        return [k["name"] for k in meta]

    @classmethod
    def _post_process(cls, meta, records) -> list[dict[str, Any]]:
        columns = cls._process_columns_from_column_metadata(meta)
        rows = cls._process_rows_from_records(records)
        return cls._process_cells_from_rows_and_columns(columns, rows)

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)
        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
        function_kwargs = {"Sql": query, "Database": self.database}
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn

        response = self.client.execute_statement(**function_kwargs)
        response_id = response["Id"]

        statement = self.client.describe_statement(Id=response_id)

        while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
            time.sleep(self.wait_for_query_completion_sec)
            statement = self.client.describe_statement(Id=response_id)

        if statement["Status"] == "FINISHED":
            statement_result = self.client.get_statement_result(Id=response_id)
            results = statement_result.get("Records", [])

            while "NextToken" in statement_result:
                statement_result = self.client.get_statement_result(
                    Id=response_id, NextToken=statement_result["NextToken"]
                )
                results = results + response.get("Records", [])

            return self._post_process(statement_result["ColumnMetadata"], results)

        elif statement["Status"] in ["FAILED", "ABORTED"]:
            return None

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        function_kwargs = {"Database": self.database, "Table": table_name}
        if schema:
            function_kwargs["Schema"] = schema
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn
        response = self.client.describe_table(**function_kwargs)
        return str([col["name"] for col in response["ColumnList"]])

client: Any = field(default=Factory(lambda self: self.session.client('redshift-data'), takes_self=True), kw_only=True) class-attribute instance-attribute

cluster_identifier: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

database: str = field(kw_only=True) class-attribute instance-attribute

database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

db_user: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(kw_only=True) class-attribute instance-attribute

wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) class-attribute instance-attribute

workgroup_name: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

execute_query(query)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)
    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
    function_kwargs = {"Sql": query, "Database": self.database}
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn

    response = self.client.execute_statement(**function_kwargs)
    response_id = response["Id"]

    statement = self.client.describe_statement(Id=response_id)

    while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
        time.sleep(self.wait_for_query_completion_sec)
        statement = self.client.describe_statement(Id=response_id)

    if statement["Status"] == "FINISHED":
        statement_result = self.client.get_statement_result(Id=response_id)
        results = statement_result.get("Records", [])

        while "NextToken" in statement_result:
            statement_result = self.client.get_statement_result(
                Id=response_id, NextToken=statement_result["NextToken"]
            )
            results = results + response.get("Records", [])

        return self._post_process(statement_result["ColumnMetadata"], results)

    elif statement["Status"] in ["FAILED", "ABORTED"]:
        return None

get_table_schema(table_name, schema=None)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    function_kwargs = {"Database": self.database, "Table": table_name}
    if schema:
        function_kwargs["Schema"] = schema
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn
    response = self.client.describe_table(**function_kwargs)
    return str([col["name"] for col in response["ColumnList"]])

validate_params(_, workgroup_name)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator  # pyright: ignore
def validate_params(self, _, workgroup_name: Optional[str]) -> None:
    if not self.cluster_identifier and not self.workgroup_name:
        raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
    elif self.cluster_identifier and self.workgroup_name:
        raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

AmazonS3FileManagerDriver

Bases: BaseFileManagerDriver

AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

Attributes:

Name Type Description
session Session

The boto3 session to use for S3 operations.

bucket str

The name of the S3 bucket.

workdir str

The absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory.

s3_client Any

The S3 client to use for S3 operations.

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@define
class AmazonS3FileManagerDriver(BaseFileManagerDriver):
    """
    AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

    Attributes:
        session: The boto3 session to use for S3 operations.
        bucket: The name of the S3 bucket.
        workdir: The absolute working directory (must start with "/"). List, load, and save
            operations will be performed relative to this directory.
        s3_client: The S3 client to use for S3 operations.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bucket: str = field(kw_only=True)
    workdir: str = field(default="/", kw_only=True)
    s3_client: Any = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True)

    @workdir.validator  # pyright: ignore
    def validate_workdir(self, _, workdir: str) -> None:
        if not Path(workdir).is_absolute():
            raise ValueError("Workdir must be an absolute path")

    def try_list_files(self, path: str) -> list[str]:
        full_key = self._to_dir_full_key(path)
        files_and_dirs = self._list_files_and_dirs(full_key)
        if len(files_and_dirs) == 0:
            if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
                raise NotADirectoryError
            else:
                raise FileNotFoundError
        return files_and_dirs

    def try_load_file(self, path: str) -> bytes:
        botocore = import_optional_dependency("botocore")
        full_key = self._to_full_key(path)

        if self._is_a_directory(full_key):
            raise IsADirectoryError

        try:
            response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key)
            return response["Body"].read()
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                raise FileNotFoundError
            else:
                raise e

    def try_save_file(self, path: str, value: bytes):
        full_key = self._to_full_key(path)
        if self._is_a_directory(full_key):
            raise IsADirectoryError
        self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value)

    def _to_full_key(self, path: str) -> str:
        path = path.lstrip("/")
        full_key = os.path.join(self.workdir, path)
        # Need to keep the trailing slash if it was there,
        # because it means the path is a directory.
        ended_with_slash = path.endswith("/")
        full_key = os.path.normpath(full_key)
        if ended_with_slash:
            full_key += "/"
        return full_key.lstrip("/")

    def _to_dir_full_key(self, path: str) -> str:
        full_key = self._to_full_key(path)
        # S3 "directories" always end with a slash, except for the root.
        if full_key != "" and not full_key.endswith("/"):
            full_key += "/"
        return full_key

    def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
        max_items = kwargs.get("max_items")
        pagination_config = {}
        if max_items is not None:
            pagination_config["MaxItems"] = max_items

        paginator = self.s3_client.get_paginator("list_objects_v2")
        pages = paginator.paginate(
            Bucket=self.bucket, Prefix=full_key, Delimiter="/", PaginationConfig=pagination_config
        )
        files_and_dirs = []
        for page in pages:
            for obj in page.get("CommonPrefixes", []):
                prefix = obj.get("Prefix")
                dir = prefix[len(full_key) :].rstrip("/")
                files_and_dirs.append(dir)

            for obj in page.get("Contents", []):
                key = obj.get("Key")
                file = key[len(full_key) :]
                files_and_dirs.append(file)
        return files_and_dirs

    def _is_a_directory(self, full_key: str) -> bool:
        botocore = import_optional_dependency("botocore")
        if full_key == "" or full_key.endswith("/"):
            return True

        try:
            self.s3_client.head_object(Bucket=self.bucket, Key=full_key)
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                return len(self._list_files_and_dirs(full_key, max_items=1)) > 0
            else:
                raise e

        return False

bucket: str = field(kw_only=True) class-attribute instance-attribute

s3_client: Any = field(default=Factory(lambda self: self.session.client('s3'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

workdir: str = field(default='/', kw_only=True) class-attribute instance-attribute

try_list_files(path)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]:
    full_key = self._to_dir_full_key(path)
    files_and_dirs = self._list_files_and_dirs(full_key)
    if len(files_and_dirs) == 0:
        if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
            raise NotADirectoryError
        else:
            raise FileNotFoundError
    return files_and_dirs

try_load_file(path)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_load_file(self, path: str) -> bytes:
    botocore = import_optional_dependency("botocore")
    full_key = self._to_full_key(path)

    if self._is_a_directory(full_key):
        raise IsADirectoryError

    try:
        response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key)
        return response["Body"].read()
    except botocore.exceptions.ClientError as e:
        if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
            raise FileNotFoundError
        else:
            raise e

try_save_file(path, value)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_save_file(self, path: str, value: bytes):
    full_key = self._to_full_key(path)
    if self._is_a_directory(full_key):
        raise IsADirectoryError
    self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value)

validate_workdir(_, workdir)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@workdir.validator  # pyright: ignore
def validate_workdir(self, _, workdir: str) -> None:
    if not Path(workdir).is_absolute():
        raise ValueError("Workdir must be an absolute path")

AmazonSageMakerJumpstartEmbeddingDriver

Bases: BaseEmbeddingDriver

Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    endpoint: str = field(kw_only=True, metadata={"serializable": True})
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"text_inputs": chunk, "mode": "embedding"}

        endpoint_response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=json.dumps(payload).encode("utf-8"),
            CustomAttributes=self.custom_attributes,
            **(
                {"InferenceComponentName": self.inference_component_name}
                if self.inference_component_name is not None
                else {}
            ),
        )

        response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))

        if "embedding" in response:
            embedding = response["embedding"]

            if embedding:
                if isinstance(embedding[0], list):
                    return embedding[0]
                else:
                    return embedding
            else:
                raise ValueError("model response is empty")
        else:
            raise ValueError("invalid response from model")

custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda self: self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"text_inputs": chunk, "mode": "embedding"}

    endpoint_response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.endpoint,
        ContentType="application/json",
        Body=json.dumps(payload).encode("utf-8"),
        CustomAttributes=self.custom_attributes,
        **(
            {"InferenceComponentName": self.inference_component_name}
            if self.inference_component_name is not None
            else {}
        ),
    )

    response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))

    if "embedding" in response:
        embedding = response["embedding"]

        if embedding:
            if isinstance(embedding[0], list):
                return embedding[0]
            else:
                return embedding
        else:
            raise ValueError("model response is empty")
    else:
        raise ValueError("invalid response from model")

AmazonSageMakerJumpstartPromptDriver

Bases: BasePromptDriver

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@define
class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    endpoint: str = field(kw_only=True, metadata={"serializable": True})
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True
        ),
        kw_only=True,
    )

    @stream.validator  # pyright: ignore
    def validate_stream(self, _, stream):
        if stream:
            raise ValueError("streaming is not supported")

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        payload = {"inputs": self._to_model_input(prompt_stack), "parameters": self._to_model_params(prompt_stack)}

        response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
            **(
                {"InferenceComponentName": self.inference_component_name}
                if self.inference_component_name is not None
                else {}
            ),
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))

        if isinstance(decoded_body, list):
            if decoded_body:
                return TextArtifact(decoded_body[0]["generated_text"])
            else:
                raise ValueError("model response is empty")
        else:
            return TextArtifact(decoded_body["generated_text"])

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

    def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict:
        return {"role": prompt_input.role, "content": prompt_input.content}

    def _to_model_input(self, prompt_stack: PromptStack) -> str:
        prompt = self.tokenizer.tokenizer.apply_chat_template(
            [self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs],
            tokenize=False,
            add_generation_prompt=True,
        )

        if isinstance(prompt, str):
            return prompt
        else:
            raise ValueError("Invalid output type.")

    def _to_model_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "temperature": self.temperature,
            "max_new_tokens": self.max_tokens,
            "do_sample": True,
            "eos_token_id": self.tokenizer.tokenizer.eos_token_id,
            "stop_strings": self.tokenizer.stop_sequences,
            "return_full_text": False,
        }

custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

max_tokens: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda self: self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    payload = {"inputs": self._to_model_input(prompt_stack), "parameters": self._to_model_params(prompt_stack)}

    response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.endpoint,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes=self.custom_attributes,
        **(
            {"InferenceComponentName": self.inference_component_name}
            if self.inference_component_name is not None
            else {}
        ),
    )

    decoded_body = json.loads(response["Body"].read().decode("utf8"))

    if isinstance(decoded_body, list):
        if decoded_body:
            return TextArtifact(decoded_body[0]["generated_text"])
        else:
            raise ValueError("model response is empty")
    else:
        return TextArtifact(decoded_body["generated_text"])

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@stream.validator  # pyright: ignore
def validate_stream(self, _, stream):
    if stream:
        raise ValueError("streaming is not supported")

AmazonSqsEventListenerDriver

Bases: BaseEventListenerDriver

Source code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
@define
class AmazonSqsEventListenerDriver(BaseEventListenerDriver):
    queue_url: str = field(kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True))

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

    def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
        entries = [
            {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
            for event_payload in event_payload_batch
        ]

        self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)

queue_url: str = field(kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

sqs_client: Any = field(default=Factory(lambda self: self.session.client('sqs'), takes_self=True)) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

try_publish_event_payload_batch(event_payload_batch)

Source code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
    entries = [
        {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
        for event_payload in event_payload_batch
    ]

    self.sqs_client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)

AnthropicImageQueryDriver

Bases: BaseImageQueryDriver

Attributes:

Name Type Description
api_key Optional[str]

Anthropic API key.

model str

Anthropic model name.

client Any

Custom Anthropic client.

Source code in griptape/drivers/image_query/anthropic_image_query_driver.py
@define
class AnthropicImageQueryDriver(BaseImageQueryDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
    """

    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Any = field(
        default=Factory(
            lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        if self.max_tokens is None:
            raise TypeError("max_output_tokens can't be empty")

        response = self.client.messages.create(**self._base_params(query, images))
        content_blocks = response.content

        if len(content_blocks) < 1:
            raise ValueError("Response content is empty")

        text_content = content_blocks[0].text

        return TextArtifact(text_content)

    def _base_params(self, text_query: str, images: list[ImageArtifact]):
        content = [self._construct_image_message(image) for image in images]
        content.append(self._construct_text_message(text_query))
        messages = self._construct_messages(content)
        params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens}

        return params

    def _construct_image_message(self, image_data: ImageArtifact) -> dict:
        data = image_data.base64
        type = image_data.mime_type

        return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"}

    def _construct_text_message(self, query: str) -> dict:
        return {"text": query, "type": "text"}

    def _construct_messages(self, content: list) -> list:
        return [{"content": content, "role": "user"}]

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

client: Any = field(default=Factory(lambda self: import_optional_dependency('anthropic').Anthropic(api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/anthropic_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    if self.max_tokens is None:
        raise TypeError("max_output_tokens can't be empty")

    response = self.client.messages.create(**self._base_params(query, images))
    content_blocks = response.content

    if len(content_blocks) < 1:
        raise ValueError("Response content is empty")

    text_content = content_blocks[0].text

    return TextArtifact(text_content)

AnthropicPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key Optional[str]

Anthropic API key.

model str

Anthropic model name.

client Any

Custom Anthropic client.

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@define
class AnthropicPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
    """

    api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Any = field(
        default=Factory(
            lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
    top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
    max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        response = self.client.messages.create(**self._base_params(prompt_stack))

        return TextArtifact(value=response.content[0].text)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        response = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

        for chunk in response:
            if chunk.type == "content_block_delta":
                yield TextArtifact(value=chunk.delta.text)

    def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict:
        content = prompt_input.content

        if prompt_input.is_system():
            return {"role": "system", "content": content}
        elif prompt_input.is_assistant():
            return {"role": "assistant", "content": content}
        else:
            return {"role": "user", "content": content}

    def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        messages = [
            self._prompt_stack_input_to_message(prompt_input)
            for prompt_input in prompt_stack.inputs
            if not prompt_input.is_system()
        ]
        system = next((self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs if i.is_system()), None)

        if system is None:
            return {"messages": messages}
        else:
            return {"messages": messages, "system": system["content"]}

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "model": self.model,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "max_tokens": self.max_tokens,
            **self._prompt_stack_to_model_input(prompt_stack),
        }

api_key: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

client: Any = field(default=Factory(lambda self: import_optional_dependency('anthropic').Anthropic(api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

max_tokens: int = field(default=1000, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

top_k: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p: float = field(default=0.999, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    response = self.client.messages.create(**self._base_params(prompt_stack))

    return TextArtifact(value=response.content[0].text)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    response = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

    for chunk in response:
        if chunk.type == "content_block_delta":
            yield TextArtifact(value=chunk.delta.text)

AwsIotCoreEventListenerDriver

Bases: BaseEventListenerDriver

Source code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
@define
class AwsIotCoreEventListenerDriver(BaseEventListenerDriver):
    iot_endpoint: str = field(kw_only=True)
    topic: str = field(kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    iotdata_client: Any = field(default=Factory(lambda self: self.session.client("iot-data"), takes_self=True))

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))

    def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
        self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))

iot_endpoint: str = field(kw_only=True) class-attribute instance-attribute

iotdata_client: Any = field(default=Factory(lambda self: self.session.client('iot-data'), takes_self=True)) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

topic: str = field(kw_only=True) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))

try_publish_event_payload_batch(event_payload_batch)

Source code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
    self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))

AzureMongoDbVectorStoreDriver

Bases: MongoDbAtlasVectorStoreDriver

A Vector Store Driver for CosmosDB with MongoDB vCore API.

Source code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
@define
class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver):
    """A Vector Store Driver for CosmosDB with MongoDB vCore API."""

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        offset: Optional[int] = None,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Queries the MongoDB collection for documents that match the provided query string.

        Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
        """
        collection = self.get_collection()

        # Using the embedding driver to convert the query string into a vector
        vector = self.embedding_driver.embed_string(query)

        count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        offset = offset if offset else 0

        pipeline = []

        pipeline.append(
            {
                "$search": {
                    "cosmosSearch": {
                        "vector": vector,
                        "path": self.vector_path,
                        "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                    },
                    "returnStoredSource": True,
                }
            }
        )

        if namespace:
            pipeline.append({"$match": {"namespace": namespace}})

        pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}})

        return [
            BaseVectorStoreDriver.QueryResult(
                id=str(doc["_id"]),
                vector=doc[self.vector_path] if include_vectors else [],
                score=doc["similarityScore"],
                meta=doc["document"]["meta"],
                namespace=namespace,
            )
            for doc in collection.aggregate(pipeline)
        ]

query(query, count=None, namespace=None, include_vectors=False, offset=None, **kwargs)

Queries the MongoDB collection for documents that match the provided query string.

Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.

Source code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    offset: Optional[int] = None,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Queries the MongoDB collection for documents that match the provided query string.

    Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
    """
    collection = self.get_collection()

    # Using the embedding driver to convert the query string into a vector
    vector = self.embedding_driver.embed_string(query)

    count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    offset = offset if offset else 0

    pipeline = []

    pipeline.append(
        {
            "$search": {
                "cosmosSearch": {
                    "vector": vector,
                    "path": self.vector_path,
                    "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                },
                "returnStoredSource": True,
            }
        }
    )

    if namespace:
        pipeline.append({"$match": {"namespace": namespace}})

    pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}})

    return [
        BaseVectorStoreDriver.QueryResult(
            id=str(doc["_id"]),
            vector=doc[self.vector_path] if include_vectors else [],
            score=doc["similarityScore"],
            meta=doc["document"]["meta"],
            namespace=namespace,
        )
        for doc in collection.aggregate(pipeline)
    ]

AzureOpenAiChatPromptDriver

Bases: OpenAiChatPromptDriver

Attributes:

Name Type Description
azure_deployment str

An optional Azure OpenAi deployment id. Defaults to the model name.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@define
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
    """
    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(
        kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = super()._base_params(prompt_stack)
        # TODO: Add `seed` parameter once Azure supports it.
        del params["seed"]

        return params

api_version: str = field(default='2023-05-15', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

AzureOpenAiEmbeddingDriver

Bases: OpenAiEmbeddingDriver

Attributes:

Name Type Description
azure_deployment str

An optional Azure OpenAi deployment id. Defaults to the model name.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

tokenizer OpenAiTokenizer

An OpenAiTokenizer.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@define
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
    """
    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        tokenizer: An `OpenAiTokenizer`.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(
        kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

AzureOpenAiImageGenerationDriver

Bases: OpenAiImageGenerationDriver

Driver for Azure-hosted OpenAI image generation API.

Attributes:

Name Type Description
azure_deployment str

An optional Azure OpenAi deployment id. Defaults to the model name.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/image_generation/azure_openai_image_generation_driver.py
@define
class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
    """Driver for Azure-hosted OpenAI image generation API.

    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(
        kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True})
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2024-02-01', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

AzureOpenAiImageQueryDriver

Bases: OpenAiImageQueryDriver

Driver for Azure-hosted OpenAI image query API.

Attributes:

Name Type Description
azure_deployment str

An optional Azure OpenAi deployment id. Defaults to the model name.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/image_query/azure_openai_image_query_driver.py
@define
class AzureOpenAiImageQueryDriver(OpenAiImageQueryDriver):
    """Driver for Azure-hosted OpenAI image query API.

    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.