Skip to content

Base prompt driver

BasePromptDriver

Bases: SerializableMixin, ExponentialBackoffMixin, ABC

Base class for Prompt Drivers.

Attributes:

Name Type Description
temperature float

The temperature to use for the completion.

max_tokens Optional[int]

The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.

structure Optional[Structure]

An optional Structure to publish events to.

prompt_stack_to_string str

A function that converts a PromptStack to a string.

ignored_exception_types tuple[type[Exception], ...]

A tuple of exception types to ignore.

model str

The model name.

tokenizer BaseTokenizer

An instance of BaseTokenizer to when calculating tokens.

stream bool

Whether to stream the completion or not. CompletionChunkEvents will be published to the Structure if one is provided.

Source code in griptape/drivers/prompt/base_prompt_driver.py
@define
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    """Base class for Prompt Drivers.

    Attributes:
        temperature: The temperature to use for the completion.
        max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
        structure: An optional `Structure` to publish events to.
        prompt_stack_to_string: A function that converts a `PromptStack` to a string.
        ignored_exception_types: A tuple of exception types to ignore.
        model: The model name.
        tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
        stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
    """

    temperature: float = field(default=0.1, kw_only=True, metadata={"serializable": True})
    max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    structure: Optional[Structure] = field(default=None, kw_only=True)
    ignored_exception_types: tuple[type[Exception], ...] = field(
        default=Factory(lambda: (ImportError, ValueError)), kw_only=True
    )
    model: str = field(metadata={"serializable": True})
    tokenizer: BaseTokenizer
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

    def before_run(self, prompt_stack: PromptStack) -> None:
        if self.structure:
            self.structure.publish_event(
                StartPromptEvent(
                    model=self.model,
                    token_count=self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack)),
                    prompt_stack=prompt_stack,
                    prompt=self.prompt_stack_to_string(prompt_stack),
                )
            )

    def after_run(self, result: TextArtifact) -> None:
        if self.structure:
            self.structure.publish_event(
                FinishPromptEvent(
                    model=self.model, result=result.value, token_count=self.tokenizer.count_tokens(result.value)
                )
            )

    def run(self, prompt_stack: PromptStack) -> TextArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompt_stack)

                if self.stream:
                    tokens = []
                    completion_chunks = self.try_stream(prompt_stack)
                    for chunk in completion_chunks:
                        self.structure.publish_event(CompletionChunkEvent(token=chunk.value))
                        tokens.append(chunk.value)
                    result = TextArtifact(value="".join(tokens).strip())
                else:
                    result = self.try_run(prompt_stack)
                    result.value = result.value.strip()

                self.after_run(result)

                return result
        else:
            raise Exception("prompt driver failed after all retry attempts")

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        """Converts a Prompt Stack to a string for token counting or model input.
        This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens.

        Args:
            prompt_stack: The Prompt Stack to convert to a string.

        Returns:
            A single string representation of the Prompt Stack.
        """
        prompt_lines = []

        for i in prompt_stack.inputs:
            if i.is_user():
                prompt_lines.append(f"User: {i.content}")
            elif i.is_assistant():
                prompt_lines.append(f"Assistant: {i.content}")
            else:
                prompt_lines.append(i.content)

        prompt_lines.append("Assistant:")

        return "\n\n".join(prompt_lines)

    @abstractmethod
    def try_run(self, prompt_stack: PromptStack) -> TextArtifact: ...

    @abstractmethod
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: ...

ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (ImportError, ValueError)), kw_only=True) class-attribute instance-attribute

max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(metadata={'serializable': True}) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

structure: Optional[Structure] = field(default=None, kw_only=True) class-attribute instance-attribute

temperature: float = field(default=0.1, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: BaseTokenizer instance-attribute

after_run(result)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def after_run(self, result: TextArtifact) -> None:
    if self.structure:
        self.structure.publish_event(
            FinishPromptEvent(
                model=self.model, result=result.value, token_count=self.tokenizer.count_tokens(result.value)
            )
        )

before_run(prompt_stack)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def before_run(self, prompt_stack: PromptStack) -> None:
    if self.structure:
        self.structure.publish_event(
            StartPromptEvent(
                model=self.model,
                token_count=self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack)),
                prompt_stack=prompt_stack,
                prompt=self.prompt_stack_to_string(prompt_stack),
            )
        )

prompt_stack_to_string(prompt_stack)

Converts a Prompt Stack to a string for token counting or model input. This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens.

Parameters:

Name Type Description Default
prompt_stack PromptStack

The Prompt Stack to convert to a string.

required

Returns:

Type Description
str

A single string representation of the Prompt Stack.

Source code in griptape/drivers/prompt/base_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
    """Converts a Prompt Stack to a string for token counting or model input.
    This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens.

    Args:
        prompt_stack: The Prompt Stack to convert to a string.

    Returns:
        A single string representation of the Prompt Stack.
    """
    prompt_lines = []

    for i in prompt_stack.inputs:
        if i.is_user():
            prompt_lines.append(f"User: {i.content}")
        elif i.is_assistant():
            prompt_lines.append(f"Assistant: {i.content}")
        else:
            prompt_lines.append(i.content)

    prompt_lines.append("Assistant:")

    return "\n\n".join(prompt_lines)

run(prompt_stack)

Source code in griptape/drivers/prompt/base_prompt_driver.py
def run(self, prompt_stack: PromptStack) -> TextArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompt_stack)

            if self.stream:
                tokens = []
                completion_chunks = self.try_stream(prompt_stack)
                for chunk in completion_chunks:
                    self.structure.publish_event(CompletionChunkEvent(token=chunk.value))
                    tokens.append(chunk.value)
                result = TextArtifact(value="".join(tokens).strip())
            else:
                result = self.try_run(prompt_stack)
                result.value = result.value.strip()

            self.after_run(result)

            return result
    else:
        raise Exception("prompt driver failed after all retry attempts")

try_run(prompt_stack) abstractmethod

Source code in griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_run(self, prompt_stack: PromptStack) -> TextArtifact: ...

try_stream(prompt_stack) abstractmethod

Source code in griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: ...