Skip to content

Base chunker


Bases: ABC

Source code in griptape/chunkers/
class BaseChunker(ABC):
    DEFAULT_SEPARATORS = [ChunkSeparator(" ")]

    separators: list[ChunkSeparator] = field(
        default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True), kw_only=True
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), kw_only=True
    max_tokens: int = field(
        default=Factory(lambda self: self.tokenizer.max_input_tokens, takes_self=True), kw_only=True

    def chunk(self, text: TextArtifact | str) -> list[TextArtifact]:
        text = text.value if isinstance(text, TextArtifact) else text

        return [TextArtifact(c) for c in self._chunk_recursively(text)]

    def _chunk_recursively(self, chunk: str, current_separator: Optional[ChunkSeparator] = None) -> list[str]:
        token_count = self.tokenizer.count_tokens(chunk)

        if token_count <= self.max_tokens:
            return [chunk]
            balance_index = -1
            balance_diff = float("inf")
            tokens_count = 0
            half_token_count = token_count // 2

            # If a separator is provided, only use separators after it.
            if current_separator:
                separators = self.separators[self.separators.index(current_separator) :]
                separators = self.separators

            # Loop through available separators to find the best split.
            for separator in separators:
                # Split the chunk into subchunks using the current separator.
                subchunks = list(filter(None, chunk.split(separator.value)))

                # Check if the split resulted in more than one subchunk.
                if len(subchunks) > 1:
                    # Iterate through the subchunks and calculate token counts.
                    for index, subchunk in enumerate(subchunks):
                        if index < len(subchunks):
                            if separator.is_prefix:
                                subchunk = separator.value + subchunk
                                subchunk = subchunk + separator.value

                        tokens_count += self.tokenizer.count_tokens(subchunk)

                        # Update the best split if the current one is more balanced.
                        if abs(tokens_count - half_token_count) < balance_diff:
                            balance_index = index
                            balance_diff = abs(tokens_count - half_token_count)

                    # Create the two subchunks based on the best separator.
                    if separator.is_prefix:
                        # If the separator is a prefix, append it before this subchunk.
                        first_subchunk = separator.value + separator.value.join(subchunks[: balance_index + 1])
                        second_subchunk = separator.value + separator.value.join(subchunks[balance_index + 1 :])
                        # If the separator is not a prefix, append it after this subchunk.
                        first_subchunk = separator.value.join(subchunks[: balance_index + 1]) + separator.value
                        second_subchunk = separator.value.join(subchunks[balance_index + 1 :])

                    # Continue recursively chunking the subchunks.
                    first_subchunk_rec = self._chunk_recursively(first_subchunk.strip(), separator)
                    second_subchunk_rec = self._chunk_recursively(second_subchunk.strip(), separator)

                    # Return the concatenated results of the subchunks if both are non-empty.
                    if first_subchunk_rec and second_subchunk_rec:
                        return first_subchunk_rec + second_subchunk_rec
                    # If only one subchunk is non-empty, return it.
                    elif first_subchunk_rec:
                        return first_subchunk_rec
                    elif second_subchunk_rec:
                        return second_subchunk_rec
                        return []
            # If none of the separators result in a balanced split, split the chunk in half.
            midpoint = len(chunk) // 2
            return self._chunk_recursively(chunk[:midpoint]) + self._chunk_recursively(chunk[midpoint:])

DEFAULT_SEPARATORS = [ChunkSeparator(' ')] class-attribute instance-attribute

max_tokens: int = field(default=Factory(lambda self: self.tokenizer.max_input_tokens, takes_self=True), kw_only=True) class-attribute instance-attribute

separators: list[ChunkSeparator] = field(default=Factory(lambda self: self.DEFAULT_SEPARATORS, takes_self=True), kw_only=True) class-attribute instance-attribute

tokenizer: BaseTokenizer = field(default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), kw_only=True) class-attribute instance-attribute


Source code in griptape/chunkers/
def chunk(self, text: TextArtifact | str) -> list[TextArtifact]:
    text = text.value if isinstance(text, TextArtifact) else text

    return [TextArtifact(c) for c in self._chunk_recursively(text)]