You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
212 lines
7.3 KiB
212 lines
7.3 KiB
import toolz
|
|
|
|
from .utils import ignoring
|
|
from .base import is_dask_collection
|
|
from .compatibility import Mapping
|
|
|
|
|
|
class HighLevelGraph(Mapping):
|
|
""" Task graph composed of layers of dependent subgraphs
|
|
|
|
This object encodes a Dask task graph that is composed of layers of
|
|
dependent subgraphs, such as commonly occurs when building task graphs
|
|
using high level collections like Dask array, bag, or dataframe.
|
|
|
|
Typically each high level array, bag, or dataframe operation takes the task
|
|
graphs of the input collections, merges them, and then adds one or more new
|
|
layers of tasks for the new operation. These layers typically have at
|
|
least as many tasks as there are partitions or chunks in the collection.
|
|
The HighLevelGraph object stores the subgraphs for each operation
|
|
separately in sub-graphs, and also stores the dependency structure between
|
|
them.
|
|
|
|
Parameters
|
|
----------
|
|
layers : Dict[str, Mapping]
|
|
The subgraph layers, keyed by a unique name
|
|
dependencies : Dict[str, Set[str]]
|
|
The set of layers on which each layer depends
|
|
|
|
Examples
|
|
--------
|
|
|
|
Here is an idealized example that shows the internal state of a
|
|
HighLevelGraph
|
|
|
|
>>> import dask.dataframe as dd
|
|
|
|
>>> df = dd.read_csv('myfile.*.csv') # doctest: +SKIP
|
|
>>> df = df + 100 # doctest: +SKIP
|
|
>>> df = df[df.name == 'Alice'] # doctest: +SKIP
|
|
|
|
>>> graph = df.__dask_graph__() # doctest: +SKIP
|
|
>>> graph.layers # doctest: +SKIP
|
|
{
|
|
'read-csv': {('read-csv', 0): (pandas.read_csv, 'myfile.0.csv'),
|
|
('read-csv', 1): (pandas.read_csv, 'myfile.1.csv'),
|
|
('read-csv', 2): (pandas.read_csv, 'myfile.2.csv'),
|
|
('read-csv', 3): (pandas.read_csv, 'myfile.3.csv')},
|
|
'add': {('add', 0): (operator.add, ('read-csv', 0), 100),
|
|
('add', 1): (operator.add, ('read-csv', 1), 100),
|
|
('add', 2): (operator.add, ('read-csv', 2), 100),
|
|
('add', 3): (operator.add, ('read-csv', 3), 100)}
|
|
'filter': {('filter', 0): (lambda part: part[part.name == 'Alice'], ('add', 0)),
|
|
('filter', 1): (lambda part: part[part.name == 'Alice'], ('add', 1)),
|
|
('filter', 2): (lambda part: part[part.name == 'Alice'], ('add', 2)),
|
|
('filter', 3): (lambda part: part[part.name == 'Alice'], ('add', 3))}
|
|
}
|
|
|
|
>>> graph.dependencies # doctest: +SKIP
|
|
{
|
|
'read-csv': set(),
|
|
'add': {'read-csv'},
|
|
'filter': {'add'}
|
|
}
|
|
|
|
See Also
|
|
--------
|
|
HighLevelGraph.from_collections :
|
|
typically used by developers to make new HighLevelGraphs
|
|
"""
|
|
def __init__(self, layers, dependencies):
|
|
for v in layers.values():
|
|
assert not isinstance(v, HighLevelGraph)
|
|
assert all(layers)
|
|
self.layers = layers
|
|
self.dependencies = dependencies
|
|
assert set(dependencies) == set(layers)
|
|
|
|
@property
|
|
def dicts(self):
|
|
# Backwards compatibility for now
|
|
return self.layers
|
|
|
|
@classmethod
|
|
def from_collections(cls, name, layer, dependencies=()):
|
|
""" Construct a HighLevelGraph from a new layer and a set of collections
|
|
|
|
This constructs a HighLevelGraph in the common case where we have a single
|
|
new layer and a set of old collections on which we want to depend.
|
|
|
|
This pulls out the ``__dask_layers__()`` method of the collections if
|
|
they exist, and adds them to the dependencies for this new layer. It
|
|
also merges all of the layers from all of the dependent collections
|
|
together into the new layers for this graph.
|
|
|
|
Parameters
|
|
----------
|
|
name : str
|
|
The name of the new layer
|
|
layer : Mapping
|
|
The graph layer itself
|
|
dependencies : List of Dask collections
|
|
A lit of other dask collections (like arrays or dataframes) that
|
|
have graphs themselves
|
|
|
|
Examples
|
|
--------
|
|
|
|
In typical usage we make a new task layer, and then pass that layer
|
|
along with all dependent collections to this method.
|
|
|
|
>>> def add(self, other):
|
|
... name = 'add-' + tokenize(self, other)
|
|
... layer = {(name, i): (add, input_key, other)
|
|
... for i, input_key in enumerate(self.__dask_keys__())}
|
|
... graph = HighLevelGraph.from_collections(name, layer, dependencies=[self])
|
|
... return new_collection(name, graph)
|
|
"""
|
|
layers = {name: layer}
|
|
deps = {}
|
|
deps[name] = set()
|
|
for collection in toolz.unique(dependencies, key=id):
|
|
if is_dask_collection(collection):
|
|
graph = collection.__dask_graph__()
|
|
if isinstance(graph, HighLevelGraph):
|
|
layers.update(graph.layers)
|
|
deps.update(graph.dependencies)
|
|
with ignoring(AttributeError):
|
|
deps[name] |= set(collection.__dask_layers__())
|
|
else:
|
|
key = id(graph)
|
|
layers[key] = graph
|
|
deps[name].add(key)
|
|
deps[key] = set()
|
|
else:
|
|
raise TypeError(type(collection))
|
|
|
|
return cls(layers, deps)
|
|
|
|
def __getitem__(self, key):
|
|
for d in self.layers.values():
|
|
if key in d:
|
|
return d[key]
|
|
raise KeyError(key)
|
|
|
|
def __len__(self):
|
|
return sum(1 for _ in self)
|
|
|
|
def items(self):
|
|
seen = set()
|
|
for d in self.layers.values():
|
|
for key in d:
|
|
if key not in seen:
|
|
seen.add(key)
|
|
yield (key, d[key])
|
|
|
|
def __iter__(self):
|
|
return toolz.unique(toolz.concat(self.layers.values()))
|
|
|
|
@classmethod
|
|
def merge(cls, *graphs):
|
|
layers = {}
|
|
dependencies = {}
|
|
for g in graphs:
|
|
if isinstance(g, HighLevelGraph):
|
|
layers.update(g.layers)
|
|
dependencies.update(g.dependencies)
|
|
elif isinstance(g, Mapping):
|
|
layers[id(g)] = g
|
|
dependencies[id(g)] = set()
|
|
else:
|
|
raise TypeError(g)
|
|
return cls(layers, dependencies)
|
|
|
|
def visualize(self, filename='dask.pdf', format=None, **kwargs):
|
|
from .dot import graphviz_to_file
|
|
g = to_graphviz(self, **kwargs)
|
|
return graphviz_to_file(g, filename, format)
|
|
|
|
|
|
def to_graphviz(hg, data_attributes=None, function_attributes=None,
|
|
rankdir='BT', graph_attr={}, node_attr=None, edge_attr=None, **kwargs):
|
|
from .dot import graphviz, name, label
|
|
|
|
if data_attributes is None:
|
|
data_attributes = {}
|
|
if function_attributes is None:
|
|
function_attributes = {}
|
|
|
|
graph_attr = graph_attr or {}
|
|
graph_attr['rankdir'] = rankdir
|
|
graph_attr.update(kwargs)
|
|
g = graphviz.Digraph(graph_attr=graph_attr,
|
|
node_attr=node_attr,
|
|
edge_attr=edge_attr)
|
|
|
|
cache = {}
|
|
|
|
for k in hg.dependencies:
|
|
k_name = name(k)
|
|
attrs = data_attributes.get(k, {})
|
|
attrs.setdefault('label', label(k, cache=cache))
|
|
attrs.setdefault('shape', 'box')
|
|
g.node(k_name, **attrs)
|
|
|
|
for k, deps in hg.dependencies.items():
|
|
k_name = name(k)
|
|
for dep in deps:
|
|
dep_name = name(dep)
|
|
g.edge(dep_name, k_name)
|
|
return g
|