Skip to content

Cohere prompt driver

CoherePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key str

Cohere API key.

model str

Cohere model name.

client Client

Custom cohere.Client.

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
@define
class CoherePromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Cohere API key.
        model: 	Cohere model name.
        client: Custom `cohere.Client`.
    """

    api_key: str = field(kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Client = field(
        default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
        kw_only=True,
    )
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        result = self.client.chat(**self._base_params(prompt_stack))

        return TextArtifact(value=result.text)

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        result = self.client.chat_stream(**self._base_params(prompt_stack))

        for event in result:
            if event.event_type == "text-generation":
                yield TextArtifact(value=event.text)

    def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict:
        if prompt_input.is_system():
            return {"role": "SYSTEM", "text": prompt_input.content}
        elif prompt_input.is_user():
            return {"role": "USER", "text": prompt_input.content}
        else:
            return {"role": "ASSISTANT", "text": prompt_input.content}

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        user_message = prompt_stack.inputs[-1].content

        history_messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1]]

        return {
            "message": user_message,
            "chat_history": history_messages,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_tokens,
        }

api_key: str = field(kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

client: Client = field(default=Factory(lambda self: import_optional_dependency('cohere').Client(self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    result = self.client.chat(**self._base_params(prompt_stack))

    return TextArtifact(value=result.text)

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    result = self.client.chat_stream(**self._base_params(prompt_stack))

    for event in result:
        if event.event_type == "text-generation":
            yield TextArtifact(value=event.text)