Skip to content

engines

__all__ = ['BaseEvalEngine', 'BaseExtractionEngine', 'BaseSummaryEngine', 'CsvExtractionEngine', 'EvalEngine', 'JsonExtractionEngine', 'PromptSummaryEngine', 'RagEngine'] module-attribute

BaseEvalEngine

Bases: ABC

Source code in griptape/engines/eval/base_eval_engine.py
@define
class BaseEvalEngine(ABC): ...

BaseExtractionEngine

Bases: ABC

Source code in griptape/engines/extraction/base_extraction_engine.py
@define
class BaseExtractionEngine(ABC):
    max_token_multiplier: float = field(default=0.5, kw_only=True)
    chunk_joiner: str = field(default="\n\n", kw_only=True)
    prompt_driver: BasePromptDriver = field(
        default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True
    )
    chunker: BaseChunker = field(
        default=Factory(
            lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )

    @max_token_multiplier.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_max_token_multiplier(self, _: Attribute, max_token_multiplier: int) -> None:
        if max_token_multiplier > 1:
            raise ValueError("has to be less than or equal to 1")
        if max_token_multiplier <= 0:
            raise ValueError("has to be greater than 0")

    @property
    def max_chunker_tokens(self) -> int:
        return round(self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier)

    @property
    def min_response_tokens(self) -> int:
        return round(
            self.prompt_driver.tokenizer.max_input_tokens
            - self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier,
        )

    def extract_text(
        self,
        text: str,
        *,
        rulesets: Optional[list[Ruleset]] = None,
        **kwargs,
    ) -> ListArtifact:
        return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs)

    @abstractmethod
    def extract_artifacts(
        self,
        artifacts: ListArtifact[TextArtifact],
        *,
        rulesets: Optional[list[Ruleset]] = None,
        **kwargs,
    ) -> ListArtifact: ...

chunk_joiner = field(default='\n\n', kw_only=True) class-attribute instance-attribute

chunker = field(default=Factory(lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

max_chunker_tokens property

max_token_multiplier = field(default=0.5, kw_only=True) class-attribute instance-attribute

min_response_tokens property

prompt_driver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True) class-attribute instance-attribute

extract_artifacts(artifacts, *, rulesets=None, **kwargs) abstractmethod

Source code in griptape/engines/extraction/base_extraction_engine.py
@abstractmethod
def extract_artifacts(
    self,
    artifacts: ListArtifact[TextArtifact],
    *,
    rulesets: Optional[list[Ruleset]] = None,
    **kwargs,
) -> ListArtifact: ...

extract_text(text, *, rulesets=None, **kwargs)

Source code in griptape/engines/extraction/base_extraction_engine.py
def extract_text(
    self,
    text: str,
    *,
    rulesets: Optional[list[Ruleset]] = None,
    **kwargs,
) -> ListArtifact:
    return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs)

validate_max_token_multiplier(_, max_token_multiplier)

Source code in griptape/engines/extraction/base_extraction_engine.py
@max_token_multiplier.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_max_token_multiplier(self, _: Attribute, max_token_multiplier: int) -> None:
    if max_token_multiplier > 1:
        raise ValueError("has to be less than or equal to 1")
    if max_token_multiplier <= 0:
        raise ValueError("has to be greater than 0")

BaseSummaryEngine

Bases: ABC

Source code in griptape/engines/summary/base_summary_engine.py
@define
class BaseSummaryEngine(ABC):
    def summarize_text(self, text: str, *, rulesets: Optional[list[Ruleset]] = None) -> str:
        return self.summarize_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets).value

    @abstractmethod
    def summarize_artifacts(
        self,
        artifacts: ListArtifact,
        *,
        rulesets: Optional[list[Ruleset]] = None,
    ) -> TextArtifact: ...

summarize_artifacts(artifacts, *, rulesets=None) abstractmethod

Source code in griptape/engines/summary/base_summary_engine.py
@abstractmethod
def summarize_artifacts(
    self,
    artifacts: ListArtifact,
    *,
    rulesets: Optional[list[Ruleset]] = None,
) -> TextArtifact: ...

summarize_text(text, *, rulesets=None)

Source code in griptape/engines/summary/base_summary_engine.py
def summarize_text(self, text: str, *, rulesets: Optional[list[Ruleset]] = None) -> str:
    return self.summarize_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets).value

CsvExtractionEngine

Bases: BaseExtractionEngine

Source code in griptape/engines/extraction/csv_extraction_engine.py
@define
class CsvExtractionEngine(BaseExtractionEngine):
    column_names: list[str] = field(kw_only=True)
    generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
    generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
    format_header: Callable[[list[str]], str] = field(
        default=Factory(lambda: lambda value: ",".join(value)), kw_only=True
    )
    format_row: Callable[[dict], str] = field(
        default=Factory(lambda: lambda value: ",".join([value or "" for value in value.values()])), kw_only=True
    )

    def extract_artifacts(
        self,
        artifacts: ListArtifact[TextArtifact],
        *,
        rulesets: Optional[list[Ruleset]] = None,
        **kwargs,
    ) -> ListArtifact[TextArtifact]:
        return ListArtifact(
            self._extract_rec(
                cast("list[TextArtifact]", artifacts.value),
                [TextArtifact(self.format_header(self.column_names))],
                rulesets=rulesets,
            ),
            item_separator="\n",
        )

    def text_to_csv_rows(self, text: str) -> list[TextArtifact]:
        rows = []

        with io.StringIO(text) as f:
            for row in csv.DictReader(f):
                rows.append(TextArtifact(self.format_row(row)))

        return rows

    def _extract_rec(
        self,
        artifacts: list[TextArtifact],
        rows: list[TextArtifact],
        *,
        rulesets: Optional[list[Ruleset]] = None,
    ) -> list[TextArtifact]:
        artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
        system_prompt = self.generate_system_template.render(
            column_names=self.column_names,
            rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
        )
        user_prompt = self.generate_user_template.render(
            text=artifacts_text,
        )

        if (
            self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt)
            >= self.min_response_tokens
        ):
            rows.extend(
                self.text_to_csv_rows(
                    self.prompt_driver.run(
                        PromptStack(
                            messages=[
                                Message(system_prompt, role=Message.SYSTEM_ROLE),
                                Message(user_prompt, role=Message.USER_ROLE),
                            ]
                        )
                    ).value,
                ),
            )

            return rows
        chunks = self.chunker.chunk(artifacts_text)
        partial_text = self.generate_user_template.render(
            text=chunks[0].value,
        )

        rows.extend(
            self.text_to_csv_rows(
                self.prompt_driver.run(
                    PromptStack(
                        messages=[
                            Message(system_prompt, role=Message.SYSTEM_ROLE),
                            Message(partial_text, role=Message.USER_ROLE),
                        ]
                    )
                ).value,
            ),
        )

        return self._extract_rec(chunks[1:], rows, rulesets=rulesets)

column_names = field(kw_only=True) class-attribute instance-attribute

format_header = field(default=Factory(lambda: lambda value: ','.join(value)), kw_only=True) class-attribute instance-attribute

format_row = field(default=Factory(lambda: lambda value: ','.join([value or '' for value in value.values()])), kw_only=True) class-attribute instance-attribute

generate_system_template = field(default=Factory(lambda: J2('engines/extraction/csv/system.j2')), kw_only=True) class-attribute instance-attribute

generate_user_template = field(default=Factory(lambda: J2('engines/extraction/csv/user.j2')), kw_only=True) class-attribute instance-attribute

_extract_rec(artifacts, rows, *, rulesets=None)

Source code in griptape/engines/extraction/csv_extraction_engine.py
def _extract_rec(
    self,
    artifacts: list[TextArtifact],
    rows: list[TextArtifact],
    *,
    rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
    artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
    system_prompt = self.generate_system_template.render(
        column_names=self.column_names,
        rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
    )
    user_prompt = self.generate_user_template.render(
        text=artifacts_text,
    )

    if (
        self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt)
        >= self.min_response_tokens
    ):
        rows.extend(
            self.text_to_csv_rows(
                self.prompt_driver.run(
                    PromptStack(
                        messages=[
                            Message(system_prompt, role=Message.SYSTEM_ROLE),
                            Message(user_prompt, role=Message.USER_ROLE),
                        ]
                    )
                ).value,
            ),
        )

        return rows
    chunks = self.chunker.chunk(artifacts_text)
    partial_text = self.generate_user_template.render(
        text=chunks[0].value,
    )

    rows.extend(
        self.text_to_csv_rows(
            self.prompt_driver.run(
                PromptStack(
                    messages=[
                        Message(system_prompt, role=Message.SYSTEM_ROLE),
                        Message(partial_text, role=Message.USER_ROLE),
                    ]
                )
            ).value,
        ),
    )

    return self._extract_rec(chunks[1:], rows, rulesets=rulesets)

extract_artifacts(artifacts, *, rulesets=None, **kwargs)

Source code in griptape/engines/extraction/csv_extraction_engine.py
def extract_artifacts(
    self,
    artifacts: ListArtifact[TextArtifact],
    *,
    rulesets: Optional[list[Ruleset]] = None,
    **kwargs,
) -> ListArtifact[TextArtifact]:
    return ListArtifact(
        self._extract_rec(
            cast("list[TextArtifact]", artifacts.value),
            [TextArtifact(self.format_header(self.column_names))],
            rulesets=rulesets,
        ),
        item_separator="\n",
    )

text_to_csv_rows(text)

Source code in griptape/engines/extraction/csv_extraction_engine.py
def text_to_csv_rows(self, text: str) -> list[TextArtifact]:
    rows = []

    with io.StringIO(text) as f:
        for row in csv.DictReader(f):
            rows.append(TextArtifact(self.format_row(row)))

    return rows

EvalEngine

Bases: BaseEvalEngine, SerializableMixin

Source code in griptape/engines/eval/eval_engine.py
@define(kw_only=True)
class EvalEngine(BaseEvalEngine, SerializableMixin):
    id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
    name: str = field(
        default=Factory(lambda self: self.id, takes_self=True),
        metadata={"serializable": True},
    )
    criteria: Optional[str] = field(default=None, metadata={"serializable": True})
    evaluation_steps: Optional[list[str]] = field(default=None, metadata={"serializable": True})
    prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver))
    generate_steps_system_template: J2 = field(default=Factory(lambda: J2("engines/eval/steps/system.j2")))
    generate_steps_user_template: J2 = field(default=Factory(lambda: J2("engines/eval/steps/user.j2")))
    generate_results_system_template: J2 = field(default=Factory(lambda: J2("engines/eval/results/system.j2")))
    generate_results_user_template: J2 = field(default=Factory(lambda: J2("engines/eval/results/user.j2")))

    @criteria.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_criteria(self, _: Attribute, value: Optional[str]) -> None:
        if value is None:
            if self.evaluation_steps is None:
                raise ValueError("either criteria or evaluation_steps must be specified")
            return

        if self.evaluation_steps is not None:
            raise ValueError("can't have both criteria and evaluation_steps specified")

        if not value:
            raise ValueError("criteria must not be empty")

    @evaluation_steps.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_evaluation_steps(self, _: Attribute, value: Optional[list[str]]) -> None:
        if value is None:
            if self.criteria is None:
                raise ValueError("either evaluation_steps or criteria must be specified")
            return

        if self.criteria is not None:
            raise ValueError("can't have both evaluation_steps and criteria specified")

        if not value:
            raise ValueError("evaluation_steps must not be empty")

    def evaluate(self, input: str, actual_output: str, **kwargs) -> tuple[float, str]:  # noqa: A002
        evaluation_params = {
            key.replace("_", " ").title(): value
            for key, value in {"input": input, "actual_output": actual_output, **kwargs}.items()
        }

        if self.evaluation_steps is None:
            # Need to disable validators to allow for both `criteria` and `evaluation_steps` to be set
            with validators.disabled():
                self.evaluation_steps = self._generate_steps(evaluation_params)

        return self._generate_results(evaluation_params)

    def _generate_steps(self, evaluation_params: dict[str, str]) -> list[str]:
        system_prompt = self.generate_steps_system_template.render(
            evaluation_params=", ".join(param for param in evaluation_params),
            criteria=self.criteria,
        )
        user_prompt = self.generate_steps_user_template.render()

        result = self.prompt_driver.run(
            PromptStack(
                messages=[
                    Message(system_prompt, role=Message.SYSTEM_ROLE),
                    Message(user_prompt, role=Message.USER_ROLE),
                ],
                output_schema=STEPS_SCHEMA,
            ),
        ).to_artifact()

        parsed_result = json.loads(result.value)

        return parsed_result["steps"]

    def _generate_results(self, evaluation_params: dict[str, str]) -> tuple[float, str]:
        system_prompt = self.generate_results_system_template.render(
            evaluation_params=", ".join(param for param in evaluation_params),
            evaluation_steps=self.evaluation_steps,
            evaluation_text="\n\n".join(f"{key}: {value}" for key, value in evaluation_params.items()),
        )
        user_prompt = self.generate_results_user_template.render()

        result = self.prompt_driver.run(
            PromptStack(
                messages=[
                    Message(system_prompt, role=Message.SYSTEM_ROLE),
                    Message(user_prompt, role=Message.USER_ROLE),
                ],
                output_schema=RESULTS_SCHEMA,
            ),
        ).to_text()

        parsed_result = json.loads(result)

        # Better to have the LLM deal strictly with integers to avoid ambiguities with floating point precision.
        # We want the user to receive a float, however.
        score = float(parsed_result["score"]) / 10
        reason = parsed_result["reason"]

        return score, reason

criteria = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

evaluation_steps = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

generate_results_system_template = field(default=Factory(lambda: J2('engines/eval/results/system.j2'))) class-attribute instance-attribute

generate_results_user_template = field(default=Factory(lambda: J2('engines/eval/results/user.j2'))) class-attribute instance-attribute

generate_steps_system_template = field(default=Factory(lambda: J2('engines/eval/steps/system.j2'))) class-attribute instance-attribute

generate_steps_user_template = field(default=Factory(lambda: J2('engines/eval/steps/user.j2'))) class-attribute instance-attribute

id = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

name = field(default=Factory(lambda self: self.id, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

prompt_driver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver)) class-attribute instance-attribute

_generate_results(evaluation_params)

Source code in griptape/engines/eval/eval_engine.py
def _generate_results(self, evaluation_params: dict[str, str]) -> tuple[float, str]:
    system_prompt = self.generate_results_system_template.render(
        evaluation_params=", ".join(param for param in evaluation_params),
        evaluation_steps=self.evaluation_steps,
        evaluation_text="\n\n".join(f"{key}: {value}" for key, value in evaluation_params.items()),
    )
    user_prompt = self.generate_results_user_template.render()

    result = self.prompt_driver.run(
        PromptStack(
            messages=[
                Message(system_prompt, role=Message.SYSTEM_ROLE),
                Message(user_prompt, role=Message.USER_ROLE),
            ],
            output_schema=RESULTS_SCHEMA,
        ),
    ).to_text()

    parsed_result = json.loads(result)

    # Better to have the LLM deal strictly with integers to avoid ambiguities with floating point precision.
    # We want the user to receive a float, however.
    score = float(parsed_result["score"]) / 10
    reason = parsed_result["reason"]

    return score, reason

_generate_steps(evaluation_params)

Source code in griptape/engines/eval/eval_engine.py
def _generate_steps(self, evaluation_params: dict[str, str]) -> list[str]:
    system_prompt = self.generate_steps_system_template.render(
        evaluation_params=", ".join(param for param in evaluation_params),
        criteria=self.criteria,
    )
    user_prompt = self.generate_steps_user_template.render()

    result = self.prompt_driver.run(
        PromptStack(
            messages=[
                Message(system_prompt, role=Message.SYSTEM_ROLE),
                Message(user_prompt, role=Message.USER_ROLE),
            ],
            output_schema=STEPS_SCHEMA,
        ),
    ).to_artifact()

    parsed_result = json.loads(result.value)

    return parsed_result["steps"]

evaluate(input, actual_output, **kwargs)

Source code in griptape/engines/eval/eval_engine.py
def evaluate(self, input: str, actual_output: str, **kwargs) -> tuple[float, str]:  # noqa: A002
    evaluation_params = {
        key.replace("_", " ").title(): value
        for key, value in {"input": input, "actual_output": actual_output, **kwargs}.items()
    }

    if self.evaluation_steps is None:
        # Need to disable validators to allow for both `criteria` and `evaluation_steps` to be set
        with validators.disabled():
            self.evaluation_steps = self._generate_steps(evaluation_params)

    return self._generate_results(evaluation_params)

validate_criteria(_, value)

Source code in griptape/engines/eval/eval_engine.py
@criteria.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_criteria(self, _: Attribute, value: Optional[str]) -> None:
    if value is None:
        if self.evaluation_steps is None:
            raise ValueError("either criteria or evaluation_steps must be specified")
        return

    if self.evaluation_steps is not None:
        raise ValueError("can't have both criteria and evaluation_steps specified")

    if not value:
        raise ValueError("criteria must not be empty")

validate_evaluation_steps(_, value)

Source code in griptape/engines/eval/eval_engine.py
@evaluation_steps.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_evaluation_steps(self, _: Attribute, value: Optional[list[str]]) -> None:
    if value is None:
        if self.criteria is None:
            raise ValueError("either evaluation_steps or criteria must be specified")
        return

    if self.criteria is not None:
        raise ValueError("can't have both evaluation_steps and criteria specified")

    if not value:
        raise ValueError("evaluation_steps must not be empty")

JsonExtractionEngine

Bases: BaseExtractionEngine

Source code in griptape/engines/extraction/json_extraction_engine.py
@define
class JsonExtractionEngine(BaseExtractionEngine):
    JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

    template_schema: dict = field(kw_only=True)
    generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True)
    generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)

    def extract_artifacts(
        self,
        artifacts: ListArtifact[TextArtifact],
        *,
        rulesets: Optional[list[Ruleset]] = None,
        **kwargs,
    ) -> ListArtifact[JsonArtifact]:
        return ListArtifact(
            self._extract_rec(cast("list[TextArtifact]", artifacts.value), [], rulesets=rulesets),
            item_separator="\n",
        )

    def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]:
        json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

        if json_matches:
            return [JsonArtifact(e) for e in json.loads(json_matches[-1])]
        return []

    def _extract_rec(
        self,
        artifacts: list[TextArtifact],
        extractions: list[JsonArtifact],
        *,
        rulesets: Optional[list[Ruleset]] = None,
    ) -> list[JsonArtifact]:
        artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
        system_prompt = self.generate_system_template.render(
            json_template_schema=json.dumps(self.template_schema),
            rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
        )
        user_prompt = self.generate_user_template.render(
            text=artifacts_text,
        )

        if (
            self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt)
            >= self.min_response_tokens
        ):
            extractions.extend(
                self.json_to_text_artifacts(
                    self.prompt_driver.run(
                        PromptStack(
                            messages=[
                                Message(system_prompt, role=Message.SYSTEM_ROLE),
                                Message(user_prompt, role=Message.USER_ROLE),
                            ]
                        )
                    ).value
                ),
            )

            return extractions
        chunks = self.chunker.chunk(artifacts_text)
        partial_text = self.generate_user_template.render(
            text=chunks[0].value,
        )

        extractions.extend(
            self.json_to_text_artifacts(
                self.prompt_driver.run(
                    PromptStack(
                        messages=[
                            Message(system_prompt, role=Message.SYSTEM_ROLE),
                            Message(partial_text, role=Message.USER_ROLE),
                        ]
                    )
                ).value,
            ),
        )

        return self._extract_rec(chunks[1:], extractions, rulesets=rulesets)

JSON_PATTERN = '(?s)[^\\[]*(\\[.*\\])' class-attribute instance-attribute

generate_system_template = field(default=Factory(lambda: J2('engines/extraction/json/system.j2')), kw_only=True) class-attribute instance-attribute

generate_user_template = field(default=Factory(lambda: J2('engines/extraction/json/user.j2')), kw_only=True) class-attribute instance-attribute

template_schema = field(kw_only=True) class-attribute instance-attribute

_extract_rec(artifacts, extractions, *, rulesets=None)

Source code in griptape/engines/extraction/json_extraction_engine.py
def _extract_rec(
    self,
    artifacts: list[TextArtifact],
    extractions: list[JsonArtifact],
    *,
    rulesets: Optional[list[Ruleset]] = None,
) -> list[JsonArtifact]:
    artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
    system_prompt = self.generate_system_template.render(
        json_template_schema=json.dumps(self.template_schema),
        rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
    )
    user_prompt = self.generate_user_template.render(
        text=artifacts_text,
    )

    if (
        self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt)
        >= self.min_response_tokens
    ):
        extractions.extend(
            self.json_to_text_artifacts(
                self.prompt_driver.run(
                    PromptStack(
                        messages=[
                            Message(system_prompt, role=Message.SYSTEM_ROLE),
                            Message(user_prompt, role=Message.USER_ROLE),
                        ]
                    )
                ).value
            ),
        )

        return extractions
    chunks = self.chunker.chunk(artifacts_text)
    partial_text = self.generate_user_template.render(
        text=chunks[0].value,
    )

    extractions.extend(
        self.json_to_text_artifacts(
            self.prompt_driver.run(
                PromptStack(
                    messages=[
                        Message(system_prompt, role=Message.SYSTEM_ROLE),
                        Message(partial_text, role=Message.USER_ROLE),
                    ]
                )
            ).value,
        ),
    )

    return self._extract_rec(chunks[1:], extractions, rulesets=rulesets)

extract_artifacts(artifacts, *, rulesets=None, **kwargs)

Source code in griptape/engines/extraction/json_extraction_engine.py
def extract_artifacts(
    self,
    artifacts: ListArtifact[TextArtifact],
    *,
    rulesets: Optional[list[Ruleset]] = None,
    **kwargs,
) -> ListArtifact[JsonArtifact]:
    return ListArtifact(
        self._extract_rec(cast("list[TextArtifact]", artifacts.value), [], rulesets=rulesets),
        item_separator="\n",
    )

json_to_text_artifacts(json_input)

Source code in griptape/engines/extraction/json_extraction_engine.py
def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]:
    json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

    if json_matches:
        return [JsonArtifact(e) for e in json.loads(json_matches[-1])]
    return []

PromptSummaryEngine

Bases: BaseSummaryEngine

Source code in griptape/engines/summary/prompt_summary_engine.py
@define
class PromptSummaryEngine(BaseSummaryEngine):
    chunk_joiner: str = field(default="\n\n", kw_only=True)
    max_token_multiplier: float = field(default=0.5, kw_only=True)
    generate_system_template: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True)
    generate_user_template: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True)
    prompt_driver: BasePromptDriver = field(
        default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True
    )
    chunker: BaseChunker = field(
        default=Factory(
            lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )

    @max_token_multiplier.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_allowlist(self, _: Attribute, max_token_multiplier: int) -> None:
        if max_token_multiplier > 1:
            raise ValueError("has to be less than or equal to 1")
        if max_token_multiplier <= 0:
            raise ValueError("has to be greater than 0")

    @property
    def max_chunker_tokens(self) -> int:
        return round(self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier)

    @property
    def min_response_tokens(self) -> int:
        return round(
            self.prompt_driver.tokenizer.max_input_tokens
            - self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier,
        )

    def summarize_artifacts(self, artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None) -> TextArtifact:
        return self.summarize_artifacts_rec(cast("list[TextArtifact]", artifacts.value), None, rulesets=rulesets)

    def summarize_artifacts_rec(
        self,
        artifacts: list[TextArtifact],
        summary: Optional[str] = None,
        rulesets: Optional[list[Ruleset]] = None,
    ) -> TextArtifact:
        if not artifacts:
            if summary is None:
                raise ValueError("No artifacts to summarize")
            return TextArtifact(summary)

        artifacts_text = self.chunk_joiner.join([a.to_text() for a in artifacts])

        system_prompt = self.generate_system_template.render(
            summary=summary,
            rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
        )

        user_prompt = self.generate_user_template.render(text=artifacts_text)

        if (
            self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt)
            >= self.min_response_tokens
        ):
            result = self.prompt_driver.run(
                PromptStack(
                    messages=[
                        Message(system_prompt, role=Message.SYSTEM_ROLE),
                        Message(user_prompt, role=Message.USER_ROLE),
                    ],
                ),
            ).to_artifact()

            if isinstance(result, TextArtifact):
                return result
            raise ValueError("Prompt driver did not return a TextArtifact")
        chunks = self.chunker.chunk(artifacts_text)

        partial_text = self.generate_user_template.render(text=chunks[0].value)

        return self.summarize_artifacts_rec(
            chunks[1:],
            self.prompt_driver.run(
                PromptStack(
                    messages=[
                        Message(system_prompt, role=Message.SYSTEM_ROLE),
                        Message(partial_text, role=Message.USER_ROLE),
                    ],
                ),
            ).value,
            rulesets=rulesets,
        )

chunk_joiner = field(default='\n\n', kw_only=True) class-attribute instance-attribute

chunker = field(default=Factory(lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

generate_system_template = field(default=Factory(lambda: J2('engines/summary/system.j2')), kw_only=True) class-attribute instance-attribute

generate_user_template = field(default=Factory(lambda: J2('engines/summary/user.j2')), kw_only=True) class-attribute instance-attribute

max_chunker_tokens property

max_token_multiplier = field(default=0.5, kw_only=True) class-attribute instance-attribute

min_response_tokens property

prompt_driver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True) class-attribute instance-attribute

summarize_artifacts(artifacts, *, rulesets=None)

Source code in griptape/engines/summary/prompt_summary_engine.py
def summarize_artifacts(self, artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None) -> TextArtifact:
    return self.summarize_artifacts_rec(cast("list[TextArtifact]", artifacts.value), None, rulesets=rulesets)

summarize_artifacts_rec(artifacts, summary=None, rulesets=None)

Source code in griptape/engines/summary/prompt_summary_engine.py
def summarize_artifacts_rec(
    self,
    artifacts: list[TextArtifact],
    summary: Optional[str] = None,
    rulesets: Optional[list[Ruleset]] = None,
) -> TextArtifact:
    if not artifacts:
        if summary is None:
            raise ValueError("No artifacts to summarize")
        return TextArtifact(summary)

    artifacts_text = self.chunk_joiner.join([a.to_text() for a in artifacts])

    system_prompt = self.generate_system_template.render(
        summary=summary,
        rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
    )

    user_prompt = self.generate_user_template.render(text=artifacts_text)

    if (
        self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt)
        >= self.min_response_tokens
    ):
        result = self.prompt_driver.run(
            PromptStack(
                messages=[
                    Message(system_prompt, role=Message.SYSTEM_ROLE),
                    Message(user_prompt, role=Message.USER_ROLE),
                ],
            ),
        ).to_artifact()

        if isinstance(result, TextArtifact):
            return result
        raise ValueError("Prompt driver did not return a TextArtifact")
    chunks = self.chunker.chunk(artifacts_text)

    partial_text = self.generate_user_template.render(text=chunks[0].value)

    return self.summarize_artifacts_rec(
        chunks[1:],
        self.prompt_driver.run(
            PromptStack(
                messages=[
                    Message(system_prompt, role=Message.SYSTEM_ROLE),
                    Message(partial_text, role=Message.USER_ROLE),
                ],
            ),
        ).value,
        rulesets=rulesets,
    )

validate_allowlist(_, max_token_multiplier)

Source code in griptape/engines/summary/prompt_summary_engine.py
@max_token_multiplier.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_allowlist(self, _: Attribute, max_token_multiplier: int) -> None:
    if max_token_multiplier > 1:
        raise ValueError("has to be less than or equal to 1")
    if max_token_multiplier <= 0:
        raise ValueError("has to be greater than 0")

RagEngine

Source code in griptape/engines/rag/rag_engine.py
@define(kw_only=True)
class RagEngine:
    query_stage: Optional[QueryRagStage] = field(default=None)
    retrieval_stage: Optional[RetrievalRagStage] = field(default=None)
    response_stage: Optional[ResponseRagStage] = field(default=None)

    def __attrs_post_init__(self) -> None:
        modules = []

        if self.query_stage is not None:
            modules.extend(self.query_stage.modules)

        if self.retrieval_stage is not None:
            modules.extend(self.retrieval_stage.modules)

        if self.response_stage is not None:
            modules.extend(self.response_stage.modules)

        module_names = [m.name for m in modules]

        if len(module_names) > len(set(module_names)):
            raise ValueError("module names have to be unique")

    def process_query(self, query: str) -> RagContext:
        return self.process(RagContext(query=query))

    def process(self, context: RagContext) -> RagContext:
        if self.query_stage:
            context = self.query_stage.run(context)

        if self.retrieval_stage:
            context = self.retrieval_stage.run(context)

        if self.response_stage:
            context = self.response_stage.run(context)

        return context

query_stage = field(default=None) class-attribute instance-attribute

response_stage = field(default=None) class-attribute instance-attribute

retrieval_stage = field(default=None) class-attribute instance-attribute

__attrs_post_init__()

Source code in griptape/engines/rag/rag_engine.py
def __attrs_post_init__(self) -> None:
    modules = []

    if self.query_stage is not None:
        modules.extend(self.query_stage.modules)

    if self.retrieval_stage is not None:
        modules.extend(self.retrieval_stage.modules)

    if self.response_stage is not None:
        modules.extend(self.response_stage.modules)

    module_names = [m.name for m in modules]

    if len(module_names) > len(set(module_names)):
        raise ValueError("module names have to be unique")

process(context)

Source code in griptape/engines/rag/rag_engine.py
def process(self, context: RagContext) -> RagContext:
    if self.query_stage:
        context = self.query_stage.run(context)

    if self.retrieval_stage:
        context = self.retrieval_stage.run(context)

    if self.response_stage:
        context = self.response_stage.run(context)

    return context

process_query(query)

Source code in griptape/engines/rag/rag_engine.py
def process_query(self, query: str) -> RagContext:
    return self.process(RagContext(query=query))