Skip to content

base_tokenizer

BaseTokenizer

Bases: ABC

Source code in griptape/tokenizers/base_tokenizer.py
@define()
class BaseTokenizer(ABC):
    DEFAULT_MAX_INPUT_TOKENS = 4096
    DEFAULT_MAX_OUTPUT_TOKENS = 1000
    MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {}
    MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {}

    model: str = field(kw_only=True)
    stop_sequences: list[str] = field(default=Factory(list), kw_only=True)
    max_input_tokens: int = field(kw_only=True, default=None)
    max_output_tokens: int = field(kw_only=True, default=None)

    def __attrs_post_init__(self) -> None:
        if hasattr(self, "model"):
            if self.max_input_tokens is None:
                self.max_input_tokens = self._default_max_input_tokens()

            if self.max_output_tokens is None:
                self.max_output_tokens = self._default_max_output_tokens()

    def count_input_tokens_left(self, text: str) -> int:
        diff = self.max_input_tokens - self.count_tokens(text)

        if diff > 0:
            return diff
        else:
            return 0

    def count_output_tokens_left(self, text: str) -> int:
        diff = self.max_output_tokens - self.count_tokens(text)

        if diff > 0:
            return diff
        else:
            return 0

    @abstractmethod
    def count_tokens(self, text: str) -> int: ...

    def _default_max_input_tokens(self) -> int:
        tokens = next((v for k, v in self.MODEL_PREFIXES_TO_MAX_INPUT_TOKENS.items() if self.model.startswith(k)), None)

        if tokens is None:
            logging.warning(
                "Model %s not found in MODEL_PREFIXES_TO_MAX_INPUT_TOKENS, using default value of %s.",
                self.model,
                self.DEFAULT_MAX_INPUT_TOKENS,
            )
            return self.DEFAULT_MAX_INPUT_TOKENS
        else:
            return tokens

    def _default_max_output_tokens(self) -> int:
        tokens = next(
            (v for k, v in self.MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS.items() if self.model.startswith(k)),
            None,
        )

        if tokens is None:
            logging.warning(
                "Model %s not found in MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS, using default value of %s.",
                self.model,
                self.DEFAULT_MAX_OUTPUT_TOKENS,
            )
            return self.DEFAULT_MAX_OUTPUT_TOKENS
        else:
            return tokens

DEFAULT_MAX_INPUT_TOKENS = 4096 class-attribute instance-attribute

DEFAULT_MAX_OUTPUT_TOKENS = 1000 class-attribute instance-attribute

MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {} class-attribute instance-attribute

MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {} class-attribute instance-attribute

max_input_tokens: int = field(kw_only=True, default=None) class-attribute instance-attribute

max_output_tokens: int = field(kw_only=True, default=None) class-attribute instance-attribute

model: str = field(kw_only=True) class-attribute instance-attribute

stop_sequences: list[str] = field(default=Factory(list), kw_only=True) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/tokenizers/base_tokenizer.py
def __attrs_post_init__(self) -> None:
    if hasattr(self, "model"):
        if self.max_input_tokens is None:
            self.max_input_tokens = self._default_max_input_tokens()

        if self.max_output_tokens is None:
            self.max_output_tokens = self._default_max_output_tokens()

count_input_tokens_left(text)

Source code in griptape/tokenizers/base_tokenizer.py
def count_input_tokens_left(self, text: str) -> int:
    diff = self.max_input_tokens - self.count_tokens(text)

    if diff > 0:
        return diff
    else:
        return 0

count_output_tokens_left(text)

Source code in griptape/tokenizers/base_tokenizer.py
def count_output_tokens_left(self, text: str) -> int:
    diff = self.max_output_tokens - self.count_tokens(text)

    if diff > 0:
        return diff
    else:
        return 0

count_tokens(text) abstractmethod

Source code in griptape/tokenizers/base_tokenizer.py
@abstractmethod
def count_tokens(self, text: str) -> int: ...