Skip to content

Openai chat prompt driver

OpenAiChatPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
base_url Optional[str]

An optional OpenAi API URL.

api_key Optional[str]

An optional OpenAi API key. If not provided, the OPENAI_API_KEY environment variable will be used.

organization Optional[str]

An optional OpenAI organization. If not provided, the OPENAI_ORG_ID environment variable will be used.

client OpenAI

An openai.OpenAI client.

model str

An OpenAI model name.

tokenizer BaseTokenizer

An OpenAiTokenizer.

user str

A user id. Can be used to track requests by user.

response_format Optional[Literal['json_object']]

An optional OpenAi Chat Completion response format. Currently only supports json_object which will enable OpenAi's JSON mode.

seed Optional[int]

An optional OpenAi Chat Completion seed.

ignored_exception_types tuple[type[Exception], ...]

An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.

Source code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@define
class OpenAiChatPromptDriver(BasePromptDriver):
    """
    Attributes:
        base_url: An optional OpenAi API URL.
        api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used.
        organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used.
        client: An `openai.OpenAI` client.
        model: An OpenAI model name.
        tokenizer: An `OpenAiTokenizer`.
        user: A user id. Can be used to track requests by user.
        response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
        seed: An optional OpenAi Chat Completion seed.
        ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
    """

    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    client: openai.OpenAI = field(
        default=Factory(
            lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
            takes_self=True,
        )
    )
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    user: str = field(default="", kw_only=True, metadata={"serializable": True})
    response_format: Optional[Literal["json_object"]] = field(
        default=None, kw_only=True, metadata={"serializable": True}
    )
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    ignored_exception_types: tuple[type[Exception], ...] = field(
        default=Factory(
            lambda: (
                openai.BadRequestError,
                openai.AuthenticationError,
                openai.PermissionDeniedError,
                openai.NotFoundError,
                openai.ConflictError,
                openai.UnprocessableEntityError,
            )
        ),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        result = self.client.chat.completions.create(**self._base_params(prompt_stack))

        if len(result.choices) == 1:
            return TextArtifact(value=result.choices[0].message.content.strip())
        else:
            raise Exception("Completion with more than one choice is not supported yet.")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        result = self.client.chat.completions.create(**self._base_params(prompt_stack), stream=True)

        for chunk in result:
            if len(chunk.choices) == 1:
                delta = chunk.choices[0].delta
            else:
                raise Exception("Completion with more than one choice is not supported yet.")

            if delta.content is not None:
                delta_content = delta.content

                yield TextArtifact(value=delta_content)

    def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict:
        content = prompt_input.content

        if prompt_input.is_system():
            return {"role": "system", "content": content}
        elif prompt_input.is_assistant():
            return {"role": "assistant", "content": content}
        else:
            return {"role": "user", "content": content}

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = {
            "model": self.model,
            "temperature": self.temperature,
            "stop": self.tokenizer.stop_sequences,
            "user": self.user,
            "seed": self.seed,
        }

        if self.response_format == "json_object":
            params["response_format"] = {"type": "json_object"}
            # JSON mode still requires a system input instructing the LLM to output JSON.
            prompt_stack.add_system_input("Provide your response as a valid JSON object.")

        messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs]

        if self.max_tokens is not None:
            params["max_tokens"] = self.max_tokens

        params["messages"] = messages

        return params

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), takes_self=True)) class-attribute instance-attribute

ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (openai.BadRequestError, openai.AuthenticationError, openai.PermissionDeniedError, openai.NotFoundError, openai.ConflictError, openai.UnprocessableEntityError)), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

organization: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

response_format: Optional[Literal['json_object']] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

seed: Optional[int] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

user: str = field(default='', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    result = self.client.chat.completions.create(**self._base_params(prompt_stack))

    if len(result.choices) == 1:
        return TextArtifact(value=result.choices[0].message.content.strip())
    else:
        raise Exception("Completion with more than one choice is not supported yet.")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    result = self.client.chat.completions.create(**self._base_params(prompt_stack), stream=True)

    for chunk in result:
        if len(chunk.choices) == 1:
            delta = chunk.choices[0].delta
        else:
            raise Exception("Completion with more than one choice is not supported yet.")

        if delta.content is not None:
            delta_content = delta.content

            yield TextArtifact(value=delta_content)