Skip to content

snowflake_sql_driver

SnowflakeSqlDriver

Bases: BaseSqlDriver

Source code in griptape/drivers/sql/snowflake_sql_driver.py
@define
class SnowflakeSqlDriver(BaseSqlDriver):
    get_connection: Callable[[], SnowflakeConnection] = field(kw_only=True)
    _engine: Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False})

    @get_connection.validator  # pyright: ignore[reportFunctionMemberAccess]
    def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None:
        snowflake_connection = get_connection()
        snowflake = import_optional_dependency("snowflake")

        if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
            raise ValueError("The get_connection function must return a SnowflakeConnection")
        if not snowflake_connection.schema or not snowflake_connection.database:
            raise ValueError("Provide a schema and database for the Snowflake connection")

    @lazy_property()
    def engine(self) -> Engine:
        return import_optional_dependency("sqlalchemy").create_engine(
            "snowflake://not@used/db",
            creator=self.get_connection,
        )

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)

        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        else:
            return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        with self.engine.connect() as con:
            results = con.execute(sqlalchemy.text(query))

            if results is not None:
                if results.returns_rows:
                    return [dict(result._mapping) for result in results]
                else:
                    return None
            else:
                raise ValueError("No results found")

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        try:
            metadata_obj = sqlalchemy.MetaData()
            metadata_obj.reflect(bind=self.engine)
            table = sqlalchemy.Table(table_name, metadata_obj, schema=schema, autoload=True, autoload_with=self.engine)
            return str([(c.name, c.type) for c in table.columns])
        except sqlalchemy.exc.NoSuchTableError:
            return None

get_connection: Callable[[], SnowflakeConnection] = field(kw_only=True) class-attribute instance-attribute

engine()

Source code in griptape/drivers/sql/snowflake_sql_driver.py
@lazy_property()
def engine(self) -> Engine:
    return import_optional_dependency("sqlalchemy").create_engine(
        "snowflake://not@used/db",
        creator=self.get_connection,
    )

execute_query(query)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)

    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    else:
        return None

execute_query_raw(query)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    with self.engine.connect() as con:
        results = con.execute(sqlalchemy.text(query))

        if results is not None:
            if results.returns_rows:
                return [dict(result._mapping) for result in results]
            else:
                return None
        else:
            raise ValueError("No results found")

get_table_schema(table_name, schema=None)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    try:
        metadata_obj = sqlalchemy.MetaData()
        metadata_obj.reflect(bind=self.engine)
        table = sqlalchemy.Table(table_name, metadata_obj, schema=schema, autoload=True, autoload_with=self.engine)
        return str([(c.name, c.type) for c in table.columns])
    except sqlalchemy.exc.NoSuchTableError:
        return None

validate_get_connection(_, get_connection)

Source code in griptape/drivers/sql/snowflake_sql_driver.py
@get_connection.validator  # pyright: ignore[reportFunctionMemberAccess]
def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None:
    snowflake_connection = get_connection()
    snowflake = import_optional_dependency("snowflake")

    if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
        raise ValueError("The get_connection function must return a SnowflakeConnection")
    if not snowflake_connection.schema or not snowflake_connection.database:
        raise ValueError("Provide a schema and database for the Snowflake connection")