Skip to content

Drivers

__all__ = ['BasePromptDriver', 'OpenAiChatPromptDriver', 'OpenAiCompletionPromptDriver', 'AzureOpenAiChatPromptDriver', 'AzureOpenAiCompletionPromptDriver', 'CoherePromptDriver', 'HuggingFacePipelinePromptDriver', 'HuggingFaceHubPromptDriver', 'AnthropicPromptDriver', 'AmazonSageMakerPromptDriver', 'AmazonBedrockPromptDriver', 'BaseMultiModelPromptDriver', 'BaseConversationMemoryDriver', 'LocalConversationMemoryDriver', 'AmazonDynamoDbConversationMemoryDriver', 'BaseEmbeddingDriver', 'OpenAiEmbeddingDriver', 'AzureOpenAiEmbeddingDriver', 'BedrockTitanEmbeddingDriver', 'BaseVectorStoreDriver', 'LocalVectorStoreDriver', 'PineconeVectorStoreDriver', 'MarqoVectorStoreDriver', 'MongoDbAtlasVectorStoreDriver', 'RedisVectorStoreDriver', 'OpenSearchVectorStoreDriver', 'AmazonOpenSearchVectorStoreDriver', 'PgVectorVectorStoreDriver', 'BaseSqlDriver', 'AmazonRedshiftSqlDriver', 'SnowflakeSqlDriver', 'SqlDriver', 'BasePromptModelDriver', 'SageMakerLlamaPromptModelDriver', 'SageMakerFalconPromptModelDriver', 'BedrockTitanPromptModelDriver', 'BedrockClaudePromptModelDriver', 'BedrockJurassicPromptModelDriver'] module-attribute

AmazonBedrockPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@define
class AmazonBedrockPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
        payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
        if isinstance(model_input, dict):
            payload.update(model_input)

        response = self.bedrock_client.invoke_model(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = response["body"].read()

        if response_body:
            return self.prompt_model_driver.process_output(response_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
        payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
        if isinstance(model_input, dict):
            payload.update(model_input)

        response = self.bedrock_client.invoke_model_with_response_stream(
            modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
        )

        response_body = response["body"]
        if response_body:
            for chunk in response["body"]:
                chunk_bytes = chunk["chunk"]["bytes"]
                yield self.prompt_model_driver.process_output(chunk_bytes)
        else:
            raise Exception("model response is empty")

bedrock_client: Any = field(default=Factory(lambda : self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
    payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
    if isinstance(model_input, dict):
        payload.update(model_input)

    response = self.bedrock_client.invoke_model(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = response["body"].read()

    if response_body:
        return self.prompt_model_driver.process_output(response_body)
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack)
    payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)}
    if isinstance(model_input, dict):
        payload.update(model_input)

    response = self.bedrock_client.invoke_model_with_response_stream(
        modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
    )

    response_body = response["body"]
    if response_body:
        for chunk in response["body"]:
            chunk_bytes = chunk["chunk"]["bytes"]
            yield self.prompt_model_driver.process_output(chunk_bytes)
    else:
        raise Exception("model response is empty")

AmazonDynamoDbConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source code in griptape/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@define
class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    table_name: str = field(kw_only=True)
    partition_key: str = field(kw_only=True)
    value_attribute_key: str = field(kw_only=True)
    partition_key_value: str = field(kw_only=True)

    table: Any = field(init=False)

    def __attrs_post_init__(self) -> None:
        dynamodb = self.session.resource("dynamodb")

        self.table = dynamodb.Table(self.table_name)

    def store(self, memory: ConversationMemory) -> None:
        self.table.update_item(
            Key={self.partition_key: self.partition_key_value},
            UpdateExpression="set #attr = :value",
            ExpressionAttributeNames={"#attr": self.value_attribute_key},
            ExpressionAttributeValues={":value": memory.to_json()},
        )

    def load(self) -> Optional[ConversationMemory]:
        response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

        if "Item" in response and self.value_attribute_key in response["Item"]:
            memory_value = response["Item"][self.value_attribute_key]

            memory = ConversationMemory.from_json(memory_value)

            memory.driver = self

            return memory
        else:
            return None

partition_key: str = field(kw_only=True) class-attribute instance-attribute

partition_key_value: str = field(kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

table: Any = field(init=False) class-attribute instance-attribute

table_name: str = field(kw_only=True) class-attribute instance-attribute

value_attribute_key: str = field(kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def __attrs_post_init__(self) -> None:
    dynamodb = self.session.resource("dynamodb")

    self.table = dynamodb.Table(self.table_name)

load()

Source code in griptape/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def load(self) -> Optional[ConversationMemory]:
    response = self.table.get_item(Key={self.partition_key: self.partition_key_value})

    if "Item" in response and self.value_attribute_key in response["Item"]:
        memory_value = response["Item"][self.value_attribute_key]

        memory = ConversationMemory.from_json(memory_value)

        memory.driver = self

        return memory
    else:
        return None

store(memory)

Source code in griptape/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def store(self, memory: ConversationMemory) -> None:
    self.table.update_item(
        Key={self.partition_key: self.partition_key_value},
        UpdateExpression="set #attr = :value",
        ExpressionAttributeNames={"#attr": self.value_attribute_key},
        ExpressionAttributeValues={":value": memory.to_json()},
    )

AmazonOpenSearchVectorStoreDriver

Bases: OpenSearchVectorStoreDriver

A Vector Store Driver for Amazon OpenSearch.

Attributes:

Name Type Description
session Session

The boto3 session to use.

http_auth Optional[str | Tuple[str, str]]

The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.

client Optional[OpenSearch]

An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.

Source code in griptape/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
@define
class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver):
    """A Vector Store Driver for Amazon OpenSearch.

    Attributes:
        session: The boto3 session to use.
        http_auth: The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.
        client: An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.
    """

    session: Session = field(kw_only=True)

    http_auth: Optional[str | Tuple[str, str]] = field(
        default=Factory(
            lambda self: import_optional_dependency("requests_aws4auth").AWS4Auth(
                self.session.get_credentials().access_key,
                self.session.get_credentials().secret_key,
                self.session.region_name,
                "es",
            ),
            takes_self=True,
        )
    )

    client: Optional[OpenSearch] = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").OpenSearch(
                hosts=[{"host": self.host, "port": self.port}],
                http_auth=self.http_auth,
                use_ssl=self.use_ssl,
                verify_certs=self.verify_certs,
                connection_class=import_optional_dependency("opensearchpy").RequestsHttpConnection,
            ),
            takes_self=True,
        )
    )

client: Optional[OpenSearch] = field(default=Factory(lambda : import_optional_dependency('opensearchpy').OpenSearch(hosts=[{'host': self.host, 'port': self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=import_optional_dependency('opensearchpy').RequestsHttpConnection), takes_self=True)) class-attribute instance-attribute

http_auth: Optional[str | Tuple[str, str]] = field(default=Factory(lambda : import_optional_dependency('requests_aws4auth').AWS4Auth(self.session.get_credentials().access_key, self.session.get_credentials().secret_key, self.session.region_name, 'es'), takes_self=True)) class-attribute instance-attribute

session: Session = field(kw_only=True) class-attribute instance-attribute

AmazonRedshiftSqlDriver

Bases: BaseSqlDriver

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
@define
class AmazonRedshiftSqlDriver(BaseSqlDriver):
    database: str = field(kw_only=True)
    session: boto3.Session = field(kw_only=True)
    cluster_identifier: Optional[str] = field(default=None, kw_only=True)
    workgroup_name: Optional[str] = field(default=None, kw_only=True)
    db_user: Optional[str] = field(default=None, kw_only=True)
    database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True)
    wait_for_query_completion_sec: float = field(default=0.3, kw_only=True)
    client: Any = field(
        default=Factory(lambda self: self.session.client("redshift-data"), takes_self=True), kw_only=True
    )

    @workgroup_name.validator
    def validate_params(self, _, workgroup_name: Optional[str]) -> None:
        if not self.cluster_identifier and not self.workgroup_name:
            raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
        elif self.cluster_identifier and self.workgroup_name:
            raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

    @classmethod
    def _process_rows_from_records(cls, records) -> list[list]:
        return [[c[list(c.keys())[0]] for c in r] for r in records]

    @classmethod
    def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]:
        return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]

    @classmethod
    def _process_columns_from_column_metadata(cls, meta) -> list:
        return [k["name"] for k in meta]

    @classmethod
    def _post_process(cls, meta, records) -> list[dict[str, Any]]:
        columns = cls._process_columns_from_column_metadata(meta)
        rows = cls._process_rows_from_records(records)
        return cls._process_cells_from_rows_and_columns(columns, rows)

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)
        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
        function_kwargs = {"Sql": query, "Database": self.database}
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn

        response = self.client.execute_statement(**function_kwargs)
        response_id = response["Id"]

        statement = self.client.describe_statement(Id=response_id)

        while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
            time.sleep(self.wait_for_query_completion_sec)
            statement = self.client.describe_statement(Id=response_id)

        if statement["Status"] == "FINISHED":
            statement_result = self.client.get_statement_result(Id=response_id)
            results = statement_result.get("Records", [])

            while "NextToken" in statement_result:
                statement_result = self.client.get_statement_result(
                    Id=response_id, NextToken=statement_result["NextToken"]
                )
                results = results + response.get("Records", [])

            return self._post_process(statement_result["ColumnMetadata"], results)

        elif statement["Status"] in ["FAILED", "ABORTED"]:
            return None

    def get_table_schema(self, table: str, schema: Optional[str] = None) -> Optional[str]:
        function_kwargs = {"Database": self.database, "Table": table}
        if schema:
            function_kwargs["Schema"] = schema
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn
        response = self.client.describe_table(**function_kwargs)
        return [col["name"] for col in response["ColumnList"]]

client: Any = field(default=Factory(lambda : self.session.client('redshift-data'), takes_self=True), kw_only=True) class-attribute instance-attribute

cluster_identifier: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

database: str = field(kw_only=True) class-attribute instance-attribute

database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

db_user: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(kw_only=True) class-attribute instance-attribute

wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) class-attribute instance-attribute

workgroup_name: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

execute_query(query)

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)
    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
    function_kwargs = {"Sql": query, "Database": self.database}
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn

    response = self.client.execute_statement(**function_kwargs)
    response_id = response["Id"]

    statement = self.client.describe_statement(Id=response_id)

    while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
        time.sleep(self.wait_for_query_completion_sec)
        statement = self.client.describe_statement(Id=response_id)

    if statement["Status"] == "FINISHED":
        statement_result = self.client.get_statement_result(Id=response_id)
        results = statement_result.get("Records", [])

        while "NextToken" in statement_result:
            statement_result = self.client.get_statement_result(
                Id=response_id, NextToken=statement_result["NextToken"]
            )
            results = results + response.get("Records", [])

        return self._post_process(statement_result["ColumnMetadata"], results)

    elif statement["Status"] in ["FAILED", "ABORTED"]:
        return None

get_table_schema(table, schema=None)

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table: str, schema: Optional[str] = None) -> Optional[str]:
    function_kwargs = {"Database": self.database, "Table": table}
    if schema:
        function_kwargs["Schema"] = schema
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn
    response = self.client.describe_table(**function_kwargs)
    return [col["name"] for col in response["ColumnList"]]

validate_params(_, workgroup_name)

Source code in griptape/griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator
def validate_params(self, _, workgroup_name: Optional[str]) -> None:
    if not self.cluster_identifier and not self.workgroup_name:
        raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
    elif self.cluster_identifier and self.workgroup_name:
        raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

AmazonSageMakerPromptDriver

Bases: BaseMultiModelPromptDriver

Source code in griptape/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@define
class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    sagemaker_client: Any = field(
        default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True
    )
    custom_attributes: str = field(default="accept_eula=true", kw_only=True)
    stream: bool = field(default=False, kw_only=True)

    @stream.validator
    def validate_stream(self, _, stream):
        if stream:
            raise ValueError("streaming is not supported")

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        payload = {
            "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
            "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
        }
        response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.model,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))

        if decoded_body:
            return self.prompt_model_driver.process_output(decoded_body)
        else:
            raise Exception("model response is empty")

    def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

custom_attributes: str = field(default='accept_eula=true', kw_only=True) class-attribute instance-attribute

sagemaker_client: Any = field(default=Factory(lambda : self.session.client('sagemaker-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    payload = {
        "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack),
        "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack),
    }
    response = self.sagemaker_client.invoke_endpoint(
        EndpointName=self.model,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes=self.custom_attributes,
    )

    decoded_body = json.loads(response["Body"].read().decode("utf8"))

    if decoded_body:
        return self.prompt_model_driver.process_output(decoded_body)
    else:
        raise Exception("model response is empty")

try_stream(_)

Source code in griptape/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py
@stream.validator
def validate_stream(self, _, stream):
    if stream:
        raise ValueError("streaming is not supported")

AnthropicPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key str

Anthropic API key.

model str

Anthropic model name.

tokenizer AnthropicTokenizer

Custom AnthropicTokenizer.

Source code in griptape/griptape/drivers/prompt/anthropic_prompt_driver.py
@define
class AnthropicPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        tokenizer: Custom `AnthropicTokenizer`.
    """

    api_key: str = field(kw_only=True)
    model: str = field(kw_only=True)
    tokenizer: AnthropicTokenizer = field(
        default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        anthropic = import_optional_dependency("anthropic")

        response = anthropic.Anthropic(api_key=self.api_key).completions.create(**self._base_params(prompt_stack))
        return TextArtifact(value=response.completion)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        anthropic = import_optional_dependency("anthropic")

        response = anthropic.Anthropic(api_key=self.api_key).completions.create(
            **self._base_params(prompt_stack), stream=True
        )

        for chunk in response:
            yield TextArtifact(value=chunk.completion)

    def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_assistant():
                prompt_lines.append(f"Assistant: {i.content}")
            else:
                prompt_lines.append(f"Human: {i.content}")

        prompt_lines.append("Assistant:")

        return "\n\n" + "\n\n".join(prompt_lines)

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_string(prompt_stack)

        return {
            "prompt": self.prompt_stack_to_string(prompt_stack),
            "model": self.model,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "max_tokens_to_sample": self.max_output_tokens(prompt),
        }

api_key: str = field(kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

tokenizer: AnthropicTokenizer = field(default=Factory(lambda : AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

default_prompt_stack_to_string_converter(prompt_stack)

Source code in griptape/griptape/drivers/prompt/anthropic_prompt_driver.py
def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_assistant():
            prompt_lines.append(f"Assistant: {i.content}")
        else:
            prompt_lines.append(f"Human: {i.content}")

    prompt_lines.append("Assistant:")

    return "\n\n" + "\n\n".join(prompt_lines)

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/anthropic_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    anthropic = import_optional_dependency("anthropic")

    response = anthropic.Anthropic(api_key=self.api_key).completions.create(**self._base_params(prompt_stack))
    return TextArtifact(value=response.completion)

try_stream(prompt_stack)

Source code in griptape/griptape/drivers/prompt/anthropic_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    anthropic = import_optional_dependency("anthropic")

    response = anthropic.Anthropic(api_key=self.api_key).completions.create(
        **self._base_params(prompt_stack), stream=True
    )

    for chunk in response:
        yield TextArtifact(value=chunk.completion)

AzureOpenAiChatPromptDriver

Bases: OpenAiChatPromptDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[str]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@define
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True)
    azure_endpoint: str = field(kw_only=True)
    azure_ad_token: Optional[str] = field(kw_only=True, default=None)
    azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
    api_version: str = field(default="2023-05-15", kw_only=True)
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = super()._base_params(prompt_stack)
        # TODO: Add `seed` parameter once Azure supports it.
        del params["seed"]

        return params

api_version: str = field(default='2023-05-15', kw_only=True) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda : openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

AzureOpenAiCompletionPromptDriver

Bases: OpenAiCompletionPromptDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[str]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py
@define
class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True)
    azure_endpoint: str = field(kw_only=True)
    azure_ad_token: Optional[str] = field(kw_only=True, default=None)
    azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
    api_version: str = field(default="2023-05-15", kw_only=True)
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda : openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment), takes_self=True)) class-attribute instance-attribute

AzureOpenAiEmbeddingDriver

Bases: OpenAiEmbeddingDriver

Attributes:

Name Type Description
azure_deployment str

An Azure OpenAi deployment id.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[str]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

tokenizer OpenAiTokenizer

An OpenAiTokenizer.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/griptape/drivers/embedding/azure_openai_embedding_driver.py
@define
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
    """
    Attributes:
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        tokenizer: An `OpenAiTokenizer`.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(kw_only=True)
    azure_endpoint: str = field(kw_only=True)
    azure_ad_token: Optional[str] = field(kw_only=True, default=None)
    azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None)
    api_version: str = field(default="2023-05-15", kw_only=True)
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    client: openai.AzureOpenAI = field(
        default=Factory(
            lambda self: openai.AzureOpenAI(
                organization=self.organization,
                api_key=self.api_key,
                api_version=self.api_version,
                azure_endpoint=self.azure_endpoint,
                azure_deployment=self.azure_deployment,
                azure_ad_token=self.azure_ad_token,
                azure_ad_token_provider=self.azure_ad_token_provider,
            ),
            takes_self=True,
        )
    )

api_version: str = field(default='2023-05-15', kw_only=True) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True) class-attribute instance-attribute

client: openai.AzureOpenAI = field(default=Factory(lambda : openai.AzureOpenAI(organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider), takes_self=True)) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda : OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

BaseConversationMemoryDriver

Bases: ABC

Source code in griptape/griptape/drivers/memory/conversation/base_conversation_memory_driver.py
class BaseConversationMemoryDriver(ABC):
    @abstractmethod
    def store(self, *args, **kwargs) -> None:
        ...

    @abstractmethod
    def load(self, *args, **kwargs) -> Optional[ConversationMemory]:
        ...

load(*args, **kwargs) abstractmethod

Source code in griptape/griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def load(self, *args, **kwargs) -> Optional[ConversationMemory]:
    ...

store(*args, **kwargs) abstractmethod

Source code in griptape/griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def store(self, *args, **kwargs) -> None:
    ...

BaseEmbeddingDriver

Bases: ExponentialBackoffMixin, ABC

Attributes:

Name Type Description
dimensions int

Vector dimensions.

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
@define
class BaseEmbeddingDriver(ExponentialBackoffMixin, ABC):
    """
    Attributes:
        dimensions: Vector dimensions.
    """

    dimensions: int = field(kw_only=True)
    tokenizer: BaseTokenizer = field(kw_only=True)
    chunker: BaseChunker = field(init=False)

    def __attrs_post_init__(self) -> None:
        self.chunker = TextChunker(tokenizer=self.tokenizer)

    def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
        return self.embed_string(artifact.to_text())

    def embed_string(self, string: str) -> list[float]:
        for attempt in self.retrying():
            with attempt:
                if self.tokenizer.count_tokens(string) > self.tokenizer.max_tokens:
                    return self._embed_long_string(string)
                else:
                    return self.try_embed_chunk(string)

        else:
            raise RuntimeError("Failed to embed string.")

    @abstractmethod
    def try_embed_chunk(self, chunk: str) -> list[float]:
        ...

    def _embed_long_string(self, string: str) -> list[float]:
        """Embeds a string that is too long to embed in one go.

        Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
        """
        chunks = self.chunker.chunk(string)

        embedding_chunks = []
        length_chunks = []
        for chunk in chunks:
            embedding_chunks.append(self.try_embed_chunk(chunk.value))
            length_chunks.append(len(chunk))

        # generate weighted averages
        embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

        # normalize length to 1
        embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

        return embedding_chunks.tolist()

chunker: BaseChunker = field(init=False) class-attribute instance-attribute

dimensions: int = field(kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
def __attrs_post_init__(self) -> None:
    self.chunker = TextChunker(tokenizer=self.tokenizer)

embed_string(string)

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
def embed_string(self, string: str) -> list[float]:
    for attempt in self.retrying():
        with attempt:
            if self.tokenizer.count_tokens(string) > self.tokenizer.max_tokens:
                return self._embed_long_string(string)
            else:
                return self.try_embed_chunk(string)

    else:
        raise RuntimeError("Failed to embed string.")

embed_text_artifact(artifact)

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
def embed_text_artifact(self, artifact: TextArtifact) -> list[float]:
    return self.embed_string(artifact.to_text())

try_embed_chunk(chunk) abstractmethod

Source code in griptape/griptape/drivers/embedding/base_embedding_driver.py
@abstractmethod
def try_embed_chunk(self, chunk: str) -> list[float]:
    ...

BaseMultiModelPromptDriver

Bases: BasePromptDriver, ABC

Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models.

Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack into a model input and parameters, and to process the model output.

Attributes:

Name Type Description
model

Name of the model to use.

tokenizer Optional[BaseTokenizer]

Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver.

prompt_model_driver BasePromptModelDriver

Prompt Model Driver to use.

Source code in griptape/griptape/drivers/prompt/base_multi_model_prompt_driver.py
@define
class BaseMultiModelPromptDriver(BasePromptDriver, ABC):
    """Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models.

    Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack
    into a model input and parameters, and to process the model output.

    Attributes:
        model: Name of the model to use.
        tokenizer: Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver.
        prompt_model_driver: Prompt Model Driver to use.
    """

    tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
    prompt_model_driver: BasePromptModelDriver = field(kw_only=True)
    stream: bool = field(default=False, kw_only=True)

    @stream.validator
    def validate_stream(self, _, stream):
        if stream and not self.prompt_model_driver.supports_streaming:
            raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming")

    def __attrs_post_init__(self) -> None:
        self.prompt_model_driver.prompt_driver = self

        if not self.tokenizer:
            self.tokenizer = self.prompt_model_driver.tokenizer

prompt_model_driver: BasePromptModelDriver = field(kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True) class-attribute instance-attribute

tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/griptape/drivers/prompt/base_multi_model_prompt_driver.py
def __attrs_post_init__(self) -> None:
    self.prompt_model_driver.prompt_driver = self

    if not self.tokenizer:
        self.tokenizer = self.prompt_model_driver.tokenizer

validate_stream(_, stream)

Source code in griptape/griptape/drivers/prompt/base_multi_model_prompt_driver.py
@stream.validator
def validate_stream(self, _, stream):
    if stream and not self.prompt_model_driver.supports_streaming:
        raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming")

BasePromptDriver

Bases: ExponentialBackoffMixin, ABC

Base class for Prompt Drivers.

Attributes:

Name Type Description
temperature float

The temperature to use for the completion.

max_tokens Optional[int]

The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.

structure Optional[Structure]

An optional Structure to publish events to.

prompt_stack_to_string Callable[[PromptStack], str]

A function that converts a PromptStack to a string.

ignored_exception_types Tuple[Type[Exception], ...]

A tuple of exception types to ignore.

model str

The model name.

tokenizer BaseTokenizer

An instance of BaseTokenizer to when calculating tokens.

stream bool

Whether to stream the completion or not. CompletionChunkEvents will be published to the Structure if one is provided.

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
@define
class BasePromptDriver(ExponentialBackoffMixin, ABC):
    """Base class for Prompt Drivers.

    Attributes:
        temperature: The temperature to use for the completion.
        max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
        structure: An optional `Structure` to publish events to.
        prompt_stack_to_string: A function that converts a `PromptStack` to a string.
        ignored_exception_types: A tuple of exception types to ignore.
        model: The model name.
        tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
        stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
    """

    temperature: float = field(default=0.1, kw_only=True)
    max_tokens: Optional[int] = field(default=None, kw_only=True)
    structure: Optional[Structure] = field(default=None, kw_only=True)
    prompt_stack_to_string: Callable[[PromptStack], str] = field(
        default=Factory(lambda self: self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True
    )
    ignored_exception_types: Tuple[Type[Exception], ...] = field(default=Factory(lambda: (ImportError)), kw_only=True)
    model: str
    tokenizer: BaseTokenizer
    stream: bool = field(default=False, kw_only=True)

    def max_output_tokens(self, text: str | list) -> int:
        tokens_left = self.tokenizer.count_tokens_left(text)

        if self.max_tokens:
            return min(self.max_tokens, tokens_left)
        else:
            return tokens_left

    def token_count(self, prompt_stack: PromptStack) -> int:
        return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

    def before_run(self, prompt_stack: PromptStack) -> None:
        if self.structure:
            self.structure.publish_event(StartPromptEvent(token_count=self.token_count(prompt_stack)))

    def after_run(self, result: TextArtifact) -> None:
        if self.structure:
            self.structure.publish_event(FinishPromptEvent(token_count=result.token_count(self.tokenizer)))

    def run(self, prompt_stack: PromptStack) -> TextArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompt_stack)

                if self.stream:
                    tokens = []
                    completion_chunks = self.try_stream(prompt_stack)
                    for chunk in completion_chunks:
                        self.structure.publish_event(CompletionChunkEvent(token=chunk.value))
                        tokens.append(chunk.value)
                    result = TextArtifact(value="".join(tokens).strip())
                else:
                    result = self.try_run(prompt_stack)
                    result.value = result.value.strip()

                self.after_run(result)

                return result
        else:
            raise Exception("prompt driver failed after all retry attempts")

    def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_user():
                prompt_lines.append(f"User: {i.content}")
            elif i.is_assistant():
                prompt_lines.append(f"Assistant: {i.content}")
            else:
                prompt_lines.append(i.content)

        prompt_lines.append("Assistant:")

        return "\n\n".join(prompt_lines)

    @abstractmethod
    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        ...

    @abstractmethod
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        ...

ignored_exception_types: Tuple[Type[Exception], ...] = field(default=Factory(lambda : ImportError), kw_only=True) class-attribute instance-attribute

max_tokens: Optional[int] = field(default=None, kw_only=True) class-attribute instance-attribute

model: str instance-attribute

prompt_stack_to_string: Callable[[PromptStack], str] = field(default=Factory(lambda : self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True) class-attribute instance-attribute

structure: Optional[Structure] = field(default=None, kw_only=True) class-attribute instance-attribute

temperature: float = field(default=0.1, kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer instance-attribute

after_run(result)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def after_run(self, result: TextArtifact) -> None:
    if self.structure:
        self.structure.publish_event(FinishPromptEvent(token_count=result.token_count(self.tokenizer)))

before_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def before_run(self, prompt_stack: PromptStack) -> None:
    if self.structure:
        self.structure.publish_event(StartPromptEvent(token_count=self.token_count(prompt_stack)))

default_prompt_stack_to_string_converter(prompt_stack)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_user():
            prompt_lines.append(f"User: {i.content}")
        elif i.is_assistant():
            prompt_lines.append(f"Assistant: {i.content}")
        else:
            prompt_lines.append(i.content)

    prompt_lines.append("Assistant:")

    return "\n\n".join(prompt_lines)

max_output_tokens(text)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def max_output_tokens(self, text: str | list) -> int:
    tokens_left = self.tokenizer.count_tokens_left(text)

    if self.max_tokens:
        return min(self.max_tokens, tokens_left)
    else:
        return tokens_left

run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def run(self, prompt_stack: PromptStack) -> TextArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompt_stack)

            if self.stream:
                tokens = []
                completion_chunks = self.try_stream(prompt_stack)
                for chunk in completion_chunks:
                    self.structure.publish_event(CompletionChunkEvent(token=chunk.value))
                    tokens.append(chunk.value)
                result = TextArtifact(value="".join(tokens).strip())
            else:
                result = self.try_run(prompt_stack)
                result.value = result.value.strip()

            self.after_run(result)

            return result
    else:
        raise Exception("prompt driver failed after all retry attempts")

token_count(prompt_stack)

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
def token_count(self, prompt_stack: PromptStack) -> int:
    return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack))

try_run(prompt_stack) abstractmethod

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    ...

try_stream(prompt_stack) abstractmethod

Source code in griptape/griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    ...

BasePromptModelDriver

Bases: ABC

Source code in griptape/griptape/drivers/prompt_model/base_prompt_model_driver.py
@define
class BasePromptModelDriver(ABC):
    max_tokens: int = field(default=600, kw_only=True)
    prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True)
    supports_streaming: bool = field(default=True, kw_only=True)

    @property
    @abstractmethod
    def tokenizer(self) -> BaseTokenizer:
        ...

    @abstractmethod
    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list | dict:
        ...

    @abstractmethod
    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        ...

    @abstractmethod
    def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
        ...

max_tokens: int = field(default=600, kw_only=True) class-attribute instance-attribute

prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

supports_streaming: bool = field(default=True, kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer abstractmethod property

process_output(output) abstractmethod

Source code in griptape/griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def process_output(self, output: list[dict] | str | bytes) -> TextArtifact:
    ...

prompt_stack_to_model_input(prompt_stack) abstractmethod

Source code in griptape/griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list | dict:
    ...

prompt_stack_to_model_params(prompt_stack) abstractmethod

Source code in griptape/griptape/drivers/prompt_model/base_prompt_model_driver.py
@abstractmethod
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    ...

BaseSqlDriver

Bases: ABC

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@define
class BaseSqlDriver(ABC):
    @dataclass
    class RowResult:
        cells: dict[str, Any]

    @abstractmethod
    def execute_query(self, query: str) -> Optional[list[RowResult]]:
        ...

    @abstractmethod
    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
        ...

    @abstractmethod
    def get_table_schema(self, table: str, schema: Optional[str] = None) -> Optional[str]:
        ...

RowResult dataclass

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@dataclass
class RowResult:
    cells: dict[str, Any]
cells: dict[str, Any] instance-attribute

execute_query(query) abstractmethod

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def execute_query(self, query: str) -> Optional[list[RowResult]]:
    ...

execute_query_raw(query) abstractmethod

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
    ...

get_table_schema(table, schema=None) abstractmethod

Source code in griptape/griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def get_table_schema(self, table: str, schema: Optional[str] = None) -> Optional[str]:
    ...

BaseVectorStoreDriver

Bases: ABC

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@define
class BaseVectorStoreDriver(ABC):
    DEFAULT_QUERY_COUNT = 5

    @dataclass
    class QueryResult:
        id: str
        vector: list[float]
        score: float
        meta: Optional[dict] = None
        namespace: Optional[str] = None

    @dataclass
    class Entry:
        id: str
        vector: list[float]
        meta: Optional[dict] = None
        namespace: Optional[str] = None

    embedding_driver: BaseEmbeddingDriver = field(kw_only=True)
    futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)

    def upsert_text_artifacts(
        self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, **kwargs
    ) -> None:
        utils.execute_futures_dict(
            {
                namespace: self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
                for namespace, artifact_list in artifacts.items()
                for a in artifact_list
            }
        )

    def upsert_text_artifact(
        self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs
    ) -> str:
        if not meta:
            meta = {}

        meta["artifact"] = artifact.to_json()

        if artifact.embedding:
            vector = artifact.embedding
        else:
            vector = artifact.generate_embedding(self.embedding_driver)

        return self.upsert_vector(vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs)

    def upsert_text(
        self,
        string: str,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs
    ) -> str:
        return self.upsert_vector(
            self.embedding_driver.embed_string(string),
            vector_id=vector_id,
            namespace=namespace,
            meta=meta if meta else {},
            **kwargs
        )

    @abstractmethod
    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs
    ) -> str:
        ...

    @abstractmethod
    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry:
        ...

    @abstractmethod
    def load_entries(self, namespace: Optional[str] = None) -> list[Entry]:
        ...

    @abstractmethod
    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs
    ) -> list[QueryResult]:
        ...

DEFAULT_QUERY_COUNT = 5 class-attribute instance-attribute

embedding_driver: BaseEmbeddingDriver = field(kw_only=True) class-attribute instance-attribute

futures_executor: futures.Executor = field(default=Factory(lambda : futures.ThreadPoolExecutor()), kw_only=True) class-attribute instance-attribute

Entry dataclass

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@dataclass
class Entry:
    id: str
    vector: list[float]
    meta: Optional[dict] = None
    namespace: Optional[str] = None
id: str instance-attribute
meta: Optional[dict] = None class-attribute instance-attribute
namespace: Optional[str] = None class-attribute instance-attribute
vector: list[float] instance-attribute

QueryResult dataclass

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@dataclass
class QueryResult:
    id: str
    vector: list[float]
    score: float
    meta: Optional[dict] = None
    namespace: Optional[str] = None
id: str instance-attribute
meta: Optional[dict] = None class-attribute instance-attribute
namespace: Optional[str] = None class-attribute instance-attribute
score: float instance-attribute
vector: list[float] instance-attribute

load_entries(namespace=None) abstractmethod

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def load_entries(self, namespace: Optional[str] = None) -> list[Entry]:
    ...

load_entry(vector_id, namespace=None) abstractmethod

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Entry:
    ...

query(query, count=None, namespace=None, include_vectors=False, **kwargs) abstractmethod

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs
) -> list[QueryResult]:
    ...

upsert_text(string, vector_id=None, namespace=None, meta=None, **kwargs)

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
def upsert_text(
    self,
    string: str,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs
) -> str:
    return self.upsert_vector(
        self.embedding_driver.embed_string(string),
        vector_id=vector_id,
        namespace=namespace,
        meta=meta if meta else {},
        **kwargs
    )

upsert_text_artifact(artifact, namespace=None, meta=None, **kwargs)

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifact(
    self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs
) -> str:
    if not meta:
        meta = {}

    meta["artifact"] = artifact.to_json()

    if artifact.embedding:
        vector = artifact.embedding
    else:
        vector = artifact.generate_embedding(self.embedding_driver)

    return self.upsert_vector(vector, vector_id=artifact.id, namespace=namespace, meta=meta, **kwargs)

upsert_text_artifacts(artifacts, meta=None, **kwargs)

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifacts(
    self, artifacts: dict[str, list[TextArtifact]], meta: Optional[dict] = None, **kwargs
) -> None:
    utils.execute_futures_dict(
        {
            namespace: self.futures_executor.submit(self.upsert_text_artifact, a, namespace, meta, **kwargs)
            for namespace, artifact_list in artifacts.items()
            for a in artifact_list
        }
    )

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs) abstractmethod

Source code in griptape/griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs
) -> str:
    ...

BedrockClaudePromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
@define
class BedrockClaudePromptModelDriver(BasePromptModelDriver):
    top_p: float = field(default=0.999, kw_only=True)
    top_k: int = field(default=250, kw_only=True)
    _tokenizer: BedrockClaudeTokenizer = field(default=None, kw_only=True)
    prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)

    @property
    def tokenizer(self) -> BedrockClaudeTokenizer:
        """Returns the tokenizer for this driver.

        We need to pass the `session` field from the Prompt Driver to the
        Tokenizer. However, the Prompt Driver is not initialized until after
        the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
        field a @property that is only initialized when it is first accessed.
        This ensures that by the time we need to initialize the Tokenizer, the
        Prompt Driver has already been initialized.

        See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

        Returns:
            BedrockClaudeTokenizer: The tokenizer for this driver.
        """
        if self._tokenizer:
            return self._tokenizer
        else:
            self._tokenizer = BedrockClaudeTokenizer(model=self.prompt_driver.model)
            return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_assistant():
                prompt_lines.append(f"Assistant: {i.content}")
            else:
                prompt_lines.append(f"Human: {i.content}")

        prompt_lines.append("Assistant:")

        return {"prompt": "\n\n" + "\n\n".join(prompt_lines)}

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)["prompt"]

        return {
            "max_tokens_to_sample": self.prompt_driver.max_output_tokens(prompt),
            "stop_sequences": self.tokenizer.stop_sequences,
            "temperature": self.prompt_driver.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
        }

    def process_output(self, response_body: bytes) -> TextArtifact:
        body = json.loads(response_body.decode())

        return TextArtifact(body["completion"])

prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

tokenizer: BedrockClaudeTokenizer property

Returns the tokenizer for this driver.

We need to pass the session field from the Prompt Driver to the Tokenizer. However, the Prompt Driver is not initialized until after the Prompt Model Driver is initialized. To resolve this, we make the tokenizer field a @property that is only initialized when it is first accessed. This ensures that by the time we need to initialize the Tokenizer, the Prompt Driver has already been initialized.

See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

Returns:

Name Type Description
BedrockClaudeTokenizer BedrockClaudeTokenizer

The tokenizer for this driver.

top_k: int = field(default=250, kw_only=True) class-attribute instance-attribute

top_p: float = field(default=0.999, kw_only=True) class-attribute instance-attribute

process_output(response_body)

Source code in griptape/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
def process_output(self, response_body: bytes) -> TextArtifact:
    body = json.loads(response_body.decode())

    return TextArtifact(body["completion"])

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_assistant():
            prompt_lines.append(f"Assistant: {i.content}")
        else:
            prompt_lines.append(f"Human: {i.content}")

    prompt_lines.append("Assistant:")

    return {"prompt": "\n\n" + "\n\n".join(prompt_lines)}

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)["prompt"]

    return {
        "max_tokens_to_sample": self.prompt_driver.max_output_tokens(prompt),
        "stop_sequences": self.tokenizer.stop_sequences,
        "temperature": self.prompt_driver.temperature,
        "top_p": self.top_p,
        "top_k": self.top_k,
    }

BedrockJurassicPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
@define
class BedrockJurassicPromptModelDriver(BasePromptModelDriver):
    top_p: float = field(default=0.9, kw_only=True)
    _tokenizer: BedrockJurassicTokenizer = field(default=None, kw_only=True)
    prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
    supports_streaming: bool = field(default=False, kw_only=True)

    @property
    def tokenizer(self) -> BedrockJurassicTokenizer:
        """Returns the tokenizer for this driver.

        We need to pass the `session` field from the Prompt Driver to the
        Tokenizer. However, the Prompt Driver is not initialized until after
        the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
        field a @property that is only initialized when it is first accessed.
        This ensures that by the time we need to initialize the Tokenizer, the
        Prompt Driver has already been initialized.

        See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

        Returns:
            BedrockJurassicTokenizer: The tokenizer for this driver.
        """
        if self._tokenizer:
            return self._tokenizer
        else:
            self._tokenizer = BedrockJurassicTokenizer(
                model=self.prompt_driver.model, session=self.prompt_driver.session
            )
            return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_user():
                prompt_lines.append(f"User: {i.content}")
            elif i.is_assistant():
                prompt_lines.append(f"Bot: {i.content}")
            elif i.is_system():
                prompt_lines.append(f"Instructions: {i.content}")
            else:
                prompt_lines.append(i.content)
        prompt_lines.append("Bot:")

        prompt = "\n".join(prompt_lines)

        return {"prompt": prompt}

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)["prompt"]

        return {
            "maxTokens": self.prompt_driver.max_output_tokens(prompt),
            "temperature": self.prompt_driver.temperature,
            "stopSequences": self.tokenizer.stop_sequences,
            "countPenalty": {"scale": 0},
            "presencePenalty": {"scale": 0},
            "frequencyPenalty": {"scale": 0},
        }

    def process_output(self, response_body: str) -> TextArtifact:
        body = json.loads(response_body)

        return TextArtifact(body["completions"][0]["data"]["text"])

prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

supports_streaming: bool = field(default=False, kw_only=True) class-attribute instance-attribute

tokenizer: BedrockJurassicTokenizer property

Returns the tokenizer for this driver.

We need to pass the session field from the Prompt Driver to the Tokenizer. However, the Prompt Driver is not initialized until after the Prompt Model Driver is initialized. To resolve this, we make the tokenizer field a @property that is only initialized when it is first accessed. This ensures that by the time we need to initialize the Tokenizer, the Prompt Driver has already been initialized.

See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

Returns:

Name Type Description
BedrockJurassicTokenizer BedrockJurassicTokenizer

The tokenizer for this driver.

top_p: float = field(default=0.9, kw_only=True) class-attribute instance-attribute

process_output(response_body)

Source code in griptape/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
def process_output(self, response_body: str) -> TextArtifact:
    body = json.loads(response_body)

    return TextArtifact(body["completions"][0]["data"]["text"])

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_user():
            prompt_lines.append(f"User: {i.content}")
        elif i.is_assistant():
            prompt_lines.append(f"Bot: {i.content}")
        elif i.is_system():
            prompt_lines.append(f"Instructions: {i.content}")
        else:
            prompt_lines.append(i.content)
    prompt_lines.append("Bot:")

    prompt = "\n".join(prompt_lines)

    return {"prompt": prompt}

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)["prompt"]

    return {
        "maxTokens": self.prompt_driver.max_output_tokens(prompt),
        "temperature": self.prompt_driver.temperature,
        "stopSequences": self.tokenizer.stop_sequences,
        "countPenalty": {"scale": 0},
        "presencePenalty": {"scale": 0},
        "frequencyPenalty": {"scale": 0},
    }

BedrockTitanEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes:

Name Type Description
model str

Embedding model name. Defaults to DEFAULT_MODEL.

dimensions int

Vector dimensions. Defaults to DEFAULT_MAX_TOKENS.

tokenizer BedrockTitanTokenizer

Optionally provide custom BedrockTitanTokenizer.

session Session

Optionally provide custom boto3.Session.

bedrock_client Any

Optionally provide custom bedrock-runtime client.

Source code in griptape/griptape/drivers/embedding/bedrock_titan_embedding_driver.py
@define
class BedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
    """
    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        dimensions: Vector dimensions. Defaults to DEFAULT_MAX_TOKENS.
        tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
        session: Optionally provide custom `boto3.Session`.
        bedrock_client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "amazon.titan-embed-text-v1"
    DEFAULT_MAX_TOKENS = 1536

    model: str = field(default=DEFAULT_MODEL, kw_only=True)
    dimensions: int = field(default=DEFAULT_MAX_TOKENS, kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BedrockTitanTokenizer = field(
        default=Factory(lambda self: BedrockTitanTokenizer(model=self.model, session=self.session), takes_self=True),
        kw_only=True,
    )
    bedrock_client: Any = field(
        default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
    )

    def try_embed_chunk(self, chunk: str) -> list[float]:
        payload = {"inputText": chunk}

        response = self.bedrock_client.invoke_model(
            body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embedding")

DEFAULT_MAX_TOKENS = 1536 class-attribute instance-attribute

DEFAULT_MODEL = 'amazon.titan-embed-text-v1' class-attribute instance-attribute

bedrock_client: Any = field(default=Factory(lambda : self.session.client('bedrock-runtime'), takes_self=True), kw_only=True) class-attribute instance-attribute

dimensions: int = field(default=DEFAULT_MAX_TOKENS, kw_only=True) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True) class-attribute instance-attribute

session: boto3.Session = field(default=Factory(lambda : import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

tokenizer: BedrockTitanTokenizer = field(default=Factory(lambda : BedrockTitanTokenizer(model=self.model, session=self.session), takes_self=True), kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk)

Source code in griptape/griptape/drivers/embedding/bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    payload = {"inputText": chunk}

    response = self.bedrock_client.invoke_model(
        body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json"
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embedding")

BedrockTitanPromptModelDriver

Bases: BasePromptModelDriver

Source code in griptape/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
@define
class BedrockTitanPromptModelDriver(BasePromptModelDriver):
    top_p: float = field(default=0.9, kw_only=True)
    _tokenizer: BedrockTitanTokenizer = field(default=None, kw_only=True)
    prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)

    @property
    def tokenizer(self) -> BedrockTitanTokenizer:
        """Returns the tokenizer for this driver.

        We need to pass the `session` field from the Prompt Driver to the
        Tokenizer. However, the Prompt Driver is not initialized until after
        the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer`
        field a @property that is only initialized when it is first accessed.
        This ensures that by the time we need to initialize the Tokenizer, the
        Prompt Driver has already been initialized.

        See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

        Returns:
            BedrockTitanTokenizer: The tokenizer for this driver.
        """
        if self._tokenizer:
            return self._tokenizer
        else:
            self._tokenizer = BedrockTitanTokenizer(model=self.prompt_driver.model, session=self.prompt_driver.session)
            return self._tokenizer

    def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_user():
                prompt_lines.append(f"User: {i.content}")
            elif i.is_assistant():
                prompt_lines.append(f"Bot: {i.content}")
            elif i.is_system():
                prompt_lines.append(f"Instructions: {i.content}")
            else:
                prompt_lines.append(i.content)
        prompt_lines.append("Bot:")

        prompt = "\n\n".join(prompt_lines)

        return {"inputText": prompt}

    def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_model_input(prompt_stack)["inputText"]

        return {
            "textGenerationConfig": {
                "maxTokenCount": self.prompt_driver.max_output_tokens(prompt),
                "stopSequences": self.tokenizer.stop_sequences,
                "temperature": self.prompt_driver.temperature,
                "topP": self.top_p,
            }
        }

    def process_output(self, response_body: str | bytes) -> TextArtifact:
        # When streaming, the response body comes back as bytes.
        if isinstance(response_body, bytes):
            response_body = response_body.decode()

        body = json.loads(response_body)

        if self.prompt_driver.stream:
            return TextArtifact(body["outputText"])
        else:
            return TextArtifact(body["results"][0]["outputText"])

prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

tokenizer: BedrockTitanTokenizer property

Returns the tokenizer for this driver.

We need to pass the session field from the Prompt Driver to the Tokenizer. However, the Prompt Driver is not initialized until after the Prompt Model Driver is initialized. To resolve this, we make the tokenizer field a @property that is only initialized when it is first accessed. This ensures that by the time we need to initialize the Tokenizer, the Prompt Driver has already been initialized.

See this thread more more information: https://github.com/griptape-ai/griptape/issues/244

Returns:

Name Type Description
BedrockTitanTokenizer BedrockTitanTokenizer

The tokenizer for this driver.

top_p: float = field(default=0.9, kw_only=True) class-attribute instance-attribute

process_output(response_body)

Source code in griptape/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
def process_output(self, response_body: str | bytes) -> TextArtifact:
    # When streaming, the response body comes back as bytes.
    if isinstance(response_body, bytes):
        response_body = response_body.decode()

    body = json.loads(response_body)

    if self.prompt_driver.stream:
        return TextArtifact(body["outputText"])
    else:
        return TextArtifact(body["results"][0]["outputText"])

prompt_stack_to_model_input(prompt_stack)

Source code in griptape/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict:
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_user():
            prompt_lines.append(f"User: {i.content}")
        elif i.is_assistant():
            prompt_lines.append(f"Bot: {i.content}")
        elif i.is_system():
            prompt_lines.append(f"Instructions: {i.content}")
        else:
            prompt_lines.append(i.content)
    prompt_lines.append("Bot:")

    prompt = "\n\n".join(prompt_lines)

    return {"inputText": prompt}

prompt_stack_to_model_params(prompt_stack)

Source code in griptape/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py
def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
    prompt = self.prompt_stack_to_model_input(prompt_stack)["inputText"]

    return {
        "textGenerationConfig": {
            "maxTokenCount": self.prompt_driver.max_output_tokens(prompt),
            "stopSequences": self.tokenizer.stop_sequences,
            "temperature": self.prompt_driver.temperature,
            "topP": self.top_p,
        }
    }

CoherePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key str

Cohere API key.

model str

Cohere model name.

client Client

Custom cohere.Client.

tokenizer CohereTokenizer

Custom CohereTokenizer.

Source code in griptape/griptape/drivers/prompt/cohere_prompt_driver.py
@define
class CoherePromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Cohere API key.
        model: 	Cohere model name.
        client: Custom `cohere.Client`.
        tokenizer: Custom `CohereTokenizer`.
    """

    api_key: str = field(kw_only=True)
    model: str = field(kw_only=True)
    client: Client = field(
        default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
        kw_only=True,
    )
    tokenizer: CohereTokenizer = field(
        default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        result = self.client.generate(**self._base_params(prompt_stack))

        if result.generations:
            if len(result.generations) == 1:
                generation = result.generations[0]

                return TextArtifact(value=generation.text.strip())
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        result = self.client.generate(**self._base_params(prompt_stack), stream=True)

        for chunk in result:
            yield TextArtifact(value=chunk.text)

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_string(prompt_stack)
        return {
            "prompt": self.prompt_stack_to_string(prompt_stack),
            "model": self.model,
            "temperature": self.temperature,
            "end_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_output_tokens(prompt),
        }

api_key: str = field(kw_only=True) class-attribute instance-attribute

client: Client = field(default=Factory(lambda : import_optional_dependency('cohere').Client(self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

tokenizer: CohereTokenizer = field(default=Factory(lambda : CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/cohere_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    result = self.client.generate(**self._base_params(prompt_stack))

    if result.generations:
        if len(result.generations) == 1:
            generation = result.generations[0]

            return TextArtifact(value=generation.text.strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/griptape/drivers/prompt/cohere_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    result = self.client.generate(**self._base_params(prompt_stack), stream=True)

    for chunk in result:
        yield TextArtifact(value=chunk.text)

HuggingFaceHubPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_token str

Hugging Face Hub API token.

use_gpu bool

Use GPU during model run.

params dict

Custom model run parameters.

model str

Hugging Face Hub model name.

client InferenceApi

Custom InferenceApi.

tokenizer HuggingFaceTokenizer

Custom HuggingFaceTokenizer.

Source code in griptape/griptape/drivers/prompt/hugging_face_hub_prompt_driver.py
@define
class HuggingFaceHubPromptDriver(BasePromptDriver):
    """
    Attributes:
        api_token: Hugging Face Hub API token.
        use_gpu: Use GPU during model run.
        params: Custom model run parameters.
        model: Hugging Face Hub model name.
        client: Custom `InferenceApi`.
        tokenizer: Custom `HuggingFaceTokenizer`.

    """

    SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
    MAX_NEW_TOKENS = 250
    DEFAULT_PARAMS = {"return_full_text": False, "max_new_tokens": MAX_NEW_TOKENS}

    api_token: str = field(kw_only=True)
    use_gpu: bool = field(default=False, kw_only=True)
    params: dict = field(factory=dict, kw_only=True)
    model: str = field(kw_only=True)
    client: InferenceApi = field(
        default=Factory(
            lambda self: import_optional_dependency("huggingface_hub").InferenceApi(
                repo_id=self.model, token=self.api_token, gpu=self.use_gpu
            ),
            takes_self=True,
        ),
        kw_only=True,
    )
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model),
                max_tokens=self.MAX_NEW_TOKENS,
            ),
            takes_self=True,
        ),
        kw_only=True,
    )
    stream: bool = field(default=False, kw_only=True)

    @stream.validator
    def validate_stream(self, _, stream):
        if stream:
            raise ValueError("streaming is not supported")

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        prompt = self.prompt_stack_to_string(prompt_stack)

        if self.client.task in self.SUPPORTED_TASKS:
            response = self.client(inputs=prompt, params=self.DEFAULT_PARAMS | self.params)

            if len(response) == 1:
                return TextArtifact(value=response[0]["generated_text"].strip())
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

    def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

DEFAULT_PARAMS = {'return_full_text': False, 'max_new_tokens': MAX_NEW_TOKENS} class-attribute instance-attribute

MAX_NEW_TOKENS = 250 class-attribute instance-attribute

SUPPORTED_TASKS = ['text2text-generation', 'text-generation'] class-attribute instance-attribute

api_token: str = field(kw_only=True) class-attribute instance-attribute

client: InferenceApi = field(default=Factory(lambda : import_optional_dependency('huggingface_hub').InferenceApi(repo_id=self.model, token=self.api_token, gpu=self.use_gpu), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

params: dict = field(factory=dict, kw_only=True) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda : HuggingFaceTokenizer(tokenizer=import_optional_dependency('transformers').AutoTokenizer.from_pretrained(self.model), max_tokens=self.MAX_NEW_TOKENS), takes_self=True), kw_only=True) class-attribute instance-attribute

use_gpu: bool = field(default=False, kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/hugging_face_hub_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    prompt = self.prompt_stack_to_string(prompt_stack)

    if self.client.task in self.SUPPORTED_TASKS:
        response = self.client(inputs=prompt, params=self.DEFAULT_PARAMS | self.params)

        if len(response) == 1:
            return TextArtifact(value=response[0]["generated_text"].strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

try_stream(_)

Source code in griptape/griptape/drivers/prompt/hugging_face_hub_prompt_driver.py
def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

validate_stream(_, stream)

Source code in griptape/griptape/drivers/prompt/hugging_face_hub_prompt_driver.py
@stream.validator
def validate_stream(self, _, stream):
    if stream:
        raise ValueError("streaming is not supported")

HuggingFacePipelinePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
params dict

Custom model run parameters.

model str

Hugging Face Hub model name.

tokenizer HuggingFaceTokenizer

Custom HuggingFaceTokenizer.

Source code in griptape/griptape/drivers/prompt/hugging_face_pipeline_prompt_driver.py
@define
class HuggingFacePipelinePromptDriver(BasePromptDriver):
    """
    Attributes:
        params: Custom model run parameters.
        model: Hugging Face Hub model name.
        tokenizer: Custom `HuggingFaceTokenizer`.

    """

    SUPPORTED_TASKS = ["text2text-generation", "text-generation"]
    DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1}

    model: str = field(kw_only=True)
    params: dict = field(factory=dict, kw_only=True)
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(
                tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model)
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        prompt = self.prompt_stack_to_string(prompt_stack)
        pipeline = import_optional_dependency("transformers").pipeline

        generator = pipeline(
            tokenizer=self.tokenizer.tokenizer,
            model=self.model,
            max_new_tokens=self.tokenizer.count_tokens_left(prompt),
        )

        if generator.task in self.SUPPORTED_TASKS:
            extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id}

            response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params))

            if len(response) == 1:
                return TextArtifact(value=response[0]["generated_text"].strip())
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

    def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
        raise NotImplementedError("streaming is not supported")

DEFAULT_PARAMS = {'return_full_text': False, 'num_return_sequences': 1} class-attribute instance-attribute

SUPPORTED_TASKS = ['text2text-generation', 'text-generation'] class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

params: dict = field(factory=dict, kw_only=True) class-attribute instance-attribute

tokenizer: HuggingFaceTokenizer = field(default=Factory(lambda : HuggingFaceTokenizer(tokenizer=import_optional_dependency('transformers').AutoTokenizer.from_pretrained(self.model)), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/griptape/drivers/prompt/hugging_face_pipeline_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    prompt = self.prompt_stack_to_string(prompt_stack)
    pipeline = import_optional_dependency("transformers").pipeline

    generator = pipeline(
        tokenizer=self.tokenizer.tokenizer,
        model=self.model,
        max_new_tokens=self.tokenizer.count_tokens_left(prompt),
    )

    if generator.task in self.SUPPORTED_TASKS:
        extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id}

        response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params))

        if len(response) == 1:
            return TextArtifact(value=response[0]["generated_text"].strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}")

try_stream(_)

Source code in griptape/griptape/drivers/prompt/hugging_face_pipeline_prompt_driver.py
def try_stream(self, _: PromptStack) -> Iterator[TextArtifact]:
    raise NotImplementedError("streaming is not supported")

LocalConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source code in griptape/griptape/drivers/memory/conversation/local_conversation_memory_driver.py
@define
class LocalConversationMemoryDriver(BaseConversationMemoryDriver):
    file_path: str = field(default="griptape_memory.json", kw_only=True)

    def store(self, memory: ConversationMemory) -> None:
        with open(self.file_path, "w") as file:
            file.write(memory.to_json())

    def load(self) -> Optional[ConversationMemory]:
        if not os.path.exists(self.file_path):
            return None
        with open(self.file_path, "r") as file:
            memory = ConversationMemory.from_json(file.read())

            memory.driver = self

            return memory

file_path: str = field(default='griptape_memory.json', kw_only=True) class-attribute instance-attribute

load()

Source code in griptape/griptape/drivers/memory/conversation/local_conversation_memory_driver.py
def load(self) -> Optional[ConversationMemory]:
    if not os.path.exists(self.file_path):
        return None
    with open(self.file_path, "r") as file:
        memory = ConversationMemory.from_json(file.read())

        memory.driver = self

        return memory

store(memory)

Source code in griptape/griptape/drivers/memory/conversation/local_conversation_memory_driver.py
def store(self, memory: ConversationMemory) -> None:
    with open(self.file_path, "w") as file:
        file.write(memory.to_json())

LocalVectorStoreDriver

Bases: BaseVectorStoreDriver

Source code in griptape/griptape/drivers/vector/local_vector_store_driver.py
@define
class LocalVectorStoreDriver(BaseVectorStoreDriver):
    entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict, kw_only=True)
    relatedness_fn: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)), kw_only=True)

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        vector_id = vector_id if vector_id else utils.str_to_hash(str(vector))

        self.entries[self._namespaced_vector_id(vector_id, namespace)] = self.Entry(
            id=vector_id, vector=vector, meta=meta, namespace=namespace
        )

        return vector_id

    def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        return self.entries.get(self._namespaced_vector_id(vector_id, namespace), None)

    def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

    def query(
        self,
        query: str,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.QueryResult]:
        query_embedding = self.embedding_driver.embed_string(query)

        if namespace:
            entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
        else:
            entries = self.entries

        entries_and_relatednesses = [
            (entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in entries.values()
        ]
        entries_and_relatednesses.sort(key=lambda x: x[1], reverse=True)

        result = [
            BaseVectorStoreDriver.QueryResult(id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta)
            for er in entries_and_relatednesses
        ][:count]

        if include_vectors:
            return result
        else:
            return [
                BaseVectorStoreDriver.QueryResult(id=r.id, vector=[], score=r.score, meta=r.meta, namespace=r.namespace)
                for r in result
            ]

    def _namespaced_vector_id(self, vector_id: str, namespace: Optional[str]):
        return vector_id if namespace is None else f"{namespace}-{vector_id}"

entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict, kw_only=True) class-attribute instance-attribute

relatedness_fn: Callable = field(default=lambda , : dot(x, y) / norm(x) * norm(y), kw_only=True) class-attribute instance-attribute

load_entries(namespace=None)

Source code in griptape/griptape/drivers/vector/local_vector_store_driver.py
def load_entries(self, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

load_entry(vector_id, namespace=None)

Source code in griptape/griptape/drivers/vector/local_vector_store_driver.py
def load_entry(self, vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    return self.entries.get(self._namespaced_vector_id(vector_id, namespace), None)

query(query, count=None, namespace=None, include_vectors=False, **kwargs)

Source code in griptape/griptape/drivers/vector/local_vector_store_driver.py
def query(
    self,
    query: str,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.QueryResult]:
    query_embedding = self.embedding_driver.embed_string(query)

    if namespace:
        entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
    else:
        entries = self.entries

    entries_and_relatednesses = [
        (entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in entries.values()
    ]
    entries_and_relatednesses.sort(key=lambda x: x[1], reverse=True)

    result = [
        BaseVectorStoreDriver.QueryResult(id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta)
        for er in entries_and_relatednesses
    ][:count]

    if include_vectors:
        return result
    else:
        return [
            BaseVectorStoreDriver.QueryResult(id=r.id, vector=[], score=r.score, meta=r.meta, namespace=r.namespace)
            for r in result
        ]

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source code in griptape/griptape/drivers/vector/local_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    vector_id = vector_id if vector_id else utils.str_to_hash(str(vector))

    self.entries[self._namespaced_vector_id(vector_id, namespace)] = self.Entry(
        id=vector_id, vector=vector, meta=meta, namespace=namespace
    )

    return vector_id

MarqoVectorStoreDriver

Bases: BaseVectorStoreDriver

A Vector Store Driver for Marqo.

Attributes:

Name Type Description
api_key str

The API key for the Marqo API.

url str

The URL to the Marqo API.

mq Optional[Client]

An optional Marqo client. Defaults to a new client with the given URL and API key.

index str

The name of the index to use.

Source code in griptape/griptape/drivers/vector/marqo_vector_store_driver.py
@define
class MarqoVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for Marqo.

    Attributes:
        api_key: The API key for the Marqo API.
        url: The URL to the Marqo API.
        mq: An optional Marqo client. Defaults to a new client with the given URL and API key.
        index: The name of the index to use.
    """

    api_key: str = field(kw_only=True)
    url: str = field(kw_only=True)
    mq: Optional[marqo.Client] = field(
        default=Factory(
            lambda self: import_optional_dependency("marqo").Client(self.url, api_key=self.api_key), takes_self=True
        ),
        kw_only=True,
    )
    index: str = field(kw_only=True)

    def upsert_text(
        self,
        string: str,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Upsert a text document into the Marqo index.

        Args:
            string: The string to be indexed.
            vector_id: The ID for the vector. If None, Marqo will generate an ID.
            namespace: An optional namespace for the document.
            meta: An optional dictionary of metadata for the document.

        Returns: