Skip to content

Cohere prompt driver

CoherePromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
api_key str

Cohere API key.

model str

Cohere model name.

client Client

Custom cohere.Client.

tokenizer CohereTokenizer

Custom CohereTokenizer.

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
@define
class CoherePromptDriver(BasePromptDriver):
    """
    Attributes:
        api_key: Cohere API key.
        model: 	Cohere model name.
        client: Custom `cohere.Client`.
        tokenizer: Custom `CohereTokenizer`.
    """

    api_key: str = field(kw_only=True, metadata={"serializable": True})
    model: str = field(kw_only=True, metadata={"serializable": True})
    client: Client = field(
        default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
        kw_only=True,
    )
    tokenizer: CohereTokenizer = field(
        default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        result = self.client.generate(**self._base_params(prompt_stack))

        if result.generations:
            if len(result.generations) == 1:
                generation = result.generations[0]

                return TextArtifact(value=generation.text.strip())
            else:
                raise Exception("completion with more than one choice is not supported yet")
        else:
            raise Exception("model response is empty")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        result = self.client.generate(**self._base_params(prompt_stack), stream=True)

        for chunk in result:
            yield TextArtifact(value=chunk.text)

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_string(prompt_stack)
        return {
            "prompt": self.prompt_stack_to_string(prompt_stack),
            "model": self.model,
            "temperature": self.temperature,
            "end_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_output_tokens(prompt),
        }

api_key: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: Client = field(default=Factory(lambda self: import_optional_dependency('cohere').Client(self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: CohereTokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    result = self.client.generate(**self._base_params(prompt_stack))

    if result.generations:
        if len(result.generations) == 1:
            generation = result.generations[0]

            return TextArtifact(value=generation.text.strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")
    else:
        raise Exception("model response is empty")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    result = self.client.generate(**self._base_params(prompt_stack), stream=True)

    for chunk in result:
        yield TextArtifact(value=chunk.text)