Source code for tiled.client.xarray

import threading

import dask
import dask.array
import pandas
import xarray

from ..client.base import BaseStructureClient
from ..serialization.dataframe import deserialize_arrow
from ..utils import APACHE_ARROW_FILE_MIME_TYPE
from .node import Node

LENGTH_LIMIT_FOR_WIDE_TABLE_OPTIMIZATION = 1_000_000


[docs]class DaskDatasetClient(Node): def _repr_pretty_(self, p, cycle): """ Provide "pretty" display in IPython/Jupyter. See https://ipython.readthedocs.io/en/stable/config/integrating.html#rich-display """ p.text(f"<{type(self).__name__} {list(self)}>") def _ipython_key_completions_(self): """ Provide method for the key-autocompletions in IPython. See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion """ return list(self) def download(self): super().download() list(self) self.read().compute() def _build_arrays(self, variables, optimize_wide_table): data_vars = {} coords = {} # Optimization: Download scalar columns in batch as DataFrame. # on first access. coords_fetcher = _WideTableFetcher( self.context.get_content, self.item["links"]["full"] ) data_vars_fetcher = _WideTableFetcher( self.context.get_content, self.item["links"]["full"] ) array_clients = {} array_structures = {} first_dims = [] for name, array_client in self.items(): if (variables is not None) and (name not in variables): continue array_clients[name] = array_client array_structure = array_client.structure() array_structures[name] = array_structure if array_structure.macro.shape: first_dims.append(array_structure.macro.shape[0]) else: first_dims.append(None) if len(set(first_dims)) > 1: # ragged, not tabular optimize_wide_table = False for name, array_client in array_clients.items(): array_structure = array_structures[name] shape = array_structure.macro.shape spec_names = set(spec.name for spec in array_client.specs) if optimize_wide_table and ( (not shape) # empty or ( (shape[0] < LENGTH_LIMIT_FOR_WIDE_TABLE_OPTIMIZATION) and (len(shape) < 2) ) ): if "xarray_coord" in spec_names: coords[name] = ( array_client.dims, coords_fetcher.register(name, array_client, array_structure), ) elif "xarray_data_var" in spec_names: data_vars[name] = ( array_client.dims, data_vars_fetcher.register(name, array_client, array_structure), ) else: raise ValueError( "Child nodes of xarray_dataset should include spec " "'xarray_coord' or 'xarray_data_var'." ) else: if "xarray_coord" in spec_names: coords[name] = (array_client.dims, array_client.read()) elif "xarray_data_var" in spec_names: data_vars[name] = (array_client.dims, array_client.read()) else: raise ValueError( "Child nodes of xarray_dataset should include spec " "'xarray_coord' or 'xarray_data_var'." ) return data_vars, coords
[docs] def read(self, variables=None, *, optimize_wide_table=True): data_vars, coords = self._build_arrays(variables, optimize_wide_table) return xarray.Dataset( data_vars=data_vars, coords=coords, attrs=self.metadata["attrs"] )
[docs]class DatasetClient(DaskDatasetClient):
[docs] def read(self, variables=None, *, optimize_wide_table=True): return ( super() .read(variables=variables, optimize_wide_table=optimize_wide_table) .load() )
def download(self): # Do not run super().download() because DaskDatasetClient calls compute() # which does not apply here. BaseStructureClient.download(self) self._ipython_key_completions_() self.read()
URL_CHARACTER_LIMIT = 2000 # number of characters _EXTRA_CHARS_PER_ITEM = len("&field=") class _WideTableFetcher: def __init__(self, get, link): self.get = get self.link = link self.variables = [] self._dataframe = None # This lock ensures that multiple threads (e.g. dask worker threads) # do not prompts us to re-request the same data. Only the first worker # to ask for the data should trigger a request. self._lock = threading.Lock() def register(self, name, array_client, array_structure): if self._dataframe is not None: raise RuntimeError("Cannot add variables; already fetched.") self.variables.append(name) # TODO Can we avoid .values here? return dask.array.from_delayed( dask.delayed(self.dataframe)()[name].values, shape=array_structure.macro.shape, dtype=array_structure.micro.to_numpy_dtype(), ) def dataframe(self): with self._lock: if self._dataframe is None: # If self.variables contains many and/or lengthy names, # we can bump into the URI size limit commonly imposed by # HTTP stacks (e.g. nginx). The HTTP spec does not define a limit, # but a common setting is 4K or 8K (for all the headers together). # As another reference point, Internet Explorer imposes a # 2048-character limit on URLs. variables = [] dataframes = [] budget = URL_CHARACTER_LIMIT budget -= len(self.link) # Fetch the variables in batches. for variable in self.variables: budget -= _EXTRA_CHARS_PER_ITEM + len(variable) if budget < 0: # Fetch a batch and then add `variable` to the next batch. dataframes.append(self._fetch_variables(variables)) variables.clear() budget = URL_CHARACTER_LIMIT - ( _EXTRA_CHARS_PER_ITEM + len(variable) ) variables.append(variable) if variables: # Fetch the final batch. dataframes.append(self._fetch_variables(variables)) self._dataframe = pandas.concat(dataframes, axis=1).reset_index() return self._dataframe def _fetch_variables(self, variables): content = self.get( self.link, params={"format": APACHE_ARROW_FILE_MIME_TYPE, "field": variables}, ) return deserialize_arrow(content)