Skip to content

cohere_prompt_driver

CoherePromptDriver

Bases: BasePromptDriver

Cohere Prompt Driver.

Attributes:

Name Type Description
api_key str

Cohere API key.

model str

Cohere model name.

client Client

Custom cohere.Client.

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
@define(kw_only=True)
class CoherePromptDriver(BasePromptDriver):
    """Cohere Prompt Driver.

    Attributes:
        api_key: Cohere API key.
        model: 	Cohere model name.
        client: Custom `cohere.Client`.
    """

    api_key: str = field(metadata={"serializable": False})
    model: str = field(metadata={"serializable": True})
    force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
    )

    @lazy_property()
    def client(self) -> Client:
        return import_optional_dependency("cohere").Client(self.api_key)

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        result = self.client.chat(**self._base_params(prompt_stack))
        usage = result.meta.tokens

        return Message(
            content=self.__to_prompt_stack_message_content(result),
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        result = self.client.chat_stream(**self._base_params(prompt_stack))

        for event in result:
            if event.event_type == "stream-end":
                usage = event.response.meta.tokens

                yield DeltaMessage(
                    usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
                )
            elif event.event_type == "text-generation" or event.event_type == "tool-calls-chunk":
                yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        # Current message
        last_input = prompt_stack.messages[-1]
        user_message = ""
        tool_results = []
        if last_input is not None:
            message = self.__to_cohere_messages([prompt_stack.messages[-1]])

            if "message" in message[0]:
                user_message = message[0]["message"]
            if "tool_results" in message[0]:
                tool_results = message[0]["tool_results"]

        # History messages
        history_messages = self.__to_cohere_messages(
            [message for message in prompt_stack.messages[:-1] if not message.is_system()],
        )

        # System message (preamble)
        system_messages = prompt_stack.system_messages
        preamble = system_messages[0].to_text() if system_messages else None

        return {
            "message": user_message,
            "chat_history": history_messages,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_tokens,
            **({"tool_results": tool_results} if tool_results else {}),
            **(
                {"tools": self.__to_cohere_tools(prompt_stack.tools), "force_single_step": self.force_single_step}
                if prompt_stack.tools and self.use_native_tools
                else {}
            ),
            **({"preamble": preamble} if preamble else {}),
        }

    def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
        cohere_messages = []

        for message in messages:
            cohere_message: dict = {"role": self.__to_cohere_role(message), "message": message.to_text()}

            if message.has_any_content_type(ActionResultMessageContent):
                cohere_message["tool_results"] = [
                    self.__to_cohere_message_content(action_result)
                    for action_result in message.get_content_type(ActionResultMessageContent)
                ]
            else:
                if message.has_any_content_type(ActionCallMessageContent):
                    cohere_message["tool_calls"] = [
                        self.__to_cohere_message_content(action_call)
                        for action_call in message.get_content_type(ActionCallMessageContent)
                    ]

            cohere_messages.append(cohere_message)

        return cohere_messages

    def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict:
        if isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return {"name": action.to_native_tool_name(), "parameters": action.input}
        elif isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            if isinstance(artifact, ListArtifact):
                message_content = [{"text": artifact.to_text()} for artifact in artifact.value]
            else:
                message_content = [{"text": artifact.to_text()}]

            return {
                "call": {"name": content.action.to_native_tool_name(), "parameters": content.action.input},
                "outputs": message_content,
            }
        elif isinstance(content, ActionResultMessageContent):
            return {"text": content.artifact.to_text()}
        else:
            raise ValueError(f"Unsupported content type: {type(content)}")

    def __to_cohere_role(self, message: Message) -> str:
        if message.is_system():
            return "SYSTEM"
        elif message.is_assistant():
            return "CHATBOT"
        else:
            if message.has_any_content_type(ActionResultMessageContent):
                return "TOOL"
            else:
                return "USER"

    def __to_cohere_tools(self, tools: list[BaseTool]) -> list[dict]:
        tool_definitions = []

        for tool in tools:
            for activity in tool.activities():
                activity_schema = tool.activity_schema(activity)
                if activity_schema is not None:
                    properties_values = activity_schema.json_schema("Parameters Schema")["properties"]["values"]

                    properties = properties_values["properties"]
                else:
                    properties_values = {}
                    properties = {}

                tool_definitions.append(
                    {
                        "name": tool.to_native_tool_name(activity),
                        "description": tool.activity_description(activity),
                        "parameter_definitions": {
                            property_name: {
                                "type": property_value["type"],
                                "required": property_name in properties_values["required"],
                                **(
                                    {"description": property_value["description"]}
                                    if "description" in property_value
                                    else {}
                                ),
                            }
                            for property_name, property_value in properties.items()
                        },
                    },
                )

        return tool_definitions

    def __to_prompt_stack_message_content(self, response: NonStreamedChatResponse) -> list[BaseMessageContent]:
        content = []
        if response.text:
            content.append(TextMessageContent(TextArtifact(response.text)))
        if response.tool_calls is not None:
            content.extend(
                [
                    ActionCallMessageContent(
                        ActionArtifact(
                            ToolAction(
                                tag=tool_call.name,
                                name=ToolAction.from_native_tool_name(tool_call.name)[0],
                                path=ToolAction.from_native_tool_name(tool_call.name)[1],
                                input=tool_call.parameters,
                            ),
                        ),
                    )
                    for tool_call in response.tool_calls
                ],
            )

        return content

    def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessageContent:
        if event.event_type == "text-generation":
            return TextDeltaMessageContent(event.text, index=0)
        elif event.event_type == "tool-calls-chunk":
            if event.tool_call_delta is not None:
                tool_call_delta = event.tool_call_delta
                if tool_call_delta.name is not None:
                    name, path = ToolAction.from_native_tool_name(tool_call_delta.name)

                    return ActionCallDeltaMessageContent(tag=tool_call_delta.name, name=name, path=path)
                else:
                    return ActionCallDeltaMessageContent(partial_input=tool_call_delta.parameters)

            else:
                return TextDeltaMessageContent(event.text)
        else:
            raise ValueError(f"Unsupported event type: {event.event_type}")

api_key: str = field(metadata={'serializable': False}) class-attribute instance-attribute

force_single_step: bool = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True)) class-attribute instance-attribute

use_native_tools: bool = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_cohere_message_content(content)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict:
    if isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return {"name": action.to_native_tool_name(), "parameters": action.input}
    elif isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        if isinstance(artifact, ListArtifact):
            message_content = [{"text": artifact.to_text()} for artifact in artifact.value]
        else:
            message_content = [{"text": artifact.to_text()}]

        return {
            "call": {"name": content.action.to_native_tool_name(), "parameters": content.action.input},
            "outputs": message_content,
        }
    elif isinstance(content, ActionResultMessageContent):
        return {"text": content.artifact.to_text()}
    else:
        raise ValueError(f"Unsupported content type: {type(content)}")

__to_cohere_messages(messages)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
    cohere_messages = []

    for message in messages:
        cohere_message: dict = {"role": self.__to_cohere_role(message), "message": message.to_text()}

        if message.has_any_content_type(ActionResultMessageContent):
            cohere_message["tool_results"] = [
                self.__to_cohere_message_content(action_result)
                for action_result in message.get_content_type(ActionResultMessageContent)
            ]
        else:
            if message.has_any_content_type(ActionCallMessageContent):
                cohere_message["tool_calls"] = [
                    self.__to_cohere_message_content(action_call)
                    for action_call in message.get_content_type(ActionCallMessageContent)
                ]

        cohere_messages.append(cohere_message)

    return cohere_messages

__to_cohere_role(message)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_role(self, message: Message) -> str:
    if message.is_system():
        return "SYSTEM"
    elif message.is_assistant():
        return "CHATBOT"
    else:
        if message.has_any_content_type(ActionResultMessageContent):
            return "TOOL"
        else:
            return "USER"

__to_cohere_tools(tools)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_tools(self, tools: list[BaseTool]) -> list[dict]:
    tool_definitions = []

    for tool in tools:
        for activity in tool.activities():
            activity_schema = tool.activity_schema(activity)
            if activity_schema is not None:
                properties_values = activity_schema.json_schema("Parameters Schema")["properties"]["values"]

                properties = properties_values["properties"]
            else:
                properties_values = {}
                properties = {}

            tool_definitions.append(
                {
                    "name": tool.to_native_tool_name(activity),
                    "description": tool.activity_description(activity),
                    "parameter_definitions": {
                        property_name: {
                            "type": property_value["type"],
                            "required": property_name in properties_values["required"],
                            **(
                                {"description": property_value["description"]}
                                if "description" in property_value
                                else {}
                            ),
                        }
                        for property_name, property_value in properties.items()
                    },
                },
            )

    return tool_definitions

__to_prompt_stack_delta_message_content(event)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessageContent:
    if event.event_type == "text-generation":
        return TextDeltaMessageContent(event.text, index=0)
    elif event.event_type == "tool-calls-chunk":
        if event.tool_call_delta is not None:
            tool_call_delta = event.tool_call_delta
            if tool_call_delta.name is not None:
                name, path = ToolAction.from_native_tool_name(tool_call_delta.name)

                return ActionCallDeltaMessageContent(tag=tool_call_delta.name, name=name, path=path)
            else:
                return ActionCallDeltaMessageContent(partial_input=tool_call_delta.parameters)

        else:
            return TextDeltaMessageContent(event.text)
    else:
        raise ValueError(f"Unsupported event type: {event.event_type}")

__to_prompt_stack_message_content(response)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_prompt_stack_message_content(self, response: NonStreamedChatResponse) -> list[BaseMessageContent]:
    content = []
    if response.text:
        content.append(TextMessageContent(TextArtifact(response.text)))
    if response.tool_calls is not None:
        content.extend(
            [
                ActionCallMessageContent(
                    ActionArtifact(
                        ToolAction(
                            tag=tool_call.name,
                            name=ToolAction.from_native_tool_name(tool_call.name)[0],
                            path=ToolAction.from_native_tool_name(tool_call.name)[1],
                            input=tool_call.parameters,
                        ),
                    ),
                )
                for tool_call in response.tool_calls
            ],
        )

    return content

client()

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
@lazy_property()
def client(self) -> Client:
    return import_optional_dependency("cohere").Client(self.api_key)

try_run(prompt_stack)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    result = self.client.chat(**self._base_params(prompt_stack))
    usage = result.meta.tokens

    return Message(
        content=self.__to_prompt_stack_message_content(result),
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
    )

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/cohere_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    result = self.client.chat_stream(**self._base_params(prompt_stack))

    for event in result:
        if event.event_type == "stream-end":
            usage = event.response.meta.tokens

            yield DeltaMessage(
                usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
            )
        elif event.event_type == "text-generation" or event.event_type == "tool-calls-chunk":
            yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))