Skip to content

Drivers

__all__ = ['BasePromptDriver', 'OpenAiChatPromptDriver', 'OpenAiCompletionPromptDriver', 'AzureOpenAiChatPromptDriver', 'AzureOpenAiCompletionPromptDriver', 'CoherePromptDriver', 'HuggingFacePipelinePromptDriver', 'HuggingFaceHubPromptDriver', 'AnthropicPromptDriver', 'AmazonSageMakerPromptDriver', 'AmazonBedrockPromptDriver', 'BaseMultiModelPromptDriver', 'BaseConversationMemoryDriver', 'LocalConversationMemoryDriver', 'AmazonDynamoDbConversationMemoryDriver', 'BaseEmbeddingDriver', 'OpenAiEmbeddingDriver', 'AzureOpenAiEmbeddingDriver', 'BaseMultiModelEmbeddingDriver', 'AmazonSageMakerEmbeddingDriver', 'AmazonBedrockTitanEmbeddingDriver', 'HuggingFaceHubEmbeddingDriver', 'BaseEmbeddingModelDriver', 'SageMakerHuggingFaceEmbeddingModelDriver', 'SageMakerTensorFlowHubEmbeddingModelDriver', 'BaseVectorStoreDriver', 'LocalVectorStoreDriver', 'PineconeVectorStoreDriver', 'MarqoVectorStoreDriver', 'MongoDbAtlasVectorStoreDriver', 'RedisVectorStoreDriver', 'OpenSearchVectorStoreDriver', 'AmazonOpenSearchVectorStoreDriver', 'PgVectorVectorStoreDriver', 'BaseSqlDriver', 'AmazonRedshiftSqlDriver', 'SnowflakeSqlDriver', 'SqlDriver', 'BasePromptModelDriver', 'SageMakerLlamaPromptModelDriver', 'SageMakerFalconPromptModelDriver', 'BedrockTitanPromptModelDriver', 'BedrockClaudePromptModelDriver', 'BedrockJurassicPromptModelDriver', 'BedrockLlamaPromptModelDriver', 'BaseImageGenerationModelDriver', 'BedrockStableDiffusionImageGenerationModelDriver', 'BedrockTitanImageGenerationModelDriver', 'BaseImageGenerationDriver', 'BaseMultiModelImageGenerationDriver', 'OpenAiImageGenerationDriver', 'LeonardoImageGenerationDriver', 'AmazonBedrockImageGenerationDriver', 'AzureOpenAiImageGenerationDriver'] module-attribute

AmazonBedrockImageGenerationDriver

Bases: BaseMultiModelImageGenerationDriver

Driver for image generation models provided by Amazon Bedrock.

Attributes:

Name Type Description
model

Bedrock model ID.

session Session

boto3 session.

bedrock_client Any

Bedrock runtime client.

image_width int

Width of output images. Defaults to 512 and must be a multiple of 64.

image_height int

Height of output images. Defaults to 512 and must be a multiple of 64.

seed int | None

Optionally provide a consistent seed to generation requests, increasing consistency in output.

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@define
class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
    """Driver for image generation models provided by Amazon Bedrock.

    Attributes:
        model: Bedrock model ID.
        session: boto3 session.
        bedrock_client: Bedrock runtime client.
        image_width: Width of output images. Defaults to 512 and must be a multiple of 64.
        image_height: Height of output images. Defaults to 512 and must be a multiple of 64.
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client(service_name="bedrock-runtime"), takes_self=True)
    )
    image_width: int = field(default=512, kw_only=True)
    image_height: int = field(default=512, kw_only=True)
    seed: int | None = field(default=None, kw_only=True)

    def try_text_to_image(self, prompts: list[str], negative_prompts: list[str] | None = None) -> ImageArtifact:
        request = self.image_generation_model_driver.text_to_image_request_parameters(
            prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            mime_type="image/png",
            width=self.image_width,
            height=self.image_height,
            model=self.model,
        )

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_variation_request_parameters(
            prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            mime_type="image/png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_inpainting(
        self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_inpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            mime_type="image/png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def try_image_outpainting(
        self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_outpainting_request_parameters(
            prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            prompt=", ".join(prompts),
            value=image_bytes,
            mime_type="image/png",
            width=image.width,
            height=image.height,
            model=self.model,
        )

    def _make_request(self, request: dict) -> bytes:
        response = self.bedrock_client.invoke_model(
            body=json.dumps(request), modelId=self.model, accept="application/json", contentType="application/json"
        )

        response_body = json.loads(response.get("body").read())

        try:
            image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
        except Exception as e:
            raise ValueError(f"Inpainting generation failed: {e}")

        return image_bytes

bedrock_client: Any = field(default=Factory(lambda : self.session.client(service_name='bedrock-runtime'), takes_self=True)) class-attribute instance-attribute

image_height: int = field(default=512, kw_only=True) class-attribute instance-attribute

image_width: int = field(default=512, kw_only=True) class-attribute instance-attribute

seed: int | None = field(default=None, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_inpainting(
    self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_inpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        mime_type="image/png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_outpainting(
    self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_outpainting_request_parameters(
        prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        mime_type="image/png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_variation_request_parameters(
        prompts, image=image, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        mime_type="image/png",
        width=image.width,
        height=image.height,
        model=self.model,
    )

try_text_to_image(prompts, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: list[str] | None = None) -> ImageArtifact:
    request = self.image_generation_model_driver.text_to_image_request_parameters(
        prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        prompt=", ".join(prompts),
        value=image_bytes,
        mime_type="image/png",
        width=self.image_width,
        height=self.image_height,
        model=self.model,
    )

AmazonBedrockPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@define
class AmazonBedrockPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
        payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
        if isinstance(model_input, dict):
            payload.update(model_input)

        response = self.bedrock_client.invoke_model(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = response["body"].read()

        if response_body:
            return self.prompt_model_driver.process_output(response_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
        payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
        if isinstance(model_input, dict):
            payload.update(model_input)

        response = self.bedrock_client.invoke_model_with_response_stream(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = response["body"]
        if response_body:
            for chunk in response["body"]:
                chunk_bytes = chunk["chunk"]["bytes"]
                yield self.prompt_model_driver.process_output(chunk_bytes)
        else:
            raise Exception("model response is empty")

bedrock_client: Any = field(default=Factory(lambda : self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
    payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
    if isinstance(model_input, dict):
        payload.update(model_input)

    response = self.bedrock_client.invoke_model(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = response["body"].read()

    if response_body:
        return self.prompt_model_driver.process_output(response_body)
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
    payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
    if isinstance(model_input, dict):
        payload.update(model_input)

    response = self.bedrock_client.invoke_model_with_response_stream(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = response["body"]
    if response_body:
        for chunk in response["body"]:
            chunk_bytes = chunk["chunk"]["bytes"]
            yield self.prompt_model_driver.process_output(chunk_bytes)
    else:
        raise Exception("model response is empty")

AmazonBedrockTitanEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

tokenizer BedrockTitanTokenizer

Optionally provide custom BedrockTitanTokenizer.

session Session

Optionally provide custom boto3.Session.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@define
class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
        session: Optionally provide custom `boto3.Session`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "amazon.titan-embed-text-v1"

    model: str = field(default=DEFAULT_MODEL, kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BedrockTitanTokenizer = field(
        default=Factory(lambda self: BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"inputText": chunk}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embedding")

DEFAULT_MODEL = 'amazon.titan-embed-text-v1' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda : self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BedrockTitanTokenizer = field(default=Factory(lambda : BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"inputText": chunk}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embedding")

AmazonDynamoDbConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source code in griptape/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@define
class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    table_name: str = field(kw_only=True)
    partition_key: str = field(kw_only=True)
    value_attribute_key: str = field(kw_only=True)
    partition_key_value: str = field(kw_only=True)

    table: Any = field(init=False)

    def __attrs_post_init__(self) -> None:
        dynamodb = self.session.resource("dynamodb")

        self.table = dynamodb.Table(self.table_name)

    def store(self, memory: ConversationMemory) -> None:
        self.table.update_item(
            Key={self.partition_key: self.partition_key_value},
            UpdateExpression="set #attr = :value",
            ExpressionAttributeNames={"#attr": self.value_attribute_key},
            ExpressionAttributeValues={":value": memory.to_json()},
        )

    def load(self) -> ConversationMemory | None:
        response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

        if "Item" in response and self.value_attribute_key in response["Item"]:
            memory_value = response["Item"][self.value_attribute_key]

            memory = ConversationMemory.from_json(memory_value)

            memory.driver = self

            return memory
        else:
            return None

partition_key: str = field(kw_only=True) class-attribute instance-attribute

partition_key_value: str = field(kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

table: Any = field(init=False) class-attribute instance-attribute

table_name: str = field(kw_only=True) class-attribute instance-attribute

value_attribute_key: str = field(kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def __attrs_post_init__(self) -> None:
    dynamodb = self.session.resource("dynamodb")

    self.table = dynamodb.Table(self.table_name)

load()

Source code in griptape/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def load(self) -> ConversationMemory | None:
    response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

    if "Item" in response and self.value_attribute_key in response["Item"]:
        memory_value = response["Item"][self.value_attribute_key]

        memory = ConversationMemory.from_json(memory_value)

        memory.driver = self

        return memory
    else:
        return None

store(memory)

Source code in griptape/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def store(self, memory: ConversationMemory) -> None:
    self.table.update_item(
        Key={self.partition_key: self.partition_key_value},
        UpdateExpression="set #attr = :value",
        ExpressionAttributeNames={"#attr": self.value_attribute_key},
        ExpressionAttributeValues={":value": memory.to_json()},
    )

AmazonOpenSearchVectorStoreDriver

Bases: OpenSearchVectorStoreDriver

A Vector Store Driver for Amazon OpenSearch.

Attributes:

Name Type Description
session Session

The boto3 session to use.

http_auth str | tuple[str, str] | None

The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.

client OpenSearch | None

An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.

Source code in griptape/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
@define
class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver):
    """A Vector Store Driver for Amazon OpenSearch.

    Attributes:
        session: The boto3 session to use.
        http_auth: The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.
        client: An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.
    """

    session: Session = field(kw_only=True)

    http_auth: str | tuple[str, str] | None = field(
        default=Factory(
            lambda self: import_optional_dependency("requests_aws4auth").AWS4Auth(
                self.session.get_credentials().access_key,
                self.session.get_credentials().secret_key,
                self.session.region_name,
                "es",
            ),
            takes_self=True,
        )
    )

    client: OpenSearch | None = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").OpenSearch(
                hosts=[{"host": self.host, "port": self.port}],
                http_auth=self.http_auth,
                use_ssl=self.use_ssl,
                verify_certs=self.verify_certs,
                connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection,
            ),
            takes_self=True,
        )
    )

client: OpenSearch | None = field(default=Factory(lambda : import_optional_dependency('opensearchpy').OpenSearch(hosts=[{'host': self.host, 'port': self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=import_optional_dependency('opensearchpy').RequestsHttpConnection), takes_self=True)) class-attribute instance-attribute

http_auth: str | tuple[str, str] | None = field(default=Factory(lambda : import_optional_dependency('requests_aws4auth').AWS4Auth(self.session.get_credentials().access_key, self.session.get_credentials().secret_key, self.session.region_name, 'es'), takes_self=True)) class-attribute instance-attribute

session: Session = field(kw_only=True) class-attribute instance-attribute

AmazonRedshiftSqlDriver

Bases: BaseSqlDriver

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
@define
class AmazonRedshiftSqlDriver(BaseSqlDriver):
    database: str = field(kw_only=True)
    session: boto3.Session = field(kw_only=True)
    cluster_identifier: str | None = field(default=None, kw_only=True)
    workgroup_name: str | None = field(default=None, kw_only=True)
    db_user: str | None = field(default=None, kw_only=True)
    database_credentials_secret_arn: str | None = field(default=None, kw_only=True)
    wait_for_query_completion_sec: float = field(default=0.3, kw_only=True)
    client: Any = field(
        default=Factory(lambda self: self.session.client("redshift-data"), takes_self=True), kw_only=True
    )

    @workgroup_name.validator  # pyright: ignore
    def validate_params(self, _, workgroup_name: str | None) -> None:
        if not self.cluster_identifier and not self.workgroup_name:
            raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
        elif self.cluster_identifier and self.workgroup_name:
            raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

    @classmethod
    def _process_rows_from_records(cls, records) -> list[list]:
        return [[c[list(c.keys())[0]] for c in r] for r in records]

    @classmethod
    def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]:
        return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]

    @classmethod
    def _process_columns_from_column_metadata(cls, meta) -> list:
        return [k["name"] for k in meta]

    @classmethod
    def _post_process(cls, meta, records) -> list[dict[str, Any]]:
        columns = cls._process_columns_from_column_metadata(meta)
        rows = cls._process_rows_from_records(records)
        return cls._process_cells_from_rows_and_columns(columns, rows)

    def execute_query(self, query: str) -> list[BaseSqlDriver.RowResult] | None:
        rows = self.execute_query_raw(query)
        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> list[dict[str, Any]] | None:
        function_kwargs = {"Sql": query, "Database": self.database}
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn

        response = self.client.execute_statement(**function_kwargs)
        response_id = response["Id"]

        statement = self.client.describe_statement(Id=response_id)

        while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
            time.sleep(self.wait_for_query_completion_sec)
            statement = self.client.describe_statement(Id=response_id)

        if statement["Status"] == "FINISHED":
            statement_result = self.client.get_statement_result(Id=response_id)
            results = statement_result.get("Records", [])

            while "NextToken" in statement_result:
                statement_result = self.client.get_statement_result(
                    Id=response_id, NextToken=statement_result["NextToken"]
                )
                results = results + response.get("Records", [])

            return self._post_process(statement_result["ColumnMetadata"], results)

        elif statement["Status"] in ["FAILED", "ABORTED"]:
            return None

    def get_table_schema(self, table: str, schema: str | None = None) -> str | None:
        function_kwargs = {"Database": self.database, "Table": table}
        if schema:
            function_kwargs["Schema"] = schema
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn
        response = self.client.describe_table(**function_kwargs)
        return [col["name"] for col in response["ColumnList"]]

client: Any = field(default=Factory(lambda : self.session.client('redshift-data'), takes_self=True), kw_only=True) class-attribute instance-attribute

cluster_identifier: str | None = field(default=None, kw_only=True) class-attribute instance-attribute

database: str = field(kw_only=True) class-attribute instance-attribute

database_credentials_secret_arn: str | None = field(default=None, kw_only=True) class-attribute instance-attribute

db_user: str | None = field(default=None, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(kw_only=True) class-attribute instance-attribute

wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) class-attribute instance-attribute

workgroup_name: str | None = field(default=None, kw_only=True) class-attribute instance-attribute

execute_query(query)

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> list[BaseSqlDriver.RowResult] | None:
    rows = self.execute_query_raw(query)
    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> list[dict[str, Any]] | None:
    function_kwargs = {"Sql": query, "Database": self.database}
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn

    response = self.client.execute_statement(**function_kwargs)
    response_id = response["Id"]

    statement = self.client.describe_statement(Id=response_id)

    while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
        time.sleep(self.wait_for_query_completion_sec)
        statement = self.client.describe_statement(Id=response_id)

    if statement["Status"] == "FINISHED":
        statement_result = self.client.get_statement_result(Id=response_id)
        results = statement_result.get("Records", [])

        while "NextToken" in statement_result:
            statement_result = self.client.get_statement_result(
                Id=response_id, NextToken=statement_result["NextToken"]
            )
            results = results + response.get("Records", [])

        return self._post_process(statement_result["ColumnMetadata"], results)

    elif statement["Status"] in ["FAILED", "ABORTED"]:
        return None

get_table_schema(table, schema=None)

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table: str, schema: str | None = None) -> str | None:
    function_kwargs = {"Database": self.database, "Table": table}
    if schema:
        function_kwargs["Schema"] = schema
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn
    response = self.client.describe_table(**function_kwargs)
    return [col["name"] for col in response["ColumnList"]]

validate_params(_, workgroup_name)

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator  # pyright: ignore
def validate_params(self, _, workgroup_name: str | None) -> None:
    if not self.cluster_identifier and not self.workgroup_name:
        raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
    elif self.cluster_identifier and self.workgroup_name:
        raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

AmazonSageMakerEmbeddingDriver

Bases: BaseMultiModelEmbeddingDriver

Source code in griptape/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py
@define
class AmazonSageMakerEmbeddingDriver(BaseMultiModelEmbeddingDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True)

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = self.embedding_model_driver.chunk_to_model_params(chunk)
        endpoint_response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.model, ContentType="application/x-text", Body=json.dumps(payload).encode("utf-8")
        )

        response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))
        return self.embedding_model_driver.process_output(response)

embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda : self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = self.embedding_model_driver.chunk_to_model_params(chunk)
    endpoint_response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.model, ContentType="application/x-text", Body=json.dumps(payload).encode("utf-8")
    )

    response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))
    return self.embedding_model_driver.process_output(response)

AmazonSageMakerPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@define
class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    custom_attributes: str = field(default="accept_eula=true", kw_only=True)
    stream: bool = field(default=False, kw_only=True)

    @stream.validator  # pyright: ignore
    def validate_stream(self, _, stream):
        if stream:
            raise ValueError("streaming is not supported")

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        payload = {
            "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
            "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
        }
        response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.model,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))

        if decoded_body:
            return self.prompt_model_driver.process_output(decoded_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

custom_attributes: str = field(default='accept_eula=true', kw_only=True) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda : self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    payload = {
        "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
        "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
    }
    response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.model,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes=self.custom_attributes,
    )

    decoded_body = json.loads(response["Body"].read().decode("utf8"))

    if decoded_body:
        return self.prompt_model_driver.process_output(decoded_body)
    else:
        raise Exception("model response is empty")

try_stream(_)

Source code in griptape/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@stream.validator  # pyright: ignore
def validate_stream(self, _, stream):
    if stream:
        raise ValueError("streaming is not supported")

AnthropicPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key str

Anthropic API key.

model str

Anthropic model name.

client Anthropic

Custom Anthropic client.

tokenizer AnthropicTokenizer

Custom AnthropicTokenizer.

Source code in griptape/griptape/drivers/prompt/anthropic_prompt_driver.py
@define
class AnthropicPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
        tokenizer: Custom `AnthropicTokenizer`.
    """

    api_key: str = field(kw_only=True)
    model: str = field(kw_only=True)
    client: Anthropic = field(
        default=Factory(
            lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )
    tokenizer: AnthropicTokenizer = field(
        default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        response = self.client.completions.create(**self._base_params(prompt_stack))

        return TextArtifact(value=response.completion)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        response = self.client.completions.create(**self._base_params(prompt_stack), stream=True)

        for chunk in response:
            yield TextArtifact(value=chunk.completion)

    def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_assistant():
                prompt_lines.append(f"\n\nAssistant: {i.content}")
            elif i.is_user():
                prompt_lines.append(f"\n\nHuman: {i.content}")
            elif i.is_system():
                if self.model == "claude-2.1":
                    prompt_lines.append(f"{i.content}")
                else:
                    prompt_lines.append(f"\n\nHuman: {i.content}")
                    prompt_lines.append("\n\nAssistant:")
            else:
                prompt_lines.append(f"\n\nHuman: {i.content}")

        prompt_lines.append("\n\nAssistant:")

        return "".join(prompt_lines)

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_string(prompt_stack)

        return {
            "prompt": self.prompt_stack_to_string(prompt_stack),
            "model": self.model,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "max_tokens_to_sample": self.max_output_tokens(prompt),
        }

api_key: str = field(kw_only=True) class-attribute instance-attribute

client: Anthropic = field(default=Factory(lambda : import_optional_dependency('anthropic').Anthropic(api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

tokenizer: AnthropicTokenizer = field(default=Factory(lambda : AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

default_prompt_stack_to_string_converter(prompt_stack)

Source code in griptape/griptape/drivers/prompt/anthropic_prompt_driver.py
def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_assistant():
            prompt_lines.append(f"\n\nAssistant: {i.content}")
        elif i.is_user():
            prompt_lines.append(f"\n\nHuman: {i.content}")
        elif i.is_system():
            if self.model == "claude-2.1":
                prompt_lines.append(f"{i.content}")
            else:
                prompt_lines.append(f"\n\nHuman: {i.content}")
                prompt_lines.append("\n\nAssistant:")
        else:
            prompt_lines.append(f"\n\nHuman: {i.content}")

    prompt_lines.append("\n\nAssistant:")

    return "".join(prompt_lines)

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/anthropic_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    response = self.client.completions.create(**self._base_params(prompt_stack))

    return TextArtifact(value=response.completion)

try_stream(prompt_stack)

Source code in griptape/griptape/drivers/prompt/anthropic_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    response = self.client.completions.create(**self._base_params(prompt_stack), stream=True)

    for chunk in response:
        yield TextArtifact(value=chunk.completion)

AzureOpenAiChatPromptDriver

Bases: OpenAiChatPromptDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[str]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@define
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True)
    azure_endpoint: str = field(kw_only=True)
    azure_ad_token: Optional[str] = field(kw_only=True, default=None)
    azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
    api_version: str = field(default="2023-05-15", kw_only=True)
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = super()._base_params(prompt_stack)
        # TODO: Add `seed` parameter once Azure supports it.
        del params["seed"]

        return params

api_version: str = field(default='2023-05-15', kw_only=True) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda : openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

AzureOpenAiCompletionPromptDriver

Bases: OpenAiCompletionPromptDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[str]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py
@define
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True)
    azure_endpoint: str = field(kw_only=True)
    azure_ad_token: Optional[str] = field(kw_only=True, default=None)
    azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
    api_version: str = field(default="2023-05-15", kw_only=True)
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda : openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment), takes_self=True)) class-attribute instance-attribute

AzureOpenAiEmbeddingDriver

Bases: OpenAiEmbeddingDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token str | None

An optional Azure Active Directory token.

azure_ad_token_provider str | None

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

tokenizer OpenAiTokenizer

An OpenAiTokenizer.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/griptape/drivers/embedding/azure_openai_embedding_driver.py
@define
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        tokenizer: An `OpenAiTokenizer`.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True)
    azure_endpoint: str = field(kw_only=True)
    azure_ad_token: str | None = field(kw_only=True, default=None)
    azure_ad_token_provider: str | None = field(kw_only=True, default=None)
    api_version: str = field(default="2023-05-15", kw_only=True)
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True) class-attribute instance-attribute

azure_ad_token: str | None = field(kw_only=True, default=None) class-attribute instance-attribute

azure_ad_token_provider: str | None = field(kw_only=True, default=None) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda : openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda : OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

AzureOpenAiImageGenerationDriver

Bases: OpenAiImageGenerationDriver

Driver for Azure-hosted OpenAI image generation API.

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token str | None

An optional Azure Active Directory token.

azure_ad_token_provider str | None

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/griptape/drivers/image_generation/azure_openai_image_generation_driver.py
@define
class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
    """Driver for Azure-hosted OpenAI image generation API.

    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True)
    azure_endpoint: str = field(kw_only=True)
    azure_ad_token: str | None = field(kw_only=True, default=None)
    azure_ad_token_provider: str | None = field(kw_only=True, default=None)
    api_version: str = field(default="2023-12-01-preview", kw_only=True)
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-12-01-preview', kw_only=True) class-attribute instance-attribute

azure_ad_token: str | None = field(kw_only=True, default=None) class-attribute instance-attribute

azure_ad_token_provider: str | None = field(kw_only=True, default=None) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda : openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

BaseConversationMemoryDriver

Bases: ABC

Source code in griptape/griptape/drivers/memory/conversation/base_conversation_memory_driver.py
class BaseConversationMemoryDriver(ABC):
    @abstractmethod
    def store(self, *args, **kwargs) -> None:
        ...

    @abstractmethod
    def load(self, *args, **kwargs) -> ConversationMemory | None:
        ...

load(*args, **kwargs) abstractmethod

Source code in griptape/griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def load(self, *args, **kwargs) -> ConversationMemory | None:
    ...

store(*args, **kwargs) abstractmethod

Source code in griptape/griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def store(self, *args, **kwargs) -> None:
    ...

BaseEmbeddingDriver

Bases: ExponentialBackoffMixin, ABC

Attributes:

Name Type Description
model str

The name of the model to use.

tokenizer BaseTokenizer | None

An instance of BaseTokenizer to use when calculating tokens.

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
@define
class BaseEmbeddingDriver(ExponentialBackoffMixin, ABC):
    """
    Attributes:
        model: The name of the model to use.
        tokenizer: An instance of `BaseTokenizer` to use when calculating tokens.
    """

    model: str = field(kw_only=True)
    tokenizer: BaseTokenizer | None = field(default=None, kw_only=True)
    chunker: BaseChunker = field(init=False)

    def __attrs_post_init__(self) -> None:
        if self.tokenizer:
            self.chunker = TextChunker(tokenizer=self.tokenizer)

    def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
        return self.embed_string(artifact.to_text())

    def embed_string(self, string: str) -> list[float]:
        for attempt in self.retrying():
            with attempt:
                if self.tokenizer and self.tokenizer.count_tokens(string) > self.tokenizer.max_tokens:
                    return self._embed_long_string(string)
                else:
                    return self.try_embed_chunk(string)

        else:
            raise RuntimeError("Failed to embed string.")

    @abstractmethod
    def try_embed_chunk(self, chunk: str) -> list[float]:
        ...

    def _embed_long_string(self, string: str) -> list[float]:
        """Embeds a string that is too long to embed in one go.

        Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
        """
        chunks = self.chunker.chunk(string)

        embedding_chunks = []
        length_chunks = []
        for chunk in chunks:
            embedding_chunks.append(self.try_embed_chunk(chunk.value))
            length_chunks.append(len(chunk))

        # generate weighted averages
        embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

        # normalize length to 1
        embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

        return embedding_chunks.tolist()

chunker: BaseChunker = field(init=False) class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer | None = field(default=None, kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
def __attrs_post_init__(self) -> None:
    if self.tokenizer:
        self.chunker = TextChunker(tokenizer=self.tokenizer)

embed_string(string)

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
def embed_string(self, string: str) -> list[float]:
    for attempt in self.retrying():
        with attempt:
            if self.tokenizer and self.tokenizer.count_tokens(string) > self.tokenizer.max_tokens:
                return self._embed_long_string(string)
            else:
                return self.try_embed_chunk(string)

    else:
        raise RuntimeError("Failed to embed string.")

embed_text_artifact(artifact)

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
    return self.embed_string(artifact.to_text())

try_embed_chunk(chunk) abstractmethod

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
@abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]:
    ...

BaseEmbeddingModelDriver

Bases: ABC

Source code in griptape/griptape/drivers/embedding_model/base_embedding_model_driver.py
@define
class BaseEmbeddingModelDriver(ABC):
    @abstractmethod
    def chunk_to_model_params(self, chunk: str) -> dict:
        ...

    @abstractmethod
    def process_output(self, output: dict) -> list[float]:
        ...

chunk_to_model_params(chunk) abstractmethod

Source code in griptape/griptape/drivers/embedding_model/base_embedding_model_driver.py
7
8
9
@abstractmethod
def chunk_to_model_params(self, chunk: str) -> dict:
    ...

process_output(output) abstractmethod

Source code in griptape/griptape/drivers/embedding_model/base_embedding_model_driver.py
@abstractmethod
def process_output(self, output: dict) -> list[float]:
    ...

BaseImageGenerationDriver

Bases: ExponentialBackoffMixin, ABC

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
@define
class BaseImageGenerationDriver(ExponentialBackoffMixin, ABC):
    model: str = field(kw_only=True)
    structure: Structure | None = field(default=None, kw_only=True)

    def before_run(self, prompts: list[str], negative_prompts: list[str] | None = None) -> None:
        if self.structure:
            self.structure.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))

    def after_run(self) -> None:
        if self.structure:
            self.structure.publish_event(FinishImageGenerationEvent())

    def run_text_to_image(self, prompts: list[str], negative_prompts: list[str] | None = None) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_text_to_image(prompts, negative_prompts)
                self.after_run()

                return result

        else:
            raise Exception("Failed to run text to image generation")

    def run_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_variation(prompts, image, negative_prompts)
                self.after_run()

                return result

        else:
            raise Exception("Failed to generate image variations")

    def run_image_inpainting(
        self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_inpainting(prompts, image, mask, negative_prompts)
                self.after_run()

                return result

        else:
            raise Exception("Failed to run image inpainting")

    def run_image_outpainting(
        self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_outpainting(prompts, image, mask, negative_prompts)
                self.after_run()

                return result

        else:
            raise Exception("Failed to run image outpainting")

    @abstractmethod
    def try_text_to_image(self, prompts: list[str], negative_prompts: list[str] | None = None) -> ImageArtifact:
        ...

    @abstractmethod
    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        ...

    @abstractmethod
    def try_image_inpainting(
        self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        ...

    @abstractmethod
    def try_image_outpainting(
        self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
    ) -> ImageArtifact:
        ...

model: str = field(kw_only=True) class-attribute instance-attribute

structure: Structure | None = field(default=None, kw_only=True) class-attribute instance-attribute

after_run()

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
def after_run(self) -> None:
    if self.structure:
        self.structure.publish_event(FinishImageGenerationEvent())

before_run(prompts, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
def before_run(self, prompts: list[str], negative_prompts: list[str] | None = None) -> None:
    if self.structure:
        self.structure.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))

run_image_inpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_inpainting(
    self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_inpainting(prompts, image, mask, negative_prompts)
            self.after_run()

            return result

    else:
        raise Exception("Failed to run image inpainting")

run_image_outpainting(prompts, image, mask, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_outpainting(
    self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_outpainting(prompts, image, mask, negative_prompts)
            self.after_run()

            return result

    else:
        raise Exception("Failed to run image outpainting")

run_image_variation(prompts, image, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_variation(prompts, image, negative_prompts)
            self.after_run()

            return result

    else:
        raise Exception("Failed to generate image variations")

run_text_to_image(prompts, negative_prompts=None)

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
def run_text_to_image(self, prompts: list[str], negative_prompts: list[str] | None = None) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_text_to_image(prompts, negative_prompts)
            self.after_run()

            return result

    else:
        raise Exception("Failed to run text to image generation")

try_image_inpainting(prompts, image, mask, negative_prompts=None) abstractmethod

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_inpainting(
    self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    ...

try_image_outpainting(prompts, image, mask, negative_prompts=None) abstractmethod

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_outpainting(
    self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    ...

try_image_variation(prompts, image, negative_prompts=None) abstractmethod

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: list[str] | None = None
) -> ImageArtifact:
    ...

try_text_to_image(prompts, negative_prompts=None) abstractmethod

Source code in griptape/griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_text_to_image(self, prompts: list[str], negative_prompts: list[str] | None = None) -> ImageArtifact:
    ...

BaseImageGenerationModelDriver

Bases: ABC

Source code in griptape/griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@define
class BaseImageGenerationModelDriver(ABC):
    @abstractmethod
    def get_generated_image(self, response: dict) -> bytes:
        ...

    @abstractmethod
    def text_to_image_request_parameters(
        self,
        prompts: list[str],
        image_width: int,
        image_height: int,
        negative_prompts: list[str] | None = None,
        seed: int | None = None,
    ) -> dict[str, Any]:
        ...

    @abstractmethod
    def image_variation_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: list[str] | None = None,
        seed: int | None = None,
    ) -> dict[str, Any]:
        ...

    @abstractmethod
    def image_inpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: list[str] | None = None,
        seed: int | None = None,
    ) -> dict[str, Any]:
        ...

    @abstractmethod
    def image_outpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: list[str] | None = None,
        seed: int | None = None,
    ) -> dict[str, Any]:
        ...

get_generated_image(response) abstractmethod

Source code in griptape/griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def get_generated_image(self, response: dict) -> bytes:
    ...

image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None) abstractmethod

Source code in griptape/griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_inpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: list[str] | None = None,
    seed: int | None = None,
) -> dict[str, Any]:
    ...

image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None) abstractmethod

Source code in griptape/griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_outpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: list[str] | None = None,
    seed: int | None = None,
) -> dict[str, Any]:
    ...

image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None) abstractmethod

Source code in griptape/griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_variation_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: list[str] | None = None,
    seed: int | None = None,
) -> dict[str, Any]:
    ...

text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None) abstractmethod

Source code in griptape/griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def text_to_image_request_parameters(
    self,
    prompts: list[str],
    image_width: int,
    image_height: int,
    negative_prompts: list[str] | None = None,
    seed: int | None = None,
) -> dict[str, Any]:
    ...

BaseMultiModelEmbeddingDriver

Bases: BaseEmbeddingDriver, ABC

Source code in griptape/griptape/drivers/embedding/base_multi_model_embedding_driver.py
@define
class BaseMultiModelEmbeddingDriver(BaseEmbeddingDriver, ABC):
    embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True)

embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True) class-attribute instance-attribute

BaseMultiModelImageGenerationDriver

Bases: BaseImageGenerationDriver, ABC

Image Generation Driver for platforms like Amazon Bedrock that host many LLM models.

Instances of this Image Generation Driver require a Image Generation Model Driver which is used to structure the image generation request in the format required by the model and to process the output.

Attributes:

Name Type Description
image_generation_model_driver BaseImageGenerationModelDriver

Image Model Driver to use.

Source code in griptape/griptape/drivers/image_generation/base_multi_model_image_generation_driver.py
@define
class BaseMultiModelImageGenerationDriver(BaseImageGenerationDriver, ABC):
    """Image Generation Driver for platforms like Amazon Bedrock that host many LLM models.

    Instances of this Image Generation Driver require a Image Generation Model Driver which is used to structure the
    image generation request in the format required by the model and to process the output.

    Attributes:
        image_generation_model_driver: Image Model Driver to use.
    """

    image_generation_model_driver: BaseImageGenerationModelDriver = field(kw_only=True)

image_generation_model_driver: BaseImageGenerationModelDriver = field(kw_only=True) class-attribute instance-attribute

BaseMultiModelPromptDriver

Bases: BasePromptDriver, ABC

Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models.

Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack into a model input and parameters, and to process the model output.

Attributes:

Name Type Description
model

Name of the model to use.

tokenizer BaseTokenizer | None

Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver.

prompt_model_driver BasePromptModelDriver

Prompt Model Driver to use.

Source code in griptape/griptape/drivers/prompt/base_multi_model_prompt_driver.py
@define
class BaseMultiModelPromptDriver(BasePromptDriver, ABC):
    """Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models.

    Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack
    into a model input and parameters, and to process the model output.

    Attributes:
        model: Name of the model to use.
        tokenizer: Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver.
        prompt_model_driver: Prompt Model Driver to use.
    """

    tokenizer: BaseTokenizer | None = field(default=None, kw_only=True)
    prompt_model_driver: BasePromptModelDriver = field(kw_only=True)
    stream: bool = field(default=False, kw_only=True)

    @stream.validator  # pyright: ignore
    def validate_stream(self, _, stream):
        if stream and not self.prompt_model_driver.supports_streaming:
            raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming")

    def __attrs_post_init__(self) -> None:
        self.prompt_model_driver.prompt_driver = self

        if not self.tokenizer:
            self.tokenizer = self.prompt_model_driver.tokenizer

prompt_model_driver: BasePromptModelDriver = field(kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer | None = field(default=None, kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/griptape/drivers/prompt/base_multi_model_prompt_driver.py
def __attrs_post_init__(self) -> None:
    self.prompt_model_driver.prompt_driver = self

    if not self.tokenizer:
        self.tokenizer = self.prompt_model_driver.tokenizer

validate_stream(_, stream)

Source code in griptape/griptape/drivers/prompt/base_multi_model_prompt_driver.py
@stream.validator  # pyright: ignore
def validate_stream(self, _, stream):
    if stream and not self.prompt_model_driver.supports_streaming:
        raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming")

BasePromptDriver

Bases: ExponentialBackoffMixin, ABC

Base class for Prompt Drivers.

Attributes:

Name Type Description
temperature float

The temperature to use for the completion.

max_tokens int | None

The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.

structure Structure | None

An optional Structure to publish events to.

prompt_stack_to_string Callable[[PromptStack], str]

A function that converts a PromptStack to a string.

ignored_exception_types tuple[type[Exception], ...]

A tuple of exception types to ignore.

model str

The model name.

tokenizer BaseTokenizer

An instance of BaseTokenizer to when calculating tokens.

stream bool

Whether to stream the completion or not. CompletionChunkEvents will be published to the Structure if one is provided.

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
@define
class BasePromptDriver(ExponentialBackoffMixin, ABC):
    """Base class for Prompt Drivers.

    Attributes:
        temperature: The temperature to use for the completion.
        max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
        structure: An optional `Structure` to publish events to.
        prompt_stack_to_string: A function that converts a `PromptStack` to a string.
        ignored_exception_types: A tuple of exception types to ignore.
        model: The model name.
        tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
        stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
    """

    temperature: float = field(default=0.1, kw_only=True)
    max_tokens: int | None = field(default=None, kw_only=True)
    structure: Structure | None = field(default=None, kw_only=True)
    prompt_stack_to_string: Callable[[PromptStack], str] = field(
        default=Factory(lambda self: self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True
    )
    ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (ImportError)), kw_only=True)
    model: str
    tokenizer: BaseTokenizer
    stream: bool = field(default=False, kw_only=True)

    def max_output_tokens(self, text: str | list) -> int:
        tokens_left = self.tokenizer.count_tokens_left(text)

        if self.max_tokens:
            return min(self.max_tokens, tokens_left)
        else:
            return tokens_left

    def token_count(self, prompt_stack: PromptStack) -> int:
        return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

    def before_run(self, prompt_stack: PromptStack) -> None:
        if self.structure:
            self.structure.publish_event(
                StartPromptEvent(
                    token_count=self.token_count(prompt_stack),
                    prompt_stack=prompt_stack,
                    prompt=self.prompt_stack_to_string(prompt_stack),
                )
            )

    def after_run(self, result: TextArtifact) -> None:
        if self.structure:
            self.structure.publish_event(
                FinishPromptEvent(token_count=result.token_count(self.tokenizer), result=result.value)
            )

    def run(self, prompt_stack: PromptStack) -> TextArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompt_stack)

                if self.stream:
                    tokens = []
                    completion_chunks = self.try_stream(prompt_stack)
                    for chunk in completion_chunks:
                        self.structure.publish_event(CompletionChunkEvent(token=chunk.value))
                        tokens.append(chunk.value)
                    result = TextArtifact(value="".join(tokens).strip())
                else:
                    result = self.try_run(prompt_stack)
                    result.value = result.value.strip()

                self.after_run(result)

                return result
        else:
            raise Exception("prompt driver failed after all retry attempts")

    def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_user():
                prompt_lines.append(f"User: {i.content}")
            elif i.is_assistant():
                prompt_lines.append(f"Assistant: {i.content}")
            else:
                prompt_lines.append(i.content)

        prompt_lines.append("Assistant:")

        return "\n\n".join(prompt_lines)

    @abstractmethod
    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        ...

    @abstractmethod
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        ...

ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda : ImportError), kw_only=True) class-attribute instance-attribute

max_tokens: int | None = field(default=None, kw_only=True) class-attribute instance-attribute

model: str instance-attribute

prompt_stack_to_string: Callable[[PromptStack], str] = field(default=Factory(lambda : self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True) class-attribute instance-attribute

structure: Structure | None = field(default=None, kw_only=True) class-attribute instance-attribute

temperature: float = field(default=0.1, kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer instance-attribute

after_run(result)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def after_run(self, result: TextArtifact) -> None:
    if self.structure:
        self.structure.publish_event(
            FinishPromptEvent(token_count=result.token_count(self.tokenizer), result=result.value)
        )

before_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def before_run(self, prompt_stack: PromptStack) -> None:
    if self.structure:
        self.structure.publish_event(
            StartPromptEvent(
                token_count=self.token_count(prompt_stack),
                prompt_stack=prompt_stack,
                prompt=self.prompt_stack_to_string(prompt_stack),
            )
        )

default_prompt_stack_to_string_converter(prompt_stack)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_user():
            prompt_lines.append(f"User: {i.content}")
        elif i.is_assistant():
            prompt_lines.append(f"Assistant: {i.content}")
        else:
            prompt_lines.append(i.content)

    prompt_lines.append("Assistant:")

    return "\n\n".join(prompt_lines)

max_output_tokens(text)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def max_output_tokens(self, text: str | list) -> int:
    tokens_left = self.tokenizer.count_tokens_left(text)

    if self.max_tokens:
        return min(self.max_tokens, tokens_left)
    else:
        return tokens_left

run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def run(self, prompt_stack: PromptStack) -> TextArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompt_stack)

            if self.stream:
                tokens = []
                completion_chunks = self.try_stream(prompt_stack)
                for chunk in completion_chunks:
                    self.structure.publish_event(CompletionChunkEvent(token=chunk.value))
                    tokens.append(chunk.value)
                result = TextArtifact(value="".join(tokens).strip())
            else:
                result = self.try_run(prompt_stack)
                result.value = result.value.strip()

            self.after_run(result)

            return result
    else:
        raise Exception("prompt driver failed after all retry attempts")

token_count(prompt_stack)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def token_count(self, prompt_stack: PromptStack) -> int:
    return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

try_run(prompt_stack) abstractmethod

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    ...

try_stream(prompt_stack) abstractmethod

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    ...

BasePromptModelDriver

Bases: ABC

Source code in griptape/griptape/drivers/prompt_model/base_prompt_model_driver.py
@define
class BasePromptModelDriver(ABC):
    max_tokens: int = field(default=600, kw_only=True)
    prompt_driver: BasePromptDriver | None = field(default=None, kw_only=True)
    supports_streaming: bool = field(default=True, kw_only=True)

    @property
    @abstractmethod
    def tokenizer(self) -> BaseTokenizer:
        ...

    @abstractmethod
    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list | dict:
        ...

    @abstractmethod
    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        ...

    @abstractmethod
    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        ...

max_tokens: int = field(default=600, kw_only=True) class-attribute instance-attribute

prompt_driver: BasePromptDriver | None = field(default=None, kw_only=True) class-attribute instance-attribute

supports_streaming: bool = field(default=True, kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer abstractmethod property

process_output(output) abstractmethod

Source code in griptape/griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    ...

prompt_stack_to_model_input(prompt_stack) abstractmethod

Source code in griptape/griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list | dict:
    ...

prompt_stack_to_model_params(prompt_stack) abstractmethod

Source code in griptape/griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    ...

BaseSqlDriver

Bases: ABC

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@define
class BaseSqlDriver(ABC):
    @dataclass
    class RowResult:
        cells: dict[str, Any]

    @abstractmethod
    def execute_query(self, query: str) -> Optional[list[RowResult]]:
        ...

    @abstractmethod
    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
        ...

    @abstractmethod
    def get_table_schema(self, table: str, schema: Optional[str] = None) -> Optional[str]:
        ...

RowResult dataclass

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@dataclass
class RowResult:
    cells: dict[str, Any]
cells: dict[str, Any] instance-attribute

execute_query(query) abstractmethod

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def execute_query(self, query: str) -> Optional[list[RowResult]]:
    ...

execute_query_raw(query) abstractmethod

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
    ...

get_table_schema(table, schema=None) abstractmethod

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def get_table_schema(self, table: str, schema: Optional[str] = None) -> Optional[str]:
    ...

BaseVectorStoreDriver

Bases: ABC

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@define
class BaseVectorStoreDriver(ABC):
    DEFAULT_QUERY_COUNT = 5

    @dataclass
    class QueryResult:
        id: str
        vector: list[float]
        score: float
        meta: dict | None = None
        namespace: str | None = None

    @dataclass
    class Entry:
        id: str
        vector: list[float]
        meta: dict | None = None
        namespace: str | None = None

    embedding_driver: BaseEmbeddingDriver = field(kw_only=True)
    futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)

    def upsert_text_artifacts(
        self, artifacts: dict[str, list[TextArtifact]], meta: dict | None = None, **kwargs
    ) -> None:
        utils.execute_futures_dict(
            {
                namespace: self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
                for namespace, artifact_list in artifacts.items()
                for a in artifact_list
            }
        )

    def upsert_text_artifact(
        self, artifact: TextArtifact, namespace: str | None = None, meta: dict | None = None, **kwargs
    ) -> str:
        if not meta:
            meta = {}

        meta["artifact"] = artifact.to_json()

        if artifact.embedding:
            vector = artifact.embedding
        else:
            vector = artifact.generate_embedding(self.embedding_driver)

        return self.upsert_vector(vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs)

    def upsert_text(
        self,
        string: str,
        vector_id: str | None = None,
        namespace: str | None = None,
        meta: dict | None = None,
        **kwargs,
    ) -> str:
        return self.upsert_vector(
            self.embedding_driver.embed_string(string),
            vector_id=vector_id,
            namespace=namespace,
            meta=meta if meta else {},
            **kwargs,
        )

    @abstractmethod
    def upsert_vector(
        self,
        vector: list[float],
        vector_id: str | None = None,
        namespace: str | None = None,
        meta: dict | None = None,
        **kwargs,
    ) -> str:
        ...

    @abstractmethod
    def load_entry(self, vector_id: str, namespace: str | None = None) -> Entry | None:
        ...

    @abstractmethod
    def load_entries(self, namespace: str | None = None) -> list[Entry]:
        ...

    @abstractmethod
    def query(
        self,
        query: str,
        count: int | None = None,
        namespace: str | None = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[QueryResult]:
        ...

DEFAULT_QUERY_COUNT = 5 class-attribute instance-attribute

embedding_driver: BaseEmbeddingDriver = field(kw_only=True) class-attribute instance-attribute

futures_executor: futures.Executor = field(default=Factory(lambda : futures.ThreadPoolExecutor()), kw_only=True) class-attribute instance-attribute

Entry dataclass

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@dataclass
class Entry:
    id: str
    vector: list[float]
    meta: dict | None = None
    namespace: str | None = None
id: str instance-attribute
meta: dict | None = None class-attribute instance-attribute
namespace: str | None = None class-attribute instance-attribute
vector: list[float] instance-attribute

QueryResult dataclass

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@dataclass
class QueryResult:
    id: str
    vector: list[float]
    score: float
    meta: dict | None = None
    namespace: str | None = None
id: str instance-attribute
meta: dict | None = None class-attribute instance-attribute
namespace: str | None = None class-attribute instance-attribute
score: float instance-attribute
vector: list[float] instance-attribute

load_entries(namespace=None) abstractmethod

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def load_entries(self, namespace: str | None = None) -> list[Entry]:
    ...

load_entry(vector_id, namespace=None) abstractmethod

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def load_entry(self, vector_id: str, namespace: str | None = None) -> Entry | None:
    ...

query(query, count=None, namespace=None, include_vectors=False, **kwargs) abstractmethod

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def query(
    self,
    query: str,
    count: int | None = None,
    namespace: str | None = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[QueryResult]:
    ...

upsert_text(string, vector_id=None, namespace=None, meta=None, **kwargs)

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
def upsert_text(
    self,
    string: str,
    vector_id: str | None = None,
    namespace: str | None = None,
    meta: dict | None = None,
    **kwargs,
) -> str:
    return self.upsert_vector(
        self.embedding_driver.embed_string(string),
        vector_id=vector_id,
        namespace=namespace,
        meta=meta if meta else {},
        **kwargs,
    )

upsert_text_artifact(artifact, namespace=None, meta=None, **kwargs)

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifact(
    self, artifact: TextArtifact, namespace: str | None = None, meta: dict | None = None, **kwargs
) -> str:
    if not meta:
        meta = {}

    meta["artifact"] = artifact.to_json()

    if artifact.embedding:
        vector = artifact.embedding
    else:
        vector = artifact.generate_embedding(self.embedding_driver)

    return self.upsert_vector(vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs)

upsert_text_artifacts(artifacts, meta=None, **kwargs)

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py