Skip to content

azure_openai_chat_prompt_driver

AzureOpenAiChatPromptDriver

Bases: OpenAiChatPromptDriver

Azure OpenAi Chat Prompt Driver.

Attributes:

Name Type Description
azure_deployment str

An optional Azure OpenAi deployment id. Defaults to the model name.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@define
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
    """Azure OpenAi Chat Prompt Driver.

    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(
        kw_only=True,
        default=Factory(lambda self: self.model, takes_self=True),
        metadata={"serializable": True},
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True,
        default=None,
        metadata={"serializable": False},
    )
    api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True})
    _client: Optional[openai.AzureOpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.AzureOpenAI:
        return openai.AzureOpenAI(
            organization=self.organization,
            api_key=self.api_key,
            api_version=self.api_version,
            azure_endpoint=self.azure_endpoint,
            azure_deployment=self.azure_deployment,
            azure_ad_token=self.azure_ad_token,
            azure_ad_token_provider=self.azure_ad_token_provider,
        )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = super()._base_params(prompt_stack)
        if self.api_version < "2024-02-01" and "seed" in params:
            del params["seed"]
        if self.api_version < "2024-10-21":
            if "stream_options" in params:
                del params["stream_options"]
            if "parallel_tool_calls" in params:
                del params["parallel_tool_calls"]

        # TODO: Add once Azure supports modalities
        if "modalities" in params:
            del params["modalities"]
        return params

_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

api_version = field(default='2024-10-21', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_base_params(prompt_stack)

Source code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    params = super()._base_params(prompt_stack)
    if self.api_version < "2024-02-01" and "seed" in params:
        del params["seed"]
    if self.api_version < "2024-10-21":
        if "stream_options" in params:
            del params["stream_options"]
        if "parallel_tool_calls" in params:
            del params["parallel_tool_calls"]

    # TODO: Add once Azure supports modalities
    if "modalities" in params:
        del params["modalities"]
    return params

client()

Source code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@lazy_property()
def client(self) -> openai.AzureOpenAI:
    return openai.AzureOpenAI(
        organization=self.organization,
        api_key=self.api_key,
        api_version=self.api_version,
        azure_endpoint=self.azure_endpoint,
        azure_deployment=self.azure_deployment,
        azure_ad_token=self.azure_ad_token,
        azure_ad_token_provider=self.azure_ad_token_provider,
    )