Skip to content

Vector query engine

VectorQueryEngine

Bases: BaseQueryEngine

Source code in griptape/griptape/engines/query/vector_query_engine.py
@define
class VectorQueryEngine(BaseQueryEngine):
    answer_token_offset: int = field(default=400, kw_only=True)
    vector_store_driver: BaseVectorStoreDriver = field(kw_only=True)
    prompt_driver: BasePromptDriver = field(
        default=Factory(lambda: OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)),
        kw_only=True,
    )
    template_generator: J2 = field(default=Factory(lambda: J2("engines/query/vector_query.j2")), kw_only=True)

    def query(
        self,
        query: str,
        namespace: str | None = None,
        rulesets: list[Ruleset] | None = None,
        metadata: str | None = None,
        top_n: int | None = None,
    ) -> TextArtifact:
        tokenizer = self.prompt_driver.tokenizer
        result = self.vector_store_driver.query(query, top_n, namespace)
        artifacts = [
            artifact
            for artifact in [BaseArtifact.from_json(r.meta["artifact"]) for r in result if r.meta]
            if isinstance(artifact, TextArtifact)
        ]
        text_segments = []
        message = ""

        for artifact in artifacts:
            text_segments.append(artifact.value)

            message = self.template_generator.render(
                metadata=metadata,
                query=query,
                text_segments=text_segments,
                rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
            )
            message_token_count = self.prompt_driver.token_count(
                PromptStack(inputs=[PromptStack.Input(message, role=PromptStack.USER_ROLE)])
            )

            if message_token_count + self.answer_token_offset >= tokenizer.max_tokens:
                text_segments.pop()

                message = self.template_generator.render(
                    metadata=metadata,
                    query=query,
                    text_segments=text_segments,
                    rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
                )

                break

        return self.prompt_driver.run(PromptStack(inputs=[PromptStack.Input(message, role=PromptStack.USER_ROLE)]))

    def upsert_text_artifact(self, artifact: TextArtifact, namespace: str | None = None) -> str:
        result = self.vector_store_driver.upsert_text_artifact(artifact, namespace=namespace)

        return result

    def upsert_text_artifacts(self, artifacts: list[TextArtifact], namespace: str) -> None:
        self.vector_store_driver.upsert_text_artifacts({namespace: artifacts})

    def load_artifacts(self, namespace: str) -> ListArtifact:
        result = self.vector_store_driver.load_entries(namespace)
        artifacts = [BaseArtifact.from_json(r.meta["artifact"]) for r in result if r.meta and r.meta.get("artifact")]

        return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)])

answer_token_offset: int = field(default=400, kw_only=True) class-attribute instance-attribute

prompt_driver: BasePromptDriver = field(default=Factory(lambda : OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), kw_only=True) class-attribute instance-attribute

template_generator: J2 = field(default=Factory(lambda : J2('engines/query/vector_query.j2')), kw_only=True) class-attribute instance-attribute

vector_store_driver: BaseVectorStoreDriver = field(kw_only=True) class-attribute instance-attribute

load_artifacts(namespace)

Source code in griptape/griptape/engines/query/vector_query_engine.py
def load_artifacts(self, namespace: str) -> ListArtifact:
    result = self.vector_store_driver.load_entries(namespace)
    artifacts = [BaseArtifact.from_json(r.meta["artifact"]) for r in result if r.meta and r.meta.get("artifact")]

    return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)])

query(query, namespace=None, rulesets=None, metadata=None, top_n=None)

Source code in griptape/griptape/engines/query/vector_query_engine.py
def query(
    self,
    query: str,
    namespace: str | None = None,
    rulesets: list[Ruleset] | None = None,
    metadata: str | None = None,
    top_n: int | None = None,
) -> TextArtifact:
    tokenizer = self.prompt_driver.tokenizer
    result = self.vector_store_driver.query(query, top_n, namespace)
    artifacts = [
        artifact
        for artifact in [BaseArtifact.from_json(r.meta["artifact"]) for r in result if r.meta]
        if isinstance(artifact, TextArtifact)
    ]
    text_segments = []
    message = ""

    for artifact in artifacts:
        text_segments.append(artifact.value)

        message = self.template_generator.render(
            metadata=metadata,
            query=query,
            text_segments=text_segments,
            rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
        )
        message_token_count = self.prompt_driver.token_count(
            PromptStack(inputs=[PromptStack.Input(message, role=PromptStack.USER_ROLE)])
        )

        if message_token_count + self.answer_token_offset >= tokenizer.max_tokens:
            text_segments.pop()

            message = self.template_generator.render(
                metadata=metadata,
                query=query,
                text_segments=text_segments,
                rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
            )

            break

    return self.prompt_driver.run(PromptStack(inputs=[PromptStack.Input(message, role=PromptStack.USER_ROLE)]))

upsert_text_artifact(artifact, namespace=None)

Source code in griptape/griptape/engines/query/vector_query_engine.py
def upsert_text_artifact(self, artifact: TextArtifact, namespace: str | None = None) -> str:
    result = self.vector_store_driver.upsert_text_artifact(artifact, namespace=namespace)

    return result

upsert_text_artifacts(artifacts, namespace)

Source code in griptape/griptape/engines/query/vector_query_engine.py
def upsert_text_artifacts(self, artifacts: list[TextArtifact], namespace: str) -> None:
    self.vector_store_driver.upsert_text_artifacts({namespace: artifacts})