Skip to content

griptape_cloud_structure_run_driver

logger = logging.getLogger(Defaults.logging_config.logger_name) module-attribute

GriptapeCloudStructureRunDriver

Bases: BaseStructureRunDriver

Source code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
@define
class GriptapeCloudStructureRunDriver(BaseStructureRunDriver):
    base_url: str = field(default="https://cloud.griptape.ai", kw_only=True)
    api_key: str = field(kw_only=True)
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        kw_only=True,
    )
    structure_id: str = field(kw_only=True)
    structure_run_wait_time_interval: int = field(default=2, kw_only=True)
    structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True)
    async_run: bool = field(default=False, kw_only=True)

    def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
        structure_run_id = self._create_run(*args)

        if self.async_run:
            return InfoArtifact("Run started successfully")
        else:
            return self._get_run_result(structure_run_id)

    def _create_run(self, *args: BaseArtifact) -> str:
        url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs")

        env_vars = [{"name": key, "value": value, "source": "manual"} for key, value in self.env.items()]

        response = requests.post(
            url,
            json={"args": [arg.value for arg in args], "env_vars": env_vars},
            headers=self.headers,
        )
        response.raise_for_status()
        response_json = response.json()

        return response_json["structure_run_id"]

    def _get_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact:
        events = self._get_run_events(structure_run_id)
        output = None

        for event in events:
            event_type = event["type"]
            event_payload = event.get("payload", {})
            if event["origin"] == "USER":
                try:
                    if "span_id" in event_payload:
                        span_id = event_payload.pop("span_id")
                        if "meta" in event_payload:
                            event_payload["meta"]["span_id"] = span_id
                        else:
                            event_payload["meta"] = {"span_id": span_id}
                    EventBus.publish_event(BaseEvent.from_dict(event_payload))
                except ValueError as e:
                    logger.warning("Failed to deserialize event: %s", e)
                if event["type"] == "FinishStructureRunEvent":
                    output = BaseArtifact.from_dict(event_payload["output_task_output"])
            elif event["origin"] == "SYSTEM":
                if event_type == "StructureRunError":
                    output = ErrorArtifact(event_payload["status_detail"]["error"])

        if output is None:
            raise ValueError("Output not found.")

        return output

    def _get_run_events(self, structure_run_id: str) -> Iterator[dict]:
        url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{structure_run_id}/events/stream")
        with requests.get(url, headers=self.headers, stream=True) as response:
            response.raise_for_status()
            for line in response.iter_lines():
                if line:
                    decoded_line = line.decode("utf-8")
                    if decoded_line.startswith("data:"):
                        yield json.loads(decoded_line.removeprefix("data:").strip())

api_key = field(kw_only=True) class-attribute instance-attribute

async_run = field(default=False, kw_only=True) class-attribute instance-attribute

base_url = field(default='https://cloud.griptape.ai', kw_only=True) class-attribute instance-attribute

headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

structure_id = field(kw_only=True) class-attribute instance-attribute

structure_run_max_wait_time_attempts = field(default=20, kw_only=True) class-attribute instance-attribute

structure_run_wait_time_interval = field(default=2, kw_only=True) class-attribute instance-attribute

_create_run(*args)

Source code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _create_run(self, *args: BaseArtifact) -> str:
    url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs")

    env_vars = [{"name": key, "value": value, "source": "manual"} for key, value in self.env.items()]

    response = requests.post(
        url,
        json={"args": [arg.value for arg in args], "env_vars": env_vars},
        headers=self.headers,
    )
    response.raise_for_status()
    response_json = response.json()

    return response_json["structure_run_id"]

_get_run_events(structure_run_id)

Source code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _get_run_events(self, structure_run_id: str) -> Iterator[dict]:
    url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{structure_run_id}/events/stream")
    with requests.get(url, headers=self.headers, stream=True) as response:
        response.raise_for_status()
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode("utf-8")
                if decoded_line.startswith("data:"):
                    yield json.loads(decoded_line.removeprefix("data:").strip())

_get_run_result(structure_run_id)

Source code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _get_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact:
    events = self._get_run_events(structure_run_id)
    output = None

    for event in events:
        event_type = event["type"]
        event_payload = event.get("payload", {})
        if event["origin"] == "USER":
            try:
                if "span_id" in event_payload:
                    span_id = event_payload.pop("span_id")
                    if "meta" in event_payload:
                        event_payload["meta"]["span_id"] = span_id
                    else:
                        event_payload["meta"] = {"span_id": span_id}
                EventBus.publish_event(BaseEvent.from_dict(event_payload))
            except ValueError as e:
                logger.warning("Failed to deserialize event: %s", e)
            if event["type"] == "FinishStructureRunEvent":
                output = BaseArtifact.from_dict(event_payload["output_task_output"])
        elif event["origin"] == "SYSTEM":
            if event_type == "StructureRunError":
                output = ErrorArtifact(event_payload["status_detail"]["error"])

    if output is None:
        raise ValueError("Output not found.")

    return output

try_run(*args)

Source code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
    structure_run_id = self._create_run(*args)

    if self.async_run:
        return InfoArtifact("Run started successfully")
    else:
        return self._get_run_result(structure_run_id)