Skip to content

Loaders

__all__ = ['BaseLoader', 'TextLoader', 'PdfLoader', 'WebLoader', 'SqlLoader', 'CsvLoader', 'DataFrameLoader', 'FileLoader', 'EmailLoader'] module-attribute

BaseLoader

Bases: ABC

Source code in griptape/griptape/loaders/base_loader.py
@define
class BaseLoader(ABC):
    futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True)

    @abstractmethod
    def load(self, *args, **kwargs) -> BaseArtifact | list[BaseArtifact]:
        ...

    @abstractmethod
    def load_collection(self, *args, **kwargs) -> dict[str, list[BaseArtifact | list[BaseArtifact]]]:
        ...

futures_executor: futures.Executor = field(default=Factory(lambda : futures.ThreadPoolExecutor()), kw_only=True) class-attribute instance-attribute

load(*args, **kwargs) abstractmethod

Source code in griptape/griptape/loaders/base_loader.py
@abstractmethod
def load(self, *args, **kwargs) -> BaseArtifact | list[BaseArtifact]:
    ...

load_collection(*args, **kwargs) abstractmethod

Source code in griptape/griptape/loaders/base_loader.py
@abstractmethod
def load_collection(self, *args, **kwargs) -> dict[str, list[BaseArtifact | list[BaseArtifact]]]:
    ...

CsvLoader

Bases: BaseLoader

Source code in griptape/griptape/loaders/csv_loader.py
@define
class CsvLoader(BaseLoader):
    embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
    delimiter: str = field(default=",", kw_only=True)

    def load(self, filename: str) -> list[CsvRowArtifact]:
        return self._load_file(filename)

    def load_collection(self, filenames: list[str]) -> dict[str, list[CsvRowArtifact]]:
        return utils.execute_futures_dict(
            {
                utils.str_to_hash(filename): self.futures_executor.submit(self._load_file, filename)
                for filename in filenames
            }
        )

    def _load_file(self, filename: str) -> list[CsvRowArtifact]:
        artifacts = []

        with open(filename, "r", encoding="utf-8") as csv_file:
            reader = csv.DictReader(csv_file, delimiter=self.delimiter)
            chunks = [CsvRowArtifact(row) for row in reader]

            if self.embedding_driver:
                for chunk in chunks:
                    chunk.generate_embedding(self.embedding_driver)

            for chunk in chunks:
                artifacts.append(chunk)

        return artifacts

delimiter: str = field(default=',', kw_only=True) class-attribute instance-attribute

embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

load(filename)

Source code in griptape/griptape/loaders/csv_loader.py
def load(self, filename: str) -> list[CsvRowArtifact]:
    return self._load_file(filename)

load_collection(filenames)

Source code in griptape/griptape/loaders/csv_loader.py
def load_collection(self, filenames: list[str]) -> dict[str, list[CsvRowArtifact]]:
    return utils.execute_futures_dict(
        {
            utils.str_to_hash(filename): self.futures_executor.submit(self._load_file, filename)
            for filename in filenames
        }
    )

DataFrameLoader

Bases: BaseLoader

Source code in griptape/griptape/loaders/dataframe_loader.py
@define
class DataFrameLoader(BaseLoader):
    embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)

    def load(self, dataframe: DataFrame) -> list[CsvRowArtifact]:
        return self._load_file(dataframe)

    def load_collection(self, dataframes: list[DataFrame]) -> dict[str, list[CsvRowArtifact]]:
        return utils.execute_futures_dict(
            {
                self._dataframe_to_hash(dataframe): self.futures_executor.submit(self._load_file, dataframe)
                for dataframe in dataframes
            }
        )

    def _load_file(self, dataframe: DataFrame) -> list[CsvRowArtifact]:
        artifacts = []

        chunks = [CsvRowArtifact(row) for row in dataframe.to_dict(orient="records")]

        if self.embedding_driver:
            for chunk in chunks:
                chunk.generate_embedding(self.embedding_driver)

        for chunk in chunks:
            artifacts.append(chunk)

        return artifacts

    def _dataframe_to_hash(self, dataframe: DataFrame) -> str:
        return hashlib.sha256(pd.util.hash_pandas_object(dataframe, index=True).values).hexdigest()

embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

load(dataframe)

Source code in griptape/griptape/loaders/dataframe_loader.py
def load(self, dataframe: DataFrame) -> list[CsvRowArtifact]:
    return self._load_file(dataframe)

load_collection(dataframes)

Source code in griptape/griptape/loaders/dataframe_loader.py
def load_collection(self, dataframes: list[DataFrame]) -> dict[str, list[CsvRowArtifact]]:
    return utils.execute_futures_dict(
        {
            self._dataframe_to_hash(dataframe): self.futures_executor.submit(self._load_file, dataframe)
            for dataframe in dataframes
        }
    )

EmailLoader

Bases: BaseLoader

Source code in griptape/griptape/loaders/email_loader.py
@define
class EmailLoader(BaseLoader):
    @define(frozen=True)
    class EmailQuery:
        """An email retrieval query

        Attributes:
            label: Label to retrieve emails from such as 'INBOX' or 'SENT'.
            key: Optional key for filtering such as 'FROM' or 'SUBJECT'.
            search_criteria: Optional search criteria to filter emails by key.
            max_count: Optional max email count.
        """

        label: str = field(kw_only=True)
        key: Optional[str] = field(default=None, kw_only=True)
        search_criteria: Optional[str] = field(default=None, kw_only=True)
        max_count: Optional[int] = field(default=None, kw_only=True)

    imap_url: str = field(kw_only=True)
    username: str = field(kw_only=True)
    password: str = field(kw_only=True)

    def load(self, query: EmailQuery) -> ListArtifact:
        return self._retrieve_email(query)

    def load_collection(self, queries: list[EmailQuery]) -> dict[str, ListArtifact | ErrorArtifact]:
        return utils.execute_futures_dict(
            {
                utils.str_to_hash(str(query)): self.futures_executor.submit(self._retrieve_email, query)
                for query in set(queries)
            }
        )

    def _retrieve_email(self, query: EmailQuery) -> ListArtifact | ErrorArtifact:
        label, key, search_criteria, max_count = astuple(query)

        list_artifact = ListArtifact()
        try:
            with imaplib.IMAP4_SSL(self.imap_url) as client:
                client.login(self.username, self.password)

                mailbox = client.select(f'"{label}"', readonly=True)
                if mailbox[0] != "OK":
                    raise Exception(mailbox[1][0].decode())

                if key and search_criteria:
                    _typ, [message_numbers] = client.search(None, key, f'"{search_criteria}"')
                    messages_count = self._count_messages(message_numbers)
                else:
                    messages_count = int(mailbox[1][0])

                top_n = max(0, messages_count - max_count) if max_count else 0
                for i in range(messages_count, top_n, -1):
                    result, data = client.fetch(str(i), "(RFC822)")
                    message = mailparser.parse_from_bytes(data[0][1])
                    # Note: mailparser only populates the text_plain field
                    # if the message content type is explicitly set to 'text/plain'.
                    if message.text_plain:
                        list_artifact.value.append(TextArtifact("\n".join(message.text_plain)))

                client.close()

                return list_artifact
        except Exception as e:
            logging.error(e)
            return ErrorArtifact(f"error retrieving email: {e}")

    def _count_messages(self, message_numbers: bytes):
        return len(list(filter(None, message_numbers.decode().split(" "))))

imap_url: str = field(kw_only=True) class-attribute instance-attribute

password: str = field(kw_only=True) class-attribute instance-attribute

username: str = field(kw_only=True) class-attribute instance-attribute

EmailQuery

An email retrieval query

Attributes:

Name Type Description
label str

Label to retrieve emails from such as 'INBOX' or 'SENT'.

key Optional[str]

Optional key for filtering such as 'FROM' or 'SUBJECT'.

search_criteria Optional[str]

Optional search criteria to filter emails by key.

max_count Optional[int]

Optional max email count.

Source code in griptape/griptape/loaders/email_loader.py
@define(frozen=True)
class EmailQuery:
    """An email retrieval query

    Attributes:
        label: Label to retrieve emails from such as 'INBOX' or 'SENT'.
        key: Optional key for filtering such as 'FROM' or 'SUBJECT'.
        search_criteria: Optional search criteria to filter emails by key.
        max_count: Optional max email count.
    """

    label: str = field(kw_only=True)
    key: Optional[str] = field(default=None, kw_only=True)
    search_criteria: Optional[str] = field(default=None, kw_only=True)
    max_count: Optional[int] = field(default=None, kw_only=True)
key: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute
label: str = field(kw_only=True) class-attribute instance-attribute
max_count: Optional[int] = field(default=None, kw_only=True) class-attribute instance-attribute
search_criteria: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

load(query)

Source code in griptape/griptape/loaders/email_loader.py
def load(self, query: EmailQuery) -> ListArtifact:
    return self._retrieve_email(query)

load_collection(queries)

Source code in griptape/griptape/loaders/email_loader.py
def load_collection(self, queries: list[EmailQuery]) -> dict[str, ListArtifact | ErrorArtifact]:
    return utils.execute_futures_dict(
        {
            utils.str_to_hash(str(query)): self.futures_executor.submit(self._retrieve_email, query)
            for query in set(queries)
        }
    )

FileLoader

Bases: BaseLoader

Source code in griptape/griptape/loaders/file_loader.py
@define
class FileLoader(BaseLoader):
    encoding: Optional[str] = field(default=None, kw_only=True)

    def load(self, path: str | Path) -> TextArtifact | BlobArtifact | ErrorArtifact:
        return self.file_to_artifact(path)

    def load_collection(self, paths: list[str | Path]) -> dict[str, TextArtifact | BlobArtifact | ErrorArtifact]:
        return utils.execute_futures_dict(
            {utils.str_to_hash(str(path)): self.futures_executor.submit(self.file_to_artifact, path) for path in paths}
        )

    def file_to_artifact(self, path: str | Path) -> TextArtifact | BlobArtifact | ErrorArtifact:
        file_name = os.path.basename(path)

        try:
            with open(path, "rb") as file:
                if self.encoding:
                    return TextArtifact(file.read().decode(self.encoding), name=file_name)
                else:
                    return BlobArtifact(file.read(), name=file_name, dir_name=os.path.dirname(path))
        except FileNotFoundError:
            return ErrorArtifact(f"file {file_name} not found")
        except Exception as e:
            return ErrorArtifact(f"error loading file: {e}")

encoding: Optional[str] = field(default=None, kw_only=True) class-attribute instance-attribute

file_to_artifact(path)

Source code in griptape/griptape/loaders/file_loader.py
def file_to_artifact(self, path: str | Path) -> TextArtifact | BlobArtifact | ErrorArtifact:
    file_name = os.path.basename(path)

    try:
        with open(path, "rb") as file:
            if self.encoding:
                return TextArtifact(file.read().decode(self.encoding), name=file_name)
            else:
                return BlobArtifact(file.read(), name=file_name, dir_name=os.path.dirname(path))
    except FileNotFoundError:
        return ErrorArtifact(f"file {file_name} not found")
    except Exception as e:
        return ErrorArtifact(f"error loading file: {e}")

load(path)

Source code in griptape/griptape/loaders/file_loader.py
def load(self, path: str | Path) -> TextArtifact | BlobArtifact | ErrorArtifact:
    return self.file_to_artifact(path)

load_collection(paths)

Source code in griptape/griptape/loaders/file_loader.py
def load_collection(self, paths: list[str | Path]) -> dict[str, TextArtifact | BlobArtifact | ErrorArtifact]:
    return utils.execute_futures_dict(
        {utils.str_to_hash(str(path)): self.futures_executor.submit(self.file_to_artifact, path) for path in paths}
    )

PdfLoader

Bases: TextLoader

Source code in griptape/griptape/loaders/pdf_loader.py
@define
class PdfLoader(TextLoader):
    chunker: PdfChunker = field(
        default=Factory(lambda self: PdfChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), takes_self=True),
        kw_only=True,
    )

    def load(self, stream: str | IO | Path, password: Optional[str] = None) -> list[TextArtifact]:
        return self._load_pdf(stream, password)

    def load_collection(
        self, streams: list[str | IO | Path], password: Optional[str] = None
    ) -> dict[str, list[TextArtifact]]:
        return execute_futures_dict(
            {
                str_to_hash(s.decode())
                if isinstance(s, bytes)
                else str_to_hash(str(s)): self.futures_executor.submit(self._load_pdf, s, password)
                for s in streams
            }
        )

    def _load_pdf(self, stream: str | IO | Path, password: Optional[str]) -> list[TextArtifact]:
        reader = PdfReader(stream, strict=True, password=password)

        return self.text_to_artifacts("\n".join([p.extract_text() for p in reader.pages]))

chunker: PdfChunker = field(default=Factory(lambda : PdfChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

load(stream, password=None)

Source code in griptape/griptape/loaders/pdf_loader.py
def load(self, stream: str | IO | Path, password: Optional[str] = None) -> list[TextArtifact]:
    return self._load_pdf(stream, password)

load_collection(streams, password=None)

Source code in griptape/griptape/loaders/pdf_loader.py
def load_collection(
    self, streams: list[str | IO | Path], password: Optional[str] = None
) -> dict[str, list[TextArtifact]]:
    return execute_futures_dict(
        {
            str_to_hash(s.decode())
            if isinstance(s, bytes)
            else str_to_hash(str(s)): self.futures_executor.submit(self._load_pdf, s, password)
            for s in streams
        }
    )

SqlLoader

Bases: BaseLoader

Source code in griptape/griptape/loaders/sql_loader.py
@define
class SqlLoader(BaseLoader):
    sql_driver: BaseSqlDriver = field(kw_only=True)
    embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)

    def load(self, select_query: str) -> list[CsvRowArtifact]:
        return self._load_query(select_query)

    def load_collection(self, select_queries: list[str]) -> dict[str, list[CsvRowArtifact]]:
        return utils.execute_futures_dict(
            {
                utils.str_to_hash(query): self.futures_executor.submit(self._load_query, query)
                for query in select_queries
            }
        )

    def _load_query(self, query: str) -> list[CsvRowArtifact]:
        rows = self.sql_driver.execute_query(query)
        artifacts = []

        if rows:
            chunks = [CsvRowArtifact(row.cells) for row in rows]
        else:
            chunks = []

        if self.embedding_driver:
            for chunk in chunks:
                chunk.generate_embedding(self.embedding_driver)

        for chunk in chunks:
            artifacts.append(chunk)

        return artifacts

embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

sql_driver: BaseSqlDriver = field(kw_only=True) class-attribute instance-attribute

load(select_query)

Source code in griptape/griptape/loaders/sql_loader.py
def load(self, select_query: str) -> list[CsvRowArtifact]:
    return self._load_query(select_query)

load_collection(select_queries)

Source code in griptape/griptape/loaders/sql_loader.py
def load_collection(self, select_queries: list[str]) -> dict[str, list[CsvRowArtifact]]:
    return utils.execute_futures_dict(
        {
            utils.str_to_hash(query): self.futures_executor.submit(self._load_query, query)
            for query in select_queries
        }
    )

TextLoader

Bases: BaseLoader

Source code in griptape/griptape/loaders/text_loader.py
@define
class TextLoader(BaseLoader):
    MAX_TOKEN_RATIO = 0.5

    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), kw_only=True
    )
    max_tokens: int = field(
        default=Factory(lambda self: round(self.tokenizer.max_tokens * self.MAX_TOKEN_RATIO), takes_self=True),
        kw_only=True,
    )
    chunker: TextChunker = field(
        default=Factory(
            lambda self: TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), takes_self=True
        ),
        kw_only=True,
    )
    embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)
    encoding: str = field(default="utf-8", kw_only=True)

    def load(self, text: str | Path) -> list[TextArtifact]:
        return self.text_to_artifacts(text)

    def load_collection(self, texts: list[str | Path]) -> dict[str, list[TextArtifact]]:
        return utils.execute_futures_dict(
            {utils.str_to_hash(str(text)): self.futures_executor.submit(self.text_to_artifacts, text) for text in texts}
        )

    def text_to_artifacts(self, text: str | Path) -> list[TextArtifact]:
        artifacts = []

        if isinstance(text, Path):
            with open(text, "r", encoding=self.encoding) as file:
                body = file.read()
        else:
            body = text

        if self.chunker:
            chunks = self.chunker.chunk(body)
        else:
            chunks = [TextArtifact(body)]

        if self.embedding_driver:
            for chunk in chunks:
                chunk.generate_embedding(self.embedding_driver)

        for chunk in chunks:
            chunk.encoding = self.encoding
            artifacts.append(chunk)

        return artifacts

MAX_TOKEN_RATIO = 0.5 class-attribute instance-attribute

chunker: TextChunker = field(default=Factory(lambda : TextChunker(tokenizer=self.tokenizer, max_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) class-attribute instance-attribute

encoding: str = field(default='utf-8', kw_only=True) class-attribute instance-attribute

max_tokens: int = field(default=Factory(lambda : round(self.tokenizer.max_tokens * self.MAX_TOKEN_RATIO), takes_self=True), kw_only=True) class-attribute instance-attribute

tokenizer: OpenAiTokenizer = field(default=Factory(lambda : OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)), kw_only=True) class-attribute instance-attribute

load(text)

Source code in griptape/griptape/loaders/text_loader.py
def load(self, text: str | Path) -> list[TextArtifact]:
    return self.text_to_artifacts(text)

load_collection(texts)

Source code in griptape/griptape/loaders/text_loader.py
def load_collection(self, texts: list[str | Path]) -> dict[str, list[TextArtifact]]:
    return utils.execute_futures_dict(
        {utils.str_to_hash(str(text)): self.futures_executor.submit(self.text_to_artifacts, text) for text in texts}
    )

text_to_artifacts(text)

Source code in griptape/griptape/loaders/text_loader.py
def text_to_artifacts(self, text: str | Path) -> list[TextArtifact]:
    artifacts = []

    if isinstance(text, Path):
        with open(text, "r", encoding=self.encoding) as file:
            body = file.read()
    else:
        body = text

    if self.chunker:
        chunks = self.chunker.chunk(body)
    else:
        chunks = [TextArtifact(body)]

    if self.embedding_driver:
        for chunk in chunks:
            chunk.generate_embedding(self.embedding_driver)

    for chunk in chunks:
        chunk.encoding = self.encoding
        artifacts.append(chunk)

    return artifacts

WebLoader

Bases: TextLoader

Source code in griptape/griptape/loaders/web_loader.py
@define
class WebLoader(TextLoader):
    def load(self, url: str, include_links: bool = True) -> list[TextArtifact]:
        return self._load_page_to_artifacts(url, include_links)

    def load_collection(self, urls: list[str], include_links: bool = True) -> dict[str, list[TextArtifact]]:
        return execute_futures_dict(
            {str_to_hash(u): self.futures_executor.submit(self._load_page_to_artifacts, u, include_links) for u in urls}
        )

    def _load_page_to_artifacts(self, url: str, include_links: bool = True) -> list[TextArtifact]:
        return self.text_to_artifacts(self.extract_page(url, include_links).get("text"))

    def extract_page(self, url: str, include_links: bool = True) -> dict:
        config = trafilatura.settings.use_config()
        page = trafilatura.fetch_url(url, no_ssl=True)

        # This disables signal, so that trafilatura can work on any thread:
        # More info: https://trafilatura.readthedocs.io/usage-python.html#disabling-signal
        config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0")

        # Disable error logging in trafilatura as it sometimes logs errors from lxml, even though
        # the end result of page parsing is successful.
        logging.getLogger("trafilatura").setLevel(logging.FATAL)

        if page is None:
            raise Exception("can't access URL")
        else:
            return json.loads(
                trafilatura.extract(page, include_links=include_links, output_format="json", config=config)
            )

extract_page(url, include_links=True)

Source code in griptape/griptape/loaders/web_loader.py
def extract_page(self, url: str, include_links: bool = True) -> dict:
    config = trafilatura.settings.use_config()
    page = trafilatura.fetch_url(url, no_ssl=True)

    # This disables signal, so that trafilatura can work on any thread:
    # More info: https://trafilatura.readthedocs.io/usage-python.html#disabling-signal
    config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0")

    # Disable error logging in trafilatura as it sometimes logs errors from lxml, even though
    # the end result of page parsing is successful.
    logging.getLogger("trafilatura").setLevel(logging.FATAL)

    if page is None:
        raise Exception("can't access URL")
    else:
        return json.loads(
            trafilatura.extract(page, include_links=include_links, output_format="json", config=config)
        )

load(url, include_links=True)

Source code in griptape/griptape/loaders/web_loader.py
def load(self, url: str, include_links: bool = True) -> list[TextArtifact]:
    return self._load_page_to_artifacts(url, include_links)

load_collection(urls, include_links=True)

Source code in griptape/griptape/loaders/web_loader.py
def load_collection(self, urls: list[str], include_links: bool = True) -> dict[str, list[TextArtifact]]:
    return execute_futures_dict(
        {str_to_hash(u): self.futures_executor.submit(self._load_page_to_artifacts, u, include_links) for u in urls}
    )