@define
class OpenAiChatPromptDriver(BasePromptDriver):
"""OpenAI Chat Prompt Driver.
Attributes:
base_url: An optional OpenAi API URL.
api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used.
organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used.
client: An `openai.OpenAI` client.
model: An OpenAI model name.
tokenizer: An `OpenAiTokenizer`.
user: A user id. Can be used to track requests by user.
response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
seed: An optional OpenAi Chat Completion seed.
ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
parallel_tool_calls: A flag to enable parallel tool calls. Defaults to `True`.
"""
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
tokenizer: BaseTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
user: str = field(default="", kw_only=True, metadata={"serializable": True})
response_format: Optional[dict] = field(
default=None,
kw_only=True,
metadata={"serializable": True},
)
seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False})
reasoning_effort: Literal["low", "medium", "high"] = field(
default="medium", kw_only=True, metadata={"serializable": True}
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: StructuredOutputStrategy = field(
default="native", kw_only=True, metadata={"serializable": True}
)
parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True})
ignored_exception_types: tuple[type[Exception], ...] = field(
default=Factory(
lambda: (
openai.BadRequestError,
openai.AuthenticationError,
openai.PermissionDeniedError,
openai.NotFoundError,
openai.ConflictError,
openai.UnprocessableEntityError,
),
),
kw_only=True,
)
modalities: list[str] = field(factory=list, kw_only=True, metadata={"serializable": True})
audio: dict = field(
default=Factory(lambda: {"voice": "alloy", "format": "pcm16"}), kw_only=True, metadata={"serializable": True}
)
_client: Optional[openai.OpenAI] = field(
default=None, kw_only=True, alias="client", metadata={"serializable": False}
)
@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(
base_url=self.base_url,
api_key=self.api_key,
organization=self.organization,
)
@property
def is_reasoning_model(self) -> bool:
return self.model.startswith("o")
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
params = self._base_params(prompt_stack)
logger.debug(params)
result = self.client.chat.completions.create(**params)
logger.debug(result.model_dump())
return self._to_message(result)
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
params = self._base_params(prompt_stack)
logger.debug({"stream": True, **params})
result = self.client.chat.completions.create(**params, stream=True)
return self._to_delta_message_stream(result)
def _to_message(self, result: ChatCompletion) -> Message:
if len(result.choices) == 1:
choice_message = result.choices[0].message
message = Message(
content=self.__to_prompt_stack_message_content(choice_message),
role=Message.ASSISTANT_ROLE,
)
if result.usage is not None:
message.usage = Message.Usage(
input_tokens=result.usage.prompt_tokens,
output_tokens=result.usage.completion_tokens,
)
return message
raise Exception("Completion with more than one choice is not supported yet.")
def _to_delta_message_stream(self, result: Stream[ChatCompletionChunk]) -> Iterator[DeltaMessage]:
for message in result:
if message.usage is not None:
yield DeltaMessage(
usage=DeltaMessage.Usage(
input_tokens=message.usage.prompt_tokens,
output_tokens=message.usage.completion_tokens,
),
)
if message.choices:
choice = message.choices[0]
delta = choice.delta
content = self.__to_prompt_stack_delta_message_content(delta)
if content is not None:
yield DeltaMessage(content=content)
def _base_params(self, prompt_stack: PromptStack) -> dict:
params = {
"model": self.model,
**({"user": self.user} if self.user else {}),
**({"seed": self.seed} if self.seed is not None else {}),
**({"modalities": self.modalities} if self.modalities and not self.is_reasoning_model else {}),
**(
{"reasoning_effort": self.reasoning_effort}
if self.is_reasoning_model and self.model != "o1-mini"
else {}
),
**({"temperature": self.temperature} if not self.is_reasoning_model else {}),
**({"audio": self.audio} if "audio" in self.modalities else {}),
**(
{"stop": self.tokenizer.stop_sequences}
if not self.is_reasoning_model and self.tokenizer.stop_sequences
else {}
),
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
**({"stream_options": {"include_usage": True}} if self.stream else {}),
**self.extra_params,
}
if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice
params["parallel_tool_calls"] = self.parallel_tool_calls
if prompt_stack.output_schema is not None:
if self.structured_output_strategy == "native":
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": "Output",
"schema": prompt_stack.to_output_json_schema(),
"strict": True,
},
}
elif self.structured_output_strategy == "tool" and self.use_native_tools:
params["tool_choice"] = "required"
if self.response_format is not None:
if self.response_format == {"type": "json_object"}:
params["response_format"] = self.response_format
# JSON mode still requires a system message instructing the LLM to output JSON.
prompt_stack.add_system_message("Provide your response as a valid JSON object.")
else:
params["response_format"] = self.response_format
if prompt_stack.tools and self.use_native_tools:
params["tools"] = self.__to_openai_tools(prompt_stack.tools)
messages = self.__to_openai_messages(prompt_stack.messages)
params["messages"] = messages
return params
def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
openai_messages = []
for message in messages:
# If the message only contains textual content we can send it as a single content.
if message.has_all_content_type(TextMessageContent):
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
# Action results must be sent as separate messages.
elif action_result_contents := message.get_content_type(ActionResultMessageContent):
openai_messages.extend(
{
"role": self.__to_openai_role(message, action_result_content),
"content": self.__to_openai_message_content(action_result_content),
"tool_call_id": action_result_content.action.tag,
}
for action_result_content in action_result_contents
)
if message.has_any_content_type(TextMessageContent):
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
else:
openai_message = {
"role": self.__to_openai_role(message),
"content": [],
}
for content in message.content:
if isinstance(content, ActionCallMessageContent):
if "tool_calls" not in openai_message:
openai_message["tool_calls"] = []
openai_message["tool_calls"].append(self.__to_openai_message_content(content))
elif (
isinstance(content, AudioMessageContent)
and message.is_assistant()
and time.time() < content.artifact.meta["expires_at"]
):
# For assistant audio messages, we reference the audio id instead of sending audio message content.
openai_message["audio"] = {
"id": content.artifact.meta["audio_id"],
}
else:
openai_message["content"].append(self.__to_openai_message_content(content))
# Some OpenAi-compatible services don't accept an empty array for content
if not openai_message["content"]:
del openai_message["content"]
openai_messages.append(openai_message)
return openai_messages
def __to_openai_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
if message.is_system():
if self.is_reasoning_model:
return "developer"
return "system"
if message.is_assistant():
return "assistant"
if isinstance(message_content, ActionResultMessageContent):
return "tool"
return "user"
def __to_openai_tools(self, tools: list[BaseTool]) -> list[dict]:
return [
{
"function": {
"name": tool.to_native_tool_name(activity),
"description": tool.activity_description(activity),
"parameters": tool.to_activity_json_schema(activity, "Parameters Schema"),
},
"type": "function",
}
for tool in tools
for activity in tool.activities()
]
def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict:
if isinstance(content, TextMessageContent):
return {"type": "text", "text": content.artifact.to_text()}
if isinstance(content, ImageMessageContent):
if isinstance(content.artifact, ImageArtifact):
return {
"type": "image_url",
"image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"},
}
if isinstance(content.artifact, ImageUrlArtifact):
return {
"type": "image_url",
"image_url": {"url": content.artifact.value},
}
raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
if isinstance(content, AudioMessageContent):
artifact = content.artifact
metadata = artifact.meta
# If there's an expiration date, we can assume it's an assistant message.
if "expires_at" in metadata:
# If it's expired, we send the transcript instead.
if time.time() >= metadata["expires_at"]:
return {
"type": "text",
"text": artifact.meta.get("transcript"),
}
# This should never occur, since a non-expired audio content
# should have already been referenced by the audio id.
raise ValueError("Assistant audio messages should be sent as audio ids.")
# If there's no expiration date, we can assume it's a user message where we send the audio every time.
return {
"type": "input_audio",
"input_audio": {
"data": base64.b64encode(artifact.value).decode("utf-8"),
"format": artifact.format,
},
}
if isinstance(content, ActionCallMessageContent):
action = content.artifact.value
return {
"type": "function",
"id": action.tag,
"function": {"name": action.to_native_tool_name(), "arguments": json.dumps(action.input)},
}
if isinstance(content, ActionResultMessageContent):
return content.artifact.to_text()
return {"type": "text", "text": content.artifact.to_text()}
def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) -> list[BaseMessageContent]:
content = []
if response.content is not None:
content.append(TextMessageContent(TextArtifact(response.content)))
if hasattr(response, "audio") and response.audio is not None:
content.append(
AudioMessageContent(
AudioArtifact(
value=base64.b64decode(response.audio.data),
format="wav",
meta={
"audio_id": response.audio.id,
"transcript": response.audio.transcript,
"expires_at": response.audio.expires_at,
},
)
)
)
if response.tool_calls is not None:
content.extend(
[
ActionCallMessageContent(
ActionArtifact(
ToolAction(
tag=tool_call.id,
name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
input=json.loads(tool_call.function.arguments),
),
),
)
for tool_call in response.tool_calls
],
)
return content
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]:
if content_delta.content is not None:
return TextDeltaMessageContent(content_delta.content)
if content_delta.tool_calls is not None:
tool_calls = content_delta.tool_calls
if len(tool_calls) == 1:
tool_call = tool_calls[0]
index = tool_call.index
if tool_call.function is not None:
function_name = tool_call.function.name
return ActionCallDeltaMessageContent(
index=index,
tag=tool_call.id,
name=ToolAction.from_native_tool_name(function_name)[0] if function_name else None,
path=ToolAction.from_native_tool_name(function_name)[1] if function_name else None,
partial_input=tool_call.function.arguments,
)
raise ValueError(f"Unsupported tool call delta: {tool_call}")
raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}")
# OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
if hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None:
audio_chunk: dict = getattr(content_delta, "audio")
return AudioDeltaMessageContent(
id=audio_chunk.get("id"),
data=audio_chunk.get("data"),
expires_at=audio_chunk.get("expires_at"),
transcript=audio_chunk.get("transcript"),
)
return None