Bases: BaseEmbeddingDriver
OpenAI Embedding Driver.
Attributes:
Name |
Type |
Description |
model |
str
|
OpenAI embedding model name. Defaults to text-embedding-3-small .
|
base_url |
Optional[str]
|
API URL. Defaults to OpenAI's v1 API URL.
|
api_key |
Optional[str]
|
API key to pass directly. Defaults to OPENAI_API_KEY environment variable.
|
organization |
Optional[str]
|
OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
|
tokenizer |
OpenAiTokenizer
|
Optionally provide custom OpenAiTokenizer .
|
client |
OpenAI
|
Optionally provide custom openai.OpenAI client.
|
azure_deployment |
OpenAI
|
An Azure OpenAi deployment id.
|
azure_endpoint |
OpenAI
|
An Azure OpenAi endpoint.
|
azure_ad_token |
OpenAI
|
An optional Azure Active Directory token.
|
azure_ad_token_provider |
OpenAI
|
An optional Azure Active Directory token provider.
|
api_version |
OpenAI
|
An Azure OpenAi API version.
|
Source code in griptape/drivers/embedding/openai_embedding_driver.py
| @define
class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
"""OpenAI Embedding Driver.
Attributes:
model: OpenAI embedding model name. Defaults to `text-embedding-3-small`.
base_url: API URL. Defaults to OpenAI's v1 API URL.
api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
tokenizer: Optionally provide custom `OpenAiTokenizer`.
client: Optionally provide custom `openai.OpenAI` client.
azure_deployment: An Azure OpenAi deployment id.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
"""
DEFAULT_MODEL = "text-embedding-3-small"
model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
@lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
def try_embed_chunk(self, chunk: str) -> list[float]:
# Address a performance issue in older ada models
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
if self.model.endswith("001"):
chunk = chunk.replace("\n", " ")
return self.client.embeddings.create(**self._params(chunk)).data[0].embedding
def _params(self, chunk: str) -> dict:
return {"input": chunk, "model": self.model}
|
DEFAULT_MODEL = 'text-embedding-3-small'
class-attribute
instance-attribute
api_key: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute
instance-attribute
base_url: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
organization: Optional[str] = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute
instance-attribute
tokenizer: OpenAiTokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute
instance-attribute
client()
Source code in griptape/drivers/embedding/openai_embedding_driver.py
| @lazy_property()
def client(self) -> openai.OpenAI:
return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
|
try_embed_chunk(chunk)
Source code in griptape/drivers/embedding/openai_embedding_driver.py
| def try_embed_chunk(self, chunk: str) -> list[float]:
# Address a performance issue in older ada models
# https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
if self.model.endswith("001"):
chunk = chunk.replace("\n", " ")
return self.client.embeddings.create(**self._params(chunk)).data[0].embedding
|