Source code for tiled.client.dataframe

import dask
import dask.dataframe.core
from starlette.status import HTTP_404_NOT_FOUND

from ..serialization.table import deserialize_arrow, serialize_arrow
from ..utils import APACHE_ARROW_FILE_MIME_TYPE, UNCHANGED
from .base import BaseClient
from .utils import (
    MSGPACK_MIME_TYPE,
    ClientError,
    client_for_item,
    export_util,
    handle_error,
)

_EXTRA_CHARS_PER_ITEM = len("&column=")


class _DaskDataFrameClient(BaseClient):
    "Client-side wrapper around an dataframe-like that returns dask dataframes"

    def new_variation(self, structure=UNCHANGED, **kwargs):
        if structure is UNCHANGED:
            structure = self._structure
        return super().new_variation(structure=structure, **kwargs)

    def _repr_pretty_(self, p, cycle):
        """
        Provide "pretty" display in IPython/Jupyter.

        See https://ipython.readthedocs.io/en/stable/config/integrating.html#rich-display
        """
        structure = self.structure()
        if not structure.resizable:
            p.text(f"<{type(self).__name__} {structure.columns}>")
        else:
            # Try to get the column names, but give up quickly to avoid blocking
            # for long.
            TIMEOUT = 0.2  # seconds
            try:
                content = handle_error(
                    self.context.http_client.get(
                        self.uri,
                        headers={"Accept": MSGPACK_MIME_TYPE},
                        params={"fields": "structure"},
                        timeout=TIMEOUT,
                    )
                ).json()
            except TimeoutError:
                p.text(
                    f"<{type(self).__name__} Loading column names took too long; use list(...) >"
                )
            except Exception as err:
                p.text(
                    f"<{type(self).__name__} Loading column names raised error {err!r}>"
                )
            else:
                try:
                    columns = content["data"]["attributes"]["structure"]["columns"]
                except Exception as err:
                    p.text(
                        f"<{type(self).__name__} Loading column names raised error {err!r}>"
                    )
                else:
                    p.text(f"<{type(self).__name__} {columns}>")

    def _ipython_key_completions_(self):
        """
        Provide method for the key-autocompletions in IPython.

        See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion
        """
        structure = self.structure()
        if not structure.resizable:
            # Use cached structure.
            return structure.columns
        try:
            content = handle_error(
                self.context.http_client.get(
                    self.uri,
                    headers={"Accept": MSGPACK_MIME_TYPE},
                    params={"fields": "structure"},
                )
            ).json()
            columns = content["data"]["attributes"]["structure"]["columns"]
        except Exception:
            # Do not print messy traceback from thread. Just fail silently.
            return []
        return columns

    @property
    def columns(self):
        return self.structure().columns

    def _get_partition(self, partition, columns):
        """
        Fetch the actual data for one partition in a partitioned (dask) dataframe.

        See read_partition for a public version of this.
        """
        params = {"partition": partition}
        URL_PATH = self.item["links"]["partition"]
        url_length_for_get_request = len(URL_PATH) + sum(
            _EXTRA_CHARS_PER_ITEM + len(column) for column in (columns or ())
        )
        if url_length_for_get_request > self.URL_CHARACTER_LIMIT:
            content = handle_error(
                self.context.http_client.post(
                    URL_PATH,
                    headers={"Accept": APACHE_ARROW_FILE_MIME_TYPE},
                    json=columns,
                    params=params,
                )
            ).read()
        else:
            if columns:
                # Note: The singular/plural inconsistency here is because
                # ["A", "B"] will be encoded in the URL as column=A&column=B
                params["column"] = columns
            content = handle_error(
                self.context.http_client.get(
                    URL_PATH,
                    headers={"Accept": APACHE_ARROW_FILE_MIME_TYPE},
                    params=params,
                )
            ).read()
        return deserialize_arrow(content)

    def read_partition(self, partition, columns=None):
        """
        Access one partition in a partitioned (dask) dataframe.

        Optionally select a subset of the columns.
        """
        structure = self.structure()
        npartitions = structure.npartitions
        if not (0 <= partition < npartitions):
            raise IndexError(f"partition {partition} out of range")
        meta = structure.meta
        if columns is not None:
            meta = meta[columns]
        return dask.dataframe.from_delayed(
            [dask.delayed(self._get_partition)(partition, columns)],
            meta=meta,
            divisions=(None,) * (1 + npartitions),
        )

    def read(self, columns=None):
        """
        Access the entire DataFrame. Optionally select a subset of the columns.

        The result will be internally partitioned with dask.
        """
        structure = self.structure()
        # Build a client-side dask dataframe whose partitions pull from a
        # server-side dask array.
        name = f"remote-dask-dataframe-{self.item['links']['self']}"
        dask_tasks = {
            (name, partition): (self._get_partition, partition, columns)
            for partition in range(structure.npartitions)
        }
        meta = structure.meta

        if columns is not None:
            meta = meta[columns]
        ddf = dask.dataframe.core.DataFrame(
            dask_tasks,
            name=name,
            meta=meta,
            divisions=(None,) * (1 + structure.npartitions),
        )
        if columns is not None:
            ddf = ddf[columns]
        return ddf

    # We implement *some* of the Mapping interface here but intentionally not
    # all of it. DataFrames are not quite Mapping-like. Their __len__ for
    # example returns the number of rows (which it would be costly for us to
    # compute) as opposed to holding to the usual invariant
    # `len(list(obj)) == # len(obj)` for Mappings. Additionally, their behavior
    # with `__getitem__` is a bit "extra", e.g. df[["A", "B"]].

    def __getitem__(self, column):
        try:
            self_link = self.item["links"]["self"]
            if self_link.endswith("/"):
                self_link = self_link[:-1]
            content = handle_error(
                self.context.http_client.get(
                    self_link + f"/{column}",
                    headers={"Accept": MSGPACK_MIME_TYPE},
                )
            ).json()
        except ClientError as err:
            if err.response.status_code == HTTP_404_NOT_FOUND:
                raise KeyError(column)
            raise
        item = content["data"]
        return client_for_item(self.context, self.structure_clients, item)

    def __iter__(self):
        yield from self.structure().columns

    # __len__ is intentionally not implemented. For DataFrames it means "number
    # of rows" which is expensive to compute.

    def write(self, dataframe):
        handle_error(
            self.context.http_client.put(
                self.item["links"]["full"],
                content=bytes(serialize_arrow(dataframe, {})),
                headers={"Content-Type": APACHE_ARROW_FILE_MIME_TYPE},
            )
        )

    def write_partition(self, dataframe, partition):
        handle_error(
            self.context.http_client.put(
                self.item["links"]["partition"].format(index=partition),
                content=bytes(serialize_arrow(dataframe, {})),
                headers={"Content-Type": APACHE_ARROW_FILE_MIME_TYPE},
            )
        )

    def append_partition(self, dataframe, partition):
        handle_error(
            self.context.http_client.patch(
                self.item["links"]["partition"].format(index=partition),
                content=bytes(serialize_arrow(dataframe, {})),
                headers={"Content-Type": APACHE_ARROW_FILE_MIME_TYPE},
            )
        )

    def export(self, filepath, columns=None, *, format=None):
        """
        Download data in some format and write to a file.

        Parameters
        ----------
        file: str or buffer
            Filepath or writeable buffer.
        format : str, optional
            If format is None and `file` is a filepath, the format is inferred
            from the name, like 'table.csv' implies format="text/csv". The format
            may be given as a file extension ("csv") or a media type ("text/csv").
        columns: List[str], optional
            Select a subset of the columns.
        """
        params = {}
        if columns is not None:
            params["column"] = columns
        return export_util(
            filepath,
            format,
            self.context.http_client.get,
            self.item["links"]["full"],
            params=params,
        )


# Subclass with a public class that adds the dask-specific methods.


[docs]class DaskDataFrameClient(_DaskDataFrameClient): "Client-side wrapper around an dataframe-like that returns dask dataframes" def compute(self): "Alias to client.read().compute()" return self.read().compute()
[docs]class DataFrameClient(_DaskDataFrameClient): "Client-side wrapper around a dataframe-like that returns in-memory dataframes"
[docs] def read_partition(self, partition, columns=None): """ Access one partition of the DataFrame. Optionally select a subset of the columns. """ return super().read_partition(partition, columns).compute()
[docs] def read(self, columns=None): """ Access the entire DataFrame. Optionally select a subset of the columns. """ return super().read(columns).compute()