Skip to content

Vector query engine

VectorQueryEngine

Bases: BaseQueryEngine

Source code in griptape/engines/query/vector_query_engine.py
@define
class VectorQueryEngine(BaseQueryEngine):
    answer_token_offset: int = field(default=400, kw_only=True)
    vector_store_driver: BaseVectorStoreDriver = field(kw_only=True)
    prompt_driver: BasePromptDriver = field(kw_only=True)
    user_template_generator: J2 = field(default=Factory(lambda: J2("engines/query/user.j2")), kw_only=True)
    system_template_generator: J2 = field(default=Factory(lambda: J2("engines/query/system.j2")), kw_only=True)

    def query(
        self,
        query: str,
        namespace: Optional[str] = None,
        *,
        rulesets: Optional[list[Ruleset]] = None,
        metadata: Optional[str] = None,
        top_n: Optional[int] = None,
        filter: Optional[dict] = None,
    ) -> TextArtifact:
        tokenizer = self.prompt_driver.tokenizer
        result = self.vector_store_driver.query(query, top_n, namespace, filter=filter)
        artifacts = [
            artifact
            for artifact in [BaseArtifact.from_json(r.meta["artifact"]) for r in result if r.meta]
            if isinstance(artifact, TextArtifact)
        ]
        text_segments = []
        user_message = ""
        system_message = ""

        for artifact in artifacts:
            text_segments.append(artifact.value)
            system_message = self.system_template_generator.render(
                rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
                metadata=metadata,
                text_segments=text_segments,
            )
            user_message = self.user_template_generator.render(query=query)

            message_token_count = self.prompt_driver.token_count(
                PromptStack(
                    inputs=[
                        PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE),
                        PromptStack.Input(user_message, role=PromptStack.USER_ROLE),
                    ]
                )
            )

            if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
                text_segments.pop()

                system_message = self.system_template_generator.render(
                    rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
                    metadata=metadata,
                    text_segments=text_segments,
                )

                break

        return self.prompt_driver.run(
            PromptStack(
                inputs=[
                    PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE),
                    PromptStack.Input(user_message, role=PromptStack.USER_ROLE),
                ]
            )
        )

    def upsert_text_artifact(self, artifact: TextArtifact, namespace: Optional[str] = None) -> str:
        result = self.vector_store_driver.upsert_text_artifact(artifact, namespace=namespace)

        return result

    def upsert_text_artifacts(self, artifacts: list[TextArtifact], namespace: str) -> None:
        self.vector_store_driver.upsert_text_artifacts({namespace: artifacts})

    def load_artifacts(self, namespace: str) -> ListArtifact:
        result = self.vector_store_driver.load_entries(namespace)
        artifacts = [BaseArtifact.from_json(r.meta["artifact"]) for r in result if r.meta and r.meta.get("artifact")]

        return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)])

answer_token_offset: int = field(default=400, kw_only=True) class-attribute instance-attribute

prompt_driver: BasePromptDriver = field(kw_only=True) class-attribute instance-attribute

system_template_generator: J2 = field(default=Factory(lambda: J2('engines/query/system.j2')), kw_only=True) class-attribute instance-attribute

user_template_generator: J2 = field(default=Factory(lambda: J2('engines/query/user.j2')), kw_only=True) class-attribute instance-attribute

vector_store_driver: BaseVectorStoreDriver = field(kw_only=True) class-attribute instance-attribute

load_artifacts(namespace)

Source code in griptape/engines/query/vector_query_engine.py
def load_artifacts(self, namespace: str) -> ListArtifact:
    result = self.vector_store_driver.load_entries(namespace)
    artifacts = [BaseArtifact.from_json(r.meta["artifact"]) for r in result if r.meta and r.meta.get("artifact")]

    return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)])

query(query, namespace=None, *, rulesets=None, metadata=None, top_n=None, filter=None)

Source code in griptape/engines/query/vector_query_engine.py
def query(
    self,
    query: str,
    namespace: Optional[str] = None,
    *,
    rulesets: Optional[list[Ruleset]] = None,
    metadata: Optional[str] = None,
    top_n: Optional[int] = None,
    filter: Optional[dict] = None,
) -> TextArtifact:
    tokenizer = self.prompt_driver.tokenizer
    result = self.vector_store_driver.query(query, top_n, namespace, filter=filter)
    artifacts = [
        artifact
        for artifact in [BaseArtifact.from_json(r.meta["artifact"]) for r in result if r.meta]
        if isinstance(artifact, TextArtifact)
    ]
    text_segments = []
    user_message = ""
    system_message = ""

    for artifact in artifacts:
        text_segments.append(artifact.value)
        system_message = self.system_template_generator.render(
            rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
            metadata=metadata,
            text_segments=text_segments,
        )
        user_message = self.user_template_generator.render(query=query)

        message_token_count = self.prompt_driver.token_count(
            PromptStack(
                inputs=[
                    PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE),
                    PromptStack.Input(user_message, role=PromptStack.USER_ROLE),
                ]
            )
        )

        if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
            text_segments.pop()

            system_message = self.system_template_generator.render(
                rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
                metadata=metadata,
                text_segments=text_segments,
            )

            break

    return self.prompt_driver.run(
        PromptStack(
            inputs=[
                PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE),
                PromptStack.Input(user_message, role=PromptStack.USER_ROLE),
            ]
        )
    )

upsert_text_artifact(artifact, namespace=None)

Source code in griptape/engines/query/vector_query_engine.py
def upsert_text_artifact(self, artifact: TextArtifact, namespace: Optional[str] = None) -> str:
    result = self.vector_store_driver.upsert_text_artifact(artifact, namespace=namespace)

    return result

upsert_text_artifacts(artifacts, namespace)

Source code in griptape/engines/query/vector_query_engine.py
def upsert_text_artifacts(self, artifacts: list[TextArtifact], namespace: str) -> None:
    self.vector_store_driver.upsert_text_artifacts({namespace: artifacts})