Skip to content

openai

__all__ = ['OpenAiEmbeddingDriver', 'AzureOpenAiEmbeddingDriver'] module-attribute

AzureOpenAiEmbeddingDriver

Bases: OpenAiEmbeddingDriver

Azure OpenAi Embedding Driver.

Attributes:

Name Type Description
azure_deployment str

An optional Azure OpenAi deployment id. Defaults to the model name.

azure_endpoint str

An Azure OpenAi endpoint.

azure_ad_token Optional[str]

An optional Azure Active Directory token.

azure_ad_token_provider Optional[Callable[[], str]]

An optional Azure Active Directory token provider.

api_version str

An Azure OpenAi API version.

tokenizer OpenAiTokenizer

An OpenAiTokenizer.

client AzureOpenAI

An openai.AzureOpenAI client.

Source code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@define
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
    """Azure OpenAi Embedding Driver.

    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        tokenizer: An `OpenAiTokenizer`.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(
        kw_only=True,
        default=Factory(lambda self: self.model, takes_self=True),
        metadata={"serializable": True},
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True,
        default=None,
        metadata={"serializable": False},
    )
    api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    _client: openai.AzureOpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> openai.AzureOpenAI:
        return openai.AzureOpenAI(
            organization=self.organization,
            api_key=self.api_key,
            api_version=self.api_version,
            azure_endpoint=self.azure_endpoint,
            azure_deployment=self.azure_deployment,
            azure_ad_token=self.azure_ad_token,
            azure_ad_token_provider=self.azure_ad_token_provider,
        )

api_version: str = field(default='2024-10-21', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_ad_token_provider: Optional[Callable[[], str]] = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

azure_deployment: str = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

azure_endpoint: str = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

client()

Source code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@lazy_property()
def client(self) -> openai.AzureOpenAI:
    return openai.AzureOpenAI(
        organization=self.organization,
        api_key=self.api_key,
        api_version=self.api_version,
        azure_endpoint=self.azure_endpoint,
        azure_deployment=self.azure_deployment,
        azure_ad_token=self.azure_ad_token,
        azure_ad_token_provider=self.azure_ad_token_provider,
    )

OpenAiEmbeddingDriver

Bases: BaseEmbeddingDriver

OpenAI Embedding Driver.

Attributes:

Name Type Description
model str

OpenAI embedding model name. Defaults to text-embedding-3-small.

base_url Optional[str]

API URL. Defaults to OpenAI's v1 API URL.

api_key Optional[str]

API key to pass directly. Defaults to OPENAI_API_KEY environment variable.

organization Optional[str]

OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.

tokenizer OpenAiTokenizer

Optionally provide custom OpenAiTokenizer.

client OpenAI

Optionally provide custom openai.OpenAI client.

azure_deployment OpenAI

An Azure OpenAi deployment id.

azure_endpoint OpenAI

An Azure OpenAi endpoint.

azure_ad_token OpenAI

An optional Azure Active Directory token.

azure_ad_token_provider OpenAI

An optional Azure Active Directory token provider.

api_version OpenAI

An Azure OpenAi API version.

Source code in griptape/drivers/embedding/openai_embedding_driver.py
@define
class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
    """OpenAI Embedding Driver.

    Attributes:
        model: OpenAI embedding model name. Defaults to `text-embedding-3-small`.
        base_url: API URL. Defaults to OpenAI's v1 API URL.
        api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
        organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
        tokenizer: Optionally provide custom `OpenAiTokenizer`.
        client: Optionally provide custom `openai.OpenAI` client.
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
    """

    DEFAULT_MODEL = "text-embedding-3-small"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    _client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

    def try_embed_chunk(self, chunk: str) -> list[float]:
        # Address a performance issue in older ada models
        # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
        if self.model.endswith("001"):
            chunk = chunk.replace("\n", " ")
        return self.client.embeddings.create(**self._params(chunk)).data[0].embedding

    def _params(self, chunk: str) -> dict:
        return {"input": chunk, "model": self.model}

DEFAULT_MODEL = 'text-embedding-3-small' class-attribute instance-attribute

api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

organization: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

client()

Source code in griptape/drivers/embedding/openai_embedding_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

try_embed_chunk(chunk)

Source code in griptape/drivers/embedding/openai_embedding_driver.py
def try_embed_chunk(self, chunk: str) -> list[float]:
    # Address a performance issue in older ada models
    # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
    if self.model.endswith("001"):
        chunk = chunk.replace("\n", " ")
    return self.client.embeddings.create(**self._params(chunk)).data[0].embedding