Skip to content

Base multi model prompt driver

BaseMultiModelPromptDriver

Bases: BasePromptDriver, ABC

Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models.

Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack into a model input and parameters, and to process the model output.

Attributes:

Name Type Description
model

Name of the model to use.

tokenizer Optional[BaseTokenizer]

Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver.

prompt_model_driver BasePromptModelDriver

Prompt Model Driver to use.

Source code in griptape/drivers/prompt/base_multi_model_prompt_driver.py
@define
class BaseMultiModelPromptDriver(BasePromptDriver, ABC):
    """Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models.

    Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack
    into a model input and parameters, and to process the model output.

    Attributes:
        model: Name of the model to use.
        tokenizer: Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver.
        prompt_model_driver: Prompt Model Driver to use.
    """

    tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
    prompt_model_driver: BasePromptModelDriver = field(kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})

    @stream.validator  # pyright: ignore
    def validate_stream(self, _, stream):
        if stream and not self.prompt_model_driver.supports_streaming:
            raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming")

    def __attrs_post_init__(self) -> None:
        self.prompt_model_driver.prompt_driver = self

        if not self.tokenizer:
            self.tokenizer = self.prompt_model_driver.tokenizer

prompt_model_driver: BasePromptModelDriver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

stream: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/drivers/prompt/base_multi_model_prompt_driver.py
def __attrs_post_init__(self) -> None:
    self.prompt_model_driver.prompt_driver = self

    if not self.tokenizer:
        self.tokenizer = self.prompt_model_driver.tokenizer

validate_stream(_, stream)

Source code in griptape/drivers/prompt/base_multi_model_prompt_driver.py
@stream.validator  # pyright: ignore
def validate_stream(self, _, stream):
    if stream and not self.prompt_model_driver.supports_streaming:
        raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming")