Skip to content

Openai completion prompt driver

OpenAiCompletionPromptDriver

Bases: BasePromptDriver

Attributes:

Name Type Description
base_url Optional[str]

An optional OpenAi API URL.

api_key Optional[str]

An optional OpenAi API key. If not provided, the OPENAI_API_KEY environment variable will be used.

organization Optional[str]

An optional OpenAI organization. If not provided, the OPENAI_ORG_ID environment variable will be used.

client OpenAI

An openai.OpenAI client.

model str

An OpenAI model name.

tokenizer OpenAiTokenizer

An OpenAiTokenizer.

user str

A user id. Can be used to track requests by user.

ignored_exception_types tuple[type[Exception], ...]

An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.

Source code in griptape/drivers/prompt/openai_completion_prompt_driver.py
@define
class OpenAiCompletionPromptDriver(BasePromptDriver):
    """
    Attributes:
        base_url: An optional OpenAi API URL.
        api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used.
        organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used.
        client: An `openai.OpenAI` client.
        model: An OpenAI model name.
        tokenizer: An `OpenAiTokenizer`.
        user: A user id. Can be used to track requests by user.
        ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
    """

    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    client: openai.OpenAI = field(
        default=Factory(
            lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
            takes_self=True,
        )
    )
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
    )
    user: str = field(default="", kw_only=True, metadata={"serializable": True})
    ignored_exception_types: tuple[type[Exception], ...] = field(
        default=Factory(
            lambda: (
                openai.BadRequestError,
                openai.AuthenticationError,
                openai.PermissionDeniedError,
                openai.NotFoundError,
                openai.ConflictError,
                openai.UnprocessableEntityError,
            )
        ),
        kw_only=True,
    )

    def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
        result = self.client.completions.create(**self._base_params(prompt_stack))

        if len(result.choices) == 1:
            return TextArtifact(value=result.choices[0].text.strip())
        else:
            raise Exception("completion with more than one choice is not supported yet")

    def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
        result = self.client.completions.create(**self._base_params(prompt_stack), stream=True)

        for chunk in result:
            if len(chunk.choices) == 1:
                choice = chunk.choices[0]
                delta_content = choice.text
                yield TextArtifact(value=delta_content)

            else:
                raise Exception("completion with more than one choice is not supported yet")

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        prompt = self.prompt_stack_to_string(prompt_stack)

        return {
            "model": self.model,
            "max_tokens": self.max_output_tokens(prompt),
            "temperature": self.temperature,
            "stop": self.tokenizer.stop_sequences,
            "user": self.user,
            "prompt": prompt,
        }

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client: openai.OpenAI = field(default=Factory(lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), takes_self=True)) class-attribute instance-attribute

ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (openai.BadRequestError, openai.AuthenticationError, openai.PermissionDeniedError, openai.NotFoundError, openai.ConflictError, openai.UnprocessableEntityError)), kw_only=True) class-attribute instance-attribute

model: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

organization: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

user: str = field(default='', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_run(prompt_stack)

Source code in griptape/drivers/prompt/openai_completion_prompt_driver.py
def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
    result = self.client.completions.create(**self._base_params(prompt_stack))

    if len(result.choices) == 1:
        return TextArtifact(value=result.choices[0].text.strip())
    else:
        raise Exception("completion with more than one choice is not supported yet")

try_stream(prompt_stack)

Source code in griptape/drivers/prompt/openai_completion_prompt_driver.py
def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
    result = self.client.completions.create(**self._base_params(prompt_stack), stream=True)

    for chunk in result:
        if len(chunk.choices) == 1:
            choice = chunk.choices[0]
            delta_content = choice.text
            yield TextArtifact(value=delta_content)

        else:
            raise Exception("completion with more than one choice is not supported yet")