Skip to content

Drivers

__all__ = ['BasePromptDriver', 'OpenAiChatPromptDriver', 'OpenAiCompletionPromptDriver', 'AzureOpenAiChatPromptDriver', 'AzureOpenAiCompletionPromptDriver', 'CoherePromptDriver', 'HuggingFacePipelinePromptDriver', 'HuggingFaceHubPromptDriver', 'AnthropicPromptDriver', 'AmazonSageMakerPromptDriver', 'AmazonBedrockPromptDriver', 'GooglePromptDriver', 'BaseMultiModelPromptDriver', 'DummyPromptDriver', 'BaseConversationMemoryDriver', 'LocalConversationMemoryDriver', 'AmazonDynamoDbConversationMemoryDriver', 'BaseEmbeddingDriver', 'OpenAiEmbeddingDriver', 'AzureOpenAiEmbeddingDriver', 'BaseMultiModelEmbeddingDriver', 'AmazonSageMakerEmbeddingDriver', 'AmazonBedrockTitanEmbeddingDriver', 'AmazonBedrockCohereEmbeddingDriver', 'VoyageAiEmbeddingDriver', 'HuggingFaceHubEmbeddingDriver', 'GoogleEmbeddingDriver', 'DummyEmbeddingDriver', 'BaseEmbeddingModelDriver', 'SageMakerHuggingFaceEmbeddingModelDriver', 'SageMakerTensorFlowHubEmbeddingModelDriver', 'BaseVectorStoreDriver', 'LocalVectorStoreDriver', 'PineconeVectorStoreDriver', 'MarqoVectorStoreDriver', 'MongoDbAtlasVectorStoreDriver', 'AzureMongoDbVectorStoreDriver', 'RedisVectorStoreDriver', 'OpenSearchVectorStoreDriver', 'AmazonOpenSearchVectorStoreDriver', 'PgVectorVectorStoreDriver', 'DummyVectorStoreDriver', 'BaseSqlDriver', 'AmazonRedshiftSqlDriver', 'SnowflakeSqlDriver', 'SqlDriver', 'BasePromptModelDriver', 'SageMakerLlamaPromptModelDriver', 'SageMakerFalconPromptModelDriver', 'BedrockTitanPromptModelDriver', 'BedrockClaudePromptModelDriver', 'BedrockJurassicPromptModelDriver', 'BedrockLlamaPromptModelDriver', 'BaseImageGenerationModelDriver', 'BedrockStableDiffusionImageGenerationModelDriver', 'BedrockTitanImageGenerationModelDriver', 'BaseImageGenerationDriver', 'BaseMultiModelImageGenerationDriver', 'OpenAiImageGenerationDriver', 'LeonardoImageGenerationDriver', 'AmazonBedrockImageGenerationDriver', 'AzureOpenAiImageGenerationDriver', 'DummyImageGenerationDriver', 'BaseImageQueryModelDriver', 'BedrockClaudeImageQueryModelDriver', 'BaseImageQueryDriver', 'OpenAiVisionImageQueryDriver', 'DummyImageQueryDriver', 'AnthropicImageQueryDriver', 'BaseMultiModelImageQueryDriver', 'AmazonBedrockImageQueryDriver', 'BaseWebScraperDriver', 'TrafilaturaWebScraperDriver', 'MarkdownifyWebScraperDriver', 'BaseEventListenerDriver', 'AmazonSqsEventListenerDriver', 'WebhookEventListenerDriver', 'AwsIotCoreEventListenerDriver', 'GriptapeCloudEventListenerDriver', 'BaseFileManagerDriver', 'LocalFileManagerDriver', 'AmazonS3FileManagerDriver', 'BaseStructureRunDriver', 'GriptapeCloudStructureRunDriver', 'LocalStructureRunDriver'] module-attribute

AmazonBedrockCohereEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

input_type str

Defaults to search_query. Prepends special tokens to differentiate each type from one another: search_document when you encode documents for embeddings that you store in a vector database. search_query when querying your vector DB to find relevant documents.

session Session

Optionally provide custom boto3.Session.

tokenizer BedrockCohereTokenizer

Optionally provide custom BedrockCohereTokenizer.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@define
class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        input_type: Defaults to `search_query`. Prepends special tokens to differentiate each type from one another:
            `search_document` when you encode documents for embeddings that you store in a vector database.
            `search_query` when querying your vector DB to find relevant documents.
        session: Optionally provide custom `boto3.Session`.
        tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "cohere.embed-english-v3"

    model: str = field(default=DEFAULT_MODEL, kw_only=True)
    input_type: str = field(default="search_query", kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BedrockCohereTokenizer = field(
        default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"input_type": self.input_type, "texts": [chunk]}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json"
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embeddings")[0]

DEFAULT_MODEL = 'cohere.embed-english-v3' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

input_type: str = field(default='search_query', kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BedrockCohereTokenizer = field(default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"input_type": self.input_type, "texts": [chunk]}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json"
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embeddings")[0]

AmazonBedrockImageGenerationDriver

Bases: BaseMultiModelImageGenerationDriver

Driver for image generation models provided by Amazon Bedrock.

Attributes:

Name Type Description
model

Bedrock model ID.

session Session

boto3 session.

bedrock_client Any

Bedrock runtime client.

image_width int

Width of output images. Defaults to 512 and must be a multiple of 64.

image_height int

Height of output images. Defaults to 512 and must be a multiple of 64.

seed Optional[int]

Optionally provide a consistent seed to generation requests, increasing consistency in output.

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@define
class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
    """Driver for image generation models provided by Amazon Bedrock.

    Attributes:
        model: Bedrock model ID.
        session: boto3 session.
        bedrock_client: Bedrock runtime client.
        image_width: Width of output images. Defaults to 512 and must be a multiple of 64.
        image_height: Height of output images. Defaults to 512 and must be a multiple of 64.
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client(service_name="bedrock-runtime"), takes_self=True)
    )
    image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
    image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        request = self.image_generation_model_driver.text_to_image_request_parameters(
            prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=self.image_width,
            height=self.image_height,
            model=self.model,
        )

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_variation_request_parameters(
            prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_inpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_outpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def _make_request(self, request: dict) -> bytes:
        response = self.bedrock_client.invoke_model(
            body=json.dumps(request), modelId=self.model, accept="application/json", contentType="application/json"
        )

        response_body = json.loads(response.get("body").read())

        try:
            image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
        except Exception as e:
            raise ValueError(f"Inpainting generation failed: {e}")

        return image_bytes

bedrock_client: Any = field(default=Factory(lambda self: self.session.client(service_name='bedrock-runtime'), takes_self=True)) class-attribute instance-attribute

image_height: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

image_width: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

seed: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_inpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_outpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_variation_request_parameters(
        prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    request = self.image_generation_model_driver.text_to_image_request_parameters(
        prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        format="png",
        width=self.image_width,
        height=self.image_height,
        model=self.model,
    )

AmazonBedrockImageQueryDriver

Bases: BaseMultiModelImageQueryDriver

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
@define
class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

        response = self.bedrock_client.invoke_model(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = json.loads(response.get("body").read())

        if response_body is None:
            raise ValueError("Model response is empty")

        try:
            return self.image_query_model_driver.process_output(response_body)
        except Exception as e:
            raise ValueError(f"Output is unable to be processed as returned {e}")

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

    response = self.bedrock_client.invoke_model(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = json.loads(response.get("body").read())

    if response_body is None:
        raise ValueError("Model response is empty")

    try:
        return self.image_query_model_driver.process_output(response_body)
    except Exception as e:
        raise ValueError(f"Output is unable to be processed as returned {e}")

AmazonBedrockPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@define
class AmazonBedrockPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
        payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
        if isinstance(model_input, dict):
            payload.update(model_input)

        response = self.bedrock_client.invoke_model(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = response["body"].read()

        if response_body:
            return self.prompt_model_driver.process_output(response_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
        payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
        if isinstance(model_input, dict):
            payload.update(model_input)

        response = self.bedrock_client.invoke_model_with_response_stream(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = response["body"]
        if response_body:
            for chunk in response["body"]:
                chunk_bytes = chunk["chunk"]["bytes"]
                yield self.prompt_model_driver.process_output(chunk_bytes)
        else:
            raise Exception("model response is empty")

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
    payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
    if isinstance(model_input, dict):
        payload.update(model_input)

    response = self.bedrock_client.invoke_model(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = response["body"].read()

    if response_body:
        return self.prompt_model_driver.process_output(response_body)
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
    payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
    if isinstance(model_input, dict):
        payload.update(model_input)

    response = self.bedrock_client.invoke_model_with_response_stream(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = response["body"]
    if response_body:
        for chunk in response["body"]:
            chunk_bytes = chunk["chunk"]["bytes"]
            yield self.prompt_model_driver.process_output(chunk_bytes)
    else:
        raise Exception("model response is empty")

AmazonBedrockTitanEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

tokenizer BedrockTitanTokenizer

Optionally provide custom BedrockTitanTokenizer.

session Session

Optionally provide custom boto3.Session.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@define
class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
        session: Optionally provide custom `boto3.Session`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "amazon.titan-embed-text-v1"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BedrockTitanTokenizer = field(
        default=Factory(lambda self: BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"inputText": chunk}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embedding")

DEFAULT_MODEL = 'amazon.titan-embed-text-v1' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda self: self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BedrockTitanTokenizer = field(default=Factory(lambda self: BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"inputText": chunk}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embedding")

AmazonDynamoDbConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@define
class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    partition_key: str = field(kw_only=True, metadata={"serializable": True})
    value_attribute_key: str = field(kw_only=True, metadata={"serializable": True})
    partition_key_value: str = field(kw_only=True, metadata={"serializable": True})

    table: Any = field(init=False)

    def __attrs_post_init__(self) -> None:
        dynamodb = self.session.resource("dynamodb")

        self.table = dynamodb.Table(self.table_name)

    def store(self, memory: BaseConversationMemory) -> None:
        self.table.update_item(
            Key={self.partition_key: self.partition_key_value},
            UpdateExpression="set #attr = :value",
            ExpressionAttributeNames={"#attr": self.value_attribute_key},
            ExpressionAttributeValues={":value": memory.to_json()},
        )

    def load(self) -> Optional[BaseConversationMemory]:
        response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

        if "Item" in response and self.value_attribute_key in response["Item"]:
            memory_value = response["Item"][self.value_attribute_key]

            memory = BaseConversationMemory.from_json(memory_value)

            memory.driver = self

            return memory
        else:
            return None

partition_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

partition_key_value: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

table: Any = field(init=False) class-attribute instance-attribute

table_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

value_attribute_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def __attrs_post_init__(self) -> None:
    dynamodb = self.session.resource("dynamodb")

    self.table = dynamodb.Table(self.table_name)

load()

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def load(self) -> Optional[BaseConversationMemory]:
    response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

    if "Item" in response and self.value_attribute_key in response["Item"]:
        memory_value = response["Item"][self.value_attribute_key]

        memory = BaseConversationMemory.from_json(memory_value)

        memory.driver = self

        return memory
    else:
        return None

store(memory)

Source code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def store(self, memory: BaseConversationMemory) -> None:
    self.table.update_item(
        Key={self.partition_key: self.partition_key_value},
        UpdateExpression="set #attr = :value",
        ExpressionAttributeNames={"#attr": self.value_attribute_key},
        ExpressionAttributeValues={":value": memory.to_json()},
    )

AmazonOpenSearchVectorStoreDriver

Bases: OpenSearchVectorStoreDriver

A Vector Store Driver for Amazon OpenSearch.

Attributes:

Name Type Description
session Session

The boto3 session to use.

service str

Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.

http_auth str | tuple[str, str]

The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.

client OpenSearch

An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.

Source code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
@define
class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver):
    """A Vector Store Driver for Amazon OpenSearch.

    Attributes:
        session: The boto3 session to use.
        service: Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.
        http_auth: The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.
        client: An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.
    """

    session: Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    service: str = field(default="es", kw_only=True)
    http_auth: str | tuple[str, str] = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").AWSV4SignerAuth(
                self.session.get_credentials(), self.session.region_name, self.service
            ),
            takes_self=True,
        )
    )

    client: OpenSearch = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").OpenSearch(
                hosts=[{"host": self.host, "port": self.port}],
                http_auth=self.http_auth,
                use_ssl=self.use_ssl,
                verify_certs=self.verify_certs,
                connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection,
            ),
            takes_self=True,
        )
    )

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in OpenSearch.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """

        vector_id = vector_id if vector_id else str_to_hash(str(vector))
        doc = {"vector": vector, "namespace": namespace, "metadata": meta}
        doc.update(kwargs)
        if self.service == "aoss":
            response = self.client.index(index=self.index_name, body=doc)
        else:
            response = self.client.index(index=self.index_name, id=vector_id, body=doc)

        return response["_id"]

client: OpenSearch = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').OpenSearch(hosts=[{'host': self.host, 'port': self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=import_optional_dependency('opensearchpy').RequestsHttpConnection), takes_self=True)) class-attribute instance-attribute

http_auth: str | tuple[str, str] = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').AWSV4SignerAuth(self.session.get_credentials(), self.session.region_name, self.service), takes_self=True)) class-attribute instance-attribute

service: str = field(default='es', kw_only=True) class-attribute instance-attribute

session: Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in OpenSearch.

If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided.

Source code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in OpenSearch.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """

    vector_id = vector_id if vector_id else str_to_hash(str(vector))
    doc = {"vector": vector, "namespace": namespace, "metadata": meta}
    doc.update(kwargs)
    if self.service == "aoss":
        response = self.client.index(index=self.index_name, body=doc)
    else:
        response = self.client.index(index=self.index_name, id=vector_id, body=doc)

    return response["_id"]

AmazonRedshiftSqlDriver

Bases: BaseSqlDriver

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@define
class AmazonRedshiftSqlDriver(BaseSqlDriver):
    database: str = field(kw_only=True)
    session: boto3.Session = field(kw_only=True)
    cluster_identifier: Optional[str] = field(default=None, kw_only=True)
    workgroup_name: Optional[str] = field(default=None, kw_only=True)
    db_user: Optional[str] = field(default=None, kw_only=True)
    database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True)
    wait_for_query_completion_sec: float = field(default=0.3, kw_only=True)
    client: Any = field(
        default=Factory(lambda self: self.session.client("redshift-data"), takes_self=True), kw_only=True
    )

    @workgroup_name.validator  # pyright: ignore
    def validate_params(self, _, workgroup_name: Optional[str]) -> None:
        if not self.cluster_identifier and not self.workgroup_name:
            raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
        elif self.cluster_identifier and self.workgroup_name:
            raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

    @classmethod
    def _process_rows_from_records(cls, records) -> list[list]:
        return [[c[list(c.keys())[0]] for c in r] for r in records]

    @classmethod
    def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]:
        return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]

    @classmethod
    def _process_columns_from_column_metadata(cls, meta) -> list:
        return [k["name"] for k in meta]

    @classmethod
    def _post_process(cls, meta, records) -> list[dict[str, Any]]:
        columns = cls._process_columns_from_column_metadata(meta)
        rows = cls._process_rows_from_records(records)
        return cls._process_cells_from_rows_and_columns(columns, rows)

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)
        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
        function_kwargs = {"Sql": query, "Database": self.database}
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn

        response = self.client.execute_statement(**function_kwargs)
        response_id = response["Id"]

        statement = self.client.describe_statement(Id=response_id)

        while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
            time.sleep(self.wait_for_query_completion_sec)
            statement = self.client.describe_statement(Id=response_id)

        if statement["Status"] == "FINISHED":
            statement_result = self.client.get_statement_result(Id=response_id)
            results = statement_result.get("Records", [])

            while "NextToken" in statement_result:
                statement_result = self.client.get_statement_result(
                    Id=response_id, NextToken=statement_result["NextToken"]
                )
                results = results + response.get("Records", [])

            return self._post_process(statement_result["ColumnMetadata"], results)

        elif statement["Status"] in ["FAILED", "ABORTED"]:
            return None

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        function_kwargs = {"Database": self.database, "Table": table_name}
        if schema:
            function_kwargs["Schema"] = schema
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn
        response = self.client.describe_table(**function_kwargs)
        return str([col["name"] for col in response["ColumnList"]])

client: Any = field(default=Factory(lambda self: self.session.client('redshift-data'), takes_self=True), kw_only=True) class-attribute instance-attribute

cluster_identifier: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

database: str = field(kw_only=True) class-attribute instance-attribute

database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

db_user: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(kw_only=True) class-attribute instance-attribute

wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) class-attribute instance-attribute

workgroup_name: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

execute_query(query)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)
    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
    function_kwargs = {"Sql": query, "Database": self.database}
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn

    response = self.client.execute_statement(**function_kwargs)
    response_id = response["Id"]

    statement = self.client.describe_statement(Id=response_id)

    while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
        time.sleep(self.wait_for_query_completion_sec)
        statement = self.client.describe_statement(Id=response_id)

    if statement["Status"] == "FINISHED":
        statement_result = self.client.get_statement_result(Id=response_id)
        results = statement_result.get("Records", [])

        while "NextToken" in statement_result:
            statement_result = self.client.get_statement_result(
                Id=response_id, NextToken=statement_result["NextToken"]
            )
            results = results + response.get("Records", [])

        return self._post_process(statement_result["ColumnMetadata"], results)

    elif statement["Status"] in ["FAILED", "ABORTED"]:
        return None

get_table_schema(table_name, schema=None)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    function_kwargs = {"Database": self.database, "Table": table_name}
    if schema:
        function_kwargs["Schema"] = schema
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn
    response = self.client.describe_table(**function_kwargs)
    return str([col["name"] for col in response["ColumnList"]])

validate_params(_, workgroup_name)

Source code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator  # pyright: ignore
def validate_params(self, _, workgroup_name: Optional[str]) -> None:
    if not self.cluster_identifier and not self.workgroup_name:
        raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
    elif self.cluster_identifier and self.workgroup_name:
        raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

AmazonS3FileManagerDriver

Bases: BaseFileManagerDriver

AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

Attributes:

Name Type Description
session Session

The boto3 session to use for S3 operations.

bucket str

The name of the S3 bucket.

workdir str

The absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory.

s3_client Any

The S3 client to use for S3 operations.

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@define
class AmazonS3FileManagerDriver(BaseFileManagerDriver):
    """
    AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

    Attributes:
        session: The boto3 session to use for S3 operations.
        bucket: The name of the S3 bucket.
        workdir: The absolute working directory (must start with "/"). List, load, and save
            operations will be performed relative to this directory.
        s3_client: The S3 client to use for S3 operations.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bucket: str = field(kw_only=True)
    workdir: str = field(default="/", kw_only=True)
    s3_client: Any = field(default=Factory(lambda self: self.session.client("s3"), takes_self=True), kw_only=True)

    @workdir.validator  # pyright: ignore
    def validate_workdir(self, _, workdir: str) -> None:
        if not Path(workdir).is_absolute():
            raise ValueError("Workdir must be an absolute path")

    def try_list_files(self, path: str) -> list[str]:
        full_key = self._to_dir_full_key(path)
        files_and_dirs = self._list_files_and_dirs(full_key)
        if len(files_and_dirs) == 0:
            if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
                raise NotADirectoryError
            else:
                raise FileNotFoundError
        return files_and_dirs

    def try_load_file(self, path: str) -> bytes:
        botocore = import_optional_dependency("botocore")
        full_key = self._to_full_key(path)

        if self._is_a_directory(full_key):
            raise IsADirectoryError

        try:
            response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key)
            return response["Body"].read()
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                raise FileNotFoundError
            else:
                raise e

    def try_save_file(self, path: str, value: bytes):
        full_key = self._to_full_key(path)
        if self._is_a_directory(full_key):
            raise IsADirectoryError
        self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value)

    def _to_full_key(self, path: str) -> str:
        path = path.lstrip("/")
        full_key = os.path.join(self.workdir, path)
        # Need to keep the trailing slash if it was there,
        # because it means the path is a directory.
        ended_with_slash = path.endswith("/")
        full_key = os.path.normpath(full_key)
        if ended_with_slash:
            full_key += "/"
        return full_key.lstrip("/")

    def _to_dir_full_key(self, path: str) -> str:
        full_key = self._to_full_key(path)
        # S3 "directories" always end with a slash, except for the root.
        if full_key != "" and not full_key.endswith("/"):
            full_key += "/"
        return full_key

    def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
        max_items = kwargs.get("max_items")
        pagination_config = {}
        if max_items is not None:
            pagination_config["MaxItems"] = max_items

        paginator = self.s3_client.get_paginator("list_objects_v2")
        pages = paginator.paginate(
            Bucket=self.bucket, Prefix=full_key, Delimiter="/", PaginationConfig=pagination_config
        )
        files_and_dirs = []
        for page in pages:
            for obj in page.get("CommonPrefixes", []):
                prefix = obj.get("Prefix")
                dir = prefix[len(full_key) :].rstrip("/")
                files_and_dirs.append(dir)

            for obj in page.get("Contents", []):
                key = obj.get("Key")
                file = key[len(full_key) :]
                files_and_dirs.append(file)
        return files_and_dirs

    def _is_a_directory(self, full_key: str) -> bool:
        botocore = import_optional_dependency("botocore")
        if full_key == "" or full_key.endswith("/"):
            return True

        try:
            self.s3_client.head_object(Bucket=self.bucket, Key=full_key)
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                return len(self._list_files_and_dirs(full_key, max_items=1)) > 0
            else:
                raise e

        return False

bucket: str = field(kw_only=True) class-attribute instance-attribute

s3_client: Any = field(default=Factory(lambda self: self.session.client('s3'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

workdir: str = field(default='/', kw_only=True) class-attribute instance-attribute

try_list_files(path)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]:
    full_key = self._to_dir_full_key(path)
    files_and_dirs = self._list_files_and_dirs(full_key)
    if len(files_and_dirs) == 0:
        if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
            raise NotADirectoryError
        else:
            raise FileNotFoundError
    return files_and_dirs

try_load_file(path)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_load_file(self, path: str) -> bytes:
    botocore = import_optional_dependency("botocore")
    full_key = self._to_full_key(path)

    if self._is_a_directory(full_key):
        raise IsADirectoryError

    try:
        response = self.s3_client.get_object(Bucket=self.bucket, Key=full_key)
        return response["Body"].read()
    except botocore.exceptions.ClientError as e:
        if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
            raise FileNotFoundError
        else:
            raise e

try_save_file(path, value)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_save_file(self, path: str, value: bytes):
    full_key = self._to_full_key(path)
    if self._is_a_directory(full_key):
        raise IsADirectoryError
    self.s3_client.put_object(Bucket=self.bucket, Key=full_key, Body=value)

validate_workdir(_, workdir)

Source code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@workdir.validator  # pyright: ignore
def validate_workdir(self, _, workdir: str) -> None:
    if not Path(workdir).is_absolute():
        raise ValueError("Workdir must be an absolute path")

AmazonSageMakerEmbeddingDriver

Bases: BaseMultiModelEmbeddingDriver

Source code in griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py
@define
class AmazonSageMakerEmbeddingDriver(BaseMultiModelEmbeddingDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True)

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = self.embedding_model_driver.chunk_to_model_params(chunk)
        endpoint_response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.model, ContentType="application/x-text", Body=json.dumps(payload).encode("utf-8")
        )

        response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))
        return self.embedding_model_driver.process_output(response)

embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda self: self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = self.embedding_model_driver.chunk_to_model_params(chunk)
    endpoint_response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.model, ContentType="application/x-text", Body=json.dumps(payload).encode("utf-8")
    )

    response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))
    return self.embedding_model_driver.process_output(response)

AmazonSageMakerPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@define
class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

    @stream.validator  # pyright: ignore
    def validate_stream(self, _, stream):
        if stream:
            raise ValueError("streaming is not supported")

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        payload = {
            "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
            "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
        }
        response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.model,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))

        if decoded_body:
            return self.prompt_model_driver.process_output(decoded_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

custom_attributes: str = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda self: self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    payload = {
        "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
        "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
    }
    response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.model,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes=self.custom_attributes,
    )

    decoded_body = json.loads(response["Body"].read().decode("utf8"))

    if decoded_body:
        return self.prompt_model_driver.process_output(decoded_body)
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@stream.validator  # pyright: ignore
def validate_stream(self, _, stream):
    if stream:
        raise ValueError("streaming is not supported")

AmazonSqsEventListenerDriver

Bases: BaseEventListenerDriver

Source code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
@define
class AmazonSqsEventListenerDriver(BaseEventListenerDriver):
    queue_url: str = field(kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True))

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

queue_url: str = field(kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

sqs_client: Any = field(default=Factory(lambda self: self.session.client('sqs'), takes_self=True)) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

AnthropicImageQueryDriver

Bases: BaseImageQueryDriver

Attributes:

Name Type Description
api_key Optional[str]

Anthropic API key.

model str

Anthropic model name.

client Any

Custom Anthropic client.

Source code in griptape/drivers/image_query/anthropic_image_query_driver.py
@define
class AnthropicImageQueryDriver(BaseImageQueryDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
    """

    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Any = field(
        default=Factory(
            lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        if self.max_tokens is None:
            raise TypeError("max_output_tokens can't be empty")

        response = self.client.messages.create(**self._base_params(query, images))
        content_blocks = response.content

        if len(content_blocks) < 1:
            raise ValueError("Response content is empty")

        text_content = content_blocks[0].text

        return TextArtifact(text_content)

    def _base_params(self, text_query: str, images: list[ImageArtifact]):
        content = [self._construct_image_message(image) for image in images]
        content.append(self._construct_text_message(text_query))
        messages = self._construct_messages(content)
        params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens}

        return params

    def _construct_image_message(self, image_data: ImageArtifact) -> dict:
        data = image_data.base64
        type = image_data.mime_type

        return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"}

    def _construct_text_message(self, query: str) -> dict:
        return {"text": query, "type": "text"}

    def _construct_messages(self, content: list) -> list:
        return [{"content": content, "role": "user"}]

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

client: Any = field(default=Factory(lambda self: import_optional_dependency('anthropic').Anthropic(api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/anthropic_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    if self.max_tokens is None:
        raise TypeError("max_output_tokens can't be empty")

    response = self.client.messages.create(**self._base_params(query, images))
    content_blocks = response.content

    if len(content_blocks) < 1:
        raise ValueError("Response content is empty")

    text_content = content_blocks[0].text

    return TextArtifact(text_content)

AnthropicPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key Optional[str]

Anthropic API key.

model str

Anthropic model name.

client Any

Custom Anthropic client.

tokenizer AnthropicTokenizer

Custom AnthropicTokenizer.

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
@define
class AnthropicPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
        tokenizer: Custom `AnthropicTokenizer`.
    """

    api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Any = field(
        default=Factory(
            lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )
    tokenizer: AnthropicTokenizer = field(
        default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
    top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        response = self.client.messages.create(**self._base_params(prompt_stack))

        return TextArtifact(value=response.content[0].text)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        response = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

        for chunk in response:
            if chunk.type == "content_block_delta":
                yield TextArtifact(value=chunk.delta.text)

    def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        messages = [
            {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content}
            for prompt_input in prompt_stack.inputs
            if not prompt_input.is_system()
        ]
        system = next((i for i in prompt_stack.inputs if i.is_system()), None)

        if system is None:
            return {"messages": messages}
        else:
            return {"messages": messages, "system": system.content}

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "model": self.model,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_output_tokens(self.prompt_stack_to_string(prompt_stack)),
            "top_p": self.top_p,
            "top_k": self.top_k,
            **self._prompt_stack_to_model_input(prompt_stack),
        }

    def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
        if prompt_input.is_system():
            return "system"
        elif prompt_input.is_assistant():
            return "assistant"
        else:
            return "user"

api_key: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

client: Any = field(default=Factory(lambda self: import_optional_dependency('anthropic').Anthropic(api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: AnthropicTokenizer = field(default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

top_k: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p: float = field(default=0.999, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_anthropic_role(prompt_input)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
    if prompt_input.is_system():
        return "system"
    elif prompt_input.is_assistant():
        return "assistant"
    else:
        return "user"

try_run(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    response = self.client.messages.create(**self._base_params(prompt_stack))

    return TextArtifact(value=response.content[0].text)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/anthropic_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    response = self.client.messages.create(**self._base_params(prompt_stack), stream=True)

    for chunk in response:
        if chunk.type == "content_block_delta":
            yield TextArtifact(value=chunk.delta.text)

AwsIotCoreEventListenerDriver

Bases: BaseEventListenerDriver

Source code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
@define
class AwsIotCoreEventListenerDriver(BaseEventListenerDriver):
    iot_endpoint: str = field(kw_only=True)
    topic: str = field(kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    iotdata_client: Any = field(default=Factory(lambda self: self.session.client("iot-data"), takes_self=True))

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))

iot_endpoint: str = field(kw_only=True) class-attribute instance-attribute

iotdata_client: Any = field(default=Factory(lambda self: self.session.client('iot-data'), takes_self=True)) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

topic: str = field(kw_only=True) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.iotdata_client.publish(topic=self.topic, payload=json.dumps(event_payload))

AzureMongoDbVectorStoreDriver

Bases: MongoDbAtlasVectorStoreDriver

A Vector Store Driver for CosmosDB with MongoDB vCore API.

Source code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
@define
class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver):
    """A Vector Store Driver for CosmosDB with MongoDB vCore API."""

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        offset: Optional[int] = None,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Queries the MongoDB collection for documents that match the provided query string.

        Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
        """
        collection = self.get_collection()

        # Using the embedding driver to convert the query string into a vector
        vector = self.embedding_driver.embed_string(query)

        count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        offset = offset if offset else 0

        pipeline = []

        pipeline.append(
            {
                "$search": {
                    "cosmosSearch": {
                        "vector": vector,
                        "path": self.vector_path,
                        "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                    },
                    "returnStoredSource": True,
                }
            }
        )

        if namespace:
            pipeline.append({"$match": {"namespace": namespace}})

        pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}})

        return [
            BaseVectorStoreDriver.QueryResult(
                id=str(doc["_id"]),
                vector=doc[self.vector_path] if include_vectors else [],
                score=doc["similarityScore"],
                meta=doc["document"]["meta"],
                namespace=namespace,
            )
            for doc in collection.aggregate(pipeline)
        ]

query(query, count=None, namespace=None, include_vectors=False, offset=None, **kwargs)

Queries the MongoDB collection for documents that match the provided query string.

Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.

Source code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    offset: Optional[int] = None,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Queries the MongoDB collection for documents that match the provided query string.

    Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
    """
    collection = self.get_collection()

    # Using the embedding driver to convert the query string into a vector
    vector = self.embedding_driver.embed_string(query)

    count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    offset = offset if offset else 0

    pipeline = []

    pipeline.append(
        {
            "$search": {
                "cosmosSearch": {
                    "vector": vector,
                    "path": self.vector_path,
                    "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                },
                "returnStoredSource": True,
            }
        }
    )

    if namespace:
        pipeline.append({"$match": {"namespace": namespace}})

    pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}})

    return [
        BaseVectorStoreDriver.QueryResult(
            id=str(doc["_id"]),
            vector=doc[self.vector_path] if include_vectors else [],
            score=doc["similarityScore"],
            meta=doc["document"]["meta"],
            namespace=namespace,
        )
        for doc in collection.aggregate(pipeline)
    ]

AzureOpenAiChatPromptDriver

Bases: OpenAiChatPromptDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@define
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = super()._base_params(prompt_stack)
        # TODO: Add `seed` parameter once Azure supports it.
        del params["seed"]

        return params

api_version: str = field(default='2023-05-15', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

AzureOpenAiCompletionPromptDriver

Bases: OpenAiCompletionPromptDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/prompt/azure_openai_completion_prompt_driver.py
@define
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

AzureOpenAiEmbeddingDriver

Bases: OpenAiEmbeddingDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

tokenizer OpenAiTokenizer

An OpenAiTokenizer.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@define
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        tokenizer: An `OpenAiTokenizer`.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

AzureOpenAiImageGenerationDriver

Bases: OpenAiImageGenerationDriver

Driver for Azure-hosted OpenAI image generation API.

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/image_generation/azure_openai_image_generation_driver.py
@define
class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
    """Driver for Azure-hosted OpenAI image generation API.

    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True})
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2024-02-01', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda self: openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

BaseConversationMemoryDriver

Bases: SerializableMixin, ABC

Source code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
class BaseConversationMemoryDriver(SerializableMixin, ABC):
    @abstractmethod
    def store(self, memory: BaseConversationMemory) -> None:
        ...

    @abstractmethod
    def load(self) -> Optional[BaseConversationMemory]:
        ...

load() abstractmethod

Source code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def load(self) -> Optional[BaseConversationMemory]:
    ...

store(memory) abstractmethod

Source code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def store(self, memory: BaseConversationMemory) -> None:
    ...

BaseEmbeddingDriver

Bases: SerializableMixin, ExponentialBackoffMixin, ABC

Attributes:

Name Type Description
model str

The name of the model to use.

tokenizer Optional[BaseTokenizer]

An instance of BaseTokenizer to use when calculating tokens.

Source code in griptape/drivers/embedding/base_embedding_driver.py
@define
class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    """
    Attributes:
        model: The name of the model to use.
        tokenizer: An instance of `BaseTokenizer` to use when calculating tokens.
    """

    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
    chunker: BaseChunker = field(init=False)

    def __attrs_post_init__(self) -> None:
        if self.tokenizer:
            self.chunker = TextChunker(tokenizer=self.tokenizer)

    def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
        return self.embed_string(artifact.to_text())

    def embed_string(self, string: str) -> list[float]:
        for attempt in self.retrying():
            with attempt:
                if self.tokenizer and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
                    return self._embed_long_string(string)
                else:
                    return self.try_embed_chunk(string)

        else:
            raise RuntimeError("Failed to embed string.")

    @abstractmethod
    def try_embed_chunk(self, chunk: str) -> list[float]:
        ...

    def _embed_long_string(self, string: str) -> list[float]:
        """Embeds a string that is too long to embed in one go.

        Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
        """
        chunks = self.chunker.chunk(string)

        embedding_chunks = []
        length_chunks = []
        for chunk in chunks:
            embedding_chunks.append(self.try_embed_chunk(chunk.value))
            length_chunks.append(len(chunk))

        # generate weighted averages
        embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

        # normalize length to 1
        embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

        return embedding_chunks.tolist()

chunker: BaseChunker = field(init=False) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/embedding/base_embedding_driver.py
def __attrs_post_init__(self) -> None:
    if self.tokenizer:
        self.chunker = TextChunker(tokenizer=self.tokenizer)

embed_string(string)

Source code in griptape/drivers/embedding/base_embedding_driver.py
def embed_string(self, string: str) -> list[float]:
    for attempt in self.retrying():
        with attempt:
            if self.tokenizer and self.tokenizer.count_tokens(string) > self.tokenizer.max_input_tokens:
                return self._embed_long_string(string)
            else:
                return self.try_embed_chunk(string)

    else:
        raise RuntimeError("Failed to embed string.")

embed_text_artifact(artifact)

Source code in griptape/drivers/embedding/base_embedding_driver.py
def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
    return self.embed_string(artifact.to_text())

try_embed_chunk(chunk) abstractmethod

Source code in griptape/drivers/embedding/base_embedding_driver.py
@abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]:
    ...

BaseEmbeddingModelDriver

Bases: ABC

Source code in griptape/drivers/embedding_model/base_embedding_model_driver.py
@define
class BaseEmbeddingModelDriver(ABC):
    @abstractmethod
    def chunk_to_model_params(self, chunk: str) -> dict:
        ...

    @abstractmethod
    def process_output(self, output: dict) -> list[float]:
        ...

chunk_to_model_params(chunk) abstractmethod

Source code in griptape/drivers/embedding_model/base_embedding_model_driver.py
7
8
9
@abstractmethod
def chunk_to_model_params(self, chunk: str) -> dict:
    ...

process_output(output) abstractmethod

Source code in griptape/drivers/embedding_model/base_embedding_model_driver.py
@abstractmethod
def process_output(self, output: dict) -> list[float]:
    ...

BaseEventListenerDriver

Bases: ABC

Source code in griptape/drivers/event_listener/base_event_listener_driver.py
@define
class BaseEventListenerDriver(ABC):
    futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)

    def publish_event(self, event: BaseEvent | dict) -> None:
        if isinstance(event, dict):
            self.futures_executor.submit(self._safe_try_publish_event_payload, event)
        else:
            self.futures_executor.submit(self._safe_try_publish_event_payload, event.to_dict())

    @abstractmethod
    def try_publish_event_payload(self, event_payload: dict) -> None:
        ...

    def _safe_try_publish_event_payload(self, event_payload: dict) -> None:
        try:
            self.try_publish_event_payload(event_payload)
        except Exception as e:
            logger.error(e)

futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) class-attribute instance-attribute

publish_event(event)

Source code in griptape/drivers/event_listener/base_event_listener_driver.py
def publish_event(self, event: BaseEvent | dict) -> None:
    if isinstance(event, dict):
        self.futures_executor.submit(self._safe_try_publish_event_payload, event)
    else:
        self.futures_executor.submit(self._safe_try_publish_event_payload, event.to_dict())

try_publish_event_payload(event_payload) abstractmethod

Source code in griptape/drivers/event_listener/base_event_listener_driver.py
@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None:
    ...

BaseFileManagerDriver

Bases: ABC

BaseFileManagerDriver can be used to list, load, and save files.

Attributes:

Name Type Description
default_loader BaseLoader

The default loader to use for loading file contents into artifacts.

loaders dict[str, BaseLoader]

Dictionary of file extension specifc loaders to use for loading file contents into artifacts.

Source code in griptape/drivers/file_manager/base_file_manager_driver.py
@define
class BaseFileManagerDriver(ABC):
    """
    BaseFileManagerDriver can be used to list, load, and save files.

    Attributes:
        default_loader: The default loader to use for loading file contents into artifacts.
        loaders: Dictionary of file extension specifc loaders to use for loading file contents into artifacts.
    """

    default_loader: loaders.BaseLoader = field(default=Factory(lambda: loaders.BlobLoader()), kw_only=True)
    loaders: dict[str, loaders.BaseLoader] = field(
        default=Factory(
            lambda: {
                "pdf": loaders.PdfLoader(),
                "csv": loaders.CsvLoader(),
                "txt": loaders.TextLoader(),
                "html": loaders.TextLoader(),
                "json": loaders.TextLoader(),
                "yaml": loaders.TextLoader(),
                "xml": loaders.TextLoader(),
                "png": loaders.ImageLoader(),
                "jpg": loaders.ImageLoader(),
                "jpeg": loaders.ImageLoader(),
                "webp": loaders.ImageLoader(),
                "gif": loaders.ImageLoader(),
                "bmp": loaders.ImageLoader(),
                "tiff": loaders.ImageLoader(),
            }
        ),
        kw_only=True,
    )

    def list_files(self, path: str) -> TextArtifact | ErrorArtifact:
        try:
            entries = self.try_list_files(path)
            return TextArtifact("\n".join([e for e in entries]))
        except FileNotFoundError:
            return ErrorArtifact("Path not found")
        except NotADirectoryError:
            return ErrorArtifact("Path is not a directory")
        except Exception as e:
            return ErrorArtifact(f"Failed to list files: {str(e)}")

    @abstractmethod
    def try_list_files(self, path: str) -> list[str]:
        ...

    def load_file(self, path: str) -> BaseArtifact:
        try:
            extension = path.split(".")[-1]
            loader = self.loaders.get(extension) or self.default_loader
            source = self.try_load_file(path)
            result = loader.load(source)

            if isinstance(result, BaseArtifact):
                return result
            else:
                return ListArtifact(result)
        except FileNotFoundError:
            return ErrorArtifact("Path not found")
        except IsADirectoryError:
            return ErrorArtifact("Path is a directory")
        except NotADirectoryError:
            return ErrorArtifact("Not a directory")
        except Exception as e:
            return ErrorArtifact(f"Failed to load file: {str(e)}")

    @abstractmethod
    def try_load_file(self, path: str) -> bytes:
        ...

    def save_file(self, path: str, value: bytes | str) -> InfoArtifact | ErrorArtifact:
        try:
            extension = path.split(".")[-1]
            loader = self.loaders.get(extension) or self.default_loader
            encoding = None if loader is None else loader.encoding

            if isinstance(value, str):
                if encoding is None:
                    value = value.encode()
                else:
                    value = value.encode(encoding=encoding)
            elif isinstance(value, bytearray) or isinstance(value, memoryview):
                raise ValueError(f"Unsupported type: {type(value)}")

            self.try_save_file(path, value)

            return InfoArtifact("Successfully saved file")
        except IsADirectoryError:
            return ErrorArtifact("Path is a directory")
        except Exception as e:
            return ErrorArtifact(f"Failed to save file: {str(e)}")

    @abstractmethod
    def try_save_file(self, path: str, value: bytes):
        ...

default_loader: loaders.BaseLoader = field(default=Factory(lambda: loaders.BlobLoader()), kw_only=True) class-attribute instance-attribute

loaders: dict[str, loaders.BaseLoader] = field(default=Factory(lambda: {'pdf': loaders.PdfLoader(), 'csv': loaders.CsvLoader(), 'txt': loaders.TextLoader(), 'html': loaders.TextLoader(), 'json': loaders.TextLoader(), 'yaml': loaders.TextLoader(), 'xml': loaders.TextLoader(), 'png': loaders.ImageLoader(), 'jpg': loaders.ImageLoader(), 'jpeg': loaders.ImageLoader(), 'webp': loaders.ImageLoader(), 'gif': loaders.ImageLoader(), 'bmp': loaders.ImageLoader(), 'tiff': loaders.ImageLoader()}), kw_only=True) class-attribute instance-attribute

list_files(path)

Source code in griptape/drivers/file_manager/base_file_manager_driver.py
def list_files(self, path: str) -> TextArtifact | ErrorArtifact:
    try:
        entries = self.try_list_files(path)
        return TextArtifact("\n".join([e for e in entries]))
    except FileNotFoundError:
        return ErrorArtifact("Path not found")
    except NotADirectoryError:
        return ErrorArtifact("Path is not a directory")
    except Exception as e:
        return ErrorArtifact(f"Failed to list files: {str(e)}")

load_file(path)

Source code in griptape/drivers/file_manager/base_file_manager_driver.py
def load_file(self, path: str) -> BaseArtifact:
    try:
        extension = path.split(".")[-1]
        loader = self.loaders.get(extension) or self.default_loader
        source = self.try_load_file(path)
        result = loader.load(source)

        if isinstance(result, BaseArtifact):
            return result
        else:
            return ListArtifact(result)
    except FileNotFoundError:
        return ErrorArtifact("Path not found")
    except IsADirectoryError:
        return ErrorArtifact("Path is a directory")
    except NotADirectoryError:
        return ErrorArtifact("Not a directory")
    except Exception as e:
        return ErrorArtifact(f"Failed to load file: {str(e)}")

save_file(path, value)

Source code in griptape/drivers/file_manager/base_file_manager_driver.py
def save_file(self, path: str, value: bytes | str) -> InfoArtifact | ErrorArtifact:
    try:
        extension = path.split(".")[-1]
        loader = self.loaders.get(extension) or self.default_loader
        encoding = None if loader is None else loader.encoding

        if isinstance(value, str):
            if encoding is None:
                value = value.encode()
            else:
                value = value.encode(encoding=encoding)
        elif isinstance(value, bytearray) or isinstance(value, memoryview):
            raise ValueError(f"Unsupported type: {type(value)}")

        self.try_save_file(path, value)

        return InfoArtifact("Successfully saved file")
    except IsADirectoryError:
        return ErrorArtifact("Path is a directory")
    except Exception as e:
        return ErrorArtifact(f"Failed to save file: {str(e)}")

try_list_files(path) abstractmethod

Source code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod
def try_list_files(self, path: str) -> list[str]:
    ...

try_load_file(path) abstractmethod

Source code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod
def try_load_file(self, path: str) -> bytes:
    ...

try_save_file(path, value) abstractmethod

Source code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod
def try_save_file(self, path: str, value: bytes):
    ...

BaseImageGenerationDriver

Bases: SerializableMixin, ExponentialBackoffMixin, ABC

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
@define
class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    model: str = field(kw_only=True, metadata={"serializable": True})
    structure: Optional[Structure] = field(default=None, kw_only=True)

    def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None:
        if self.structure:
            self.structure.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))

    def after_run(self) -> None:
        if self.structure:
            self.structure.publish_event(FinishImageGenerationEvent())

    def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_text_to_image(prompts, negative_prompts)
                self.after_run()

                return result

        else:
            raise Exception("Failed to run text to image generation")

    def run_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_variation(prompts, image, negative_prompts)
                self.after_run()

                return result

        else:
            raise Exception("Failed to generate image variations")

    def run_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_inpainting(prompts, image, mask, negative_prompts)
                self.after_run()

                return result

        else:
            raise Exception("Failed to run image inpainting")

    def run_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_outpainting(prompts, image, mask, negative_prompts)
                self.after_run()

                return result

        else:
            raise Exception("Failed to run image outpainting")

    @abstractmethod
    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        ...

    @abstractmethod
    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        ...

    @abstractmethod
    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        ...

    @abstractmethod
    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        ...

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structure: Optional[Structure] = field(default=None, kw_only=True) class-attribute instance-attribute

after_run()

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
def after_run(self) -> None:
    if self.structure:
        self.structure.publish_event(FinishImageGenerationEvent())

before_run(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None:
    if self.structure:
        self.structure.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))

run_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_inpainting(prompts, image, mask, negative_prompts)
            self.after_run()

            return result

    else:
        raise Exception("Failed to run image inpainting")

run_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_outpainting(prompts, image, mask, negative_prompts)
            self.after_run()

            return result

    else:
        raise Exception("Failed to run image outpainting")

run_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_variation(prompts, image, negative_prompts)
            self.after_run()

            return result

    else:
        raise Exception("Failed to generate image variations")

run_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_text_to_image(prompts, negative_prompts)
            self.after_run()

            return result

    else:
        raise Exception("Failed to run text to image generation")

try_image_inpainting(prompts, image, mask, negative_prompts=None) abstractmethod

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    ...

try_image_outpainting(prompts, image, mask, negative_prompts=None) abstractmethod

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    ...

try_image_variation(prompts, image, negative_prompts=None) abstractmethod

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    ...

try_text_to_image(prompts, negative_prompts=None) abstractmethod

Source code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    ...

BaseImageGenerationModelDriver

Bases: SerializableMixin, ABC

Source code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@define
class BaseImageGenerationModelDriver(SerializableMixin, ABC):
    @abstractmethod
    def get_generated_image(self, response: dict) -> bytes:
        ...

    @abstractmethod
    def text_to_image_request_parameters(
        self,
        prompts: list[str],
        image_width: int,
        image_height: int,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict[str, Any]:
        ...

    @abstractmethod
    def image_variation_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict[str, Any]:
        ...

    @abstractmethod
    def image_inpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict[str, Any]:
        ...

    @abstractmethod
    def image_outpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict[str, Any]:
        ...

get_generated_image(response) abstractmethod

Source code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def get_generated_image(self, response: dict) -> bytes:
    ...

image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None) abstractmethod

Source code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_inpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict[str, Any]:
    ...

image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None) abstractmethod

Source code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_outpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict[str, Any]:
    ...

image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None) abstractmethod

Source code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_variation_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict[str, Any]:
    ...

text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None) abstractmethod

Source code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def text_to_image_request_parameters(
    self,
    prompts: list[str],
    image_width: int,
    image_height: int,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict[str, Any]:
    ...

BaseImageQueryDriver

Bases: SerializableMixin, ExponentialBackoffMixin, ABC

Source code in griptape/drivers/image_query/base_image_query_driver.py
@define
class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    structure: Optional[Structure] = field(default=None, kw_only=True)
    max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True})

    def before_run(self, query: str, images: list[ImageArtifact]) -> None:
        if self.structure:
            self.structure.publish_event(
                StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images])
            )

    def after_run(self, result: str) -> None:
        if self.structure:
            self.structure.publish_event(FinishImageQueryEvent(result=result))

    def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(query, images)

                result = self.try_query(query, images)

                self.after_run(result.value)

                return result
        else:
            raise Exception("image query driver failed after all retry attempts")

    @abstractmethod
    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        ...

max_tokens: int = field(default=256, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structure: Optional[Structure] = field(default=None, kw_only=True) class-attribute instance-attribute

after_run(result)

Source code in griptape/drivers/image_query/base_image_query_driver.py
def after_run(self, result: str) -> None:
    if self.structure:
        self.structure.publish_event(FinishImageQueryEvent(result=result))

before_run(query, images)

Source code in griptape/drivers/image_query/base_image_query_driver.py
def before_run(self, query: str, images: list[ImageArtifact]) -> None:
    if self.structure:
        self.structure.publish_event(
            StartImageQueryEvent(query=query, images_info=[image.to_text() for image in images])
        )

query(query, images)

Source code in griptape/drivers/image_query/base_image_query_driver.py
def query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(query, images)

            result = self.try_query(query, images)

            self.after_run(result.value)

            return result
    else:
        raise Exception("image query driver failed after all retry attempts")

try_query(query, images) abstractmethod

Source code in griptape/drivers/image_query/base_image_query_driver.py
@abstractmethod
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    ...

BaseImageQueryModelDriver

Bases: SerializableMixin, ABC

Source code in griptape/drivers/image_query_model/base_image_query_model_driver.py
@define
class BaseImageQueryModelDriver(SerializableMixin, ABC):
    @abstractmethod
    def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict:
        ...

    @abstractmethod
    def process_output(self, output: dict) -> TextArtifact:
        ...

image_query_request_parameters(query, images, max_tokens) abstractmethod

Source code in griptape/drivers/image_query_model/base_image_query_model_driver.py
@abstractmethod
def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict:
    ...

process_output(output) abstractmethod

Source code in griptape/drivers/image_query_model/base_image_query_model_driver.py
@abstractmethod
def process_output(self, output: dict) -> TextArtifact:
    ...

BaseMultiModelEmbeddingDriver

Bases: BaseEmbeddingDriver, ABC

Source code in griptape/drivers/embedding/base_multi_model_embedding_driver.py
@define
class BaseMultiModelEmbeddingDriver(BaseEmbeddingDriver, ABC):
    embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True)

embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True) class-attribute instance-attribute

BaseMultiModelImageGenerationDriver

Bases: BaseImageGenerationDriver, ABC

Image Generation Driver for platforms like Amazon Bedrock that host many LLM models.

Instances of this Image Generation Driver require a Image Generation Model Driver which is used to structure the image generation request in the format required by the model and to process the output.

Attributes:

Name Type Description
image_generation_model_driver BaseImageGenerationModelDriver

Image Model Driver to use.

Source code in griptape/drivers/image_generation/base_multi_model_image_generation_driver.py
@define
class BaseMultiModelImageGenerationDriver(BaseImageGenerationDriver, ABC):
    """Image Generation Driver for platforms like Amazon Bedrock that host many LLM models.

    Instances of this Image Generation Driver require a Image Generation Model Driver which is used to structure the
    image generation request in the format required by the model and to process the output.

    Attributes:
        image_generation_model_driver: Image Model Driver to use.
    """

    image_generation_model_driver: BaseImageGenerationModelDriver = field(kw_only=True, metadata={"serializable": True})

image_generation_model_driver: BaseImageGenerationModelDriver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

BaseMultiModelImageQueryDriver

Bases: BaseImageQueryDriver, ABC

Image Query Driver for platforms like Amazon Bedrock that host many LLM models.

Instances of this Image Query Driver require a Image Query Model Driver which is used to structure the image generation request in the format required by the model and to process the output.

Attributes:

Name Type Description
model str

Model name to use

image_query_model_driver BaseImageQueryModelDriver

Image Model Driver to use.

Source code in griptape/drivers/image_query/base_multi_model_image_query_driver.py
@define
class BaseMultiModelImageQueryDriver(BaseImageQueryDriver, ABC):
    """Image Query Driver for platforms like Amazon Bedrock that host many LLM models.

    Instances of this Image Query Driver require a Image Query Model Driver which is used to structure the
    image generation request in the format required by the model and to process the output.

    Attributes:
        model: Model name to use
        image_query_model_driver: Image Model Driver to use.
    """

    model: str = field(kw_only=True, metadata={"serializable": True})
    image_query_model_driver: BaseImageQueryModelDriver = field(kw_only=True, metadata={"serializable": True})

image_query_model_driver: BaseImageQueryModelDriver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

BaseMultiModelPromptDriver

Bases: BasePromptDriver, ABC

Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models.

Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack into a model input and parameters, and to process the model output.

Attributes:

Name Type Description
model

Name of the model to use.

tokenizer Optional[BaseTokenizer]

Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver.

prompt_model_driver BasePromptModelDriver

Prompt Model Driver to use.

Source code in griptape/drivers/prompt/base_multi_model_prompt_driver.py
@define
class BaseMultiModelPromptDriver(BasePromptDriver, ABC):
    """Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models.

    Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack
    into a model input and parameters, and to process the model output.

    Attributes:
        model: Name of the model to use.
        tokenizer: Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver.
        prompt_model_driver: Prompt Model Driver to use.
    """

    tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
    prompt_model_driver: BasePromptModelDriver = field(kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

    @stream.validator  # pyright: ignore
    def validate_stream(self, _, stream):
        if stream and not self.prompt_model_driver.supports_streaming:
            raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming")

    def __attrs_post_init__(self) -> None:
        self.prompt_model_driver.prompt_driver = self

        if not self.tokenizer:
            self.tokenizer = self.prompt_model_driver.tokenizer

prompt_model_driver: BasePromptModelDriver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/prompt/base_multi_model_prompt_driver.py
def __attrs_post_init__(self) -> None:
    self.prompt_model_driver.prompt_driver = self

    if not self.tokenizer:
        self.tokenizer = self.prompt_model_driver.tokenizer

validate_stream(_, stream)

Source code in griptape/drivers/prompt/base_multi_model_prompt_driver.py
@stream.validator  # pyright: ignore
def validate_stream(self, _, stream):
    if stream and not self.prompt_model_driver.supports_streaming:
        raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming")

BasePromptDriver

Bases: SerializableMixin, ExponentialBackoffMixin, ABC

Base class for Prompt Drivers.

Attributes:

Name Type Description
temperature float

The temperature to use for the completion.

max_tokens Optional[int]

The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.

structure Optional[Structure]

An optional Structure to publish events to.

prompt_stack_to_string Callable[[PromptStack], str]

A function that converts a PromptStack to a string.

ignored_exception_types tuple[type[Exception], ...]

A tuple of exception types to ignore.

model str

The model name.

tokenizer BaseTokenizer

An instance of BaseTokenizer to when calculating tokens.

stream bool

Whether to stream the completion or not. CompletionChunkEvents will be published to the Structure if one is provided.

Source code in griptape/drivers/prompt/base_prompt_driver.py
@define
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    """Base class for Prompt Drivers.

    Attributes:
        temperature: The temperature to use for the completion.
        max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
        structure: An optional `Structure` to publish events to.
        prompt_stack_to_string: A function that converts a `PromptStack` to a string.
        ignored_exception_types: A tuple of exception types to ignore.
        model: The model name.
        tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
        stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
    """

    temperature: float = field(default=0.1, kw_only=True, metadata={"serializable": True})
    max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    structure: Optional[Structure] = field(default=None, kw_only=True)
    prompt_stack_to_string: Callable[[PromptStack], str] = field(
        default=Factory(lambda self: self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True
    )
    ignored_exception_types: tuple[type[Exception], ...] = field(
        default=Factory(lambda: (ImportError, ValueError)), kw_only=True
    )
    model: str = field(metadata={"serializable": True})
    tokenizer: BaseTokenizer
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

    def max_output_tokens(self, text: str | list) -> int:
        tokens_left = self.tokenizer.count_output_tokens_left(text)

        if self.max_tokens:
            return min(self.max_tokens, tokens_left)
        else:
            return tokens_left

    def token_count(self, prompt_stack: PromptStack) -> int:
        return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

    def before_run(self, prompt_stack: PromptStack) -> None:
        if self.structure:
            self.structure.publish_event(
                StartPromptEvent(
                    model=self.model,
                    token_count=self.token_count(prompt_stack),
                    prompt_stack=prompt_stack,
                    prompt=self.prompt_stack_to_string(prompt_stack),
                )
            )

    def after_run(self, result: TextArtifact) -> None:
        if self.structure:
            self.structure.publish_event(
                FinishPromptEvent(model=self.model, token_count=result.token_count(self.tokenizer), result=result.value)
            )

    def run(self, prompt_stack: PromptStack) -> TextArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompt_stack)

                if self.stream:
                    tokens = []
                    completion_chunks = self.try_stream(prompt_stack)
                    for chunk in completion_chunks:
                        self.structure.publish_event(CompletionChunkEvent(token=chunk.value))
                        tokens.append(chunk.value)
                    result = TextArtifact(value="".join(tokens).strip())
                else:
                    result = self.try_run(prompt_stack)
                    result.value = result.value.strip()

                self.after_run(result)

                return result
        else:
            raise Exception("prompt driver failed after all retry attempts")

    def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_user():
                prompt_lines.append(f"User: {i.content}")
            elif i.is_assistant():
                prompt_lines.append(f"Assistant: {i.content}")
            else:
                prompt_lines.append(i.content)

        prompt_lines.append("Assistant:")

        return "\n\n".join(prompt_lines)

    @abstractmethod
    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        ...

    @abstractmethod
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        ...

ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (ImportError, ValueError)), kw_only=True) class-attribute instance-attribute

max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(metadata={'serializable': True}) class-attribute instance-attribute

prompt_stack_to_string: Callable[[PromptStack], str] = field(default=Factory(lambda self: self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structure: Optional[Structure] = field(default=None, kw_only=True) class-attribute instance-attribute

temperature: float = field(default=0.1, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: BaseTokenizer instance-attribute

after_run(result)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def after_run(self, result: TextArtifact) -> None:
    if self.structure:
        self.structure.publish_event(
            FinishPromptEvent(model=self.model, token_count=result.token_count(self.tokenizer), result=result.value)
        )

before_run(prompt_stack)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def before_run(self, prompt_stack: PromptStack) -> None:
    if self.structure:
        self.structure.publish_event(
            StartPromptEvent(
                model=self.model,
                token_count=self.token_count(prompt_stack),
                prompt_stack=prompt_stack,
                prompt=self.prompt_stack_to_string(prompt_stack),
            )
        )

default_prompt_stack_to_string_converter(prompt_stack)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_user():
            prompt_lines.append(f"User: {i.content}")
        elif i.is_assistant():
            prompt_lines.append(f"Assistant: {i.content}")
        else:
            prompt_lines.append(i.content)

    prompt_lines.append("Assistant:")

    return "\n\n".join(prompt_lines)

max_output_tokens(text)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def max_output_tokens(self, text: str | list) -> int:
    tokens_left = self.tokenizer.count_output_tokens_left(text)

    if self.max_tokens:
        return min(self.max_tokens, tokens_left)
    else:
        return tokens_left

run(prompt_stack)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def run(self, prompt_stack: PromptStack) -> TextArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompt_stack)

            if self.stream:
                tokens = []
                completion_chunks = self.try_stream(prompt_stack)
                for chunk in completion_chunks:
                    self.structure.publish_event(CompletionChunkEvent(token=chunk.value))
                    tokens.append(chunk.value)
                result = TextArtifact(value="".join(tokens).strip())
            else:
                result = self.try_run(prompt_stack)
                result.value = result.value.strip()

            self.after_run(result)

            return result
    else:
        raise Exception("prompt driver failed after all retry attempts")

token_count(prompt_stack)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def token_count(self, prompt_stack: PromptStack) -> int:
    return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

try_run(prompt_stack) abstractmethod

Source code in griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    ...

try_stream(prompt_stack) abstractmethod

Source code in griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    ...

BasePromptModelDriver

Bases: SerializableMixin, ABC

Source code in griptape/drivers/prompt_model/base_prompt_model_driver.py
@define
class BasePromptModelDriver(SerializableMixin, ABC):
    max_tokens: Optional[int] = field(default=None, kw_only=True)
    prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True)
    supports_streaming: bool = field(default=True, kw_only=True)

    @property
    @abstractmethod
    def tokenizer(self) -> BaseTokenizer:
        ...

    @abstractmethod
    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list | dict:
        ...

    @abstractmethod
    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        ...

    @abstractmethod
    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        ...

max_tokens: Optional[int] = field(default=None, kw_only=True) class-attribute instance-attribute

prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

supports_streaming: bool = field(default=True, kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer abstractmethod property

process_output(output) abstractmethod

Source code in griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    ...

prompt_stack_to_model_input(prompt_stack) abstractmethod

Source code in griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list | dict:
    ...

prompt_stack_to_model_params(prompt_stack) abstractmethod

Source code in griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    ...

BaseSqlDriver

Bases: ABC

Source code in griptape/drivers/sql/base_sql_driver.py
@define
class BaseSqlDriver(ABC):
    @dataclass
    class RowResult:
        cells: dict[str, Any]

    @abstractmethod
    def execute_query(self, query: str) -> Optional[list[RowResult]]:
        ...

    @abstractmethod
    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
        ...

    @abstractmethod
    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        ...

RowResult dataclass

Source code in griptape/drivers/sql/base_sql_driver.py
@dataclass
class RowResult:
    cells: dict[str, Any]
cells: dict[str, Any] instance-attribute
__init__(cells)

execute_query(query) abstractmethod

Source code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def execute_query(self, query: str) -> Optional[list[RowResult]]:
    ...

execute_query_raw(query) abstractmethod

Source code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
    ...

get_table_schema(table_name, schema=None) abstractmethod

Source code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    ...

BaseStructureRunDriver

Bases: ABC

Source code in griptape/drivers/structure_run/base_structure_run_driver.py
@define
class BaseStructureRunDriver(ABC):
    def run(self, *args: BaseArtifact) -> BaseArtifact:
        return self.try_run(*args)

    @abstractmethod
    def try_run(self, *args: BaseArtifact) -> BaseArtifact:
        ...

run(*args)

Source code in griptape/drivers/structure_run/base_structure_run_driver.py
def run(self, *args: BaseArtifact) -> BaseArtifact:
    return self.try_run(*args)

try_run(*args) abstractmethod

Source code in griptape/drivers/structure_run/base_structure_run_driver.py
@abstractmethod
def try_run(self, *args: BaseArtifact) -> BaseArtifact:
    ...

BaseVectorStoreDriver

Bases: SerializableMixin, ABC

Source code in griptape/drivers/vector/base_vector_store_driver.py
@define
class BaseVectorStoreDriver(SerializableMixin, ABC):
    DEFAULT_QUERY_COUNT = 5

    @dataclass
    class QueryResult:
        id: str
        vector: Optional[list[float]]
        score: float
        meta: Optional[dict] = None
        namespace: Optional[str] = None

    @dataclass
    class Entry:
        id: str
        vector: list[float]
        meta: Optional[dict] = None
        namespace: Optional[str] = None

    embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True})
    futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)

    def upsert_text_artifacts(
        self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, **kwargs
    ) -> None:
        utils.execute_futures_dict(
            {
                namespace: self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
                for namespace, artifact_list in artifacts.items()
                for a in artifact_list
            }
        )

    def upsert_text_artifact(
        self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs
    ) -> str:
        if not meta:
            meta = {}

        meta["artifact"] = artifact.to_json()

        if artifact.embedding:
            vector = artifact.embedding
        else:
            vector = artifact.generate_embedding(self.embedding_driver)

        if isinstance(vector, list):
            return self.upsert_vector(vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs)
        else:
            raise ValueError("Vector must be an instance of 'list'.")

    def upsert_text(
        self,
        string: str,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        return self.upsert_vector(
            self.embedding_driver.embed_string(string),
            vector_id=vector_id,
            namespace=namespace,
            meta=meta if meta else {},
            **kwargs,
        )

    @abstractmethod
    def delete_vector(self, vector_id: str) -> None:
        ...

    @abstractmethod
    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        ...

    @abstractmethod
    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[Entry]:
        ...

    @abstractmethod
    def load_entries(self, namespace: Optional[str] = None) -> list[Entry]:
        ...

    @abstractmethod
    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[QueryResult]:
        ...

DEFAULT_QUERY_COUNT = 5 class-attribute instance-attribute

embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) class-attribute instance-attribute

Entry dataclass

Source code in griptape/drivers/vector/base_vector_store_driver.py
@dataclass
class Entry:
    id: str
    vector: list[float]
    meta: Optional[dict] = None
    namespace: Optional[str] = None
id: str instance-attribute
meta: Optional[dict] = None class-attribute instance-attribute
namespace: Optional[str] = None class-attribute instance-attribute
vector: list[float] instance-attribute
__init__(id, vector, meta=None, namespace=None)

QueryResult dataclass

Source code in griptape/drivers/vector/base_vector_store_driver.py
@dataclass
class QueryResult:
    id: str
    vector: Optional[list[float]]
    score: float
    meta: Optional[dict] = None
    namespace: Optional[str] = None
id: str instance-attribute
meta: Optional[dict] = None class-attribute instance-attribute
namespace: Optional[str] = None class-attribute instance-attribute
score: float instance-attribute
vector: Optional[list[float]] instance-attribute
__init__(id, vector, score, meta=None, namespace=None)

delete_vector(vector_id) abstractmethod

Source code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def delete_vector(self, vector_id: str) -> None:
    ...

load_entries(namespace=None) abstractmethod

Source code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def load_entries(self, namespace: Optional[str] = None) -> list[Entry]:
    ...

load_entry(vector_id, namespace=None) abstractmethod

Source code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[Entry]:
    ...

query(query, count=None, namespace=None, include_vectors=False, **kwargs) abstractmethod

Source code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[QueryResult]:
    ...

upsert_text(string, vector_id=None, namespace=None, meta=None, **kwargs)

Source code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text(
    self,
    string: str,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    return self.upsert_vector(
        self.embedding_driver.embed_string(string),
        vector_id=vector_id,
        namespace=namespace,
        meta=meta if meta else {},
        **kwargs,
    )

upsert_text_artifact(artifact, namespace=None, meta=None, **kwargs)

Source code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifact(
    self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs
) -> str:
    if not meta:
        meta = {}

    meta["artifact"] = artifact.to_json()

    if artifact.embedding:
        vector = artifact.embedding
    else:
        vector = artifact.generate_embedding(self.embedding_driver)

    if isinstance(vector, list):
        return self.upsert_vector(vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs)
    else:
        raise ValueError("Vector must be an instance of 'list'.")

upsert_text_artifacts(artifacts, meta=None, **kwargs)

Source code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifacts(
    self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, **kwargs
) -> None:
    utils.execute_futures_dict(
        {
            namespace: self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
            for namespace, artifact_list in artifacts.items()
            for a in artifact_list
        }
    )

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs) abstractmethod

Source code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    ...

BaseWebScraperDriver

Bases: ABC

Source code in griptape/drivers/web_scraper/base_web_scraper_driver.py
6
7
8
9
class BaseWebScraperDriver(ABC):
    @abstractmethod
    def scrape_url(self, url: str) -> TextArtifact:
        ...

scrape_url(url) abstractmethod

Source code in griptape/drivers/web_scraper/base_web_scraper_driver.py
7
8
9
@abstractmethod
def scrape_url(self, url: str) -> TextArtifact:
    ...

BedrockClaudeImageQueryModelDriver

Bases: BaseImageQueryModelDriver

Source code in griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py
@define
class BedrockClaudeImageQueryModelDriver(BaseImageQueryModelDriver):
    ANTHROPIC_VERSION = "bedrock-2023-05-31"  # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example

    def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict:
        content = [self._construct_image_message(image) for image in images]
        content.append(self._construct_text_message(query))
        messages = self._construct_messages(content)
        input_params = {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens}

        return input_params

    def process_output(self, output: dict) -> TextArtifact:
        content_blocks = output["content"]
        if len(content_blocks) < 1:
            raise ValueError("Response content is empty")

        text_content = content_blocks[0]["text"]

        return TextArtifact(text_content)

    def _construct_image_message(self, image_data: ImageArtifact) -> dict:
        data = image_data.base64
        type = image_data.mime_type

        return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"}

    def _construct_text_message(self, query: str) -> dict:
        return {"text": query, "type": "text"}

    def _construct_messages(self, content: list) -> list:
        return [{"content": content, "role": "user"}]

ANTHROPIC_VERSION = 'bedrock-2023-05-31' class-attribute instance-attribute

image_query_request_parameters(query, images, max_tokens)

Source code in griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py
def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict:
    content = [self._construct_image_message(image) for image in images]
    content.append(self._construct_text_message(query))
    messages = self._construct_messages(content)
    input_params = {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens}

    return input_params

process_output(output)

Source code in griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py
def process_output(self, output: dict) -> TextArtifact:
    content_blocks = output["content"]
    if len(content_blocks) < 1:
        raise ValueError("Response content is empty")

    text_content = content_blocks[0]["text"]

    return TextArtifact(text_content)

BedrockClaudePromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
@define
class BedrockClaudePromptModelDriver(BasePromptModelDriver):
    ANTHROPIC_VERSION = "bedrock-2023-05-31"  # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example

    top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
    top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
    _tokenizer: BedrockClaudeTokenizer = field(default=None, kw_only=True)
    prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)

    @property
    def tokenizer(self) -> BedrockClaudeTokenizer:
        """Returns the tokenizer for this driver.

        We need to pass the `session` field from the Prompt Driver to the
        Tokenizer. However, the Prompt Driver is not initialized until after
        the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
        field a @property that is only initialized when it is first accessed.
        This ensures that by the time we need to initialize the Tokenizer, the
        Prompt Driver has already been initialized.

        See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

        Returns:
            BedrockClaudeTokenizer: The tokenizer for this driver.
        """
        if self._tokenizer:
            return self._tokenizer
        else:
            self._tokenizer = BedrockClaudeTokenizer(model=self.prompt_driver.model)
            return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        messages = [
            {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content}
            for prompt_input in prompt_stack.inputs
            if not prompt_input.is_system()
        ]
        system = next((i for i in prompt_stack.inputs if i.is_system()), None)

        if system is None:
            return {"messages": messages}
        else:
            return {"messages": messages, "system": system.content}

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        input = self.prompt_stack_to_model_input(prompt_stack)

        return {
            "stop_sequences": self.tokenizer.stop_sequences,
            "temperature": self.prompt_driver.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "max_tokens": self.prompt_driver.max_output_tokens(self.prompt_driver.prompt_stack_to_string(prompt_stack)),
            "anthropic_version": self.ANTHROPIC_VERSION,
            **input,
        }

    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        if isinstance(output, bytes):
            body = json.loads(output.decode())
        else:
            raise Exception("Output must be bytes.")

        if body["type"] == "content_block_delta":
            return TextArtifact(value=body["delta"]["text"])
        elif body["type"] == "message":
            return TextArtifact(value=body["content"][0]["text"])
        else:
            return TextArtifact(value="")

    def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
        if prompt_input.is_system():
            return "system"
        elif prompt_input.is_assistant():
            return "assistant"
        else:
            return "user"

ANTHROPIC_VERSION = 'bedrock-2023-05-31' class-attribute instance-attribute

prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

tokenizer: BedrockClaudeTokenizer property

Returns the tokenizer for this driver.

We need to pass the session field from the Prompt Driver to the Tokenizer. However, the Prompt Driver is not initialized until after the Prompt Model Driver is initialized. To resolve this, we make the tokenizer field a @property that is only initialized when it is first accessed. This ensures that by the time we need to initialize the Tokenizer, the Prompt Driver has already been initialized.

See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

Returns:

Name Type Description
BedrockClaudeTokenizer BedrockClaudeTokenizer

The tokenizer for this driver.

top_k: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p: float = field(default=0.999, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_anthropic_role(prompt_input)

Source code in griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str:
    if prompt_input.is_system():
        return "system"
    elif prompt_input.is_assistant():
        return "assistant"
    else:
        return "user"

process_output(output)

Source code in griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    if isinstance(output, bytes):
        body = json.loads(output.decode())
    else:
        raise Exception("Output must be bytes.")

    if body["type"] == "content_block_delta":
        return TextArtifact(value=body["delta"]["text"])
    elif body["type"] == "message":
        return TextArtifact(value=body["content"][0]["text"])
    else:
        return TextArtifact(value="")

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
    messages = [
        {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content}
        for prompt_input in prompt_stack.inputs
        if not prompt_input.is_system()
    ]
    system = next((i for i in prompt_stack.inputs if i.is_system()), None)

    if system is None:
        return {"messages": messages}
    else:
        return {"messages": messages, "system": system.content}

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    input = self.prompt_stack_to_model_input(prompt_stack)

    return {
        "stop_sequences": self.tokenizer.stop_sequences,
        "temperature": self.prompt_driver.temperature,
        "top_p": self.top_p,
        "top_k": self.top_k,
        "max_tokens": self.prompt_driver.max_output_tokens(self.prompt_driver.prompt_stack_to_string(prompt_stack)),
        "anthropic_version": self.ANTHROPIC_VERSION,
        **input,
    }

BedrockJurassicPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
@define
class BedrockJurassicPromptModelDriver(BasePromptModelDriver):
    top_p: float = field(default=0.9, kw_only=True, metadata={"serializable": True})
    _tokenizer: BedrockJurassicTokenizer = field(default=None, kw_only=True)
    prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
    supports_streaming: bool = field(default=False, kw_only=True)

    @property
    def tokenizer(self) -> BedrockJurassicTokenizer:
        """Returns the tokenizer for this driver.

        We need to pass the `session` field from the Prompt Driver to the
        Tokenizer. However, the Prompt Driver is not initialized until after
        the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
        field a @property that is only initialized when it is first accessed.
        This ensures that by the time we need to initialize the Tokenizer, the
        Prompt Driver has already been initialized.

        See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

        Returns:
            BedrockJurassicTokenizer: The tokenizer for this driver.
        """
        if self._tokenizer:
            return self._tokenizer
        else:
            self._tokenizer = BedrockJurassicTokenizer(model=self.prompt_driver.model)
            return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_user():
                prompt_lines.append(f"User: {i.content}")
            elif i.is_assistant():
                prompt_lines.append(f"Assistant: {i.content}")
            elif i.is_system():
                prompt_lines.append(f"System: {i.content}")
            else:
                prompt_lines.append(i.content)
        prompt_lines.append("Assistant:")

        prompt = "\n".join(prompt_lines)

        return {"prompt": prompt}

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)["prompt"]

        return {
            "maxTokens": self.prompt_driver.max_output_tokens(prompt),
            "temperature": self.prompt_driver.temperature,
            "stopSequences": self.tokenizer.stop_sequences,
            "countPenalty": {"scale": 0},
            "presencePenalty": {"scale": 0},
            "frequencyPenalty": {"scale": 0},
        }

    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        if isinstance(output, bytes):
            body = json.loads(output.decode())
        else:
            raise Exception("Output must be bytes.")
        return TextArtifact(body["completions"][0]["data"]["text"])

prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

supports_streaming: bool = field(default=False, kw_only=True) class-attribute instance-attribute

tokenizer: BedrockJurassicTokenizer property

Returns the tokenizer for this driver.

We need to pass the session field from the Prompt Driver to the Tokenizer. However, the Prompt Driver is not initialized until after the Prompt Model Driver is initialized. To resolve this, we make the tokenizer field a @property that is only initialized when it is first accessed. This ensures that by the time we need to initialize the Tokenizer, the Prompt Driver has already been initialized.

See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

Returns:

Name Type Description
BedrockJurassicTokenizer BedrockJurassicTokenizer

The tokenizer for this driver.

top_p: float = field(default=0.9, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

process_output(output)

Source code in griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    if isinstance(output, bytes):
        body = json.loads(output.decode())
    else:
        raise Exception("Output must be bytes.")
    return TextArtifact(body["completions"][0]["data"]["text"])

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_user():
            prompt_lines.append(f"User: {i.content}")
        elif i.is_assistant():
            prompt_lines.append(f"Assistant: {i.content}")
        elif i.is_system():
            prompt_lines.append(f"System: {i.content}")
        else:
            prompt_lines.append(i.content)
    prompt_lines.append("Assistant:")

    prompt = "\n".join(prompt_lines)

    return {"prompt": prompt}

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)["prompt"]

    return {
        "maxTokens": self.prompt_driver.max_output_tokens(prompt),
        "temperature": self.prompt_driver.temperature,
        "stopSequences": self.tokenizer.stop_sequences,
        "countPenalty": {"scale": 0},
        "presencePenalty": {"scale": 0},
        "frequencyPenalty": {"scale": 0},
    }

BedrockLlamaPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py
@define
class BedrockLlamaPromptModelDriver(BasePromptModelDriver):
    top_p: float = field(default=0.9, kw_only=True)
    _tokenizer: BedrockLlamaTokenizer = field(default=None, kw_only=True)
    prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)

    @property
    def tokenizer(self) -> BedrockLlamaTokenizer:
        """Returns the tokenizer for this driver.

        We need to pass the `session` field from the Prompt Driver to the
        Tokenizer. However, the Prompt Driver is not initialized until after
        the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
        field a @property that is only initialized when it is first accessed.
        This ensures that by the time we need to initialize the Tokenizer, the
        Prompt Driver has already been initialized.

        See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

        Returns:
            BedrockLlamaTokenizer: The tokenizer for this driver.
        """
        if self._tokenizer:
            return self._tokenizer
        else:
            self._tokenizer = BedrockLlamaTokenizer(model=self.prompt_driver.model)
            return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
        """
        Converts a `PromptStack` to a string that can be used as the input to the model.

        Prompt structure adapted from https://huggingface.co/blog/llama2#how-to-prompt-llama-2

        Args:
            prompt_stack: The `PromptStack` to convert.
        """
        prompt_lines = []

        inputs = iter(prompt_stack.inputs)
        input_pairs: list[tuple] = list(it.zip_longest(inputs, inputs))
        for input_pair in input_pairs:
            first_input: PromptStack.Input = input_pair[0]
            second_input: Optional[PromptStack.Input] = input_pair[1]

            if first_input.is_system():
                prompt_lines.append(f"<s>[INST] <<SYS>>\n{first_input.content}\n<</SYS>>\n\n")
                if second_input:
                    if second_input.is_user():
                        prompt_lines.append(f"{second_input.content} [/INST]")
                    else:
                        raise Exception("System input must be followed by user input.")
            elif first_input.is_assistant():
                prompt_lines.append(f" {first_input.content} </s>")
                if second_input:
                    if second_input.is_user():
                        prompt_lines.append(f"<s>[INST] {second_input.content} [/INST]")
                    else:
                        raise Exception("Assistant input must be followed by user input.")
            elif first_input.is_user():
                prompt_lines.append(f"<s>[INST] {first_input.content} [/INST]")
                if second_input:
                    if second_input.is_assistant():
                        prompt_lines.append(f" {second_input.content} </s>")
                    else:
                        raise Exception("User input must be followed by assistant input.")

        return "".join(prompt_lines)

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)

        return {
            "prompt": prompt,
            "max_gen_len": self.prompt_driver.max_output_tokens(prompt),
            "temperature": self.prompt_driver.temperature,
            "top_p": self.top_p,
        }

    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        # When streaming, the response body comes back as bytes.
        if isinstance(output, bytes):
            output = output.decode()
        elif isinstance(output, list):
            raise Exception("Invalid output format.")

        body = json.loads(output)

        return TextArtifact(body["generation"])

prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

tokenizer: BedrockLlamaTokenizer property

Returns the tokenizer for this driver.

We need to pass the session field from the Prompt Driver to the Tokenizer. However, the Prompt Driver is not initialized until after the Prompt Model Driver is initialized. To resolve this, we make the tokenizer field a @property that is only initialized when it is first accessed. This ensures that by the time we need to initialize the Tokenizer, the Prompt Driver has already been initialized.

See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

Returns:

Name Type Description
BedrockLlamaTokenizer BedrockLlamaTokenizer

The tokenizer for this driver.

top_p: float = field(default=0.9, kw_only=True) class-attribute instance-attribute

process_output(output)

Source code in griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    # When streaming, the response body comes back as bytes.
    if isinstance(output, bytes):
        output = output.decode()
    elif isinstance(output, list):
        raise Exception("Invalid output format.")

    body = json.loads(output)

    return TextArtifact(body["generation"])

prompt_stack_to_model_input(prompt_stack)

Converts a PromptStack to a string that can be used as the input to the model.

Prompt structure adapted from https://huggingface.co/blog/llama2#how-to-prompt-llama-2

Parameters:

Name Type Description Default
prompt_stack PromptStack

The PromptStack to convert.

required
Source code in griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
    """
    Converts a `PromptStack` to a string that can be used as the input to the model.

    Prompt structure adapted from https://huggingface.co/blog/llama2#how-to-prompt-llama-2

    Args:
        prompt_stack: The `PromptStack` to convert.
    """
    prompt_lines = []

    inputs = iter(prompt_stack.inputs)
    input_pairs: list[tuple] = list(it.zip_longest(inputs, inputs))
    for input_pair in input_pairs:
        first_input: PromptStack.Input = input_pair[0]
        second_input: Optional[PromptStack.Input] = input_pair[1]

        if first_input.is_system():
            prompt_lines.append(f"<s>[INST] <<SYS>>\n{first_input.content}\n<</SYS>>\n\n")
            if second_input:
                if second_input.is_user():
                    prompt_lines.append(f"{second_input.content} [/INST]")
                else:
                    raise Exception("System input must be followed by user input.")
        elif first_input.is_assistant():
            prompt_lines.append(f" {first_input.content} </s>")
            if second_input:
                if second_input.is_user():
                    prompt_lines.append(f"<s>[INST] {second_input.content} [/INST]")
                else:
                    raise Exception("Assistant input must be followed by user input.")
        elif first_input.is_user():
            prompt_lines.append(f"<s>[INST] {first_input.content} [/INST]")
            if second_input:
                if second_input.is_assistant():
                    prompt_lines.append(f" {second_input.content} </s>")
                else:
                    raise Exception("User input must be followed by assistant input.")

    return "".join(prompt_lines)

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)

    return {
        "prompt": prompt,
        "max_gen_len": self.prompt_driver.max_output_tokens(prompt),
        "temperature": self.prompt_driver.temperature,
        "top_p": self.top_p,
    }

BedrockStableDiffusionImageGenerationModelDriver

Bases: BaseImageGenerationModelDriver

Image generation model driver for Stable Diffusion models on Amazon Bedrock.

For more information on all supported parameters, see the Stable Diffusion documentation: https://platform.stability.ai/docs/api-reference#tag/v1generation

Attributes:

Name Type Description
cfg_scale int

Specifies how strictly image generation follows the provided prompt. Defaults to 7.

mask_source int

Specifies mask image configuration for image-to-image generations. Defaults to "MASK_IMAGE_BLACK".

style_preset Optional[str]

If provided, specifies a specific image generation style preset.

clip_guidance_preset Optional[str]

If provided, requests a specific clip guidance preset to be used in the diffusion process.

sampler Optional[str]

If provided, requests a specific sampler to be used in the diffusion process.

steps Optional[int]

If provided, specifies the number of diffusion steps to use in the image generation.

start_schedule Optional[float]

If provided, specifies the start_schedule parameter used to determine the influence of the input image in image-to-image generation.

Source code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
@define
class BedrockStableDiffusionImageGenerationModelDriver(BaseImageGenerationModelDriver):
    """Image generation model driver for Stable Diffusion models on Amazon Bedrock.

    For more information on all supported parameters, see the Stable Diffusion documentation:
        https://platform.stability.ai/docs/api-reference#tag/v1generation

    Attributes:
        cfg_scale: Specifies how strictly image generation follows the provided prompt. Defaults to 7.
        mask_source: Specifies mask image configuration for image-to-image generations. Defaults to "MASK_IMAGE_BLACK".
        style_preset: If provided, specifies a specific image generation style preset.
        clip_guidance_preset: If provided, requests a specific clip guidance preset to be used in the diffusion process.
        sampler: If provided, requests a specific sampler to be used in the diffusion process.
        steps: If provided, specifies the number of diffusion steps to use in the image generation.
        start_schedule: If provided, specifies the start_schedule parameter used to determine the influence of the input
            image in image-to-image generation.
    """

    cfg_scale: int = field(default=7, kw_only=True, metadata={"serializable": True})
    style_preset: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    clip_guidance_preset: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    sampler: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    start_schedule: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})

    def text_to_image_request_parameters(
        self,
        prompts: list[str],
        image_width: int,
        image_height: int,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        return self._request_parameters(
            prompts, width=image_width, height=image_height, negative_prompts=negative_prompts, seed=seed
        )

    def image_variation_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        return self._request_parameters(prompts, image=image, negative_prompts=negative_prompts, seed=seed)

    def image_inpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        return self._request_parameters(
            prompts,
            image=image,
            mask=mask,
            mask_source="MASK_IMAGE_BLACK",
            negative_prompts=negative_prompts,
            seed=seed,
        )

    def image_outpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        return self._request_parameters(
            prompts,
            image=image,
            mask=mask,
            mask_source="MASK_IMAGE_WHITE",
            negative_prompts=negative_prompts,
            seed=seed,
        )

    def _request_parameters(
        self,
        prompts: list[str],
        width: Optional[int] = None,
        height: Optional[int] = None,
        image: Optional[ImageArtifact] = None,
        mask: Optional[ImageArtifact] = None,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
        mask_source: Optional[str] = None,
    ) -> dict:
        if negative_prompts is None:
            negative_prompts = []

        text_prompts = [{"text": prompt, "weight": 1.0} for prompt in prompts]
        text_prompts += [{"text": negative_prompt, "weight": -1.0} for negative_prompt in negative_prompts]

        request = {"text_prompts": text_prompts, "cfg_scale": self.cfg_scale}

        if self.style_preset is not None:
            request["style_preset"] = self.style_preset

        if self.clip_guidance_preset is not None:
            request["clip_guidance_preset"] = self.clip_guidance_preset

        if self.sampler is not None:
            request["sampler"] = self.sampler

        if image is not None:
            request["init_image"] = image.base64
            request["width"] = image.width
            request["height"] = image.height
        else:
            request["width"] = width
            request["height"] = height

        if self.steps is not None:
            request["steps"] = self.steps

        if seed is not None:
            request["seed"] = seed

        if mask is not None:
            if not mask_source:
                raise ValueError("mask_source must be provided when mask is provided")

            request["mask_source"] = mask_source
            request["mask_image"] = mask.base64

        if self.start_schedule is not None:
            request["start_schedule"] = self.start_schedule

        return request

    def get_generated_image(self, response: dict) -> bytes:
        image_response = response["artifacts"][0]

        # finishReason may be SUCCESS, CONTENT_FILTERED, or ERROR.
        if image_response.get("finishReason") == "ERROR":
            raise Exception(f"Image generation failed: {image_response.get('finishReason')}")
        elif image_response.get("finishReason") == "CONTENT_FILTERED":
            logging.warning(f"Image generation triggered content filter and may be blurred")

        return base64.decodebytes(bytes(image_response.get("base64"), "utf-8"))

cfg_scale: int = field(default=7, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

clip_guidance_preset: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

sampler: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

start_schedule: Optional[float] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

steps: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

style_preset: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

get_generated_image(response)

Source code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def get_generated_image(self, response: dict) -> bytes:
    image_response = response["artifacts"][0]

    # finishReason may be SUCCESS, CONTENT_FILTERED, or ERROR.
    if image_response.get("finishReason") == "ERROR":
        raise Exception(f"Image generation failed: {image_response.get('finishReason')}")
    elif image_response.get("finishReason") == "CONTENT_FILTERED":
        logging.warning(f"Image generation triggered content filter and may be blurred")

    return base64.decodebytes(bytes(image_response.get("base64"), "utf-8"))

image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)

Source code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_inpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    return self._request_parameters(
        prompts,
        image=image,
        mask=mask,
        mask_source="MASK_IMAGE_BLACK",
        negative_prompts=negative_prompts,
        seed=seed,
    )

image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)

Source code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_outpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    return self._request_parameters(
        prompts,
        image=image,
        mask=mask,
        mask_source="MASK_IMAGE_WHITE",
        negative_prompts=negative_prompts,
        seed=seed,
    )

image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None)

Source code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_variation_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    return self._request_parameters(prompts, image=image, negative_prompts=negative_prompts, seed=seed)

text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None)

Source code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def text_to_image_request_parameters(
    self,
    prompts: list[str],
    image_width: int,
    image_height: int,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    return self._request_parameters(
        prompts, width=image_width, height=image_height, negative_prompts=negative_prompts, seed=seed
    )

BedrockTitanImageGenerationModelDriver

Bases: BaseImageGenerationModelDriver

Image Generation Model Driver for Amazon Bedrock Titan Image Generator.

Attributes:

Name Type Description
quality str

The quality of the generated image, defaults to standard.

cfg_scale int

Specifies how strictly image generation follows the provided prompt. Defaults to 7, (1.0 to 10.0].

outpainting_mode str

Specifies the outpainting mode, defaults to PRECISE.

Source code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
@define
class BedrockTitanImageGenerationModelDriver(BaseImageGenerationModelDriver):
    """Image Generation Model Driver for Amazon Bedrock Titan Image Generator.

    Attributes:
        quality: The quality of the generated image, defaults to standard.
        cfg_scale: Specifies how strictly image generation follows the provided prompt. Defaults to 7, (1.0 to 10.0].
        outpainting_mode: Specifies the outpainting mode, defaults to PRECISE.
    """

    quality: str = field(default="standard", kw_only=True, metadata={"serializable": True})
    cfg_scale: int = field(default=7, kw_only=True, metadata={"serializable": True})
    outpainting_mode: str = field(default="PRECISE", kw_only=True, metadata={"serializable": True})

    def text_to_image_request_parameters(
        self,
        prompts: list[str],
        image_width: int,
        image_height: int,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        prompt = ", ".join(prompts)

        request = {
            "taskType": "TEXT_IMAGE",
            "textToImageParams": {"text": prompt},
            "imageGenerationConfig": {
                "numberOfImages": 1,
                "quality": self.quality,
                "width": image_width,
                "height": image_height,
                "cfgScale": self.cfg_scale,
            },
        }

        if negative_prompts:
            request["textToImageParams"]["negativeText"] = ", ".join(negative_prompts)

        if seed:
            request["imageGenerationConfig"]["seed"] = seed

        return self._add_common_params(request, image_width, image_height, seed=seed)

    def image_variation_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        prompt = ", ".join(prompts)

        request = {
            "taskType": "IMAGE_VARIATION",
            "imageVariationParams": {"text": prompt, "images": [image.base64]},
            "imageGenerationConfig": {
                "numberOfImages": 1,
                "quality": self.quality,
                "width": image.width,
                "height": image.height,
                "cfgScale": self.cfg_scale,
            },
        }

        if negative_prompts:
            request["imageVariationParams"]["negativeText"] = ", ".join(negative_prompts)

        if seed:
            request["imageGenerationConfig"]["seed"] = seed

        return self._add_common_params(request, image.width, image.height, seed=seed)

    def image_inpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        prompt = ", ".join(prompts)

        request = {
            "taskType": "INPAINTING",
            "inPaintingParams": {"text": prompt, "image": image.base64, "maskImage": mask.base64},
        }

        if negative_prompts:
            request["inPaintingParams"]["negativeText"] = ", ".join(negative_prompts)

        return self._add_common_params(request, image.width, image.height, seed=seed)

    def image_outpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        prompt = ", ".join(prompts)

        request = {
            "taskType": "OUTPAINTING",
            "outPaintingParams": {
                "text": prompt,
                "image": image.base64,
                "maskImage": mask.base64,
                "outPaintingMode": self.outpainting_mode,
            },
        }

        if negative_prompts:
            request["outPaintingParams"]["negativeText"] = ", ".join(negative_prompts)

        return self._add_common_params(request, image.width, image.height, seed=seed)

    def get_generated_image(self, response: dict) -> bytes:
        b64_image_data = response["images"][0]

        return base64.decodebytes(bytes(b64_image_data, "utf-8"))

    def _add_common_params(self, request: dict[str, Any], width: int, height: int, seed: Optional[int] = None) -> dict:
        request["imageGenerationConfig"] = {
            "numberOfImages": 1,
            "quality": self.quality,
            "width": width,
            "height": height,
            "cfgScale": self.cfg_scale,
        }

        if seed:
            request["imageGenerationConfig"]["seed"] = seed

        return request

cfg_scale: int = field(default=7, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

outpainting_mode: str = field(default='PRECISE', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

quality: str = field(default='standard', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

get_generated_image(response)

Source code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def get_generated_image(self, response: dict) -> bytes:
    b64_image_data = response["images"][0]

    return base64.decodebytes(bytes(b64_image_data, "utf-8"))

image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)

Source code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_inpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    prompt = ", ".join(prompts)

    request = {
        "taskType": "INPAINTING",
        "inPaintingParams": {"text": prompt, "image": image.base64, "maskImage": mask.base64},
    }

    if negative_prompts:
        request["inPaintingParams"]["negativeText"] = ", ".join(negative_prompts)

    return self._add_common_params(request, image.width, image.height, seed=seed)

image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)

Source code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_outpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    prompt = ", ".join(prompts)

    request = {
        "taskType": "OUTPAINTING",
        "outPaintingParams": {
            "text": prompt,
            "image": image.base64,
            "maskImage": mask.base64,
            "outPaintingMode": self.outpainting_mode,
        },
    }

    if negative_prompts:
        request["outPaintingParams"]["negativeText"] = ", ".join(negative_prompts)

    return self._add_common_params(request, image.width, image.height, seed=seed)

image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None)

Source code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_variation_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    prompt = ", ".join(prompts)

    request = {
        "taskType": "IMAGE_VARIATION",
        "imageVariationParams": {"text": prompt, "images": [image.base64]},
        "imageGenerationConfig": {
            "numberOfImages": 1,
            "quality": self.quality,
            "width": image.width,
            "height": image.height,
            "cfgScale": self.cfg_scale,
        },
    }

    if negative_prompts:
        request["imageVariationParams"]["negativeText"] = ", ".join(negative_prompts)

    if seed:
        request["imageGenerationConfig"]["seed"] = seed

    return self._add_common_params(request, image.width, image.height, seed=seed)

text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None)

Source code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def text_to_image_request_parameters(
    self,
    prompts: list[str],
    image_width: int,
    image_height: int,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    prompt = ", ".join(prompts)

    request = {
        "taskType": "TEXT_IMAGE",
        "textToImageParams": {"text": prompt},
        "imageGenerationConfig": {
            "numberOfImages": 1,
            "quality": self.quality,
            "width": image_width,
            "height": image_height,
            "cfgScale": self.cfg_scale,
        },
    }

    if negative_prompts:
        request["textToImageParams"]["negativeText"] = ", ".join(negative_prompts)

    if seed:
        request["imageGenerationConfig"]["seed"] = seed

    return self._add_common_params(request, image_width, image_height, seed=seed)

BedrockTitanPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
@define
class BedrockTitanPromptModelDriver(BasePromptModelDriver):
    top_p: float = field(default=0.9, kw_only=True, metadata={"serializable": True})
    _tokenizer: BedrockTitanTokenizer = field(default=None, kw_only=True)
    prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)

    @property
    def tokenizer(self) -> BedrockTitanTokenizer:
        """Returns the tokenizer for this driver.

        We need to pass the `session` field from the Prompt Driver to the
        Tokenizer. However, the Prompt Driver is not initialized until after
        the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
        field a @property that is only initialized when it is first accessed.
        This ensures that by the time we need to initialize the Tokenizer, the
        Prompt Driver has already been initialized.

        See this thread for more information: https://github.com/griptape-ai/griptape/issues/244

        Returns:
            BedrockTitanTokenizer: The tokenizer for this driver.
        """
        if self._tokenizer:
            return self._tokenizer
        else:
            self._tokenizer = BedrockTitanTokenizer(model=self.prompt_driver.model)
            return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_user():
                prompt_lines.append(f"User: {i.content}")
            elif i.is_assistant():
                prompt_lines.append(f"Bot: {i.content}")
            elif i.is_system():
                prompt_lines.append(f"Instructions: {i.content}")
            else:
                prompt_lines.append(i.content)
        prompt_lines.append("Bot:")

        prompt = "\n\n".join(prompt_lines)

        return {"inputText": prompt}

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)["inputText"]

        return {
            "textGenerationConfig": {
                "maxTokenCount": self.prompt_driver.max_output_tokens(prompt),
                "stopSequences": self.tokenizer.stop_sequences,
                "temperature": self.prompt_driver.temperature,
                "topP": self.top_p,
            }
        }

    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        # When streaming, the response body comes back as bytes.
        if isinstance(output, str) or isinstance(output, bytes):
            if isinstance(output, bytes):
                output = output.decode()

            body = json.loads(output)

            if self.prompt_driver.stream:
                return TextArtifact(body["outputText"])
            else:
                return TextArtifact(body["results"][0]["outputText"])
        else:
            raise ValueError("output must be an instance of 'str' or 'bytes'")

prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

tokenizer: BedrockTitanTokenizer property

Returns the tokenizer for this driver.

We need to pass the session field from the Prompt Driver to the Tokenizer. However, the Prompt Driver is not initialized until after the Prompt Model Driver is initialized. To resolve this, we make the tokenizer field a @property that is only initialized when it is first accessed. This ensures that by the time we need to initialize the Tokenizer, the Prompt Driver has already been initialized.

See this thread for more information: https://github.com/griptape-ai/griptape/issues/244

Returns:

Name Type Description
BedrockTitanTokenizer BedrockTitanTokenizer

The tokenizer for this driver.

top_p: float = field(default=0.9, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

process_output(output)

Source code in griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    # When streaming, the response body comes back as bytes.
    if isinstance(output, str) or isinstance(output, bytes):
        if isinstance(output, bytes):
            output = output.decode()

        body = json.loads(output)

        if self.prompt_driver.stream:
            return TextArtifact(body["outputText"])
        else:
            return TextArtifact(body["results"][0]["outputText"])
    else:
        raise ValueError("output must be an instance of 'str' or 'bytes'")

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_user():
            prompt_lines.append(f"User: {i.content}")
        elif i.is_assistant():
            prompt_lines.append(f"Bot: {i.content}")
        elif i.is_system():
            prompt_lines.append(f"Instructions: {i.content}")
        else:
            prompt_lines.append(i.content)
    prompt_lines.append("Bot:")

    prompt = "\n\n".join(prompt_lines)

    return {"inputText": prompt}

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)["inputText"]

    return {
        "textGenerationConfig": {
            "maxTokenCount": self.prompt_driver.max_output_tokens(prompt),
            "stopSequences": self.tokenizer.stop_sequences,
            "temperature": self.prompt_driver.temperature,
            "topP": self.top_p,
        }
    }

CoherePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key str

Cohere API key.

model str

Cohere model name.

client Client

Custom cohere.Client.

tokenizer CohereTokenizer

Custom CohereTokenizer.

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
@define
class CoherePromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Cohere API key.
        model: 	Cohere model name.
        client: Custom `cohere.Client`.
        tokenizer: Custom `CohereTokenizer`.
    """

    api_key: str = field(kw_only=True, metadata={"serializable": True})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Client = field(
        default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
        kw_only=True,
    )
    tokenizer: CohereTokenizer = field(
        default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        result = self.client.generate(**self._base_params(prompt_stack))

        if result.generations:
            if len(result.generations) == 1:
                generation = result.generations[0]

                return TextArtifact(value=generation.text.strip())
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        result = self.client.generate(**self._base_params(prompt_stack), stream=True)

        for chunk in result:
            yield TextArtifact(value=chunk.text)

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_string(prompt_stack)
        return {
            "prompt": self.prompt_stack_to_string(prompt_stack),
            "model": self.model,
            "temperature": self.temperature,
            "end_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_output_tokens(prompt),
        }

api_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: Client = field(default=Factory(lambda self: import_optional_dependency('cohere').Client(self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: CohereTokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    result = self.client.generate(**self._base_params(prompt_stack))

    if result.generations:
        if len(result.generations) == 1:
            generation = result.generations[0]

            return TextArtifact(value=generation.text.strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    result = self.client.generate(**self._base_params(prompt_stack), stream=True)

    for chunk in result:
        yield TextArtifact(value=chunk.text)

DummyEmbeddingDriver

Bases: BaseEmbeddingDriver

Source code in griptape/drivers/embedding/dummy_embedding_driver.py
@define
class DummyEmbeddingDriver(BaseEmbeddingDriver):
    model: str = field(init=False)

    def try_embed_chunk(self, chunk: str) -> list[float]:
        raise DummyException(__class__.__name__, "try_embed_chunk")

model: str = field(init=False) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/dummy_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    raise DummyException(__class__.__name__, "try_embed_chunk")

DummyImageGenerationDriver

Bases: BaseImageGenerationDriver

Source code in griptape/drivers/image_generation/dummy_image_generation_driver.py
@define
class DummyImageGenerationDriver(BaseImageGenerationDriver):
    model: str = field(init=False)

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        raise DummyException(__class__.__name__, "try_text_to_image")

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        raise DummyException(__class__.__name__, "try_image_variation")

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise DummyException(__class__.__name__, "try_image_inpainting")

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise DummyException(__class__.__name__, "try_image_outpainting")

model: str = field(init=False) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise DummyException(__class__.__name__, "try_image_inpainting")

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise DummyException(__class__.__name__, "try_image_outpainting")

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    raise DummyException(__class__.__name__, "try_image_variation")

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    raise DummyException(__class__.__name__, "try_text_to_image")

DummyImageQueryDriver

Bases: BaseImageQueryDriver

Source code in griptape/drivers/image_query/dummy_image_query_driver.py
@define
class DummyImageQueryDriver(BaseImageQueryDriver):
    model: str = field(init=False)
    max_tokens: int = field(init=False)

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        raise DummyException(__class__.__name__, "try_query")

max_tokens: int = field(init=False) class-attribute instance-attribute

model: str = field(init=False) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/dummy_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    raise DummyException(__class__.__name__, "try_query")

DummyPromptDriver

Bases: BasePromptDriver

Source code in griptape/drivers/prompt/dummy_prompt_driver.py
@define
class DummyPromptDriver(BasePromptDriver):
    model: str = field(init=False)
    tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True)

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        raise DummyException(__class__.__name__, "try_run")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        raise DummyException(__class__.__name__, "try_stream")

model: str = field(init=False) class-attribute instance-attribute

tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/dummy_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    raise DummyException(__class__.__name__, "try_run")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/dummy_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    raise DummyException(__class__.__name__, "try_stream")

DummyVectorStoreDriver

Bases: BaseVectorStoreDriver

Source code in griptape/drivers/vector/dummy_vector_store_driver.py
@define()
class DummyVectorStoreDriver(BaseVectorStoreDriver):
    embedding_driver: BaseEmbeddingDriver = field(
        kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}
    )

    def delete_vector(self, vector_id: str) -> None:
        raise DummyException(__class__.__name__, "delete_vector")

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        raise DummyException(__class__.__name__, "upsert_vector")

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        raise DummyException(__class__.__name__, "load_entry")

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        raise DummyException(__class__.__name__, "load_entries")

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        raise DummyException(__class__.__name__, "query")

embedding_driver: BaseEmbeddingDriver = field(kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={'serializable': True}) class-attribute instance-attribute

delete_vector(vector_id)

Source code in griptape/drivers/vector/dummy_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None:
    raise DummyException(__class__.__name__, "delete_vector")

load_entries(namespace=None)

Source code in griptape/drivers/vector/dummy_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    raise DummyException(__class__.__name__, "load_entries")

load_entry(vector_id, namespace=None)

Source code in griptape/drivers/vector/dummy_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    raise DummyException(__class__.__name__, "load_entry")

query(query, count=None, namespace=None, include_vectors=False, **kwargs)

Source code in griptape/drivers/vector/dummy_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    raise DummyException(__class__.__name__, "query")

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source code in griptape/drivers/vector/dummy_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    raise DummyException(__class__.__name__, "upsert_vector")

GoogleEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
api_key Optional[str]

Google API key.

model str

Google model name.

task_type str

Embedding model task type (https://ai.google.dev/tutorials/python_quickstart#use_embeddings). Defaults to retrieval_document.

title Optional[str]

Optional title for the content. Only works with retrieval_document task type.

Source code in griptape/drivers/embedding/google_embedding_driver.py
@define
class GoogleEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        api_key: Google API key.
        model: Google model name.
        task_type: Embedding model task type (https://ai.google.dev/tutorials/python_quickstart#use_embeddings). Defaults to `retrieval_document`.
        title: Optional title for the content. Only works with `retrieval_document` task type.
    """

    DEFAULT_MODEL = "models/embedding-001"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True})
    title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_embed_chunk(self, chunk: str) -> list[float]:
        genai = import_optional_dependency("google.generativeai")
        genai.configure(api_key=self.api_key)

        result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title)

        return result["embedding"]

    def _params(self, chunk: str) -> dict:
        return {"input": chunk, "model": self.model}

DEFAULT_MODEL = 'models/embedding-001' class-attribute instance-attribute

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

task_type: str = field(default='retrieval_document', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

title: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/google_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    genai = import_optional_dependency("google.generativeai")
    genai.configure(api_key=self.api_key)

    result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title)

    return result["embedding"]

GooglePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key Optional[str]

Google API key.

model str

Google model name.

model_client Any

Custom GenerativeModel client.

tokenizer BaseTokenizer

Custom GoogleTokenizer.

top_p Optional[float]

Optional value for top_p.

top_k Optional[int]

Optional value for top_k.

Source code in griptape/drivers/prompt/google_prompt_driver.py
@define
class GooglePromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Google API key.
        model: Google model name.
        model_client: Custom `GenerativeModel` client.
        tokenizer: Custom `GoogleTokenizer`.
        top_p: Optional value for top_p.
        top_k: Optional value for top_k.
    """

    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    model_client: Any = field(default=Factory(lambda self: self._default_model_client(), takes_self=True), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True),
        kw_only=True,
    )
    top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
    top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

        inputs = self._prompt_stack_to_model_input(prompt_stack)
        response = self.model_client.generate_content(
            inputs,
            generation_config=GenerationConfig(
                stop_sequences=self.tokenizer.stop_sequences,
                max_output_tokens=self.max_output_tokens(inputs),
                temperature=self.temperature,
                top_p=self.top_p,
                top_k=self.top_k,
            ),
        )

        return TextArtifact(value=response.text)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

        inputs = self._prompt_stack_to_model_input(prompt_stack)
        response = self.model_client.generate_content(
            inputs,
            stream=True,
            generation_config=GenerationConfig(
                stop_sequences=self.tokenizer.stop_sequences,
                max_output_tokens=self.max_output_tokens(inputs),
                temperature=self.temperature,
                top_p=self.top_p,
                top_k=self.top_k,
            ),
        )

        for chunk in response:
            yield TextArtifact(value=chunk.text)

    def _default_model_client(self) -> GenerativeModel:
        genai = import_optional_dependency("google.generativeai")
        genai.configure(api_key=self.api_key)

        return genai.GenerativeModel(self.model)

    def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list[ContentDict]:
        inputs = [
            self.__to_content_dict(prompt_input) for prompt_input in prompt_stack.inputs if not prompt_input.is_system()
        ]

        # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history.
        system = next((i for i in prompt_stack.inputs if i.is_system()), None)
        if system is not None:
            inputs[0]["parts"].insert(0, system.content)

        return inputs

    def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict:
        ContentDict = import_optional_dependency("google.generativeai.types").ContentDict

        return ContentDict({"role": self.__to_google_role(prompt_input), "parts": [prompt_input.content]})

    def __to_google_role(self, prompt_input: PromptStack.Input) -> str:
        if prompt_input.is_system():
            return "user"
        elif prompt_input.is_assistant():
            return "model"
        else:
            return "user"

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model_client: Any = field(default=Factory(lambda self: self._default_model_client(), takes_self=True), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

top_k: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

top_p: Optional[float] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_content_dict(prompt_input)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict:
    ContentDict = import_optional_dependency("google.generativeai.types").ContentDict

    return ContentDict({"role": self.__to_google_role(prompt_input), "parts": [prompt_input.content]})

__to_google_role(prompt_input)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_role(self, prompt_input: PromptStack.Input) -> str:
    if prompt_input.is_system():
        return "user"
    elif prompt_input.is_assistant():
        return "model"
    else:
        return "user"

try_run(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

    inputs = self._prompt_stack_to_model_input(prompt_stack)
    response = self.model_client.generate_content(
        inputs,
        generation_config=GenerationConfig(
            stop_sequences=self.tokenizer.stop_sequences,
            max_output_tokens=self.max_output_tokens(inputs),
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
        ),
    )

    return TextArtifact(value=response.text)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/google_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig

    inputs = self._prompt_stack_to_model_input(prompt_stack)
    response = self.model_client.generate_content(
        inputs,
        stream=True,
        generation_config=GenerationConfig(
            stop_sequences=self.tokenizer.stop_sequences,
            max_output_tokens=self.max_output_tokens(inputs),
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
        ),
    )

    for chunk in response:
        yield TextArtifact(value=chunk.text)

GriptapeCloudEventListenerDriver

Bases: BaseEventListenerDriver

Driver for publishing events to Griptape Cloud.

Attributes:

Name Type Description
base_url str

The base URL of Griptape Cloud. Defaults to the GT_CLOUD_BASE_URL environment variable.

api_key str

The API key to authenticate with Griptape Cloud.

headers dict

The headers to use when making requests to Griptape Cloud. Defaults to include the Authorization header.

structure_run_id str

The ID of the Structure Run to publish events to. Defaults to the GT_CLOUD_STRUCTURE_RUN_ID environment variable.

Source code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
@define
class GriptapeCloudEventListenerDriver(BaseEventListenerDriver):
    """Driver for publishing events to Griptape Cloud.

    Attributes:
        base_url: The base URL of Griptape Cloud. Defaults to the GT_CLOUD_BASE_URL environment variable.
        api_key: The API key to authenticate with Griptape Cloud.
        headers: The headers to use when making requests to Griptape Cloud. Defaults to include the Authorization header.
        structure_run_id: The ID of the Structure Run to publish events to. Defaults to the GT_CLOUD_STRUCTURE_RUN_ID environment variable.
    """

    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True
    )
    api_key: str = field(kw_only=True)
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
    )
    structure_run_id: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True)

    @structure_run_id.validator  # pyright: ignore
    def validate_run_id(self, _, structure_run_id: str):
        if structure_run_id is None:
            raise ValueError(
                "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)."
            )

    def try_publish_event_payload(self, event_payload: dict) -> None:
        url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events")

        response = requests.post(url=url, json=event_payload, headers=self.headers)
        response.raise_for_status()

api_key: str = field(kw_only=True) class-attribute instance-attribute

base_url: str = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')), kw_only=True) class-attribute instance-attribute

headers: dict = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

structure_run_id: str = field(default=Factory(lambda: os.getenv('GT_CLOUD_STRUCTURE_RUN_ID')), kw_only=True) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{self.structure_run_id}/events")

    response = requests.post(url=url, json=event_payload, headers=self.headers)
    response.raise_for_status()

validate_run_id(_, structure_run_id)

Source code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
@structure_run_id.validator  # pyright: ignore
def validate_run_id(self, _, structure_run_id: str):
    if structure_run_id is None:
        raise ValueError(
            "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)."
        )

GriptapeCloudStructureRunDriver

Bases: BaseStructureRunDriver

Source code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
@define
class GriptapeCloudStructureRunDriver(BaseStructureRunDriver):
    base_url: str = field(default="https://cloud.griptape.ai", kw_only=True)
    api_key: str = field(kw_only=True)
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
    )
    structure_id: str = field(kw_only=True)
    structure_run_wait_time_interval: int = field(default=2, kw_only=True)
    structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True)
    async_run: bool = field(default=False, kw_only=True)

    def try_run(self, *args: BaseArtifact) -> BaseArtifact:
        from requests import HTTPError, Response, exceptions, post

        url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs")

        try:
            response: Response = post(url, json={"args": [arg.value for arg in args]}, headers=self.headers)
            response.raise_for_status()
            response_json = response.json()

            if self.async_run:
                return InfoArtifact("Run started successfully")
            else:
                return self._get_structure_run_result(response_json["structure_run_id"])
        except (exceptions.RequestException, HTTPError) as err:
            return ErrorArtifact(str(err))

    def _get_structure_run_result(self, structure_run_id: str) -> InfoArtifact | TextArtifact | ErrorArtifact:
        url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{structure_run_id}")

        result = self._get_structure_run_result_attempt(url)
        status = result["status"]

        wait_attempts = 0
        while status in ("QUEUED", "RUNNING") and wait_attempts < self.structure_run_max_wait_time_attempts:
            # wait
            time.sleep(self.structure_run_wait_time_interval)
            wait_attempts += 1
            result = self._get_structure_run_result_attempt(url)
            status = result["status"]

        if wait_attempts >= self.structure_run_max_wait_time_attempts:
            return ErrorArtifact(
                f"Failed to get Run result after {self.structure_run_max_wait_time_attempts} attempts."
            )

        if status != "SUCCEEDED":
            return ErrorArtifact(result)

        if "output" in result:
            return TextArtifact.from_dict(result["output"])
        else:
            return InfoArtifact("No output found in response")

    def _get_structure_run_result_attempt(self, structure_run_url: str) -> Any:
        from requests import get, Response

        response: Response = get(structure_run_url, headers=self.headers)
        response.raise_for_status()

        return response.json()

api_key: str = field(kw_only=True) class-attribute instance-attribute

async_run: bool = field(default=False, kw_only=True) class-attribute instance-attribute

base_url: str = field(default='https://cloud.griptape.ai', kw_only=True) class-attribute instance-attribute

headers: dict = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

structure_id: str = field(kw_only=True) class-attribute instance-attribute

structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True) class-attribute instance-attribute

structure_run_wait_time_interval: int = field(default=2, kw_only=True) class-attribute instance-attribute

try_run(*args)

Source code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def try_run(self, *args: BaseArtifact) -> BaseArtifact:
    from requests import HTTPError, Response, exceptions, post

    url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs")

    try:
        response: Response = post(url, json={"args": [arg.value for arg in args]}, headers=self.headers)
        response.raise_for_status()
        response_json = response.json()

        if self.async_run:
            return InfoArtifact("Run started successfully")
        else:
            return self._get_structure_run_result(response_json["structure_run_id"])
    except (exceptions.RequestException, HTTPError) as err:
        return ErrorArtifact(str(err))

HuggingFaceHubEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
api_token str

Hugging Face Hub API token.

model str

Hugging Face Hub model name.

client InferenceClient

Custom InferenceApi.

Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
@define
class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        api_token: Hugging Face Hub API token.
        model: Hugging Face Hub model name.
        client: Custom `InferenceApi`.
    """

    api_token: str = field(kw_only=True, metadata={"serializable": True})
    client: InferenceClient = field(
        default=Factory(
            lambda self: import_optional_dependency("huggingface_hub").InferenceClient(
                model=self.model, token=self.api_token
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        response = self.client.feature_extraction(chunk)

        return response.flatten().tolist()

api_token: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: InferenceClient = field(default=Factory(lambda self: import_optional_dependency('huggingface_hub').InferenceClient(model=self.model, token=self.api_token), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    response = self.client.feature_extraction(chunk)

    return response.flatten().tolist()

HuggingFaceHubPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_token str

Hugging Face Hub API token.

use_gpu str

Use GPU during model run.

params dict

Custom model run parameters.

model str

Hugging Face Hub model name.

client InferenceClient

Custom InferenceApi.

tokenizer HuggingFaceTokenizer

Custom HuggingFaceTokenizer.

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@define
class HuggingFaceHubPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_token: Hugging Face Hub API token.
        use_gpu: Use GPU during model run.
        params: Custom model run parameters.
        model: Hugging Face Hub model name.
        client: Custom `InferenceApi`.
        tokenizer: Custom `HuggingFaceTokenizer`.

    """

    api_token: str = field(kw_only=True, metadata={"serializable": True})
    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: InferenceClient = field(
        default=Factory(
            lambda self: import_optional_dependency("huggingface_hub").InferenceClient(
                model=self.model, token=self.api_token
            ),
            takes_self=True,
        ),
        kw_only=True,
    )
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
                max_output_tokens=self.max_tokens,
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        prompt = self.prompt_stack_to_string(prompt_stack)

        response = self.client.text_generation(
            prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), **self.params
        )

        return TextArtifact(value=response)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        prompt = self.prompt_stack_to_string(prompt_stack)

        response = self.client.text_generation(
            prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), stream=True, **self.params
        )

        for token in response:
            yield TextArtifact(value=token)

api_token: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: InferenceClient = field(default=Factory(lambda self: import_optional_dependency('huggingface_hub').InferenceClient(model=self.model, token=self.api_token), takes_self=True), kw_only=True) class-attribute instance-attribute

max_tokens: int = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

params: dict = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(tokenizer=import_optional_dependency('transformers').AutoTokenizer.from_pretrained(self.model), max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    prompt = self.prompt_stack_to_string(prompt_stack)

    response = self.client.text_generation(
        prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), **self.params
    )

    return TextArtifact(value=response)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    prompt = self.prompt_stack_to_string(prompt_stack)

    response = self.client.text_generation(
        prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), stream=True, **self.params
    )

    for token in response:
        yield TextArtifact(value=token)

HuggingFacePipelinePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
params dict

Custom model run parameters.

model str

Hugging Face Hub model name.

tokenizer HuggingFaceTokenizer

Custom HuggingFaceTokenizer.

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@define
class HuggingFacePipelinePromptDriver(BasePromptDriver):
    """
    Attributes:
        params: Custom model run parameters.
        model: Hugging Face Hub model name.
        tokenizer: Custom `HuggingFaceTokenizer`.

    """

    SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
    DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1}

    model: str = field(kw_only=True, metadata={"serializable": True})
    params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
                max_output_tokens=self.max_tokens,
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        prompt = self.prompt_stack_to_string(prompt_stack)
        pipeline = import_optional_dependency("transformers").pipeline

        generator = pipeline(
            tokenizer=self.tokenizer.tokenizer,
            model=self.model,
            max_new_tokens=self.tokenizer.count_output_tokens_left(prompt),
        )

        if generator.task in self.SUPPORTED_TASKS:
            extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id}

            response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params))

            if len(response) == 1:
                return TextArtifact(value=response[0]["generated_text"].strip())
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

DEFAULT_PARAMS = {'return_full_text': False, 'num_return_sequences': 1} class-attribute instance-attribute

SUPPORTED_TASKS = ['text2text-generation', 'text-generation'] class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

params: dict = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(tokenizer=import_optional_dependency('transformers').AutoTokenizer.from_pretrained(self.model), max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    prompt = self.prompt_stack_to_string(prompt_stack)
    pipeline = import_optional_dependency("transformers").pipeline

    generator = pipeline(
        tokenizer=self.tokenizer.tokenizer,
        model=self.model,
        max_new_tokens=self.tokenizer.count_output_tokens_left(prompt),
    )

    if generator.task in self.SUPPORTED_TASKS:
        extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id}

        response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params))

        if len(response) == 1:
            return TextArtifact(value=response[0]["generated_text"].strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

LeonardoImageGenerationDriver

Bases: BaseImageGenerationDriver

Driver for the Leonardo image generation API.

Details on Leonardo image generation parameters can be found here: https://docs.leonardo.ai/reference/creategeneration

Attributes:

Name Type Description
model

The ID of the model to use when generating images.

api_key str

The API key to use when making requests to the Leonardo API.

requests_session Session

The requests session to use when making requests to the Leonardo API.

api_base str

The base URL of the Leonardo API.

max_attempts int

The maximum number of times to poll the Leonardo API for a completed image.

image_width int

The width of the generated image in the range [32, 1024] and divisible by 8.

image_height int

The height of the generated image in the range [32, 1024] and divisible by 8.

steps Optional[int]

Optionally specify the number of inference steps to run for each image generation request, [30, 60].

seed Optional[int]

Optionally provide a consistent seed to generation requests, increasing consistency in output.

init_strength Optional[float]

Optionally specify the strength of the initial image, [0.0, 1.0].

Source code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
@define
class LeonardoImageGenerationDriver(BaseImageGenerationDriver):
    """Driver for the Leonardo image generation API.

    Details on Leonardo image generation parameters can be found here:
    https://docs.leonardo.ai/reference/creategeneration

    Attributes:
        model: The ID of the model to use when generating images.
        api_key: The API key to use when making requests to the Leonardo API.
        requests_session: The requests session to use when making requests to the Leonardo API.
        api_base: The base URL of the Leonardo API.
        max_attempts: The maximum number of times to poll the Leonardo API for a completed image.
        image_width: The width of the generated image in the range [32, 1024] and divisible by 8.
        image_height: The height of the generated image in the range [32, 1024] and divisible by 8.
        steps: Optionally specify the number of inference steps to run for each image generation request, [30, 60].
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
        init_strength: Optionally specify the strength of the initial image, [0.0, 1.0].
    """

    api_key: str = field(kw_only=True, metadata={"serializable": True})
    requests_session: requests.Session = field(default=Factory(lambda: requests.Session()), kw_only=True)
    api_base: str = "https://cloud.leonardo.ai/api/rest/v1"
    max_attempts: int = field(default=10, kw_only=True, metadata={"serializable": True})
    image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
    image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
    steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    init_strength: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
    control_net: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    control_net_type: Optional[Literal["POSE", "CANNY", "DEPTH"]] = field(
        default=None, kw_only=True, metadata={"serializable": True}
    )

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        if negative_prompts is None:
            negative_prompts = []

        generation_id = self._create_generation(prompts=prompts, negative_prompts=negative_prompts)
        image_url = self._get_image_url(generation_id=generation_id)
        image_data = self._download_image(url=image_url)

        return ImageArtifact(
            value=image_data,
            format="png",
            width=self.image_width,
            height=self.image_height,
            model=self.model,
            prompt=", ".join(prompts),
        )

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        if negative_prompts is None:
            negative_prompts = []

        init_image_id = self._upload_init_image(image)
        generation_id = self._create_generation(
            prompts=prompts, negative_prompts=negative_prompts, init_image_id=init_image_id
        )
        image_url = self._get_image_url(generation_id=generation_id)
        image_data = self._download_image(url=image_url)

        return ImageArtifact(
            value=image_data,
            format="png",
            width=self.image_width,
            height=self.image_height,
            model=self.model,
            prompt=", ".join(prompts),
        )

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")

    def _upload_init_image(self, image: ImageArtifact) -> str:
        request = {"extension": image.mime_type.split("/")[1]}

        prep_response = self._make_api_request("/init-image", request=request)
        if prep_response is None or prep_response["uploadInitImage"] is None:
            raise Exception(f"failed to prepare init image: {prep_response}")

        fields = json.loads(prep_response["uploadInitImage"]["fields"])
        pre_signed_url = prep_response["uploadInitImage"]["url"]
        init_image_id = prep_response["uploadInitImage"]["id"]

        files = {"file": image.value}
        upload_response = requests.post(pre_signed_url, data=fields, files=files)
        if not upload_response.ok:
            raise Exception(f"failed to upload init image: {upload_response.text}")

        return init_image_id

    def _create_generation(
        self, prompts: list[str], negative_prompts: list[str], init_image_id: Optional[str] = None
    ) -> str:
        prompt = ", ".join(prompts)
        negative_prompt = ", ".join(negative_prompts)
        request = {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "width": self.image_width,
            "height": self.image_height,
            "num_images": 1,
            "modelId": self.model,
        }

        if init_image_id is not None:
            request["init_image_id"] = init_image_id

        if self.init_strength is not None:
            request["init_strength"] = self.init_strength

        if self.steps:
            request["num_inference_steps"] = self.steps

        if self.seed is not None:
            request["seed"] = self.seed

        if self.control_net:
            request["controlNet"] = self.control_net
            request["controlNetType"] = self.control_net_type

        response = self._make_api_request("/generations", request=request)
        if response is None or response["sdGenerationJob"] is None:
            raise Exception(f"failed to create generation: {response}")

        return response["sdGenerationJob"]["generationId"]

    def _make_api_request(self, endpoint: str, request: dict, method: str = "POST") -> dict:
        url = f"{self.api_base}{endpoint}"
        headers = {"Authorization": f"Bearer {self.api_key}"}

        response = self.requests_session.request(url=url, method=method, json=request, headers=headers)
        if not response.ok:
            raise Exception(f"failed to make API request: {response.text}")

        return response.json()

    def _get_image_url(self, generation_id: str) -> str:
        for attempt in range(self.max_attempts):
            response = self.requests_session.get(
                url=f"{self.api_base}/generations/{generation_id}", headers={"Authorization": f"Bearer {self.api_key}"}
            ).json()

            if response["generations_by_pk"]["status"] == "PENDING":
                time.sleep(attempt + 1)
                continue

            return response["generations_by_pk"]["generated_images"][0]["url"]
        else:
            raise Exception("image generation failed to complete")

    def _download_image(self, url: str) -> bytes:
        response = self.requests_session.get(url=url, headers={"Authorization": f"Bearer {self.api_key}"})

        return response.content

api_base: str = 'https://cloud.leonardo.ai/api/rest/v1' class-attribute instance-attribute

api_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

control_net: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

control_net_type: Optional[Literal['POSE', 'CANNY', 'DEPTH']] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

image_height: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

image_width: int = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

init_strength: Optional[float] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

max_attempts: int = field(default=10, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

requests_session: requests.Session = field(default=Factory(lambda: requests.Session()), kw_only=True) class-attribute instance-attribute

seed: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

steps: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    if negative_prompts is None:
        negative_prompts = []

    init_image_id = self._upload_init_image(image)
    generation_id = self._create_generation(
        prompts=prompts, negative_prompts=negative_prompts, init_image_id=init_image_id
    )
    image_url = self._get_image_url(generation_id=generation_id)
    image_data = self._download_image(url=image_url)

    return ImageArtifact(
        value=image_data,
        format="png",
        width=self.image_width,
        height=self.image_height,
        model=self.model,
        prompt=", ".join(prompts),
    )

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    if negative_prompts is None:
        negative_prompts = []

    generation_id = self._create_generation(prompts=prompts, negative_prompts=negative_prompts)
    image_url = self._get_image_url(generation_id=generation_id)
    image_data = self._download_image(url=image_url)

    return ImageArtifact(
        value=image_data,
        format="png",
        width=self.image_width,
        height=self.image_height,
        model=self.model,
        prompt=", ".join(prompts),
    )

LocalConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
@define
class LocalConversationMemoryDriver(BaseConversationMemoryDriver):
    file_path: str = field(default="griptape_memory.json", kw_only=True, metadata={"serializable": True})

    def store(self, memory: BaseConversationMemory) -> None:
        with open(self.file_path, "w") as file:
            file.write(memory.to_json())

    def load(self) -> Optional[BaseConversationMemory]:
        if not os.path.exists(self.file_path):
            return None
        with open(self.file_path) as file:
            memory = BaseConversationMemory.from_json(file.read())

            memory.driver = self

            return memory

file_path: str = field(default='griptape_memory.json', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

load()

Source code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
def load(self) -> Optional[BaseConversationMemory]:
    if not os.path.exists(self.file_path):
        return None
    with open(self.file_path) as file:
        memory = BaseConversationMemory.from_json(file.read())

        memory.driver = self

        return memory

store(memory)

Source code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
def store(self, memory: BaseConversationMemory) -> None:
    with open(self.file_path, "w") as file:
        file.write(memory.to_json())

LocalFileManagerDriver

Bases: BaseFileManagerDriver

LocalFileManagerDriver can be used to list, load, and save files on the local file system.

Attributes:

Name Type Description
workdir str

The absolute working directory. List, load, and save operations will be performed relative to this directory.

Source code in griptape/drivers/file_manager/local_file_manager_driver.py
@define
class LocalFileManagerDriver(BaseFileManagerDriver):
    """
    LocalFileManagerDriver can be used to list, load, and save files on the local file system.

    Attributes:
        workdir: The absolute working directory. List, load, and save operations will be performed relative to this directory.
    """

    workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True)

    @workdir.validator  # pyright: ignore
    def validate_workdir(self, _, workdir: str) -> None:
        if not Path(workdir).is_absolute():
            raise ValueError("Workdir must be an absolute path")

    def try_list_files(self, path: str) -> list[str]:
        full_path = self._full_path(path)
        return os.listdir(full_path)

    def try_load_file(self, path: str) -> bytes:
        full_path = self._full_path(path)
        if self._is_dir(full_path):
            raise IsADirectoryError
        with open(full_path, "rb") as file:
            return file.read()

    def try_save_file(self, path: str, value: bytes):
        full_path = self._full_path(path)
        if self._is_dir(full_path):
            raise IsADirectoryError
        os.makedirs(os.path.dirname(full_path), exist_ok=True)
        with open(full_path, "wb") as file:
            file.write(value)

    def _full_path(self, path: str) -> str:
        path = path.lstrip("/")
        full_path = os.path.join(self.workdir, path)
        # Need to keep the trailing slash if it was there,
        # because it means the path is a directory.
        ended_with_slash = path.endswith("/")
        full_path = os.path.normpath(full_path)
        if ended_with_slash:
            full_path = full_path.rstrip("/") + "/"
        return full_path

    def _is_dir(self, full_path: str) -> bool:
        return full_path.endswith("/") or Path(full_path).is_dir()

workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True) class-attribute instance-attribute

try_list_files(path)

Source code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]:
    full_path = self._full_path(path)
    return os.listdir(full_path)

try_load_file(path)

Source code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_load_file(self, path: str) -> bytes:
    full_path = self._full_path(path)
    if self._is_dir(full_path):
        raise IsADirectoryError
    with open(full_path, "rb") as file:
        return file.read()

try_save_file(path, value)

Source code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_save_file(self, path: str, value: bytes):
    full_path = self._full_path(path)
    if self._is_dir(full_path):
        raise IsADirectoryError
    os.makedirs(os.path.dirname(full_path), exist_ok=True)
    with open(full_path, "wb") as file:
        file.write(value)

validate_workdir(_, workdir)

Source code in griptape/drivers/file_manager/local_file_manager_driver.py
@workdir.validator  # pyright: ignore
def validate_workdir(self, _, workdir: str) -> None:
    if not Path(workdir).is_absolute():
        raise ValueError("Workdir must be an absolute path")

LocalStructureRunDriver

Bases: BaseStructureRunDriver

Source code in griptape/drivers/structure_run/local_structure_run_driver.py
@define
class LocalStructureRunDriver(BaseStructureRunDriver):
    structure_factory_fn: Callable[[], Structure] = field(kw_only=True)

    def try_run(self, *args: BaseArtifact) -> BaseArtifact:
        structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])

        if structure_factory_fn.output_task.output is not None:
            return structure_factory_fn.output_task.output
        else:
            return InfoArtifact("No output found in response")

structure_factory_fn: Callable[[], Structure] = field(kw_only=True) class-attribute instance-attribute

try_run(*args)

Source code in griptape/drivers/structure_run/local_structure_run_driver.py
def try_run(self, *args: BaseArtifact) -> BaseArtifact:
    structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])

    if structure_factory_fn.output_task.output is not None:
        return structure_factory_fn.output_task.output
    else:
        return InfoArtifact("No output found in response")

LocalVectorStoreDriver

Bases: BaseVectorStoreDriver

Source code in griptape/drivers/vector/local_vector_store_driver.py
@define
class LocalVectorStoreDriver(BaseVectorStoreDriver):
    entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict, kw_only=True)
    relatedness_fn: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)), kw_only=True)

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        vector_id = vector_id if vector_id else utils.str_to_hash(str(vector))

        self.entries[self._namespaced_vector_id(vector_id, namespace)] = self.Entry(
            id=vector_id, vector=vector, meta=meta, namespace=namespace
        )

        return vector_id

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        return self.entries.get(self._namespaced_vector_id(vector_id, namespace), None)

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        query_embedding = self.embedding_driver.embed_string(query)

        if namespace:
            entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
        else:
            entries = self.entries

        entries_and_relatednesses = [
            (entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in entries.values()
        ]
        entries_and_relatednesses.sort(key=lambda x: x[1], reverse=True)

        result = [
            BaseVectorStoreDriver.QueryResult(id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta)
            for er in entries_and_relatednesses
        ][:count]

        if include_vectors:
            return result
        else:
            return [
                BaseVectorStoreDriver.QueryResult(id=r.id, vector=[], score=r.score, meta=r.meta, namespace=r.namespace)
                for r in result
            ]

    def _namespaced_vector_id(self, vector_id: str, namespace: Optional[str]):
        return vector_id if namespace is None else f"{namespace}-{vector_id}"

    def delete_vector(self, vector_id: str):
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict, kw_only=True) class-attribute instance-attribute

relatedness_fn: Callable = field(default=lambda x, y: dot(x, y) / norm(x) * norm(y), kw_only=True) class-attribute instance-attribute

delete_vector(vector_id)

Source code in griptape/drivers/vector/local_vector_store_driver.py
def delete_vector(self, vector_id: str):
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(namespace=None)

Source code in griptape/drivers/vector/local_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

load_entry(vector_id, namespace=None)

Source code in griptape/drivers/vector/local_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    return self.entries.get(self._namespaced_vector_id(vector_id, namespace), None)

query(query, count=None, namespace=None, include_vectors=False, **kwargs)

Source code in griptape/drivers/vector/local_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    query_embedding = self.embedding_driver.embed_string(query)

    if namespace:
        entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
    else:
        entries = self.entries

    entries_and_relatednesses = [
        (entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in entries.values()
    ]
    entries_and_relatednesses.sort(key=lambda x: x[1], reverse=True)

    result = [
        BaseVectorStoreDriver.QueryResult(id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta)
        for er in entries_and_relatednesses
    ][:count]

    if include_vectors:
        return result
    else:
        return [
            BaseVectorStoreDriver.QueryResult(id=r.id, vector=[], score=r.score, meta=r.meta, namespace=r.namespace)
            for r in result
        ]

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source code in griptape/drivers/vector/local_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    vector_id = vector_id if vector_id else utils.str_to_hash(str(vector))

    self.entries[self._namespaced_vector_id(vector_id, namespace)] = self.Entry(
        id=vector_id, vector=vector, meta=meta, namespace=namespace
    )

    return vector_id

MarkdownifyWebScraperDriver

Bases: BaseWebScraperDriver

Driver to scrape a webpage and return the content in markdown format.

As a prerequisite to using MarkdownifyWebScraperDriver, you need to install the browsers used by playwright. You can do this by running: poetry run playwright install. For more details about playwright, see https://playwright.dev/python/docs/library.

Attributes:

Name Type Description
include_links bool

If True, the driver will include link urls in the markdown output.

exclude_tags list[str]

Optionally provide custom tags to exclude from the scraped content.

exclude_classes list[str]

Optionally provide custom classes to exclude from the scraped content.

exclude_ids list[str]

Optionally provide custom ids to exclude from the scraped content.

timeout Optional[int]

Optionally provide a timeout in milliseconds for the page to continue loading after the browser has emitted the "load" event.

Source code in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
@define
class MarkdownifyWebScraperDriver(BaseWebScraperDriver):
    """Driver to scrape a webpage and return the content in markdown format.

    As a prerequisite to using MarkdownifyWebScraperDriver, you need to install the browsers used by
    playwright. You can do this by running: `poetry run playwright install`.
    For more details about playwright, see https://playwright.dev/python/docs/library.

    Attributes:
        include_links: If `True`, the driver will include link urls in the markdown output.
        exclude_tags: Optionally provide custom tags to exclude from the scraped content.
        exclude_classes: Optionally provide custom classes to exclude from the scraped content.
        exclude_ids: Optionally provide custom ids to exclude from the scraped content.
        timeout: Optionally provide a timeout in milliseconds for the page to continue loading after
            the browser has emitted the "load" event.
    """

    DEFAULT_EXCLUDE_TAGS = ["script", "style", "head"]

    include_links: bool = field(default=True, kw_only=True)
    exclude_tags: list[str] = field(
        default=Factory(lambda self: self.DEFAULT_EXCLUDE_TAGS, takes_self=True), kw_only=True
    )
    exclude_classes: list[str] = field(default=Factory(list), kw_only=True)
    exclude_ids: list[str] = field(default=Factory(list), kw_only=True)
    timeout: Optional[int] = field(default=None, kw_only=True)

    def scrape_url(self, url: str) -> TextArtifact:
        sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright
        BeautifulSoup = import_optional_dependency("bs4").BeautifulSoup
        MarkdownConverter = import_optional_dependency("markdownify").MarkdownConverter

        include_links = self.include_links

        # Custom MarkdownConverter to optionally linked urls. If include_links is False only
        # the text of the link is returned.
        class OptionalLinksMarkdownConverter(MarkdownConverter):
            def convert_a(self, el, text, convert_as_inline):
                if include_links:
                    return super().convert_a(el, text, convert_as_inline)
                return text

        with sync_playwright() as p:
            with p.chromium.launch(headless=True) as browser:
                page = browser.new_page()

                def skip_loading_images(route):
                    if route.request.resource_type == "image":
                        return route.abort()
                    route.continue_()

                page.route("**/*", skip_loading_images)

                page.goto(url)

                # Some websites require a delay before the content is fully loaded
                # even after the browser has emitted "load" event.
                if self.timeout:
                    page.wait_for_timeout(self.timeout)

                content = page.content()

                if not content:
                    raise Exception("can't access URL")

                soup = BeautifulSoup(content, "html.parser")

                # Remove unwanted elements
                exclude_selector = ",".join(
                    self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids]
                )
                if exclude_selector:
                    for s in soup.select(exclude_selector):
                        s.extract()

                text = OptionalLinksMarkdownConverter().convert_soup(soup)

                # Remove leading and trailing whitespace from the entire text
                text = text.strip()

                # Remove trailing whitespace from each line
                text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)

                # Indent using 2 spaces instead of tabs
                text = re.sub(r"(\n?\s*?)\t", r"\1  ", text)

                # Remove triple+ newlines (keep double newlines for paragraphs)
                text = re.sub(r"\n\n+", "\n\n", text)

                return TextArtifact(text)

DEFAULT_EXCLUDE_TAGS = ['script', 'style', 'head'] class-attribute instance-attribute

exclude_classes: list[str] = field(default=Factory(list), kw_only=True) class-attribute instance-attribute

exclude_ids: list[str] = field(default=Factory(list), kw_only=True) class-attribute instance-attribute

exclude_tags: list[str] = field(default=Factory(lambda self: self.DEFAULT_EXCLUDE_TAGS, takes_self=True), kw_only=True) class-attribute instance-attribute

timeout: Optional[int] = field(default=None, kw_only=True) class-attribute instance-attribute

scrape_url(url)

Source code in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
def scrape_url(self, url: str) -> TextArtifact:
    sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright
    BeautifulSoup = import_optional_dependency("bs4").BeautifulSoup
    MarkdownConverter = import_optional_dependency("markdownify").MarkdownConverter

    include_links = self.include_links

    # Custom MarkdownConverter to optionally linked urls. If include_links is False only
    # the text of the link is returned.
    class OptionalLinksMarkdownConverter(MarkdownConverter):
        def convert_a(self, el, text, convert_as_inline):
            if include_links:
                return super().convert_a(el, text, convert_as_inline)
            return text

    with sync_playwright() as p:
        with p.chromium.launch(headless=True) as browser:
            page = browser.new_page()

            def skip_loading_images(route):
                if route.request.resource_type == "image":
                    return route.abort()
                route.continue_()

            page.route("**/*", skip_loading_images)

            page.goto(url)

            # Some websites require a delay before the content is fully loaded
            # even after the browser has emitted "load" event.
            if self.timeout:
                page.wait_for_timeout(self.timeout)

            content = page.content()

            if not content:
                raise Exception("can't access URL")

            soup = BeautifulSoup(content, "html.parser")

            # Remove unwanted elements
            exclude_selector = ",".join(
                self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids]
            )
            if exclude_selector:
                for s in soup.select(exclude_selector):
                    s.extract()

            text = OptionalLinksMarkdownConverter().convert_soup(soup)

            # Remove leading and trailing whitespace from the entire text
            text = text.strip()

            # Remove trailing whitespace from each line
            text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)

            # Indent using 2 spaces instead of tabs
            text = re.sub(r"(\n?\s*?)\t", r"\1  ", text)

            # Remove triple+ newlines (keep double newlines for paragraphs)
            text = re.sub(r"\n\n+", "\n\n", text)

            return TextArtifact(text)

MarqoVectorStoreDriver

Bases: BaseVectorStoreDriver

A Vector Store Driver for Marqo.

Attributes:

Name Type Description
api_key str

The API key for the Marqo API.

url str

The URL to the Marqo API.

mq Optional[Client]

An optional Marqo client. Defaults to a new client with the given URL and API key.

index str

The name of the index to use.

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
@define
class MarqoVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for Marqo.

    Attributes:
        api_key: The API key for the Marqo API.
        url: The URL to the Marqo API.
        mq: An optional Marqo client. Defaults to a new client with the given URL and API key.
        index: The name of the index to use.
    """

    api_key: str = field(kw_only=True, metadata={"serializable": True})
    url: str = field(kw_only=True, metadata={"serializable": True})
    mq: Optional[marqo.Client] = field(
        default=Factory(
            lambda self: import_optional_dependency("marqo").Client(self.url, api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )
    index: str = field(kw_only=True, metadata={"serializable": True})

    def upsert_text(
        self,
        string: str,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Upsert a text document into the Marqo index.

        Args:
            string: The string to be indexed.
            vector_id: The ID for the vector. If None, Marqo will generate an ID.
            namespace: An optional namespace for the document.
            meta: An optional dictionary of metadata for the document.

        Returns:
            str: The ID of the document that was added.
        """

        doc = {"_id": vector_id, "Description": string}  # Description will be treated as tensor field

        # Non-tensor fields
        if meta:
            doc["meta"] = str(meta)
        if namespace:
            doc["namespace"] = namespace

        response = self.mq.index(self.index).add_documents([doc], tensor_fields=["Description"])
        if isinstance(response, dict) and "items" in response and response["items"]:
            return response["items"][0]["_id"]
        else:
            raise ValueError(f"Failed to upsert text: {response}")

    def upsert_text_artifact(
        self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs
    ) -> str:
        """Upsert a text artifact into the Marqo index.

        Args:
            artifact: The text artifact to be indexed.
            namespace: An optional namespace for the artifact.
            meta: An optional dictionary of metadata for the artifact.

        Returns:
            str: The ID of the artifact that was added.
        """

        artifact_json = artifact.to_json()

        doc = {
            "_id": artifact.id,
            "Description": artifact.value,  # Description will be treated as tensor field
            "artifact": str(artifact_json),
            "namespace": namespace,
        }

        response = self.mq.index(self.index).add_documents([doc], tensor_fields=["Description", "artifact"])
        if isinstance(response, dict) and "items" in response and response["items"]:
            return response["items"][0]["_id"]
        else:
            raise ValueError(f"Failed to upsert text: {response}")

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Load a document entry from the Marqo index.

        Args:
            vector_id: The ID of the vector to load.
            namespace: The namespace of the vector to load.

        Returns:
            The loaded Entry if found, otherwise None.
        """
        result = self.mq.index(self.index).get_document(document_id=vector_id, expose_facets=True)

        if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0:
            return BaseVectorStoreDriver.Entry(
                id=result["_id"],
                meta={k: v for k, v in result.items() if k not in ["_id"]},
                vector=result["_tensor_facets"][0]["_embedding"],
            )
        else:
            return None

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Load all document entries from the Marqo index.

        Args:
            namespace: The namespace to filter entries by.

        Returns:
            The list of loaded Entries.
        """

        filter_string = f"namespace:{namespace}" if namespace else None

        if filter_string is not None:
            results = self.mq.index(self.index).search("", limit=10000, filter_string=filter_string)
        else:
            results = self.mq.index(self.index).search("", limit=10000)

        # get all _id's from search results
        ids = [r["_id"] for r in results["hits"]]

        # get documents corresponding to the ids
        documents = self.mq.index(self.index).get_documents(document_ids=ids, expose_facets=True)

        # for each document, if it's found, create an Entry object
        entries = []
        for doc in documents["results"]:
            if doc["_found"]:
                entries.append(
                    BaseVectorStoreDriver.Entry(
                        id=doc["_id"],
                        vector=doc["_tensor_facets"][0]["_embedding"],
                        meta={k: v for k, v in doc.items() if k not in ["_id", "_tensor_facets", "_found"]},
                        namespace=doc.get("namespace"),
                    )
                )

        return entries

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        include_metadata: bool = True,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Query the Marqo index for documents.

        Args:
            query: The query string.
            count: The maximum number of results to return.
            namespace: The namespace to filter results by.
            include_vectors: Whether to include vector data in the results.
            include_metadata: Whether to include metadata in the results.

        Returns:
            The list of query results.
        """

        params = {
            "limit": count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
            "attributes_to_retrieve": ["*"] if include_metadata else ["_id"],
            "filter_string": f"namespace:{namespace}" if namespace else None,
        } | kwargs

        results = self.mq.index(self.index).search(query, **params)

        if include_vectors:
            results["hits"] = [
                {**r, **self.mq.index(self.index).get_document(r["_id"], expose_facets=True)} for r in results["hits"]
            ]

        return [
            BaseVectorStoreDriver.QueryResult(
                id=r["_id"],
                vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [],
                score=r["_score"],
                meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]},
            )
            for r in results["hits"]
        ]

    def delete_index(self, name: str) -> dict[str, Any]:
        """Delete an index in the Marqo client.

        Args:
            name: The name of the index to delete.
        """

        return self.mq.delete_index(name)

    def get_indexes(self) -> list[str]:
        """Get a list of all indexes in the Marqo client.

        Returns:
            The list of all indexes.
        """

        return [index["index"] for index in self.mq.get_indexes()["results"]]

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Upsert a vector into the Marqo index.

        Args:
            vector: The vector to be indexed.
            vector_id: The ID for the vector. If None, Marqo will generate an ID.
            namespace: An optional namespace for the vector.
            meta: An optional dictionary of metadata for the vector.

        Raises:
            Exception: This function is not yet implemented.

        Returns:
            The ID of the vector that was added.
        """

        raise NotImplementedError(f"{self.__class__.__name__} does not support upserting a vector.")

    def delete_vector(self, vector_id: str):
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

api_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

index: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

mq: Optional[marqo.Client] = field(default=Factory(lambda self: import_optional_dependency('marqo').Client(self.url, api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

url: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

delete_index(name)

Delete an index in the Marqo client.

Parameters:

Name Type Description Default
name str

The name of the index to delete.

required
Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def delete_index(self, name: str) -> dict[str, Any]:
    """Delete an index in the Marqo client.

    Args:
        name: The name of the index to delete.
    """

    return self.mq.delete_index(name)

delete_vector(vector_id)

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def delete_vector(self, vector_id: str):
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

get_indexes()

Get a list of all indexes in the Marqo client.

Returns:

Type Description
list[str]

The list of all indexes.

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def get_indexes(self) -> list[str]:
    """Get a list of all indexes in the Marqo client.

    Returns:
        The list of all indexes.
    """

    return [index["index"] for index in self.mq.get_indexes()["results"]]

load_entries(namespace=None)

Load all document entries from the Marqo index.

Parameters:

Name Type Description Default
namespace Optional[str]

The namespace to filter entries by.

None

Returns:

Type Description
list[Entry]

The list of loaded Entries.

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Load all document entries from the Marqo index.

    Args:
        namespace: The namespace to filter entries by.

    Returns:
        The list of loaded Entries.
    """

    filter_string = f"namespace:{namespace}" if namespace else None

    if filter_string is not None:
        results = self.mq.index(self.index).search("", limit=10000, filter_string=filter_string)
    else:
        results = self.mq.index(self.index).search("", limit=10000)

    # get all _id's from search results
    ids = [r["_id"] for r in results["hits"]]

    # get documents corresponding to the ids
    documents = self.mq.index(self.index).get_documents(document_ids=ids, expose_facets=True)

    # for each document, if it's found, create an Entry object
    entries = []
    for doc in documents["results"]:
        if doc["_found"]:
            entries.append(
                BaseVectorStoreDriver.Entry(
                    id=doc["_id"],
                    vector=doc["_tensor_facets"][0]["_embedding"],
                    meta={k: v for k, v in doc.items() if k not in ["_id", "_tensor_facets", "_found"]},
                    namespace=doc.get("namespace"),
                )
            )

    return entries

load_entry(vector_id, namespace=None)

Load a document entry from the Marqo index.

Parameters:

Name Type Description Default
vector_id str

The ID of the vector to load.

required
namespace Optional[str]

The namespace of the vector to load.

None

Returns:

Type Description
Optional[Entry]

The loaded Entry if found, otherwise None.

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Load a document entry from the Marqo index.

    Args:
        vector_id: The ID of the vector to load.
        namespace: The namespace of the vector to load.

    Returns:
        The loaded Entry if found, otherwise None.
    """
    result = self.mq.index(self.index).get_document(document_id=vector_id, expose_facets=True)

    if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0:
        return BaseVectorStoreDriver.Entry(
            id=result["_id"],
            meta={k: v for k, v in result.items() if k not in ["_id"]},
            vector=result["_tensor_facets"][0]["_embedding"],
        )
    else:
        return None

query(query, count=None, namespace=None, include_vectors=False, include_metadata=True, **kwargs)

Query the Marqo index for documents.

Parameters:

Name Type Description Default
query str

The query string.

required
count Optional[int]

The maximum number of results to return.

None
namespace Optional[str]

The namespace to filter results by.

None
include_vectors bool

Whether to include vector data in the results.

False
include_metadata bool

Whether to include metadata in the results.

True

Returns:

Type Description
list[QueryResult]

The list of query results.

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    include_metadata: bool = True,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Query the Marqo index for documents.

    Args:
        query: The query string.
        count: The maximum number of results to return.
        namespace: The namespace to filter results by.
        include_vectors: Whether to include vector data in the results.
        include_metadata: Whether to include metadata in the results.

    Returns:
        The list of query results.
    """

    params = {
        "limit": count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        "attributes_to_retrieve": ["*"] if include_metadata else ["_id"],
        "filter_string": f"namespace:{namespace}" if namespace else None,
    } | kwargs

    results = self.mq.index(self.index).search(query, **params)

    if include_vectors:
        results["hits"] = [
            {**r, **self.mq.index(self.index).get_document(r["_id"], expose_facets=True)} for r in results["hits"]
        ]

    return [
        BaseVectorStoreDriver.QueryResult(
            id=r["_id"],
            vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [],
            score=r["_score"],
            meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]},
        )
        for r in results["hits"]
    ]

upsert_text(string, vector_id=None, namespace=None, meta=None, **kwargs)

Upsert a text document into the Marqo index.

Parameters:

Name Type Description Default
string str

The string to be indexed.

required
vector_id Optional[str]

The ID for the vector. If None, Marqo will generate an ID.

None
namespace Optional[str]

An optional namespace for the document.

None
meta Optional[dict]

An optional dictionary of metadata for the document.

None

Returns:

Name Type Description
str str

The ID of the document that was added.

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def upsert_text(
    self,
    string: str,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Upsert a text document into the Marqo index.

    Args:
        string: The string to be indexed.
        vector_id: The ID for the vector. If None, Marqo will generate an ID.
        namespace: An optional namespace for the document.
        meta: An optional dictionary of metadata for the document.

    Returns:
        str: The ID of the document that was added.
    """

    doc = {"_id": vector_id, "Description": string}  # Description will be treated as tensor field

    # Non-tensor fields
    if meta:
        doc["meta"] = str(meta)
    if namespace:
        doc["namespace"] = namespace

    response = self.mq.index(self.index).add_documents([doc], tensor_fields=["Description"])
    if isinstance(response, dict) and "items" in response and response["items"]:
        return response["items"][0]["_id"]
    else:
        raise ValueError(f"Failed to upsert text: {response}")

upsert_text_artifact(artifact, namespace=None, meta=None, **kwargs)

Upsert a text artifact into the Marqo index.

Parameters:

Name Type Description Default
artifact TextArtifact

The text artifact to be indexed.

required
namespace Optional[str]

An optional namespace for the artifact.

None
meta Optional[dict]

An optional dictionary of metadata for the artifact.

None

Returns:

Name Type Description
str str

The ID of the artifact that was added.

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def upsert_text_artifact(
    self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs
) -> str:
    """Upsert a text artifact into the Marqo index.

    Args:
        artifact: The text artifact to be indexed.
        namespace: An optional namespace for the artifact.
        meta: An optional dictionary of metadata for the artifact.

    Returns:
        str: The ID of the artifact that was added.
    """

    artifact_json = artifact.to_json()

    doc = {
        "_id": artifact.id,
        "Description": artifact.value,  # Description will be treated as tensor field
        "artifact": str(artifact_json),
        "namespace": namespace,
    }

    response = self.mq.index(self.index).add_documents([doc], tensor_fields=["Description", "artifact"])
    if isinstance(response, dict) and "items" in response and response["items"]:
        return response["items"][0]["_id"]
    else:
        raise ValueError(f"Failed to upsert text: {response}")

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Upsert a vector into the Marqo index.

Parameters:

Name Type Description Default
vector list[float]

The vector to be indexed.

required
vector_id Optional[str]

The ID for the vector. If None, Marqo will generate an ID.

None
namespace Optional[str]

An optional namespace for the vector.

None
meta Optional[dict]

An optional dictionary of metadata for the vector.

None

Raises:

Type Description
Exception

This function is not yet implemented.

Returns:

Type Description
str

The ID of the vector that was added.

Source code in griptape/drivers/vector/marqo_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Upsert a vector into the Marqo index.

    Args:
        vector: The vector to be indexed.
        vector_id: The ID for the vector. If None, Marqo will generate an ID.
        namespace: An optional namespace for the vector.
        meta: An optional dictionary of metadata for the vector.

    Raises:
        Exception: This function is not yet implemented.

    Returns:
        The ID of the vector that was added.
    """

    raise NotImplementedError(f"{self.__class__.__name__} does not support upserting a vector.")

MongoDbAtlasVectorStoreDriver

Bases: BaseVectorStoreDriver

A Vector Store Driver for MongoDb Atlas.

Attributes:

Name Type Description
connection_string str

The connection string for the MongoDb Atlas cluster.

database_name str

The name of the database to use.

collection_name str

The name of the collection to use.

index_name str

The name of the index to use.

vector_path str

The path to the vector field in the collection.

client MongoClient

An optional MongoDb client to use. Defaults to a new client using the connection string.

Source code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
@define
class MongoDbAtlasVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for MongoDb Atlas.

    Attributes:
        connection_string: The connection string for the MongoDb Atlas cluster.
        database_name: The name of the database to use.
        collection_name: The name of the collection to use.
        index_name: The name of the index to use.
        vector_path: The path to the vector field in the collection.
        client: An optional MongoDb client to use. Defaults to a new client using the connection string.
    """

    MAX_NUM_CANDIDATES = 10000

    connection_string: str = field(kw_only=True, metadata={"serializable": True})
    database_name: str = field(kw_only=True, metadata={"serializable": True})
    collection_name: str = field(kw_only=True, metadata={"serializable": True})
    index_name: str = field(kw_only=True, metadata={"serializable": True})
    vector_path: str = field(kw_only=True, metadata={"serializable": True})
    num_candidates_multiplier: int = field(
        default=10, kw_only=True, metadata={"serializable": True}
    )  # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#fields
    client: MongoClient = field(
        default=Factory(
            lambda self: import_optional_dependency("pymongo").MongoClient(self.connection_string), takes_self=True
        )
    )

    def get_collection(self) -> Collection:
        """Returns the MongoDB Collection instance for the specified database and collection name."""
        return self.client[self.database_name][self.collection_name]

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in the collection.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        """
        collection = self.get_collection()

        if vector_id is None:
            result = collection.insert_one({self.vector_path: vector, "namespace": namespace, "meta": meta})
            vector_id = str(result.inserted_id)
        else:
            collection.replace_one(
                {"_id": vector_id}, {self.vector_path: vector, "namespace": namespace, "meta": meta}, upsert=True
            )
        return vector_id

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Loads a document entry from the MongoDB collection based on the vector ID.

        Returns:
            The loaded Entry if found; otherwise, None is returned.
        """
        collection = self.get_collection()
        if namespace:
            doc = collection.find_one({"_id": vector_id, "namespace": namespace})
        else:
            doc = collection.find_one({"_id": vector_id})

        if doc is None:
            return doc
        else:
            return BaseVectorStoreDriver.Entry(
                id=str(doc["_id"]), vector=doc[self.vector_path], namespace=doc["namespace"], meta=doc["meta"]
            )

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Loads all document entries from the MongoDB collection.

        Entries can optionally be filtered by namespace.
        """
        collection = self.get_collection()
        if namespace is None:
            cursor = collection.find()
        else:
            cursor = collection.find({"namespace": namespace})

        return [
            BaseVectorStoreDriver.Entry(
                id=str(doc["_id"]), vector=doc[self.vector_path], namespace=doc["namespace"], meta=doc["meta"]
            )
            for doc in cursor
        ]

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        offset: Optional[int] = None,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Queries the MongoDB collection for documents that match the provided query string.

        Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
        """
        collection = self.get_collection()

        # Using the embedding driver to convert the query string into a vector
        vector = self.embedding_driver.embed_string(query)

        count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        offset = offset if offset else 0

        pipeline = [
            {
                "$vectorSearch": {
                    "index": self.index_name,
                    "path": self.vector_path,
                    "queryVector": vector,
                    "numCandidates": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                    "limit": count,
                }
            },
            {
                "$project": {
                    "_id": 1,
                    self.vector_path: 1,
                    "namespace": 1,
                    "meta": 1,
                    "score": {"$meta": "vectorSearchScore"},
                }
            },
        ]

        if namespace:
            pipeline[0]["$vectorSearch"]["filter"] = {"namespace": namespace}

        results = [
            BaseVectorStoreDriver.QueryResult(
                id=str(doc["_id"]),
                vector=doc[self.vector_path] if include_vectors else [],
                score=doc["score"],
                meta=doc["meta"],
                namespace=namespace,
            )
            for doc in collection.aggregate(pipeline)
        ]

        return results

    def delete_vector(self, vector_id: str):
        """Deletes the vector from the collection."""
        collection = self.get_collection()
        collection.delete_one({"_id": vector_id})

MAX_NUM_CANDIDATES = 10000 class-attribute instance-attribute

client: MongoClient = field(default=Factory(lambda self: import_optional_dependency('pymongo').MongoClient(self.connection_string), takes_self=True)) class-attribute instance-attribute

collection_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

connection_string: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

database_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

index_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

num_candidates_multiplier: int = field(default=10, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

vector_path: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

delete_vector(vector_id)

Deletes the vector from the collection.

Source code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def delete_vector(self, vector_id: str):
    """Deletes the vector from the collection."""
    collection = self.get_collection()
    collection.delete_one({"_id": vector_id})

get_collection()

Returns the MongoDB Collection instance for the specified database and collection name.

Source code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def get_collection(self) -> Collection:
    """Returns the MongoDB Collection instance for the specified database and collection name."""
    return self.client[self.database_name][self.collection_name]

load_entries(namespace=None)

Loads all document entries from the MongoDB collection.

Entries can optionally be filtered by namespace.

Source code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Loads all document entries from the MongoDB collection.

    Entries can optionally be filtered by namespace.
    """
    collection = self.get_collection()
    if namespace is None:
        cursor = collection.find()
    else:
        cursor = collection.find({"namespace": namespace})

    return [
        BaseVectorStoreDriver.Entry(
            id=str(doc["_id"]), vector=doc[self.vector_path], namespace=doc["namespace"], meta=doc["meta"]
        )
        for doc in cursor
    ]

load_entry(vector_id, namespace=None)

Loads a document entry from the MongoDB collection based on the vector ID.

Returns:

Type Description
Optional[Entry]

The loaded Entry if found; otherwise, None is returned.

Source code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Loads a document entry from the MongoDB collection based on the vector ID.

    Returns:
        The loaded Entry if found; otherwise, None is returned.
    """
    collection = self.get_collection()
    if namespace:
        doc = collection.find_one({"_id": vector_id, "namespace": namespace})
    else:
        doc = collection.find_one({"_id": vector_id})

    if doc is None:
        return doc
    else:
        return BaseVectorStoreDriver.Entry(
            id=str(doc["_id"]), vector=doc[self.vector_path], namespace=doc["namespace"], meta=doc["meta"]
        )

query(query, count=None, namespace=None, include_vectors=False, offset=None, **kwargs)

Queries the MongoDB collection for documents that match the provided query string.

Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.

Source code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    offset: Optional[int] = None,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Queries the MongoDB collection for documents that match the provided query string.

    Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
    """
    collection = self.get_collection()

    # Using the embedding driver to convert the query string into a vector
    vector = self.embedding_driver.embed_string(query)

    count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    offset = offset if offset else 0

    pipeline = [
        {
            "$vectorSearch": {
                "index": self.index_name,
                "path": self.vector_path,
                "queryVector": vector,
                "numCandidates": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                "limit": count,
            }
        },
        {
            "$project": {
                "_id": 1,
                self.vector_path: 1,
                "namespace": 1,
                "meta": 1,
                "score": {"$meta": "vectorSearchScore"},
            }
        },
    ]

    if namespace:
        pipeline[0]["$vectorSearch"]["filter"] = {"namespace": namespace}

    results = [
        BaseVectorStoreDriver.QueryResult(
            id=str(doc["_id"]),
            vector=doc[self.vector_path] if include_vectors else [],
            score=doc["score"],
            meta=doc["meta"],
            namespace=namespace,
        )
        for doc in collection.aggregate(pipeline)
    ]

    return results

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in the collection.

If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.

Source code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in the collection.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    """
    collection = self.get_collection()

    if vector_id is None:
        result = collection.insert_one({self.vector_path: vector, "namespace": namespace, "meta": meta})
        vector_id = str(result.inserted_id)
    else:
        collection.replace_one(
            {"_id": vector_id}, {self.vector_path: vector, "namespace": namespace, "meta": meta}, upsert=True
        )
    return vector_id

OpenAiChatPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
base_url Optional[str]

An optional OpenAi API URL.

api_key Optional[str]

An optional OpenAi API key. If not provided, the OPENAI_API_KEY environment variable will be used.

organization Optional[str]

An optional OpenAI organization. If not provided, the OPENAI_ORG_ID environment variable will be used.

client OpenAI

An openai.OpenAI client.

model str

An OpenAI model name.

tokenizer BaseTokenizer

An OpenAiTokenizer.

user str

A user id. Can be used to track requests by user.

response_format Optional[Literal['json_object']]

An optional OpenAi Chat Completion response format. Currently only supports json_object which will enable OpenAi's JSON mode.

seed Optional[int]

An optional OpenAi Chat Completion seed.

ignored_exception_types tuple[type[Exception], ...]

An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.

_ratelimit_request_limit Optional[int]

The maximum number of requests allowed in the current rate limit window.

_ratelimit_requests_remaining Optional[int]

The number of requests remaining in the current rate limit window.

_ratelimit_requests_reset_at Optional[datetime]

The time at which the current rate limit window resets.

_ratelimit_token_limit Optional[int]

The maximum number of tokens allowed in the current rate limit window.

_ratelimit_tokens_remaining Optional[int]

The number of tokens remaining in the current rate limit window.

_ratelimit_tokens_reset_at Optional[datetime]

The time at which the current rate limit window resets.

Source code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@define
class OpenAiChatPromptDriver(BasePromptDriver):
    """
    Attributes:
        base_url: An optional OpenAi API URL.
        api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used.
        organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used.
        client: An `openai.OpenAI` client.
        model: An OpenAI model name.
        tokenizer: An `OpenAiTokenizer`.
        user: A user id. Can be used to track requests by user.
        response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
        seed: An optional OpenAi Chat Completion seed.
        ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
        _ratelimit_request_limit: The maximum number of requests allowed in the current rate limit window.
        _ratelimit_requests_remaining: The number of requests remaining in the current rate limit window.
        _ratelimit_requests_reset_at: The time at which the current rate limit window resets.
        _ratelimit_token_limit: The maximum number of tokens allowed in the current rate limit window.
        _ratelimit_tokens_remaining: The number of tokens remaining in the current rate limit window.
        _ratelimit_tokens_reset_at: The time at which the current rate limit window resets.
    """

    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    client: openai.OpenAI = field(
        default=Factory(
            lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
            takes_self=True,
        )
    )
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    user: str = field(default="", kw_only=True, metadata={"serializable": True})
    response_format: Optional[Literal["json_object"]] = field(
        default=None, kw_only=True, metadata={"serializable": True}
    )
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    ignored_exception_types: tuple[type[Exception], ...] = field(
        default=Factory(
            lambda: (
                openai.BadRequestError,
                openai.AuthenticationError,
                openai.PermissionDeniedError,
                openai.NotFoundError,
                openai.ConflictError,
                openai.UnprocessableEntityError,
            )
        ),
        kw_only=True,
    )
    _ratelimit_request_limit: Optional[int] = field(init=False, default=None)
    _ratelimit_requests_remaining: Optional[int] = field(init=False, default=None)
    _ratelimit_requests_reset_at: Optional[datetime] = field(init=False, default=None)
    _ratelimit_token_limit: Optional[int] = field(init=False, default=None)
    _ratelimit_tokens_remaining: Optional[int] = field(init=False, default=None)
    _ratelimit_tokens_reset_at: Optional[datetime] = field(init=False, default=None)

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        result = self.client.chat.completions.with_raw_response.create(**self._base_params(prompt_stack))

        self._extract_ratelimit_metadata(result)

        parsed_result = result.parse()

        if len(parsed_result.choices) == 1:
            return TextArtifact(value=parsed_result.choices[0].message.content.strip())
        else:
            raise Exception("Completion with more than one choice is not supported yet.")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        result = self.client.chat.completions.create(**self._base_params(prompt_stack), stream=True)

        for chunk in result:
            if len(chunk.choices) == 1:
                delta = chunk.choices[0].delta
            else:
                raise Exception("Completion with more than one choice is not supported yet.")

            if delta.content is not None:
                delta_content = delta.content

                yield TextArtifact(value=delta_content)

    def token_count(self, prompt_stack: PromptStack) -> int:
        if isinstance(self.tokenizer, OpenAiTokenizer):
            return self.tokenizer.count_tokens(self._prompt_stack_to_messages(prompt_stack))
        else:
            return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict[str, Any]]:
        return [{"role": self.__to_openai_role(i), "content": i.content} for i in prompt_stack.inputs]

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = {
            "model": self.model,
            "temperature": self.temperature,
            "stop": self.tokenizer.stop_sequences,
            "user": self.user,
            "seed": self.seed,
        }

        if self.response_format == "json_object":
            params["response_format"] = {"type": "json_object"}
            # JSON mode still requires a system input instructing the LLM to output JSON.
            prompt_stack.add_system_input("Provide your response as a valid JSON object.")

        messages = self._prompt_stack_to_messages(prompt_stack)

        if self.max_tokens is not None:
            params["max_tokens"] = self.max_tokens

        params["messages"] = messages

        return params

    def __to_openai_role(self, prompt_input: PromptStack.Input) -> str:
        if prompt_input.is_system():
            return "system"
        elif prompt_input.is_assistant():
            return "assistant"
        else:
            return "user"

    def _extract_ratelimit_metadata(self, response):
        # The OpenAI SDK's requestssession variable is global, so this hook will fire for all API requests.
        # The following headers are not reliably returned in every API call, so we check for the presence of the
        # headers before reading and parsing their values to prevent other SDK users from encountering KeyErrors.
        reset_requests_at = response.headers.get("x-ratelimit-reset-requests")
        if reset_requests_at is not None:
            self._ratelimit_requests_reset_at = dateparser.parse(
                reset_requests_at, settings={"PREFER_DATES_FROM": "future"}
            )

            # The dateparser utility doesn't handle sub-second durations as are sometimes returned by OpenAI's API.
            # If the API returns, for example, "13ms", dateparser.parse() returns None. In this case, we will set
            # the time value to the current time plus a one second buffer.
            if self._ratelimit_requests_reset_at is None:
                self._ratelimit_requests_reset_at = datetime.now() + timedelta(seconds=1)

        reset_tokens_at = response.headers.get("x-ratelimit-reset-tokens")
        if reset_tokens_at is not None:
            self._ratelimit_tokens_reset_at = dateparser.parse(
                reset_tokens_at, settings={"PREFER_DATES_FROM": "future"}
            )

            if self._ratelimit_tokens_reset_at is None:
                self._ratelimit_tokens_reset_at = datetime.now() + timedelta(seconds=1)

        self._ratelimit_request_limit = response.headers.get("x-ratelimit-limit-requests")
        self._ratelimit_requests_remaining = response.headers.get("x-ratelimit-remaining-requests")
        self._ratelimit_token_limit = response.headers.get("x-ratelimit-limit-tokens")
        self._ratelimit_tokens_remaining = response.headers.get("x-ratelimit-remaining-tokens")

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), takes_self=True)) class-attribute instance-attribute

ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (openai.BadRequestError, openai.AuthenticationError, openai.PermissionDeniedError, openai.NotFoundError, openai.ConflictError, openai.UnprocessableEntityError)), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

organization: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

response_format: Optional[Literal['json_object']] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

seed: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

user: str = field(default='', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_openai_role(prompt_input)

Source code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_role(self, prompt_input: PromptStack.Input) -> str:
    if prompt_input.is_system():
        return "system"
    elif prompt_input.is_assistant():
        return "assistant"
    else:
        return "user"

token_count(prompt_stack)

Source code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def token_count(self, prompt_stack: PromptStack) -> int:
    if isinstance(self.tokenizer, OpenAiTokenizer):
        return self.tokenizer.count_tokens(self._prompt_stack_to_messages(prompt_stack))
    else:
        return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

try_run(prompt_stack)

Source code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    result = self.client.chat.completions.with_raw_response.create(**self._base_params(prompt_stack))

    self._extract_ratelimit_metadata(result)

    parsed_result = result.parse()

    if len(parsed_result.choices) == 1:
        return TextArtifact(value=parsed_result.choices[0].message.content.strip())
    else:
        raise Exception("Completion with more than one choice is not supported yet.")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    result = self.client.chat.completions.create(**self._base_params(prompt_stack), stream=True)

    for chunk in result:
        if len(chunk.choices) == 1:
            delta = chunk.choices[0].delta
        else:
            raise Exception("Completion with more than one choice is not supported yet.")

        if delta.content is not None:
            delta_content = delta.content

            yield TextArtifact(value=delta_content)

OpenAiCompletionPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
base_url Optional[str]

An optional OpenAi API URL.

api_key Optional[str]

An optional OpenAi API key. If not provided, the OPENAI_API_KEY environment variable will be used.

organization Optional[str]

An optional OpenAI organization. If not provided, the OPENAI_ORG_ID environment variable will be used.

client OpenAI

An openai.OpenAI client.

model str

An OpenAI model name.

tokenizer OpenAiTokenizer

An OpenAiTokenizer.

user str

A user id. Can be used to track requests by user.

ignored_exception_types tuple[type[Exception], ...]

An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.

Source code in griptape/drivers/prompt/openai_completion_prompt_driver.py
@define
class OpenAiCompletionPromptDriver(BasePromptDriver):
    """
    Attributes:
        base_url: An optional OpenAi API URL.
        api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used.
        organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used.
        client: An `openai.OpenAI` client.
        model: An OpenAI model name.
        tokenizer: An `OpenAiTokenizer`.
        user: A user id. Can be used to track requests by user.
        ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
    """

    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    client: openai.OpenAI = field(
        default=Factory(
            lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
            takes_self=True,
        )
    )
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    user: str = field(default="", kw_only=True, metadata={"serializable": True})
    ignored_exception_types: tuple[type[Exception], ...] = field(
        default=Factory(
            lambda: (
                openai.BadRequestError,
                openai.AuthenticationError,
                openai.PermissionDeniedError,
                openai.NotFoundError,
                openai.ConflictError,
                openai.UnprocessableEntityError,
            )
        ),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        result = self.client.completions.create(**self._base_params(prompt_stack))

        if len(result.choices) == 1:
            return TextArtifact(value=result.choices[0].text.strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        result = self.client.completions.create(**self._base_params(prompt_stack), stream=True)

        for chunk in result:
            if len(chunk.choices) == 1:
                choice = chunk.choices[0]
                delta_content = choice.text
                yield TextArtifact(value=delta_content)

            else:
                raise Exception("completion with more than one choice is not supported yet")

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_string(prompt_stack)

        return {
            "model": self.model,
            "max_tokens": self.max_output_tokens(prompt),
            "temperature": self.temperature,
            "stop": self.tokenizer.stop_sequences,
            "user": self.user,
            "prompt": prompt,
        }

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), takes_self=True)) class-attribute instance-attribute

ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (openai.BadRequestError, openai.AuthenticationError, openai.PermissionDeniedError, openai.NotFoundError, openai.ConflictError, openai.UnprocessableEntityError)), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

organization: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

user: str = field(default='', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/openai_completion_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    result = self.client.completions.create(**self._base_params(prompt_stack))

    if len(result.choices) == 1:
        return TextArtifact(value=result.choices[0].text.strip())
    else:
        raise Exception("completion with more than one choice is not supported yet")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/openai_completion_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    result = self.client.completions.create(**self._base_params(prompt_stack), stream=True)

    for chunk in result:
        if len(chunk.choices) == 1:
            choice = chunk.choices[0]
            delta_content = choice.text
            yield TextArtifact(value=delta_content)

        else:
            raise Exception("completion with more than one choice is not supported yet")

OpenAiEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

OpenAI embedding model name. Defaults to text-embedding-3-small.

base_url Optional[str]

API URL. Defaults to OpenAI's v1 API URL.

api_key Optional[str]

API key to pass directly. Defaults to OPENAI_API_KEY environment variable.

organization Optional[str]

OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.

tokenizer OpenAiTokenizer

Optionally provide custom OpenAiTokenizer.

client OpenAI

Optionally provide custom openai.OpenAI client.

azure_deployment OpenAI

An Azure OpenAi deployment id.

azure_endpoint OpenAI

An Azure OpenAi endpoint.

azure_ad_token OpenAI

An optional Azure Active Directory token.

azure_ad_token_provider OpenAI

An optional Azure Active Directory token provider.

api_version OpenAI

An Azure OpenAi API version.

Source code in griptape/drivers/embedding/openai_embedding_driver.py
@define
class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: OpenAI embedding model name. Defaults to `text-embedding-3-small`.
        base_url: API URL. Defaults to OpenAI's v1 API URL.
        api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
        organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
        tokenizer: Optionally provide custom `OpenAiTokenizer`.
        client: Optionally provide custom `openai.OpenAI` client.
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
    """

    DEFAULT_MODEL = "text-embedding-3-small"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    client: openai.OpenAI = field(
        default=Factory(
            lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
            takes_self=True,
        )
    )
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        # Address a performance issue in older ada models
        # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
        if self.model.endswith("001"):
            chunk = chunk.replace("\n", " ")
        return self.client.embeddings.create(**self._params(chunk)).data[0].embedding

    def _params(self, chunk: str) -> dict:
        return {"input": chunk, "model": self.model}

DEFAULT_MODEL = 'text-embedding-3-small' class-attribute instance-attribute

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), takes_self=True)) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

organization: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/openai_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    # Address a performance issue in older ada models
    # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
    if self.model.endswith("001"):
        chunk = chunk.replace("\n", " ")
    return self.client.embeddings.create(**self._params(chunk)).data[0].embedding

OpenAiImageGenerationDriver

Bases: BaseImageGenerationDriver

Driver for the OpenAI image generation API.

Attributes:

Name Type Description
model

OpenAI model, for example 'dall-e-2' or 'dall-e-3'.

api_type str

OpenAI API type, for example 'open_ai' or 'azure'.

api_version Optional[str]

API version.

base_url Optional[str]

API URL.

api_key Optional[str]

OpenAI API key.

organization Optional[str]

OpenAI organization ID.

style Optional[str]

Optional and only supported for dall-e-3, can be either 'vivid' or 'natural'.

quality Union[Literal['standard'], Literal['hd']]

Optional and only supported for dall-e-3. Accepts 'standard', 'hd'.

image_size Union[Literal['256x256'], Literal['512x512'], Literal['1024x1024'], Literal['1024x1792'], Literal['1792x1024']]

Size of the generated image. Must be one of the following, depending on the requested model: dall-e-2: [256x256, 512x512, 1024x1024] dall-e-3: [1024x1024, 1024x1792, 1792x1024]

response_format Literal['b64_json']

The response format. Currently only supports 'b64_json' which will return a base64 encoded image in a JSON object.

Source code in griptape/drivers/image_generation/openai_image_generation_driver.py
@define
class OpenAiImageGenerationDriver(BaseImageGenerationDriver):
    """Driver for the OpenAI image generation API.

    Attributes:
        model: OpenAI model, for example 'dall-e-2' or 'dall-e-3'.
        api_type: OpenAI API type, for example 'open_ai' or 'azure'.
        api_version: API version.
        base_url: API URL.
        api_key: OpenAI API key.
        organization: OpenAI organization ID.
        style: Optional and only supported for dall-e-3, can be either 'vivid' or 'natural'.
        quality: Optional and only supported for dall-e-3. Accepts 'standard', 'hd'.
        image_size: Size of the generated image. Must be one of the following, depending on the requested model:
            dall-e-2: [256x256, 512x512, 1024x1024]
            dall-e-3: [1024x1024, 1024x1792, 1792x1024]
        response_format: The response format. Currently only supports 'b64_json' which will return
            a base64 encoded image in a JSON object.
    """

    api_type: str = field(default=openai.api_type, kw_only=True)
    api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
    client: openai.OpenAI = field(
        default=Factory(
            lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
            takes_self=True,
        )
    )
    style: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    quality: Union[Literal["standard"], Literal["hd"]] = field(
        default="standard", kw_only=True, metadata={"serializable": True}
    )
    image_size: (
        Union[Literal["256x256"], Literal["512x512"], Literal["1024x1024"], Literal["1024x1792"], Literal["1792x1024"]]
    ) = field(default="1024x1024", kw_only=True, metadata={"serializable": True})
    response_format: Literal["b64_json"] = field(default="b64_json", kw_only=True, metadata={"serializable": True})

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        prompt = ", ".join(prompts)

        additional_params = {}

        if self.style:
            additional_params["style"] = self.style

        if self.quality:
            additional_params["quality"] = self.quality

        response = self.client.images.generate(
            model=self.model,
            prompt=prompt,
            size=self.image_size,
            response_format=self.response_format,
            n=1,
            **additional_params,
        )

        return self._parse_image_response(response, prompt)

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        image_size = self._dall_e_2_filter_image_size("variation")

        response = self.client.images.create_variation(
            image=image.value, n=1, response_format=self.response_format, size=image_size
        )

        return self._parse_image_response(response, "")

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        image_size = self._dall_e_2_filter_image_size("inpainting")

        prompt = ", ".join(prompts)
        response = self.client.images.edit(
            prompt=prompt, image=image.value, mask=mask.value, response_format=self.response_format, size=image_size
        )

        return self._parse_image_response(response, prompt)

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

    def _image_size_to_ints(self, image_size: str) -> list[int]:
        return [int(x) for x in image_size.split("x")]

    def _dall_e_2_filter_image_size(self, method: str) -> Literal["256x256", "512x512", "1024x1024"]:
        if self.model != "dall-e-2":
            raise NotImplementedError(f"{method} only supports dall-e-2")

        if self.image_size not in {"256x256", "512x512", "1024x1024"}:
            raise ValueError(f"support image sizes for {method} are 256x256, 512x512, and 1024x1024")

        return cast(Literal["256x256", "512x512", "1024x1024"], self.image_size)

    def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageArtifact:
        if response.data is None or response.data[0] is None or response.data[0].b64_json is None:
            raise Exception("Failed to generate image")

        image_data = base64.b64decode(response.data[0].b64_json)
        image_dimensions = self._image_size_to_ints(self.image_size)

        return ImageArtifact(
            value=image_data,
            format="png",
            width=image_dimensions[0],
            height=image_dimensions[1],
            model=self.model,
            prompt=prompt,
        )

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

api_type: str = field(default=openai.api_type, kw_only=True) class-attribute instance-attribute

api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), takes_self=True)) class-attribute instance-attribute

image_size: Union[Literal['256x256'], Literal['512x512'], Literal['1024x1024'], Literal['1024x1792'], Literal['1792x1024']] = field(default='1024x1024', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

quality: Union[Literal['standard'], Literal['hd']] = field(default='standard', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

response_format: Literal['b64_json'] = field(default='b64_json', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

style: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    image_size = self._dall_e_2_filter_image_size("inpainting")

    prompt = ", ".join(prompts)
    response = self.client.images.edit(
        prompt=prompt, image=image.value, mask=mask.value, response_format=self.response_format, size=image_size
    )

    return self._parse_image_response(response, prompt)

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    image_size = self._dall_e_2_filter_image_size("variation")

    response = self.client.images.create_variation(
        image=image.value, n=1, response_format=self.response_format, size=image_size
    )

    return self._parse_image_response(response, "")

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    prompt = ", ".join(prompts)

    additional_params = {}

    if self.style:
        additional_params["style"] = self.style

    if self.quality:
        additional_params["quality"] = self.quality

    response = self.client.images.generate(
        model=self.model,
        prompt=prompt,
        size=self.image_size,
        response_format=self.response_format,
        n=1,
        **additional_params,
    )

    return self._parse_image_response(response, prompt)

OpenAiVisionImageQueryDriver

Bases: BaseImageQueryDriver

Source code in griptape/drivers/image_query/openai_vision_image_query_driver.py
@define
class OpenAiVisionImageQueryDriver(BaseImageQueryDriver):
    model: str = field(kw_only=True, metadata={"serializable": True})
    api_type: str = field(default=openai.api_type, kw_only=True)
    api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True)
    organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
    image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True})
    client: openai.OpenAI = field(
        default=Factory(
            lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
            takes_self=True,
        )
    )

    def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
        message_parts: list[ChatCompletionContentPartParam] = [
            ChatCompletionContentPartTextParam(type="text", text=query)
        ]

        for image in images:
            message_parts.append(
                ChatCompletionContentPartImageParam(
                    type="image_url",
                    image_url={"url": f"data:{image.mime_type};base64,{image.base64}", "detail": self.image_quality},
                )
            )

        messages = ChatCompletionUserMessageParam(content=message_parts, role="user")
        params = {"model": self.model, "messages": [messages], "max_tokens": self.max_tokens}

        response = self.client.chat.completions.create(**params)

        if len(response.choices) != 1:
            raise Exception("Image query responses with more than one choice are not supported yet.")

        return TextArtifact(response.choices[0].message.content)

api_key: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

api_type: str = field(default=openai.api_type, kw_only=True) class-attribute instance-attribute

api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), takes_self=True)) class-attribute instance-attribute

image_quality: Literal['auto', 'low', 'high'] = field(default='auto', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_query(query, images)

Source code in griptape/drivers/image_query/openai_vision_image_query_driver.py
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
    message_parts: list[ChatCompletionContentPartParam] = [
        ChatCompletionContentPartTextParam(type="text", text=query)
    ]

    for image in images:
        message_parts.append(
            ChatCompletionContentPartImageParam(
                type="image_url",
                image_url={"url": f"data:{image.mime_type};base64,{image.base64}", "detail": self.image_quality},
            )
        )

    messages = ChatCompletionUserMessageParam(content=message_parts, role="user")
    params = {"model": self.model, "messages": [messages], "max_tokens": self.max_tokens}

    response = self.client.chat.completions.create(**params)

    if len(response.choices) != 1:
        raise Exception("Image query responses with more than one choice are not supported yet.")

    return TextArtifact(response.choices[0].message.content)

OpenSearchVectorStoreDriver

Bases: BaseVectorStoreDriver

A Vector Store Driver for OpenSearch.

Attributes:

Name Type Description
host str

The host of the OpenSearch cluster.

port int

The port of the OpenSearch cluster.

http_auth str | tuple[str, Optional[str]]

The HTTP authentication credentials to use.

use_ssl bool

Whether to use SSL.

verify_certs bool

Whether to verify SSL certificates.

index_name str

The name of the index to use.

Source code in griptape/drivers/vector/opensearch_vector_store_driver.py
@define
class OpenSearchVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for OpenSearch.

    Attributes:
        host: The host of the OpenSearch cluster.
        port: The port of the OpenSearch cluster.
        http_auth: The HTTP authentication credentials to use.
        use_ssl: Whether to use SSL.
        verify_certs: Whether to verify SSL certificates.
        index_name: The name of the index to use.
    """

    host: str = field(kw_only=True, metadata={"serializable": True})
    port: int = field(default=443, kw_only=True, metadata={"serializable": True})
    http_auth: str | tuple[str, Optional[str]] = field(default=None, kw_only=True, metadata={"serializable": True})
    use_ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    verify_certs: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    index_name: str = field(kw_only=True, metadata={"serializable": True})

    client: OpenSearch = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").OpenSearch(
                hosts=[{"host": self.host, "port": self.port}],
                http_auth=self.http_auth,
                use_ssl=self.use_ssl,
                verify_certs=self.verify_certs,
                connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection,
            ),
            takes_self=True,
        )
    )

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in OpenSearch.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """

        vector_id = vector_id if vector_id else utils.str_to_hash(str(vector))
        doc = {"vector": vector, "namespace": namespace, "metadata": meta}
        doc.update(kwargs)
        response = self.client.index(index=self.index_name, id=vector_id, body=doc)

        return response["_id"]

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace.

        Returns:
            If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
        """
        try:
            query = {"bool": {"must": [{"term": {"_id": vector_id}}]}}

            if namespace:
                query["bool"]["must"].append({"term": {"namespace": namespace}})

            response = self.client.search(index=self.index_name, body={"query": query, "size": 1})

            if response["hits"]["total"]["value"] > 0:
                vector_data = response["hits"]["hits"][0]["_source"]
                entry = BaseVectorStoreDriver.Entry(
                    id=vector_id,
                    meta=vector_data.get("metadata"),
                    vector=vector_data.get("vector"),
                    namespace=vector_data.get("namespace"),
                )
                return entry
            else:
                return None
        except Exception as e:
            logging.error(f"Error while loading entry: {e}")
            return None

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Retrieves all vector entries from OpenSearch that match the optional namespace.

        Returns:
            A list of BaseVectorStoreDriver.Entry objects.
        """

        query_body = {"size": 10000, "query": {"match_all": {}}}

        if namespace:
            query_body["query"] = {"match": {"namespace": namespace}}

        response = self.client.search(index=self.index_name, body=query_body)

        entries = [
            BaseVectorStoreDriver.Entry(
                id=hit["_id"],
                vector=hit["_source"].get("vector"),
                meta=hit["_source"].get("metadata"),
                namespace=hit["_source"].get("namespace"),
            )
            for hit in response["hits"]["hits"]
        ]
        return entries

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        include_metadata=True,
        field_name: str = "vector",
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.

        Results can be limited using the count parameter and optionally filtered by a namespace.

        Returns:
            A list of BaseVectorStoreDriver.QueryResult objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
        """
        count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        vector = self.embedding_driver.embed_string(query)
        # Base k-NN query
        query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}}

        if namespace:
            query_body["query"] = {
                "bool": {
                    "must": [{"match": {"namespace": namespace}}, {"knn": {field_name: {"vector": vector, "k": count}}}]
                }
            }

        response = self.client.search(index=self.index_name, body=query_body)

        return [
            BaseVectorStoreDriver.QueryResult(
                id=hit["_id"],
                namespace=hit["_source"].get("namespace") if namespace else None,
                score=hit["_score"],
                vector=hit["_source"].get("vector") if include_vectors else None,
                meta=hit["_source"].get("metadata") if include_metadata else None,
            )
            for hit in response["hits"]["hits"]
        ]

    def delete_vector(self, vector_id: str):
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

client: OpenSearch = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').OpenSearch(hosts=[{'host': self.host, 'port': self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=import_optional_dependency('opensearchpy').RequestsHttpConnection), takes_self=True)) class-attribute instance-attribute

host: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

http_auth: str | tuple[str, Optional[str]] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

index_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

port: int = field(default=443, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

use_ssl: bool = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

verify_certs: bool = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

delete_vector(vector_id)

Source code in griptape/drivers/vector/opensearch_vector_store_driver.py
def delete_vector(self, vector_id: str):
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(namespace=None)

Retrieves all vector entries from OpenSearch that match the optional namespace.

Returns:

Type Description
list[Entry]

A list of BaseVectorStoreDriver.Entry objects.

Source code in griptape/drivers/vector/opensearch_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Retrieves all vector entries from OpenSearch that match the optional namespace.

    Returns:
        A list of BaseVectorStoreDriver.Entry objects.
    """

    query_body = {"size": 10000, "query": {"match_all": {}}}

    if namespace:
        query_body["query"] = {"match": {"namespace": namespace}}

    response = self.client.search(index=self.index_name, body=query_body)

    entries = [
        BaseVectorStoreDriver.Entry(
            id=hit["_id"],
            vector=hit["_source"].get("vector"),
            meta=hit["_source"].get("metadata"),
            namespace=hit["_source"].get("namespace"),
        )
        for hit in response["hits"]["hits"]
    ]
    return entries

load_entry(vector_id, namespace=None)

Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace.

Returns:

Type Description
Optional[Entry]

If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.

Source code in griptape/drivers/vector/opensearch_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace.

    Returns:
        If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
    """
    try:
        query = {"bool": {"must": [{"term": {"_id": vector_id}}]}}

        if namespace:
            query["bool"]["must"].append({"term": {"namespace": namespace}})

        response = self.client.search(index=self.index_name, body={"query": query, "size": 1})

        if response["hits"]["total"]["value"] > 0:
            vector_data = response["hits"]["hits"][0]["_source"]
            entry = BaseVectorStoreDriver.Entry(
                id=vector_id,
                meta=vector_data.get("metadata"),
                vector=vector_data.get("vector"),
                namespace=vector_data.get("namespace"),
            )
            return entry
        else:
            return None
    except Exception as e:
        logging.error(f"Error while loading entry: {e}")
        return None

query(query, count=None, namespace=None, include_vectors=False, include_metadata=True, field_name='vector', **kwargs)

Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.

Results can be limited using the count parameter and optionally filtered by a namespace.

Returns:

Type Description
list[QueryResult]

A list of BaseVectorStoreDriver.QueryResult objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.

Source code in griptape/drivers/vector/opensearch_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    include_metadata=True,
    field_name: str = "vector",
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.

    Results can be limited using the count parameter and optionally filtered by a namespace.

    Returns:
        A list of BaseVectorStoreDriver.QueryResult objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
    """
    count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    vector = self.embedding_driver.embed_string(query)
    # Base k-NN query
    query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}}

    if namespace:
        query_body["query"] = {
            "bool": {
                "must": [{"match": {"namespace": namespace}}, {"knn": {field_name: {"vector": vector, "k": count}}}]
            }
        }

    response = self.client.search(index=self.index_name, body=query_body)

    return [
        BaseVectorStoreDriver.QueryResult(
            id=hit["_id"],
            namespace=hit["_source"].get("namespace") if namespace else None,
            score=hit["_score"],
            vector=hit["_source"].get("vector") if include_vectors else None,
            meta=hit["_source"].get("metadata") if include_metadata else None,
        )
        for hit in response["hits"]["hits"]
    ]

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in OpenSearch.

If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided.

Source code in griptape/drivers/vector/opensearch_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in OpenSearch.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """

    vector_id = vector_id if vector_id else utils.str_to_hash(str(vector))
    doc = {"vector": vector, "namespace": namespace, "metadata": meta}
    doc.update(kwargs)
    response = self.client.index(index=self.index_name, id=vector_id, body=doc)

    return response["_id"]

PgVectorVectorStoreDriver

Bases: BaseVectorStoreDriver

A vector store driver to Postgres using the PGVector extension.

Attributes:

Name Type Description
connection_string Optional[str]

An optional string describing the target Postgres database instance.

create_engine_params dict

Additional configuration params passed when creating the database connection.

engine Optional[Engine]

An optional sqlalchemy Postgres engine to use.

table_name str

Optionally specify the name of the table to used to store vectors.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
@define
class PgVectorVectorStoreDriver(BaseVectorStoreDriver):
    """A vector store driver to Postgres using the PGVector extension.

    Attributes:
        connection_string: An optional string describing the target Postgres database instance.
        create_engine_params: Additional configuration params passed when creating the database connection.
        engine: An optional sqlalchemy Postgres engine to use.
        table_name: Optionally specify the name of the table to used to store vectors.
    """

    connection_string: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
    engine: Optional[Engine] = field(default=None, kw_only=True)
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    _model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True))

    @connection_string.validator  # pyright: ignore
    def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
        # If an engine is provided, the connection string is not used.
        if self.engine is not None:
            return

        # If an engine is not provided, a connection string is required.
        if connection_string is None:
            raise ValueError("An engine or connection string is required")

        if not connection_string.startswith("postgresql://"):
            raise ValueError("The connection string must describe a Postgres database connection")

    @engine.validator  # pyright: ignore
    def validate_engine(self, _, engine: Optional[Engine]) -> None:
        # If a connection string is provided, an engine does not need to be provided.
        if self.connection_string is not None:
            return

        # If a connection string is not provided, an engine is required.
        if engine is None:
            raise ValueError("An engine or connection string is required")

    def __attrs_post_init__(self) -> None:
        """If an engine is provided, it will be used to connect to the database.
        If not, a connection string is used to create a new database connection here.
        """
        if self.engine is None:
            self.engine = cast(Engine, create_engine(self.connection_string, **self.create_engine_params))

    def setup(
        self, create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True
    ) -> None:
        """Provides a mechanism to initialize the database schema and extensions."""
        if install_uuid_extension:
            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')

        if install_vector_extension:
            self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";')

        if create_schema:
            self._model.metadata.create_all(self.engine)

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in the collection."""
        with Session(self.engine) as session:
            obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

            obj = session.merge(obj)
            session.commit()

            return str(getattr(obj, "id"))

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
        """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
        with Session(self.engine) as session:
            result = session.get(self._model, vector_id)

            return BaseVectorStoreDriver.Entry(
                id=getattr(result, "id"),
                vector=getattr(result, "vector"),
                namespace=getattr(result, "namespace"),
                meta=getattr(result, "meta"),
            )

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Retrieves all vector entries from the collection, optionally filtering to only
        those that match the provided namespace.
        """
        with Session(self.engine) as session:
            query = session.query(self._model)
            if namespace:
                query = query.filter_by(namespace=namespace)

            results = query.all()

            return [
                BaseVectorStoreDriver.Entry(
                    id=str(result.id), vector=result.vector, namespace=result.namespace, meta=result.meta
                )
                for result in results
            ]

    def query(
        self,
        query: str,
        count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        distance_metric: str = "cosine_distance",
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Performs a search on the collection to find vectors similar to the provided input vector,
        optionally filtering to only those that match the provided namespace.
        """
        distance_metrics = {
            "cosine_distance": self._model.vector.cosine_distance,
            "l2_distance": self._model.vector.l2_distance,
            "inner_product": self._model.vector.max_inner_product,
        }

        if distance_metric not in distance_metrics:
            raise ValueError("Invalid distance metric provided")

        op = distance_metrics[distance_metric]

        with Session(self.engine) as session:
            vector = self.embedding_driver.embed_string(query)

            # The query should return both the vector and the distance metric score.
            query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector))  # pyright: ignore

            filter_kwargs: Optional[OrderedDict] = None

            if namespace is not None:
                filter_kwargs = OrderedDict(namespace=namespace)

            if "filter" in kwargs and isinstance(kwargs["filter"], dict):
                filter_kwargs = filter_kwargs or OrderedDict()
                filter_kwargs.update(kwargs["filter"])

            if filter_kwargs is not None:
                query_result = query_result.filter_by(**filter_kwargs)

            results = query_result.limit(count).all()

            return [
                BaseVectorStoreDriver.QueryResult(
                    id=str(result[0].id),
                    vector=result[0].vector if include_vectors else None,
                    score=result[1],
                    meta=result[0].meta,
                    namespace=result[0].namespace,
                )
                for result in results
            ]

    def default_vector_model(self) -> Any:
        Vector = import_optional_dependency("pgvector.sqlalchemy").Vector
        Base = declarative_base()

        @dataclass
        class VectorModel(Base):
            __tablename__ = self.table_name

            id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
            vector = Column(Vector())
            namespace = Column(String)
            meta = Column(JSON)

        return VectorModel

    def delete_vector(self, vector_id: str):
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

connection_string: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

create_engine_params: dict = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

engine: Optional[Engine] = field(default=None, kw_only=True) class-attribute instance-attribute

table_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__attrs_post_init__()

If an engine is provided, it will be used to connect to the database. If not, a connection string is used to create a new database connection here.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def __attrs_post_init__(self) -> None:
    """If an engine is provided, it will be used to connect to the database.
    If not, a connection string is used to create a new database connection here.
    """
    if self.engine is None:
        self.engine = cast(Engine, create_engine(self.connection_string, **self.create_engine_params))

default_vector_model()

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def default_vector_model(self) -> Any:
    Vector = import_optional_dependency("pgvector.sqlalchemy").Vector
    Base = declarative_base()

    @dataclass
    class VectorModel(Base):
        __tablename__ = self.table_name

        id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
        vector = Column(Vector())
        namespace = Column(String)
        meta = Column(JSON)

    return VectorModel

delete_vector(vector_id)

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def delete_vector(self, vector_id: str):
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(namespace=None)

Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Retrieves all vector entries from the collection, optionally filtering to only
    those that match the provided namespace.
    """
    with Session(self.engine) as session:
        query = session.query(self._model)
        if namespace:
            query = query.filter_by(namespace=namespace)

        results = query.all()

        return [
            BaseVectorStoreDriver.Entry(
                id=str(result.id), vector=result.vector, namespace=result.namespace, meta=result.meta
            )
            for result in results
        ]

load_entry(vector_id, namespace=None)

Retrieves a specific vector entry from the collection based on its identifier and optional namespace.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
    """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
    with Session(self.engine) as session:
        result = session.get(self._model, vector_id)

        return BaseVectorStoreDriver.Entry(
            id=getattr(result, "id"),
            vector=getattr(result, "vector"),
            namespace=getattr(result, "namespace"),
            meta=getattr(result, "meta"),
        )

query(query, count=BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, namespace=None, include_vectors=False, distance_metric='cosine_distance', **kwargs)

Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    distance_metric: str = "cosine_distance",
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Performs a search on the collection to find vectors similar to the provided input vector,
    optionally filtering to only those that match the provided namespace.
    """
    distance_metrics = {
        "cosine_distance": self._model.vector.cosine_distance,
        "l2_distance": self._model.vector.l2_distance,
        "inner_product": self._model.vector.max_inner_product,
    }

    if distance_metric not in distance_metrics:
        raise ValueError("Invalid distance metric provided")

    op = distance_metrics[distance_metric]

    with Session(self.engine) as session:
        vector = self.embedding_driver.embed_string(query)

        # The query should return both the vector and the distance metric score.
        query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector))  # pyright: ignore

        filter_kwargs: Optional[OrderedDict] = None

        if namespace is not None:
            filter_kwargs = OrderedDict(namespace=namespace)

        if "filter" in kwargs and isinstance(kwargs["filter"], dict):
            filter_kwargs = filter_kwargs or OrderedDict()
            filter_kwargs.update(kwargs["filter"])

        if filter_kwargs is not None:
            query_result = query_result.filter_by(**filter_kwargs)

        results = query_result.limit(count).all()

        return [
            BaseVectorStoreDriver.QueryResult(
                id=str(result[0].id),
                vector=result[0].vector if include_vectors else None,
                score=result[1],
                meta=result[0].meta,
                namespace=result[0].namespace,
            )
            for result in results
        ]

setup(create_schema=True, install_uuid_extension=True, install_vector_extension=True)

Provides a mechanism to initialize the database schema and extensions.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def setup(
    self, create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True
) -> None:
    """Provides a mechanism to initialize the database schema and extensions."""
    if install_uuid_extension:
        self.engine.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')

    if install_vector_extension:
        self.engine.execute('CREATE EXTENSION IF NOT EXISTS "vector";')

    if create_schema:
        self._model.metadata.create_all(self.engine)

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in the collection.

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in the collection."""
    with Session(self.engine) as session:
        obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

        obj = session.merge(obj)
        session.commit()

        return str(getattr(obj, "id"))

validate_connection_string(_, connection_string)

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
@connection_string.validator  # pyright: ignore
def validate_connection_string(self, _, connection_string: Optional[str]) -> None:
    # If an engine is provided, the connection string is not used.
    if self.engine is not None:
        return

    # If an engine is not provided, a connection string is required.
    if connection_string is None:
        raise ValueError("An engine or connection string is required")

    if not connection_string.startswith("postgresql://"):
        raise ValueError("The connection string must describe a Postgres database connection")

validate_engine(_, engine)

Source code in griptape/drivers/vector/pgvector_vector_store_driver.py
@engine.validator  # pyright: ignore
def validate_engine(self, _, engine: Optional[Engine]) -> None:
    # If a connection string is provided, an engine does not need to be provided.
    if self.connection_string is not None:
        return

    # If a connection string is not provided, an engine is required.
    if engine is None:
        raise ValueError("An engine or connection string is required")

PineconeVectorStoreDriver

Bases: BaseVectorStoreDriver

Source code in griptape/drivers/vector/pinecone_vector_store_driver.py
@define
class PineconeVectorStoreDriver(BaseVectorStoreDriver):
    api_key: str = field(kw_only=True, metadata={"serializable": True})
    index_name: str = field(kw_only=True, metadata={"serializable": True})
    environment: str = field(kw_only=True, metadata={"serializable": True})
    project_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    index: pinecone.Index = field(init=False)

    def __attrs_post_init__(self) -> None:
        pinecone = import_optional_dependency("pinecone").Pinecone(
            api_key=self.api_key, environment=self.environment, project_name=self.project_name
        )

        self.index = pinecone.Index(self.index_name)

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        vector_id = vector_id if vector_id else str_to_hash(str(vector))

        params: dict[str, Any] = {"namespace": namespace} | kwargs

        self.index.upsert(vectors=[(vector_id, vector, meta)], **params)

        return vector_id

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()
        vectors = list(result["vectors"].values())

        if len(vectors) > 0:
            vector = vectors[0]

            return BaseVectorStoreDriver.Entry(
                id=vector["id"], meta=vector["metadata"], vector=vector["values"], namespace=result["namespace"]
            )
        else:
            return None

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching
        # all values from a namespace:
        # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5

        results = self.index.query(
            vector=self.embedding_driver.embed_string(""), top_k=10000, include_metadata=True, namespace=namespace
        )

        return [
            BaseVectorStoreDriver.Entry(
                id=r["id"], vector=r["values"], meta=r["metadata"], namespace=results["namespace"]
            )
            for r in results["matches"]
        ]

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        # PineconeVectorStorageDriver-specific params:
        include_metadata=True,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        vector = self.embedding_driver.embed_string(query)

        params = {
            "top_k": count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
            "namespace": namespace,
            "include_values": include_vectors,
            "include_metadata": include_metadata,
        } | kwargs

        results = self.index.query(vector=vector, **params)

        return [
            BaseVectorStoreDriver.QueryResult(
                id=r["id"], vector=r["values"], score=r["score"], meta=r["metadata"], namespace=results["namespace"]
            )
            for r in results["matches"]
        ]

    def delete_vector(self, vector_id: str):
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

api_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

environment: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

index: pinecone.Index = field(init=False) class-attribute instance-attribute

index_name: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

project_name: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/vector/pinecone_vector_store_driver.py
def __attrs_post_init__(self) -> None:
    pinecone = import_optional_dependency("pinecone").Pinecone(
        api_key=self.api_key, environment=self.environment, project_name=self.project_name
    )

    self.index = pinecone.Index(self.index_name)

delete_vector(vector_id)

Source code in griptape/drivers/vector/pinecone_vector_store_driver.py
def delete_vector(self, vector_id: str):
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(namespace=None)

Source code in griptape/drivers/vector/pinecone_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching
    # all values from a namespace:
    # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5

    results = self.index.query(
        vector=self.embedding_driver.embed_string(""), top_k=10000, include_metadata=True, namespace=namespace
    )

    return [
        BaseVectorStoreDriver.Entry(
            id=r["id"], vector=r["values"], meta=r["metadata"], namespace=results["namespace"]
        )
        for r in results["matches"]
    ]

load_entry(vector_id, namespace=None)

Source code in griptape/drivers/vector/pinecone_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()
    vectors = list(result["vectors"].values())

    if len(vectors) > 0:
        vector = vectors[0]

        return BaseVectorStoreDriver.Entry(
            id=vector["id"], meta=vector["metadata"], vector=vector["values"], namespace=result["namespace"]
        )
    else:
        return None

query(query, count=None, namespace=None, include_vectors=False, include_metadata=True, **kwargs)

Source code in griptape/drivers/vector/pinecone_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    # PineconeVectorStorageDriver-specific params:
    include_metadata=True,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    vector = self.embedding_driver.embed_string(query)

    params = {
        "top_k": count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        "namespace": namespace,
        "include_values": include_vectors,
        "include_metadata": include_metadata,
    } | kwargs

    results = self.index.query(vector=vector, **params)

    return [
        BaseVectorStoreDriver.QueryResult(
            id=r["id"], vector=r["values"], score=r["score"], meta=r["metadata"], namespace=results["namespace"]
        )
        for r in results["matches"]
    ]

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source code in griptape/drivers/vector/pinecone_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    vector_id = vector_id if vector_id else str_to_hash(str(vector))

    params: dict[str, Any] = {"namespace": namespace} | kwargs

    self.index.upsert(vectors=[(vector_id, vector, meta)], **params)

    return vector_id

RedisVectorStoreDriver

Bases: BaseVectorStoreDriver

A Vector Store Driver for Redis.

This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store, retrieve, and query vectors in a structured manner. Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly.

Attributes:

Name Type Description
host str

The host of the Redis instance.

port int

The port of the Redis instance.

db int

The database of the Redis instance.

password Optional[str]

The password of the Redis instance.

index str

The name of the index to use.

Source code in griptape/drivers/vector/redis_vector_store_driver.py
@define
class RedisVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for Redis.

    This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store, retrieve, and query vectors in a structured manner.
    Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly.

    Attributes:
        host: The host of the Redis instance.
        port: The port of the Redis instance.
        db: The database of the Redis instance.
        password: The password of the Redis instance.
        index: The name of the index to use.
    """

    host: str = field(kw_only=True, metadata={"serializable": True})
    port: int = field(kw_only=True, metadata={"serializable": True})
    db: int = field(kw_only=True, default=0, metadata={"serializable": True})
    password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    index: str = field(kw_only=True, metadata={"serializable": True})

    client: Redis = field(
        default=Factory(
            lambda self: import_optional_dependency("redis").Redis(
                host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=False
            ),
            takes_self=True,
        )
    )

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in Redis.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """
        vector_id = vector_id if vector_id else str_to_hash(str(vector))
        key = self._generate_key(vector_id, namespace)
        bytes_vector = json.dumps(vector).encode("utf-8")

        mapping = {}
        mapping["vector"] = np.array(vector, dtype=np.float32).tobytes()
        mapping["vec_string"] = bytes_vector

        if meta:
            mapping["metadata"] = json.dumps(meta)

        self.client.hset(key, mapping=mapping)

        return vector_id

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Retrieves a specific vector entry from Redis based on its identifier and optional namespace.

        Returns:
            If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
        """
        key = self._generate_key(vector_id, namespace)
        result = self.client.hgetall(key)
        vector = np.frombuffer(result[b"vector"], dtype=np.float32).tolist()
        meta = json.loads(result[b"metadata"]) if b"metadata" in result else None

        return BaseVectorStoreDriver.Entry(id=vector_id, meta=meta, vector=vector, namespace=namespace)

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Retrieves all vector entries from Redis that match the optional namespace.

        Returns:
            A list of `BaseVectorStoreDriver.Entry` objects.
        """
        pattern = f"{namespace}:*" if namespace else "*"
        keys = self.client.keys(pattern)

        entries = []
        for key in keys:
            entry = self.load_entry(key.decode("utf-8"), namespace)
            if entry:
                entries.append(entry)

        return entries

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        """Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector.

        Results can be limited using the count parameter and optionally filtered by a namespace.

        Returns:
            A list of BaseVectorStoreDriver.QueryResult objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
        """
        Query = import_optional_dependency("redis.commands.search.query").Query

        vector = self.embedding_driver.embed_string(query)

        query_expression = (
            Query(f"*=>[KNN {count or 10} @vector $vector as score]")
            .sort_by("score")
            .return_fields("id", "score", "metadata", "vec_string")
            .paging(0, count or 10)
            .dialect(2)
        )

        query_params = {"vector": np.array(vector, dtype=np.float32).tobytes()}

        results = self.client.ft(self.index).search(query_expression, query_params).docs  # pyright: ignore

        query_results = []
        for document in results:
            metadata = getattr(document, "metadata", None)
            namespace = document.id.split(":")[0] if ":" in document.id else None
            vector_id = document.id.split(":")[1] if ":" in document.id else document.id
            vector_float_list = json.loads(document["vec_string"]) if include_vectors else None
            query_results.append(
                BaseVectorStoreDriver.QueryResult(
                    id=vector_id,
                    vector=vector_float_list,
                    score=float(document["score"]),
                    meta=metadata,
                    namespace=namespace,
                )
            )
        return query_results

    def _generate_key(self, vector_id: str, namespace: Optional[str] = None) -> str:
        """Generates a Redis key using the provided vector ID and optionally a namespace."""
        return f"{namespace}:{vector_id}" if namespace else vector_id

    def _get_doc_prefix(self, namespace: Optional[str] = None) -> str:
        """Get the document prefix based on the provided namespace."""
        return f"{namespace}:" if namespace else ""

    def delete_vector(self, vector_id: str):
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

client: Redis = field(default=Factory(lambda self: import_optional_dependency('redis').Redis(host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=False), takes_self=True)) class-attribute instance-attribute

db: int = field(kw_only=True, default=0, metadata={'serializable': True}) class-attribute instance-attribute

host: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

index: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

password: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

port: int = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

delete_vector(vector_id)

Source code in griptape/drivers/vector/redis_vector_store_driver.py
def delete_vector(self, vector_id: str):
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(namespace=None)

Retrieves all vector entries from Redis that match the optional namespace.

Returns:

Type Description
list[Entry]

A list of BaseVectorStoreDriver.Entry objects.

Source code in griptape/drivers/vector/redis_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Retrieves all vector entries from Redis that match the optional namespace.

    Returns:
        A list of `BaseVectorStoreDriver.Entry` objects.
    """
    pattern = f"{namespace}:*" if namespace else "*"
    keys = self.client.keys(pattern)

    entries = []
    for key in keys:
        entry = self.load_entry(key.decode("utf-8"), namespace)
        if entry:
            entries.append(entry)

    return entries

load_entry(vector_id, namespace=None)

Retrieves a specific vector entry from Redis based on its identifier and optional namespace.

Returns:

Type Description
Optional[Entry]

If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.

Source code in griptape/drivers/vector/redis_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Retrieves a specific vector entry from Redis based on its identifier and optional namespace.

    Returns:
        If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
    """
    key = self._generate_key(vector_id, namespace)
    result = self.client.hgetall(key)
    vector = np.frombuffer(result[b"vector"], dtype=np.float32).tolist()
    meta = json.loads(result[b"metadata"]) if b"metadata" in result else None

    return BaseVectorStoreDriver.Entry(id=vector_id, meta=meta, vector=vector, namespace=namespace)

query(query, count=None, namespace=None, include_vectors=False, **kwargs)

Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector.

Results can be limited using the count parameter and optionally filtered by a namespace.

Returns:

Type Description
list[QueryResult]

A list of BaseVectorStoreDriver.QueryResult objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.

Source code in griptape/drivers/vector/redis_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    """Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector.

    Results can be limited using the count parameter and optionally filtered by a namespace.

    Returns:
        A list of BaseVectorStoreDriver.QueryResult objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
    """
    Query = import_optional_dependency("redis.commands.search.query").Query

    vector = self.embedding_driver.embed_string(query)

    query_expression = (
        Query(f"*=>[KNN {count or 10} @vector $vector as score]")
        .sort_by("score")
        .return_fields("id", "score", "metadata", "vec_string")
        .paging(0, count or 10)
        .dialect(2)
    )

    query_params = {"vector": np.array(vector, dtype=np.float32).tobytes()}

    results = self.client.ft(self.index).search(query_expression, query_params).docs  # pyright: ignore

    query_results = []
    for document in results:
        metadata = getattr(document, "metadata", None)
        namespace = document.id.split(":")[0] if ":" in document.id else None
        vector_id = document.id.split(":")[1] if ":" in document.id else document.id
        vector_float_list = json.loads(document["vec_string"]) if include_vectors else None
        query_results.append(
            BaseVectorStoreDriver.QueryResult(
                id=vector_id,
                vector=vector_float_list,
                score=float(document["score"]),
                meta=metadata,
                namespace=namespace,
            )
        )
    return query_results

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Inserts or updates a vector in Redis.

If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided.

Source code in griptape/drivers/vector/redis_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in Redis.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """
    vector_id = vector_id if vector_id else str_to_hash(str(vector))
    key = self._generate_key(vector_id, namespace)
    bytes_vector = json.dumps(vector).encode("utf-8")

    mapping = {}
    mapping["vector"] = np.array(vector, dtype=np.float32).tobytes()
    mapping["vec_string"] = bytes_vector

    if meta:
        mapping["metadata"] = json.dumps(meta)

    self.client.hset(key, mapping=mapping)

    return vector_id

SageMakerFalconPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
@define
class SageMakerFalconPromptModelDriver(BasePromptModelDriver):
    DEFAULT_MAX_TOKENS = 600

    _tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True)

    @property
    def tokenizer(self) -> HuggingFaceTokenizer:
        if self._tokenizer is None:
            self._tokenizer = HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained("tiiuae/falcon-40b"),
                max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS,
            )
        return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
        return self.prompt_driver.prompt_stack_to_string(prompt_stack)

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)
        stop_sequences = self.prompt_driver.tokenizer.stop_sequences

        return {
            "max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
            "temperature": self.prompt_driver.temperature,
            "do_sample": True,
            "stop": stop_sequences,
        }

    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        if isinstance(output, list):
            return TextArtifact(output[0]["generated_text"].strip())
        else:
            raise ValueError("output must be an instance of 'list'")

DEFAULT_MAX_TOKENS = 600 class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer property

process_output(output)

Source code in griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    if isinstance(output, list):
        return TextArtifact(output[0]["generated_text"].strip())
    else:
        raise ValueError("output must be an instance of 'list'")

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str:
    return self.prompt_driver.prompt_stack_to_string(prompt_stack)

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)
    stop_sequences = self.prompt_driver.tokenizer.stop_sequences

    return {
        "max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
        "temperature": self.prompt_driver.temperature,
        "do_sample": True,
        "stop": stop_sequences,
    }

SageMakerHuggingFaceEmbeddingModelDriver

Bases: BaseEmbeddingModelDriver

Source code in griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py
@define
class SageMakerHuggingFaceEmbeddingModelDriver(BaseEmbeddingModelDriver):
    def chunk_to_model_params(self, chunk: str) -> dict:
        return {"text_inputs": chunk}

    def process_output(self, output: dict) -> list[float]:
        return output["embedding"][0]

chunk_to_model_params(chunk)

Source code in griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py
def chunk_to_model_params(self, chunk: str) -> dict:
    return {"text_inputs": chunk}

process_output(output)

Source code in griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py
def process_output(self, output: dict) -> list[float]:
    return output["embedding"][0]

SageMakerLlamaPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py
@define
class SageMakerLlamaPromptModelDriver(BasePromptModelDriver):
    DEFAULT_MAX_TOKENS = 600

    _tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True)

    @property
    def tokenizer(self) -> HuggingFaceTokenizer:
        if self._tokenizer is None:
            self._tokenizer = HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").LlamaTokenizerFast.from_pretrained(
                    "hf-internal-testing/llama-tokenizer"
                ),
                max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS,
            )
        return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list:
        return [[{"role": i.role, "content": i.content} for i in prompt_stack.inputs]]

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack)

        return {
            "max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
            "temperature": self.prompt_driver.temperature,
        }

    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        if isinstance(output, list):
            return TextArtifact(output[0]["generation"]["content"].strip())
        else:
            raise ValueError("output must be an instance of 'list'")

DEFAULT_MAX_TOKENS = 600 class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer property

process_output(output)

Source code in griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    if isinstance(output, list):
        return TextArtifact(output[0]["generation"]["content"].strip())
    else:
        raise ValueError("output must be an instance of 'list'")

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list:
    return [[{"role": i.role, "content": i.content} for i in prompt_stack.inputs]]

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack)

    return {
        "max_new_tokens": self.prompt_driver.max_output_tokens(prompt),
        "temperature": self.prompt_driver.temperature,
    }

SageMakerTensorFlowHubEmbeddingModelDriver

Bases: BaseEmbeddingModelDriver

Source code in griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py
@define
class SageMakerTensorFlowHubEmbeddingModelDriver(BaseEmbeddingModelDriver):
    def chunk_to_model_params(self, chunk: str) -> dict:
        return {"text_inputs": chunk}

    def process_output(self, output: dict) -> list[float]:
        return output["embedding"]

chunk_to_model_params(chunk)

Source code in griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py
def chunk_to_model_params(self, chunk: str) -> dict:
    return {"text_inputs": chunk}

process_output(output)

Source code in griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py
def process_output(self, output: dict) -> list[float]:
    return output["embedding"]

SnowflakeSqlDriver

Bases: BaseSqlDriver

Source code in griptape/drivers/sql/snowflake_sql_driver.py
@define
class SnowflakeSqlDriver(BaseSqlDriver):
    connection_func: Callable[[], SnowflakeConnection] = field(kw_only=True)
    engine: Engine = field(
        default=Factory(
            # Creator bypasses the URL param
            # https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.creator
            lambda self: import_optional_dependency("sqlalchemy").create_engine(
                "snowflake://not@used/db", creator=self.connection_func
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    @connection_func.validator  # pyright: ignore
    def validate_connection_func(self, _, connection_func: Callable[[], SnowflakeConnection]) -> None:
        snowflake_connection = connection_func()
        snowflake = import_optional_dependency("snowflake")

        if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
            raise ValueError("The connection_func must return a SnowflakeConnection")
        if not snowflake_connection.schema or not snowflake_connection.database:
            raise ValueError("Provide a schema and database for the Snowflake connection")

    @engine.validator  # pyright: ignore
    def validate_engine_url(self, _, engine: Engine) -> None:
        if not engine.url.render_as_string().startswith("snowflake://"):
            raise ValueError("Provide a Snowflake connection")

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)

        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        with self.engine.connect() as con:
            results = con.execute(sqlalchemy.text(query))

            if results is not None:
                if results.returns_rows:
                    return [{column: value for column, value in result.items()} for result in results]
                else:
                    return None
            else:
                raise ValueError("No results found")

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        try:
            metadata_obj = sqlalchemy.MetaData()
            metadata_obj.reflect(bind=self.engine)
            table = sqlalchemy.Table(table_name, metadata_obj, schema=schema, autoload=True, autoload_with=self.engine)
            return str([(c.name, c.type) for c in table.columns])
        except sqlalchemy.exc.NoSuchTableError:
            return None

connection_func: Callable[[], SnowflakeConnection] = field(kw_only=True) class-attribute instance-attribute

engine: Engine = field(default=Factory(lambda self: import_optional_dependency('sqlalchemy').create_engine('snowflake://not@used/db', creator=self.connection_func), takes_self=True), kw_only=True) class-attribute instance-attribute

execute_query(query)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)

    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    with self.engine.connect() as con:
        results = con.execute(sqlalchemy.text(query))

        if results is not None:
            if results.returns_rows:
                return [{column: value for column, value in result.items()} for result in results]
            else:
                return None
        else:
            raise ValueError("No results found")

get_table_schema(table_name, schema=None)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    try:
        metadata_obj = sqlalchemy.MetaData()
        metadata_obj.reflect(bind=self.engine)
        table = sqlalchemy.Table(table_name, metadata_obj, schema=schema, autoload=True, autoload_with=self.engine)
        return str([(c.name, c.type) for c in table.columns])
    except sqlalchemy.exc.NoSuchTableError:
        return None

validate_connection_func(_, connection_func)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
@connection_func.validator  # pyright: ignore
def validate_connection_func(self, _, connection_func: Callable[[], SnowflakeConnection]) -> None:
    snowflake_connection = connection_func()
    snowflake = import_optional_dependency("snowflake")

    if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
        raise ValueError("The connection_func must return a SnowflakeConnection")
    if not snowflake_connection.schema or not snowflake_connection.database:
        raise ValueError("Provide a schema and database for the Snowflake connection")

validate_engine_url(_, engine)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
@engine.validator  # pyright: ignore
def validate_engine_url(self, _, engine: Engine) -> None:
    if not engine.url.render_as_string().startswith("snowflake://"):
        raise ValueError("Provide a Snowflake connection")

SqlDriver

Bases: BaseSqlDriver

Source code in griptape/drivers/sql/sql_driver.py
@define
class SqlDriver(BaseSqlDriver):
    engine_url: str = field(kw_only=True)
    create_engine_params: dict = field(factory=dict, kw_only=True)
    engine: Engine = field(init=False)

    def __attrs_post_init__(self) -> None:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        self.engine = sqlalchemy.create_engine(self.engine_url, **self.create_engine_params)

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)

        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        with self.engine.begin() as con:
            results = con.execute(sqlalchemy.text(query))

            if results is not None:
                if results.returns_rows:
                    return [{column: value for column, value in result.items()} for result in results]
                else:
                    return None
            else:
                raise ValueError("No result found")

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        try:
            table = sqlalchemy.Table(
                table_name,
                sqlalchemy.MetaData(bind=self.engine),
                schema=schema,
                autoload=True,
                autoload_with=self.engine,
            )
            return str([(c.name, c.type) for c in table.columns])
        except sqlalchemy.exc.NoSuchTableError:
            return None

create_engine_params: dict = field(factory=dict, kw_only=True) class-attribute instance-attribute

engine: Engine = field(init=False) class-attribute instance-attribute

engine_url: str = field(kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/sql/sql_driver.py
def __attrs_post_init__(self) -> None:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    self.engine = sqlalchemy.create_engine(self.engine_url, **self.create_engine_params)

execute_query(query)

Source code in griptape/drivers/sql/sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)

    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/drivers/sql/sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    with self.engine.begin() as con:
        results = con.execute(sqlalchemy.text(query))

        if results is not None:
            if results.returns_rows:
                return [{column: value for column, value in result.items()} for result in results]
            else:
                return None
        else:
            raise ValueError("No result found")

get_table_schema(table_name, schema=None)

Source code in griptape/drivers/sql/sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    try:
        table = sqlalchemy.Table(
            table_name,
            sqlalchemy.MetaData(bind=self.engine),
            schema=schema,
            autoload=True,
            autoload_with=self.engine,
        )
        return str([(c.name, c.type) for c in table.columns])
    except sqlalchemy.exc.NoSuchTableError:
        return None

TrafilaturaWebScraperDriver

Bases: BaseWebScraperDriver

Source code in griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
@define
class TrafilaturaWebScraperDriver(BaseWebScraperDriver):
    include_links: bool = field(default=True, kw_only=True)

    def scrape_url(self, url: str) -> TextArtifact:
        trafilatura = import_optional_dependency("trafilatura")
        use_config = trafilatura.settings.use_config

        config = use_config()
        page = trafilatura.fetch_url(url, no_ssl=True)

        # This disables signal, so that trafilatura can work on any thread:
        # More info: https://trafilatura.readthedocs.io/usage-python.html#disabling-signal
        config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0")

        # Disable error logging in trafilatura as it sometimes logs errors from lxml, even though
        # the end result of page parsing is successful.
        logging.getLogger("trafilatura").setLevel(logging.FATAL)

        if page is None:
            raise Exception("can't access URL")
        else:
            extracted_page = trafilatura.extract(
                page, include_links=self.include_links, output_format="json", config=config
            )

        if not extracted_page:
            raise Exception("can't extract page")

        text = json.loads(extracted_page).get("text")

        return TextArtifact(text)

scrape_url(url)

Source code in griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
def scrape_url(self, url: str) -> TextArtifact:
    trafilatura = import_optional_dependency("trafilatura")
    use_config = trafilatura.settings.use_config

    config = use_config()
    page = trafilatura.fetch_url(url, no_ssl=True)

    # This disables signal, so that trafilatura can work on any thread:
    # More info: https://trafilatura.readthedocs.io/usage-python.html#disabling-signal
    config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0")

    # Disable error logging in trafilatura as it sometimes logs errors from lxml, even though
    # the end result of page parsing is successful.
    logging.getLogger("trafilatura").setLevel(logging.FATAL)

    if page is None:
        raise Exception("can't access URL")
    else:
        extracted_page = trafilatura.extract(
            page, include_links=self.include_links, output_format="json", config=config
        )

    if not extracted_page:
        raise Exception("can't extract page")

    text = json.loads(extracted_page).get("text")

    return TextArtifact(text)

VoyageAiEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

VoyageAI embedding model name. Defaults to voyage-large-2.

api_key Optional[str]

API key to pass directly. Defaults to VOYAGE_API_KEY environment variable.

tokenizer VoyageAiTokenizer

Optionally provide custom VoyageAiTokenizer.

client Any

Optionally provide custom VoyageAI Client.

input_type str

VoyageAI input type. Defaults to document.

Source code in griptape/drivers/embedding/voyageai_embedding_driver.py
@define
class VoyageAiEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: VoyageAI embedding model name. Defaults to `voyage-large-2`.
        api_key: API key to pass directly. Defaults to `VOYAGE_API_KEY` environment variable.
        tokenizer: Optionally provide custom `VoyageAiTokenizer`.
        client: Optionally provide custom VoyageAI `Client`.
        input_type: VoyageAI input type. Defaults to `document`.
    """

    DEFAULT_MODEL = "voyage-large-2"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    client: Any = field(
        default=Factory(
            lambda self: import_optional_dependency("voyageai").Client(api_key=self.api_key), takes_self=True
        )
    )
    tokenizer: VoyageAiTokenizer = field(
        default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True),
        kw_only=True,
    )
    input_type: str = field(default="document", kw_only=True, metadata={"serializable": True})

    def try_embed_chunk(self, chunk: str) -> list[float]:
        return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]

DEFAULT_MODEL = 'voyage-large-2' class-attribute instance-attribute

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

client: Any = field(default=Factory(lambda self: import_optional_dependency('voyageai').Client(api_key=self.api_key), takes_self=True)) class-attribute instance-attribute

input_type: str = field(default='document', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: VoyageAiTokenizer = field(default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/voyageai_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]

WebhookEventListenerDriver

Bases: BaseEventListenerDriver

Source code in griptape/drivers/event_listener/webhook_event_listener_driver.py
@define
class WebhookEventListenerDriver(BaseEventListenerDriver):
    webhook_url: str = field(kw_only=True)
    headers: dict = field(default=None, kw_only=True)

    def try_publish_event_payload(self, event_payload: dict) -> None:
        response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers)
        response.raise_for_status()

headers: dict = field(default=None, kw_only=True) class-attribute instance-attribute

webhook_url: str = field(kw_only=True) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source code in griptape/drivers/event_listener/webhook_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers)
    response.raise_for_status()