Skip to content

Drivers

__all__ = ['BasePromptDriver', 'OpenAiChatPromptDriver', 'OpenAiCompletionPromptDriver', 'AzureOpenAiChatPromptDriver', 'AzureOpenAiCompletionPromptDriver', 'CoherePromptDriver', 'HuggingFacePipelinePromptDriver', 'HuggingFaceHubPromptDriver', 'AnthropicPromptDriver', 'AmazonSageMakerPromptDriver', 'AmazonBedrockPromptDriver', 'GooglePromptDriver', 'BaseMultiModelPromptDriver', 'DummyPromptDriver', 'BaseConversationMemoryDriver', 'LocalConversationMemoryDriver', 'AmazonDynamoDbConversationMemoryDriver', 'BaseEmbeddingDriver', 'OpenAiEmbeddingDriver', 'AzureOpenAiEmbeddingDriver', 'BaseMultiModelEmbeddingDriver', 'AmazonSageMakerEmbeddingDriver', 'AmazonBedrockTitanEmbeddingDriver', 'AmazonBedrockCohereEmbeddingDriver', 'VoyageAiEmbeddingDriver', 'HuggingFaceHubEmbeddingDriver', 'GoogleEmbeddingDriver', 'DummyEmbeddingDriver', 'BaseEmbeddingModelDriver', 'SageMakerHuggingFaceEmbeddingModelDriver', 'SageMakerTensorFlowHubEmbeddingModelDriver', 'BaseVectorStoreDriver', 'LocalVectorStoreDriver', 'PineconeVectorStoreDriver', 'MarqoVectorStoreDriver', 'MongoDbAtlasVectorStoreDriver', 'AzureMongoDbVectorStoreDriver', 'RedisVectorStoreDriver', 'OpenSearchVectorStoreDriver', 'AmazonOpenSearchVectorStoreDriver', 'PgVectorVectorStoreDriver', 'DummyVectorStoreDriver', 'BaseSqlDriver', 'AmazonRedshiftSqlDriver', 'SnowflakeSqlDriver', 'SqlDriver', 'BasePromptModelDriver', 'SageMakerLlamaPromptModelDriver', 'SageMakerFalconPromptModelDriver', 'BedrockTitanPromptModelDriver', 'BedrockClaudePromptModelDriver', 'BedrockJurassicPromptModelDriver', 'BedrockLlamaPromptModelDriver', 'BaseImageGenerationModelDriver', 'BedrockStableDiffusionImageGenerationModelDriver', 'BedrockTitanImageGenerationModelDriver', 'BaseImageGenerationDriver', 'BaseMultiModelImageGenerationDriver', 'OpenAiImageGenerationDriver', 'LeonardoImageGenerationDriver', 'AmazonBedrockImageGenerationDriver', 'AzureOpenAiImageGenerationDriver', 'DummyImageGenerationDriver', 'BaseImageQueryModelDriver', 'BedrockClaudeImageQueryModelDriver', 'BaseImageQueryDriver', 'OpenAiVisionImageQueryDriver', 'DummyImageQueryDriver', 'AnthropicImageQueryDriver', 'BaseMultiModelImageQueryDriver', 'AmazonBedrockImageQueryDriver', 'BaseWebScraperDriver', 'TrafilaturaWebScraperDriver', 'MarkdownifyWebScraperDriver', 'BaseEventListenerDriver', 'AmazonSqsEventListenerDriver', 'WebhookEventListenerDriver', 'AwsIotCoreEventListenerDriver', 'GriptapeCloudEventListenerDriver', 'BaseFileManagerDriver', 'LocalFileManagerDriver', 'AmazonS3FileManagerDriver', 'BaseStructureRunDriver', 'GriptapeCloudStructureRunDriver', 'LocalStructureRunDriver'] module-attribute

AmazonBedrockCohereEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

input_type str

Defaults to search_query. Prepends special tokens to differentiate each type from one another: search_document when you encode documents for embeddings that you store in a vector database. search_query when querying your vector DB to find relevant documents.

session Session

Optionally provide custom boto3.Session.

tokenizer BedrockCohereTokenizer

Optionally provide custom BedrockCohereTokenizer.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@define
class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        input_type: Defaults to `search_query`. Prepends special tokens to differentiate each type from one another:
            `search_document` when you encode documents for embeddings that you store in a vector database.
            `search_query` when querying your vector DB to find relevant documents.
        session: Optionally provide custom `boto3.Session`.
        tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "cohere.embed-english-v3"

    model: str = field(default=DEFAULT_MODEL, kw_only=True)
    input_type: str = field(default="search_query", kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BedrockCohereTokenizer = field(
        default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"input_type": self.input_type, "texts": [chunk]}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json"
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embeddings")[0]

DEFAULT_MODEL = 'cohere.embed-english-v3' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

input_type: str = field(default='search_query', kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BedrockCohereTokenizer = field(default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"input_type": self.input_type, "texts": [chunk]}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json"
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embeddings")[0]

AmazonBedrockImageGenerationDriver

Bases: BaseMultiModelImageGenerationDriver

Driver for image generation models provided by Amazon Bedrock.

Attributes:

Name Type Description
model

Bedrock model ID.

session Session

boto3 session.

bedrock_client Any

Bedrock runtime client.

image_width int

Width of output images. Defaults to 512 and must be a multiple of 64.

image_height int

Height of output images. Defaults to 512 and must be a multiple of 64.

seed Optional[int]

Optionally provide a consistent seed to generation requests, increasing consistency in output.

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@define
class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
    """Driver for image generation models provided by Amazon Bedrock.

    Attributes:
        model: Bedrock model ID.
        session: boto3 session.
        bedrock_client: Bedrock runtime client.
        image_width: Width of output images. Defaults to 512 and must be a multiple of 64.
        image_height: Height of output images. Defaults to 512 and must be a multiple of 64.
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client(service_name="bedrock-runtime"), takes_self=True)
    )
    image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
    image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        request = self.image_generation_model_driver.text_to_image_request_parameters(
            prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=self.image_width,
            height=self.image_height,
            model=self.model,
        )

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_variation_request_parameters(
            prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_inpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_outpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def _make_request(self, request: dict) -> bytes:
        response = self.bedrock_client.invoke_model(
            body=json.dumps(request), modelId=self.model, accept="application/json", contentType="application/json"
        )

        response_body = json.loads(response.get("body").read())

        try:
            image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
        except Exception as e:
            raise ValueError(f"Inpainting generation failed: {e}")

        return image_bytes

bedrock_client: Any = field(default=Factory(lambda self: self.session.client(service_name='bedrock-runtime'), takes_self=True)) class-attribute instance-attribute

image_height: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

image_width: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

seed: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_inpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_outpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_variation_request_parameters(
        prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    request = self.image_generation_model_driver.text_to_image_request_parameters(
        prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=self.image_width,
        height=self.image_height,
        model=self.model,
    )

AmazonBedrockImageQueryDriver

Bases: BaseMultiModelImageQueryDriver

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
@define
class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

        response = self.bedrock_client.invoke_model(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = json.loads(response.get("body").read())

        if response_body is None:
            raise ValueError("Model response is empty")

        try:
            return self.image_query_model_driver.process_output(response_body)
        except Exception as e:
            raise ValueError(f"Output is unable to be processed as returned {e}")

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

    response = self.bedrock_client.invoke_model(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = json.loads(response.get("body").read())

    if response_body is None:
        raise ValueError("Model response is empty")

    try:
        return self.image_query_model_driver.process_output(response_body)
    except Exception as e:
        raise ValueError(f"Output is unable to be processed as returned {e}")

AmazonBedrockPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@define
class AmazonBedrockPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
        payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
        if isinstance(model_input, dict):
            payload.update(model_input)

        response = self.bedrock_client.invoke_model(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = response["body"].read()

        if response_body:
            return self.prompt_model_driver.process_output(response_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
        payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
        if isinstance(model_input, dict):
            payload.update(model_input)

        response = self.bedrock_client.invoke_model_with_response_stream(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = response["body"]
        if response_body:
            for chunk in response["body"]:
                chunk_bytes = chunk["chunk"]["bytes"]
                yield self.prompt_model_driver.process_output(chunk_bytes)
        else:
            raise Exception("model response is empty")

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
    payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
    if isinstance(model_input, dict):
        payload.update(model_input)

    response = self.bedrock_client.invoke_model(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = response["body"].read()

    if response_body:
        return self.prompt_model_driver.process_output(response_body)
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
    payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
    if isinstance(model_input, dict):
        payload.update(model_input)

    response = self.bedrock_client.invoke_model_with_response_stream(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = response["body"]
    if response_body:
        for chunk in response["body"]:
            chunk_bytes = chunk["chunk"]["bytes"]
            yield self.prompt_model_driver.process_output(chunk_bytes)
    else:
        raise Exception("model response is empty")

AmazonBedrockTitanEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

tokenizer BedrockTitanTokenizer

Optionally provide custom BedrockTitanTokenizer.

session Session

Optionally provide custom boto3.Session.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@define
class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
        session: Optionally provide custom `boto3.Session`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "amazon.titan-embed-text-v1"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BedrockTitanTokenizer = field(
        default=Factory(lambda self: BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"inputText": chunk}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embedding")

DEFAULT_MODEL = 'amazon.titan-embed-text-v1' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BedrockTitanTokenizer = field(default=Factory(lambda self: BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"inputText": chunk}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embedding")

AmazonDynamoDbConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@define
class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    partition_key: str = field(kw_only=True, metadata={"serializable": True})
    value_attribute_key: str = field(kw_only=True, metadata={"serializable": True})
    partition_key_value: str = field(kw_only=True, metadata={"serializable": True})

    table: Any = field(init=False)

    def __attrs_post_init__(self) -> None:
        dynamodb = self.session.resource("dynamodb")

        self.table = dynamodb.Table(self.table_name)

    def store(self, memory: BaseConversationMemory) -> None:
        self.table.update_item(
            Key={self.partition_key: self.partition_key_value},
            UpdateExpression="set #attr = :value",
            ExpressionAttributeNames={"#attr": self.value_attribute_key},
            ExpressionAttributeValues={":value": memory.to_json()},
        )

    def load(self) -> Optional[BaseConversationMemory]:
        response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

        if "Item" in response and self.value_attribute_key in response["Item"]:
            memory_value = response["Item"][self.value_attribute_key]

            memory = BaseConversationMemory.from_json(memory_value)

            memory.driver = self

            return memory
        else:
            return None

partition_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

partition_key_value: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

table: Any = field(init=False) class-attribute instance-attribute

table_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

value_attribute_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def __attrs_post_init__(self) -> None:
    dynamodb = self.session.resource("dynamodb")

    self.table = dynamodb.Table(self.table_name)

load()

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def load(self) -> Optional[BaseConversationMemory]:
    response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

    if "Item" in response and self.value_attribute_key in response["Item"]:
        memory_value = response["Item"][self.value_attribute_key]

        memory = BaseConversationMemory.from_json(memory_value)

        memory.driver = self

        return memory
    else:
        return None

store(memory)

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def store(self, memory: BaseConversationMemory) -> None:
    self.table.update_item(
        Key={self.partition_key: self.partition_key_value},
        UpdateExpression="set #attr = :value",
        ExpressionAttributeNames={"#attr": self.value_attribute_key},
        ExpressionAttributeValues={":value": memory.to_json()},
    )

AmazonOpenSearchVectorStoreDriver

Bases: OpenSearchVectorStoreDriver

A Vector Store Driver for Amazon OpenSearch.

Attributes:

Name Type Description
session Session

The boto3 session to use.

service str

Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.

http_auth str | tuple[str, str]

The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.

client OpenSearch

An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.

Source code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
@define
class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver):
    """A Vector Store Driver for Amazon OpenSearch.

    Attributes:
        session: The boto3 session to use.
        service: Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.
        http_auth: The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.
        client: An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.
    """

    session: Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    service: str = field(default="es", kw_only=True)
    http_auth: str | tuple[str, str] = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").AWSV4SignerAuth(
                self.session.get_credentials(), self.session.region_name, self.service
            ),
            takes_self=True,
        )
    )

    client: OpenSearch = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").OpenSearch(
                hosts=[{"host": self.host, "port": self.port}],
                http_auth=self.http_auth,
                use_ssl=self.use_ssl,
                verify_certs=self.verify_certs,
                connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection,
            ),
            takes_self=True,
        )
    )

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in OpenSearch.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """

        vector_id = vector_id if vector_id else str_to_hash(str(vector))
        doc = {"vector": vector, "namespace": namespace, "metadata": meta}
        doc.update(kwargs)
        if self.service == "aoss":
            response = self.client.index(index=self.index_name, body=doc)
        else:
            response = self.client.index(index=self.index_name, id=vector_id, body=doc)

        return response["_id"]

client: OpenSearch = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').OpenSearch(hosts=[{'host': self.host, 'port': self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=import_optional_dependency('opensearchpy').RequestsHttpConnection), takes_self=True)) class-attribute instance-attribute

http_auth: str | tuple[str, str] = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').AWSV4SignerAuth(self.session.get_credentials(), self.session.region_name, self.service), takes_self=True)) class-attribute instance-attribute

service: str = field(default='es', kw_only=True) class-attribute instance-attribute

session: Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in OpenSearch.

If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided.

Source code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in OpenSearch.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """

    vector_id = vector_id if vector_id else str_to_hash(str(vector))
    doc = {"vector": vector, "namespace": namespace, "metadata": meta}
    doc.update(kwargs)
    if self.service == "aoss":
        response = self.client.index(index=self.index_name, body=doc)
    else:
        response = self.client.index(index=self.index_name, id=vector_id, body=doc)

    return response["_id"]

AmazonRedshiftSqlDriver

Bases: BaseSqlDriver

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@define
class AmazonRedshiftSqlDriver(BaseSqlDriver):
    database: str = field(kw_only=True)
    session: boto3.Session = field(kw_only=True)
    cluster_identifier: Optional[str] = field(default=None, kw_only=True)
    workgroup_name: Optional[str] = field(default=None, kw_only=True)
    db_user: Optional[str] = field(default=None, kw_only=True)
    database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True)
    wait_for_query_completion_sec: float = field(default=0.3, kw_only=True)
    client: Any = field(
        default=Factory(lambda self: self.session.client("redshift-data"), takes_self=True), kw_only=True
    )

    @workgroup_name.validator  # pyright: ignore
    def validate_params(self, _, workgroup_name: Optional[str]) -> None:
        if not self.cluster_identifier and not self.workgroup_name:
            raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
        elif self.cluster_identifier and self.workgroup_name:
            raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

    @classmethod
    def _process_rows_from_records(cls, records) -> list[list]:
        return [[c[list(c.keys())[0]] for c in r] for r in records]

    @classmethod
    def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]:
        return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]

    @classmethod
    def _process_columns_from_column_metadata(cls, meta) -> list:
        return [k["name"] for k in meta]

    @classmethod
    def _post_process(cls, meta, records) -> list[dict[str, Any]]:
        columns = cls._process_columns_from_column_metadata(meta)
        rows = cls._process_rows_from_records(records)
        return cls._process_cells_from_rows_and_columns(columns, rows)

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)
        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
        function_kwargs = {"Sql": query, "Database": self.database}
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn

        response = self.client.execute_statement(**function_kwargs)
        response_id = response["Id"]

        statement = self.client.describe_statement(Id=response_id)

        while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
            time.sleep(self.wait_for_query_completion_sec)
            statement = self.client.describe_statement(Id=response_id)

        if statement["Status"] == "FINISHED":
            statement_result = self.client.get_statement_result(Id=response_id)
            results = statement_result.get("Records", [])

            while "NextToken" in statement_result:
                statement_result = self.client.get_statement_result(
                    Id=response_id, NextToken=statement_result["NextToken"]
                )
                results = results + response.get("Records", [])

            return self._post_process(statement_result["ColumnMetadata"], results)

        elif statement["Status"] in ["FAILED", "ABORTED"]:
            return None

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        function_kwargs = {"Database": self.database, "Table": table_name}
        if schema:
            function_kwargs["Schema"] = schema
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn
        response = self.client.describe_table(**function_kwargs)
        return str([col["name"] for col in response["ColumnList"]])

client: Any = field(default=Factory(lambda self: self.session.client('redshift-data'), takes_self=True), kw_only=True) class-attribute instance-attribute

cluster_identifier: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

database: str = field(kw_only=True) class-attribute instance-attribute

database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

db_user: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(kw_only=True) class-attribute instance-attribute

wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) class-attribute instance-attribute

workgroup_name: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

execute_query(query)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)
    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
    function_kwargs = {"Sql": query, "Database": self.database}
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn

    response = self.client.execute_statement(**function_kwargs)
    response_id = response["Id"]

    statement = self.client.describe_statement(Id=response_id)

    while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
        time.sleep(self.wait_for_query_completion_sec)
        statement = self.client.describe_statement(Id=response_id)

    if statement["Status"] == "FINISHED":
        statement_result = self.client.get_statement_result(Id=response_id)
        results = statement_result.get("Records", [])

        while "NextToken" in statement_result:
            statement_result = self.client.get_statement_result(
                Id=response_id, NextToken=statement_result["NextToken"]
            )
            results = results + response.get("Records", [])

        return self._post_process(statement_result["ColumnMetadata"], results)

    elif statement["Status"] in ["FAILED", "ABORTED"]:
        return None

get_table_schema(table_name, schema=None)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    function_kwargs = {"Database": self.database, "Table": table_name}
    if schema:
        function_kwargs["Schema"] = schema
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn
    response = self.client.describe_table(**function_kwargs)
    return str([col["name"] for col in response["ColumnList"]])

validate_params(_, workgroup_name)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator  # pyright: ignore
def validate_params(self, _, workgroup_name: Optional[str]) -> None:
    if not self.cluster_identifier and not self.workgroup_name:
        raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
    elif self.cluster_identifier and self.workgroup_name:
        raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

AmazonS3FileManagerDriver

Bases: BaseFileManagerDriver

AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

Attributes:

Name Type Description
session Session

The boto3 session to use for S3 operations.

bucket str

The name of the S3 bucket.

workdir str

The absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory.

s3_client Any

The S3 client to use for S3 operations.

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@define
class AmazonS3FileManagerDriver(BaseFileManagerDriver):
    """
    AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

    Attributes:
        session: The boto3 session to use for S3 operations.
        bucket: The name of the S3 bucket.
        workdir: The absolute working directory (must start with "/"). List, load, and save
            operations will be performed relative to this directory.
        s3_client: The S3 client to use for S3 operations.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bucket: str = field(kw_only=True)
    workdir: str = field(default="/", kw_only=True)
    s3_client: Any = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True)

    @workdir.validator  # pyright: ignore
    def validate_workdir(self, _, workdir: str) -> None:
        if not Path(workdir).is_absolute():
            raise ValueError("Workdir must be an absolute path")

    def try_list_files(self, path: str) -> list[str]:
        full_key = self._to_dir_full_key(path)
        files_and_dirs = self._list_files_and_dirs(full_key)
        if len(files_and_dirs) == 0:
            if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
                raise NotADirectoryError
            else:
                raise FileNotFoundError
        return files_and_dirs

    def try_load_file(self, path: str) -> bytes:
        botocore = import_optional_dependency("botocore")
        full_key = self._to_full_key(path)

        if self._is_a_directory(full_key):
            raise IsADirectoryError

        try:
            response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key)
            return response["Body"].read()
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                raise FileNotFoundError
            else:
                raise e

    def try_save_file(self, path: str, value: bytes):
        full_key = self._to_full_key(path)
        if self._is_a_directory(full_key):
            raise IsADirectoryError
        self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value)

    def _to_full_key(self, path: str) -> str:
        path = path.lstrip("/")
        full_key = os.path.join(self.workdir, path)
        # Need to keep the trailing slash if it was there,
        # because it means the path is a directory.
        ended_with_slash = path.endswith("/")
        full_key = os.path.normpath(full_key)
        if ended_with_slash:
            full_key += "/"
        return full_key.lstrip("/")

    def _to_dir_full_key(self, path: str) -> str:
        full_key = self._to_full_key(path)
        # S3 "directories" always end with a slash, except for the root.
        if full_key != "" and not full_key.endswith("/"):
            full_key += "/"
        return full_key

    def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
        max_items = kwargs.get("max_items")
        pagination_config = {}
        if max_items is not None:
            pagination_config["MaxItems"] = max_items

        paginator = self.s3_client.get_paginator("list_objects_v2")
        pages = paginator.paginate(
            Bucket=self.bucket, Prefix=full_key, Delimiter="/", PaginationConfig=pagination_config
        )
        files_and_dirs = []
        for page in pages:
            for obj in page.get("CommonPrefixes", []):
                prefix = obj.get("Prefix")
                dir = prefix[len(full_key) :].rstrip("/")
                files_and_dirs.append(dir)

            for obj in page.get("Contents", []):
                key = obj.get("Key")
                file = key[len(full_key) :]
                files_and_dirs.append(file)
        return files_and_dirs

    def _is_a_directory(self, full_key: str) -> bool:
        botocore = import_optional_dependency("botocore")
        if full_key == "" or full_key.endswith("/"):
            return True

        try:
            self.s3_client.head_object(Bucket=self.bucket, Key=full_key)
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                return len(self._list_files_and_dirs(full_key, max_items=1)) > 0
            else:
                raise e

        return False

bucket: str = field(kw_only=True) class-attribute instance-attribute

s3_client: Any = field(default=Factory(lambda self: self.session.client('s3'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

workdir: str = field(default='/', kw_only=True) class-attribute instance-attribute

try_list_files(path)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]:
    full_key = self._to_dir_full_key(path)
    files_and_dirs = self._list_files_and_dirs(full_key)
    if len(files_and_dirs) == 0:
        if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
            raise NotADirectoryError
        else:
            raise FileNotFoundError
    return files_and_dirs

try_load_file(path)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_load_file(self, path: str) -> bytes:
    botocore = import_optional_dependency("botocore")
    full_key = self._to_full_key(path)

    if self._is_a_directory(full_key):
        raise IsADirectoryError

    try:
        response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key)
        return response["Body"].read()
    except botocore.exceptions.ClientError as e:
        if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
            raise FileNotFoundError
        else:
            raise e

try_save_file(path, value)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_save_file(self, path: str, value: bytes):
    full_key = self._to_full_key(path)
    if self._is_a_directory(full_key):
        raise IsADirectoryError
    self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value)

validate_workdir(_, workdir)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@workdir.validator  # pyright: ignore
def validate_workdir(self, _, workdir: str) -> None:
    if not Path(workdir).is_absolute():
        raise ValueError("Workdir must be an absolute path")

AmazonSageMakerEmbeddingDriver

Bases: BaseMultiModelEmbeddingDriver

Source code in griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py
@define
class AmazonSageMakerEmbeddingDriver(BaseMultiModelEmbeddingDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True)

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = self.embedding_model_driver.chunk_to_model_params(chunk)
        endpoint_response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.model, ContentType="application/x-text", Body=json.dumps(payload).encode("utf-8")
        )

        response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))
        return self.embedding_model_driver.process_output(response)

embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda self: self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = self.embedding_model_driver.chunk_to_model_params(chunk)
    endpoint_response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.model, ContentType="application/x-text", Body=json.dumps(payload).encode("utf-8")
    )

    response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))
    return self.embedding_model_driver.process_output(response)

AmazonSageMakerPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@define
class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

    @stream.validator  # pyright: ignore
    def validate_stream(self, _, stream):
        if stream:
            raise ValueError("streaming is not supported")

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        payload = {
            "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
            "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
        }
        response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.model,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))

        if decoded_body:
            return self.prompt_model_driver.process_output(decoded_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda self: self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    payload = {
        "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
        "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
    }
    response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.model,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes=self.custom_attributes,
    )

    decoded_body = json.loads(response["Body"].read().decode("utf8"))

    if decoded_body:
        return self.prompt_model_driver.process_output(decoded_body)
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@stream.validator  # pyright: ignore
def validate_stream(self, _, stream):
    if stream:
        raise ValueError("streaming is not supported")

AmazonSqsEventListenerDriver

Bases: BaseEventListenerDriver

Source code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
@define
class AmazonSqsEventListenerDriver(BaseEventListenerDriver):
    queue_url: str = field(kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True))

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

queue_url: str = field(kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

sqs_client: Any = field(default=Factory(lambda self: self.session.client('sqs'), takes_self=True)) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

AnthropicImageQueryDriver

Bases: BaseImageQueryDriver

Attributes:

Name Type Description
api_key Optional[str]

Anthropic API key.

model str

Anthropic model name.

client Any

Custom Anthropic client.

Source code in griptape/drivers/image_query/anthropic_image_query_driver.py
@define
class AnthropicImageQueryDriver(BaseImageQueryDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
    """

    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Any = field(
        default=Factory(
            lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        if self.max_tokens is None:
            raise TypeError("max_output_tokens can't be empty")

        response = self.client.messages.create(**self._base_params(query, images))
        content_blocks = response.content

        if len(content_blocks) < 1:
            raise ValueError("Response content is empty")

        text_content = content_blocks[0].text

        return TextArtifact(text_content)

    def _base_params(self, text_query: str, images: list[ImageArtifact]):
        content = [self._construct_image_message(image) for image in images]
        content.append(self._construct_text_message(text_query))
        messages = self._construct_messages(content)
        params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens}

        return params

    def _construct_image_message(self, image_data: ImageArtifact) -> dict:
        data = image_data.base64
        type = image_data.mime_type

        return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"}

    def _construct_text_message(self, query: str) -> dict:
        return {"text": query, "type": "text"}

    def _construct_messages(self, content: list) -> list:
        return [{"content": content, "role": "user"}]

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

client: Any = field(default=Factory(lambda self: import_optional_dependency('anthropic').Anthropic(api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/anthropic_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    if self.max_tokens is None:
        raise TypeError("max_output_tokens can't be empty")

    response = self.client.messages.create(**self._base_params(query, images))
    content_blocks = response.content

    if len(content_blocks) < 1:
        raise ValueError("Response content is empty")

    text_content = content_blocks[0].text

    return TextArtifact(text_content)

AnthropicPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key Optional[str]

Anthropic API key.

model str

Anthropic model name.

client Any

Custom Anthropic client.

tokenizer AnthropicTokenizer

Custom AnthropicTokenizer.

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@define
class AnthropicPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
        tokenizer: Custom `AnthropicTokenizer`.
    """

    api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Any = field(
        default=Factory(
            lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )
    tokenizer: AnthropicTokenizer = field(
        default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
    top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        response = self.client.messages.create(**self._base_params(prompt_stack))

        return TextArtifact(value=response.content[0].text)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        response = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

        for chunk in response:
            if chunk.type == "content_block_delta":
                yield TextArtifact(value=chunk.delta.text)

    def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        messages = [
            {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content}
            for prompt_input in prompt_stack.inputs
            if not prompt_input.is_system()
        ]
        system = next((i for i in prompt_stack.inputs if i.is_system()), None)

        if system is None:
            return {"messages": messages}
        else:
            return {"messages": messages, "system": system.content}

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "model": self.model,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_output_tokens(self.prompt_stack_to_string(prompt_stack)),
            "top_p": self.top_p,
            "top_k": self.top_k,
            **self._prompt_stack_to_model_input(prompt_stack),
        }

    def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
        if prompt_input.is_system():
            return "system"
        elif prompt_input.is_assistant():
            return "assistant"
        else:
            return "user"

api_key: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

client: Any = field(default=Factory(lambda self: import_optional_dependency('anthropic').Anthropic(api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: AnthropicTokenizer = field(default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

top_k: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p: float = field(default=0.999, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_anthropic_role(prompt_input)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
    if prompt_input.is_system():
        return "system"
    elif prompt_input.is_assistant():
        return "assistant"
    else:
        return "user"

try_run(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    response = self.client.messages.create(**self._base_params(prompt_stack))

    return TextArtifact(value=response.content[0].text)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    response = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

    for chunk in response:
        if chunk.type == "content_block_delta":
            yield TextArtifact(value=chunk.delta.text)

AwsIotCoreEventListenerDriver

Bases: BaseEventListenerDriver

Source code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
@define
class AwsIotCoreEventListenerDriver(BaseEventListenerDriver):
    iot_endpoint: str = field(kw_only=True)
    topic: str = field(kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    iotdata_client: Any = field(default=Factory(lambda self: self.session.client("iot-data"), takes_self=True))

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))

iot_endpoint: str = field(kw_only=True) class-attribute instance-attribute

iotdata_client: Any = field(default=Factory(lambda self: self.session.client('iot-data'), takes_self=True)) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

topic: str = field(kw_only=True) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))

AzureMongoDbVectorStoreDriver

Bases: MongoDbAtlasVectorStoreDriver

A Vector Store Driver for CosmosDB with MongoDB vCore API.

Source code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
@define
class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver):
    """A Vector Store Driver for CosmosDB with MongoDB vCore API."""

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        offset: Optional[int] = None,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Queries the MongoDB collection for documents that match the provided query string.

        Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
        """
        collection = self.get_collection()

        # Using the embedding driver to convert the query string into a vector
        vector = self.embedding_driver.embed_string(query)

        count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        offset = offset if offset else 0

        pipeline = []

        pipeline.append(
            {
                "$search": {
                    "cosmosSearch": {
                        "vector": vector,
                        "path": self.vector_path,
                        "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                    },
                    "returnStoredSource": True,
                }
            }
        )

        if namespace:
            pipeline.append({"$match": {"namespace": namespace}})

        pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}})

        return [
            BaseVectorStoreDriver.QueryResult(
                id=str(doc["_id"]),
                vector=doc[self.vector_path] if include_vectors else [],
                score=doc["similarityScore"],
                meta=doc["document"]["meta"],
                namespace=namespace,
            )
            for doc in collection.aggregate(pipeline)
        ]

query(query, count=None, namespace=None, include_vectors=False, offset=None, **kwargs)

Queries the MongoDB collection for documents that match the provided query string.

Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.

Source code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    offset: Optional[int] = None,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Queries the MongoDB collection for documents that match the provided query string.

    Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
    """
    collection = self.get_collection()

    # Using the embedding driver to convert the query string into a vector
    vector = self.embedding_driver.embed_string(query)

    count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    offset = offset if offset else 0

    pipeline = []

    pipeline.append(
        {
            "$search": {
                "cosmosSearch": {
                    "vector": vector,
                    "path": self.vector_path,
                    "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                },
                "returnStoredSource": True,
            }
        }
    )

    if namespace:
        pipeline.append({"$match": {"namespace": namespace}})

    pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}})

    return [
        BaseVectorStoreDriver.QueryResult(
            id=str(doc["_id"]),
            vector=doc[self.vector_path] if include_vectors else [],
            score=doc["similarityScore"],
            meta=doc["document"]["meta"],
            namespace=namespace,
        )
        for doc in collection.aggregate(pipeline)
    ]

AzureOpenAiChatPromptDriver

Bases: OpenAiChatPromptDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@define
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = super()._base_params(prompt_stack)
        # TODO: Add `seed` parameter once Azure supports it.
        del params["seed"]

        return params

api_version: str = field(default='2023-05-15', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

AzureOpenAiCompletionPromptDriver

Bases: OpenAiCompletionPromptDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/prompt/azure_openai_completion_prompt_driver.py
@define
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

AzureOpenAiEmbeddingDriver

Bases: OpenAiEmbeddingDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

tokenizer OpenAiTokenizer

An OpenAiTokenizer.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@define
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        tokenizer: An `OpenAiTokenizer`.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

AzureOpenAiImageGenerationDriver

Bases: OpenAiImageGenerationDriver

Driver for Azure-hosted OpenAI image generation API.

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/image_generation/azure_openai_image_generation_driver.py
@define
class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
    """Driver for Azure-hosted OpenAI image generation API.

    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True})
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2024-02-01', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

BaseConversationMemoryDriver

Bases: SerializableMixin, ABC

Source code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
class BaseConversationMemoryDriver(SerializableMixin, ABC):
    @abstractmethod
    def store(self, memory: BaseConversationMemory) -> None:
        ...

    @abstractmethod
    def load(self) -> Optional[BaseConversationMemory]:
        ...

load() abstractmethod

Source code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def load(self) -> Optional[BaseConversationMemory]:
    ...

store(memory) abstractmethod

Source code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def store(self, memory: BaseConversationMemory) -> None:
    ...

BaseEmbeddingDriver

Bases: SerializableMixin, ExponentialBackoffMixin, ABC

Attributes:

Name Type Description
model str

The name of the model to use.

tokenizer Optional[BaseTokenizer]

An instance of BaseTokenizer to use when calculating tokens.

Source code in griptape/drivers/embedding/base_embedding_driver.py
@define
class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    """
    Attributes:
        model: The name of the model to use.
        tokenizer: An instance of `BaseTokenizer` to use when calculating tokens.
    """

    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
    chunker: BaseChunker = field(init=False)

    def __attrs_post_init__(self) -> None:
        if self.tokenizer:
            self.chunker = TextChunker(tokenizer=self.tokenizer)

    def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
        return self.embed_string(artifact.to_text())

    def embed_string(self, string: str) -> list[float]:
        for attempt in self.retrying():
            with attempt:
                if self.tokenizer and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
                    return self._embed_long_string(string)
                else:
                    return self.try_embed_chunk(string)

        else:
            raise RuntimeError("Failed to embed string.")

    @abstractmethod
    def try_embed_chunk(self, chunk: str) -> list[float]:
        ...

    def _embed_long_string(self, string: str) -> list[float]:
        """Embeds a string that is too long to embed in one go.

        Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
        """
        chunks = self.chunker.chunk(string)

        embedding_chunks = []
        length_chunks = []
        for chunk in chunks:
            embedding_chunks.append(self.try_embed_chunk(chunk.value))
            length_chunks.append(len(chunk))

        # generate weighted averages
        embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

        # normalize length to 1
        embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

        return embedding_chunks.tolist()

chunker: BaseChunker = field(init=False) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/embedding/base_embedding_driver.py
def __attrs_post_init__(self) -> None:
    if self.tokenizer:
        self.chunker = TextChunker(tokenizer=self.tokenizer)

embed_string(string)

Source code in griptape/drivers/embedding/base_embedding_driver.py
def embed_string(self, string: str) -> list[float]:
    for attempt in self.retrying():
        with attempt:
            if self.tokenizer and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
                return self._embed_long_string(string)
            else:
                return self.try_embed_chunk(string)

    else:
        raise RuntimeError("Failed to embed string.")

embed_text_artifact(artifact)

Source code in griptape/drivers/embedding/base_embedding_driver.py
def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
    return self.embed_string(artifact.to_text())

try_embed_chunk(chunk) abstractmethod

Source code in griptape/drivers/embedding/base_embedding_driver.py
@abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]:
    ...

BaseEmbeddingModelDriver

Bases: ABC

Source code in griptape/drivers/embedding_model/base_embedding_model_driver.py
@define
class BaseEmbeddingModelDriver(ABC):
    @abstractmethod
    def chunk_to_model_params(self, chunk: str) -> dict:
        ...

    @abstractmethod
    def process_output(self, output