Skip to content

openai_assistant_driver

OpenAiAssistantDriver

Bases: BaseAssistantDriver

Source code in griptape/drivers/assistant/openai_assistant_driver.py
@define
class OpenAiAssistantDriver(BaseAssistantDriver):
    class EventHandler(AssistantEventHandler):
        @override
        def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
            if delta.value is not None:
                EventBus.publish_event(TextChunkEvent(token=delta.value))

        @override
        def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
            if delta.type == "code_interpreter" and delta.code_interpreter is not None:
                if delta.code_interpreter.input:
                    EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input))
                if delta.code_interpreter.outputs:
                    EventBus.publish_event(TextChunkEvent(token="\n\noutput >"))
                    for output in delta.code_interpreter.outputs:
                        if output.type == "logs" and output.logs:
                            EventBus.publish_event(TextChunkEvent(token=output.logs))

    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    thread_id: Optional[str] = field(kw_only=True)
    assistant_id: str = field(kw_only=True)
    event_handler: AssistantEventHandler = field(
        default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={"serializable": False}
    )

    _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(
            base_url=self.base_url,
            api_key=self.api_key,
            organization=self.organization,
        )

    def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
        content = "\n".join(arg.value for arg in args)
        self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=content)
        with self.client.beta.threads.runs.stream(
            thread_id=self.thread_id,
            assistant_id=self.assistant_id,
            event_handler=self.event_handler,
        ) as stream:
            stream.until_done()
            last_messages = stream.get_final_messages()

            message_contents = []
            for message in last_messages:
                message_contents.append("".join(content.text.value for content in message.content))
            message_text = "\n".join(message_contents)

            return TextArtifact(message_text)

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

assistant_id: str = field(kw_only=True) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

event_handler: AssistantEventHandler = field(default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

organization: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

thread_id: Optional[str] = field(kw_only=True) class-attribute instance-attribute

EventHandler

Bases: AssistantEventHandler

Source code in griptape/drivers/assistant/openai_assistant_driver.py
class EventHandler(AssistantEventHandler):
    @override
    def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
        if delta.value is not None:
            EventBus.publish_event(TextChunkEvent(token=delta.value))

    @override
    def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
        if delta.type == "code_interpreter" and delta.code_interpreter is not None:
            if delta.code_interpreter.input:
                EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input))
            if delta.code_interpreter.outputs:
                EventBus.publish_event(TextChunkEvent(token="\n\noutput >"))
                for output in delta.code_interpreter.outputs:
                    if output.type == "logs" and output.logs:
                        EventBus.publish_event(TextChunkEvent(token=output.logs))
on_text_delta(delta, snapshot)
Source code in griptape/drivers/assistant/openai_assistant_driver.py
@override
def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
    if delta.value is not None:
        EventBus.publish_event(TextChunkEvent(token=delta.value))
on_tool_call_delta(delta, snapshot)
Source code in griptape/drivers/assistant/openai_assistant_driver.py
@override
def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
    if delta.type == "code_interpreter" and delta.code_interpreter is not None:
        if delta.code_interpreter.input:
            EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input))
        if delta.code_interpreter.outputs:
            EventBus.publish_event(TextChunkEvent(token="\n\noutput >"))
            for output in delta.code_interpreter.outputs:
                if output.type == "logs" and output.logs:
                    EventBus.publish_event(TextChunkEvent(token=output.logs))

client()

Source code in griptape/drivers/assistant/openai_assistant_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(
        base_url=self.base_url,
        api_key=self.api_key,
        organization=self.organization,
    )

try_run(*args)

Source code in griptape/drivers/assistant/openai_assistant_driver.py
def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
    content = "\n".join(arg.value for arg in args)
    self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=content)
    with self.client.beta.threads.runs.stream(
        thread_id=self.thread_id,
        assistant_id=self.assistant_id,
        event_handler=self.event_handler,
    ) as stream:
        stream.until_done()
        last_messages = stream.get_final_messages()

        message_contents = []
        for message in last_messages:
            message_contents.append("".join(content.text.value for content in message.content))
        message_text = "\n".join(message_contents)

        return TextArtifact(message_text)