from __future__ import print_function, division, absolute_import
from collections import defaultdict, deque, OrderedDict
from datetime import datetime, timedelta
import logging
import math
from math import log
import os
import pickle
import random
import socket
from time import time
from timeit import default_timer
try:
from cytoolz import frequencies, topk
except ImportError:
from toolz import frequencies, topk
from toolz import memoize, valmap, first, second, keymap, unique, concat, merge
from tornado import gen
from tornado.gen import Return
from tornado.queues import Queue
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.iostream import StreamClosedError, IOStream
from dask.compatibility import PY3, unicode
from dask.core import reverse_dict
from dask.order import order
from .batched import BatchedSend
from .config import config
from .core import (rpc, connect, read, write, MAX_BUFFER_SIZE,
Server, send_recv, coerce_to_address, error_message)
from .utils import (All, ignoring, clear_queue, get_ip, ignore_exceptions,
ensure_ip, get_fileno_limit, log_errors, key_split, mean,
divide_n_among_bins)
from .utils_comm import (scatter_to_workers, gather_from_workers)
from .versions import get_versions
logger = logging.getLogger(__name__)
BANDWIDTH = config.get('bandwidth', 100e6)
ALLOWED_FAILURES = config.get('allowed-failures', 3)
LOG_PDB = config.get('pdb-on-err') or os.environ.get('DASK_ERROR_PDB', False)
[docs]class Scheduler(Server):
""" Dynamic distributed task scheduler
The scheduler tracks the current state of workers, data, and computations.
The scheduler listens for events and responds by controlling workers
appropriately. It continuously tries to use the workers to execute an ever
growing dask graph.
All events are handled quickly, in linear time with respect to their input
(which is often of constant size) and generally within a millisecond. To
accomplish this the scheduler tracks a lot of state. Every operation
maintains the consistency of this state.
The scheduler communicates with the outside world through Tornado IOStreams
It maintains a consistent and valid view of the world even when listening
to several clients at once.
A Scheduler is typically started either with the ``dask-scheduler``
executable::
$ dask-scheduler
Scheduler started at 127.0.0.1:8786
Or within a LocalCluster a Client starts up without connection
information::
>>> c = Client() # doctest: +SKIP
>>> c.cluster.scheduler # doctest: +SKIP
Scheduler(...)
Users typically do not interact with the scheduler directly but rather with
the client object ``Client``.
**State**
The scheduler contains the following state variables. Each variable is
listed along with what it stores and a brief description.
* **tasks:** ``{key: task}``:
Dictionary mapping key to a serialized task like the following:
``{'function': b'...', 'args': b'...'}`` or ``{'task': b'...'}``
* **dependencies:** ``{key: {keys}}``:
Dictionary showing which keys depend on which others
* **dependents:** ``{key: {keys}}``:
Dictionary showing which keys are dependent on which others
* **task_state:** ``{key: string}``:
Dictionary listing the current state of every task among the following:
released, waiting, stacks, queue, no-worker, processing, memory, erred
* **priority:** ``{key: tuple}``:
A score per key that determines its priority
* **waiting:** ``{key: {key}}``:
Dictionary like dependencies but excludes keys already computed
* **waiting_data:** ``{key: {key}}``:
Dictionary like dependents but excludes keys already computed
* **ready:** ``deque(key)``
Keys that are ready to run, but not yet assigned to a worker
* **stacks:** ``{worker: [keys]}``:
List of keys waiting to be sent to each worker
* **processing:** ``{worker: {key: cost}}``:
Set of keys currently in execution on each worker and their expected
duration
* **stack_durations:** ``{worker: [ints]}``:
Expected durations of stacked tasks
* **stacks_duration:** ``{worker: int}``:
Total duration of all tasks in each workers stack
* **rprocessing:** ``{key: {worker}}``:
Set of workers currently executing a particular task
* **who_has:** ``{key: {worker}}``:
Where each key lives. The current state of distributed memory.
* **has_what:** ``{worker: {key}}``:
What worker has what keys. The transpose of who_has.
* **released:** ``{keys}``
Set of keys that are known, but released from memory
* **unrunnable:** ``{key}``
Keys that we are unable to run
* **restrictions:** ``{key: {hostnames}}``:
A set of hostnames per key of where that key can be run. Usually this
is empty unless a key has been specifically restricted to only run on
certain hosts. These restrictions don't include a worker port. Any
worker on that hostname is deemed valid.
* **loose_restrictions:** ``{key}``:
Set of keys for which we are allow to violate restrictions (see above)
if not valid workers are present.
* **exceptions:** ``{key: Exception}``:
A dict mapping keys to remote exceptions
* **tracebacks:** ``{key: list}``:
A dict mapping keys to remote tracebacks stored as a list of strings
* **exceptions_blame:** ``{key: key}``:
A dict mapping a key to another key on which it depends that has failed
* **suspicious_tasks:** ``{key: int}``
Number of times a task has been involved in a worker failure
* **deleted_keys:** ``{key: {workers}}``
Locations of workers that have keys that should be deleted
* **wants_what:** ``{client: {key}}``:
What keys are wanted by each client.. The transpose of who_wants.
* **who_wants:** ``{key: {client}}``:
Which clients want each key. The active targets of computation.
* **nbytes:** ``{key: int}``:
Number of bytes for a key as reported by workers holding that key.
* **stealable:** ``[[key]]``
A list of stacks of stealable keys, ordered by stealability
* **ncores:** ``{worker: int}``:
Number of cores owned by each worker
* **idle:** ``{worker}``:
Set of workers that are not fully utilized
* **worker_info:** ``{worker: {str: data}}``:
Information about each worker
* **host_info:** ``{hostname: dict}``:
Information about each worker host
* **worker_bytes:** ``{worker: int}``:
Number of bytes in memory on each worker
* **occupancy:** ``{worker: time}``
Expected runtime for all tasks currently processing on a worker
* **services:** ``{str: port}``:
Other services running on this scheduler, like HTTP
* **loop:** ``IOLoop``:
The running Tornado IOLoop
* **streams:** ``[IOStreams]``:
A list of Tornado IOStreams from which we both accept stimuli and
report results
* **task_duration:** ``{key-prefix: time}``
Time we expect certain functions to take, e.g. ``{'sum': 0.25}``
* **coroutines:** ``[Futures]``:
A list of active futures that control operation
* **scheduler_queues:** ``[Queues]``:
A list of Tornado Queues from which we accept stimuli
* **report_queues:** ``[Queues]``:
A list of Tornado Queues on which we report results
"""
default_port = 8786
def __init__(self, center=None, loop=None,
max_buffer_size=MAX_BUFFER_SIZE, delete_interval=500,
synchronize_worker_interval=60000,
ip=None, services=None, allowed_failures=ALLOWED_FAILURES,
validate=False, steal=True, **kwargs):
# Attributes
self.ip = ip or get_ip()
self.allowed_failures = allowed_failures
self.validate = validate
self.status = None
self.delete_interval = delete_interval
self.synchronize_worker_interval = synchronize_worker_interval
self.steal = steal
# Communication state
self.loop = loop or IOLoop.current()
self.scheduler_queues = [Queue()]
self.report_queues = []
self.worker_streams = dict()
self.streams = dict()
self.coroutines = []
self._worker_coroutines = []
self._ipython_kernel = None
# Task state
self.tasks = dict()
self.task_state = dict()
self.dependencies = dict()
self.dependents = dict()
self.generation = 0
self.released = set()
self.priority = dict()
self.nbytes = dict()
self.worker_bytes = dict()
self.processing = dict()
self.rprocessing = defaultdict(set)
self.task_duration = {prefix: 0.00001 for prefix in fast_tasks}
self.restrictions = dict()
self.loose_restrictions = set()
self.suspicious_tasks = defaultdict(lambda: 0)
self.stacks = dict()
self.stack_durations = dict()
self.stack_duration = dict()
self.waiting = dict()
self.waiting_data = dict()
self.ready = deque()
self.unrunnable = set()
self.idle = set()
self.maybe_idle = set()
self.who_has = dict()
self.has_what = dict()
self.who_wants = defaultdict(set)
self.wants_what = defaultdict(set)
self.deleted_keys = defaultdict(set)
self.exceptions = dict()
self.tracebacks = dict()
self.exceptions_blame = dict()
self.datasets = dict()
self.stealable = [set() for i in range(12)]
self.key_stealable = dict()
self.stealable_unknown_durations = defaultdict(set)
# Worker state
self.ncores = dict()
self.worker_info = dict()
self.host_info = defaultdict(dict)
self.aliases = dict()
self.saturated = set()
self.occupancy = dict()
self.plugins = []
self.transition_log = deque(maxlen=config.get('transition-log-length',
100000))
self.compute_handlers = {'update-graph': self.update_graph,
'update-data': self.update_data,
'missing-data': self.stimulus_missing_data,
'client-releases-keys': self.client_releases_keys,
'restart': self.restart}
self.handlers = {'register-client': self.add_client,
'scatter': self.scatter,
'register': self.add_worker,
'unregister': self.remove_worker,
'gather': self.gather,
'cancel': self.stimulus_cancel,
'feed': self.feed,
'terminate': self.close,
'broadcast': self.broadcast,
'ncores': self.get_ncores,
'has_what': self.get_has_what,
'who_has': self.get_who_has,
'stacks': self.get_stacks,
'processing': self.get_processing,
'nbytes': self.get_nbytes,
'versions': self.get_versions,
'add_keys': self.add_keys,
'rebalance': self.rebalance,
'replicate': self.replicate,
'start_ipython': self.start_ipython,
'list_datasets': self.list_datasets,
'get_dataset': self.get_dataset,
'publish_dataset': self.publish_dataset,
'unpublish_dataset': self.unpublish_dataset,
'update_data': self.update_data,
'change_worker_cores': self.change_worker_cores}
self.services = {}
for k, v in (services or {}).items():
if isinstance(k, tuple):
k, port = k
else:
port = 0
try:
service = v(self, io_loop=self.loop)
service.listen(port)
self.services[k] = service
except Exception as e:
logger.info("Could not launch service: %s-%d", k, port,
exc_info=True)
self._transitions = {
('released', 'waiting'): self.transition_released_waiting,
('waiting', 'ready'): self.transition_waiting_ready,
('waiting', 'released'): self.transition_waiting_released,
('queue', 'processing'): self.transition_ready_processing,
('stacks', 'processing'): self.transition_ready_processing,
('processing', 'released'): self.transition_processing_released,
('queue', 'released'): self.transition_ready_released,
('stacks', 'released'): self.transition_ready_released,
('no-worker', 'released'): self.transition_ready_released,
('processing', 'memory'): self.transition_processing_memory,
('processing', 'erred'): self.transition_processing_erred,
('released', 'forgotten'): self.transition_released_forgotten,
('memory', 'forgotten'): self.transition_memory_forgotten,
('erred', 'forgotten'): self.transition_released_forgotten,
('memory', 'released'): self.transition_memory_released,
('released', 'erred'): self.transition_released_erred
}
connection_limit = get_fileno_limit() / 2
super(Scheduler, self).__init__(handlers=self.handlers,
max_buffer_size=max_buffer_size, io_loop=self.loop,
connection_limit=connection_limit, deserialize=False, **kwargs)
##################
# Administration #
##################
def __str__(self):
return '<Scheduler: "%s:%d" processes: %d cores: %d>' % (
self.ip, self.port, len(self.ncores), sum(self.ncores.values()))
__repr__ = __str__
@property
def address(self):
return '%s:%d' % (self.ip, self.port)
@property
def address_tuple(self):
return (self.ip, self.port)
[docs] def identity(self, stream):
""" Basic information about ourselves and our cluster """
d = {'type': type(self).__name__,
'id': str(self.id),
'workers': list(self.ncores),
'services': {key: v.port for (key, v) in self.services.items()},
'workers': dict(self.worker_info)}
return d
[docs] def get_versions(self, stream):
""" Basic information about ourselves and our cluster """
return get_versions()
[docs] def start(self, port=8786, start_queues=True):
""" Clear out old state and restart all running coroutines """
collections = [self.tasks, self.dependencies, self.dependents,
self.waiting, self.waiting_data, self.released, self.priority,
self.nbytes, self.restrictions, self.loose_restrictions,
self.ready, self.who_wants, self.wants_what]
for collection in collections:
collection.clear()
with ignoring(AttributeError):
for c in self._worker_coroutines:
c.cancel()
self._delete_periodic_callback = \
PeriodicCallback(callback=self.clear_data_from_workers,
callback_time=self.delete_interval,
io_loop=self.loop)
self._delete_periodic_callback.start()
self._synchronize_data_periodic_callback = \
PeriodicCallback(callback=self.synchronize_worker_data,
callback_time=self.synchronize_worker_interval,
io_loop=self.loop)
self._synchronize_data_periodic_callback.start()
if start_queues:
self.loop.add_callback(self.handle_queues, self.scheduler_queues[0], None)
for cor in self.coroutines:
if cor.done():
exc = cor.exception()
if exc:
raise exc
if self.status != 'running':
self.listen(port)
self.status = 'running'
logger.info(" Scheduler at: %20s:%s", self.ip, self.port)
for k, v in self.services.items():
logger.info("%11s at: %20s:%s", k, self.ip, v.port)
return self.finished()
@gen.coroutine
[docs] def finished(self):
""" Wait until all coroutines have ceased """
while any(not c.done() for c in self.coroutines):
yield All(self.coroutines)
[docs] def close_streams(self):
""" Close all active IOStreams """
for stream in self.streams.values():
stream.stream.close()
self.rpc.close()
@gen.coroutine
[docs] def close(self, stream=None, fast=False):
""" Send cleanup signal to all coroutines then wait until finished
See Also
--------
Scheduler.cleanup
"""
self._delete_periodic_callback.stop()
self._synchronize_data_periodic_callback.stop()
for service in self.services.values():
service.stop()
yield self.cleanup()
if not fast:
yield self.finished()
self.close_streams()
self.status = 'closed'
self.stop()
@gen.coroutine
[docs] def cleanup(self):
""" Clean up queues and coroutines, prepare to stop """
if self.status == 'closing':
raise gen.Return()
self.status = 'closing'
logger.debug("Cleaning up coroutines")
for w, bstream in list(self.worker_streams.items()):
with ignoring(AttributeError):
yield bstream.close(ignore_closed=True)
for s in self.scheduler_queues[1:]:
s.put_nowait({'op': 'close-stream'})
for q in self.report_queues:
q.put_nowait({'op': 'close'})
###########
# Stimuli #
###########
[docs] def add_worker(self, stream=None, address=None, keys=(), ncores=None,
name=None, coerce_address=True, nbytes=None, now=None,
host_info=None, **info):
""" Add a new worker to the cluster """
with log_errors():
local_now = time()
now = now or time()
info = info or {}
host_info = host_info or {}
if coerce_address:
address = self.coerce_address(address)
host, port = address.split(':')
self.host_info[host]['last-seen'] = local_now
if address not in self.worker_info:
self.worker_info[address] = dict()
if info:
self.worker_info[address].update(info)
if host_info:
self.host_info[host].update(host_info)
delay = time() - now
self.worker_info[address]['time-delay'] = delay
self.worker_info[address]['last-seen'] = time()
if address in self.ncores:
return 'OK'
name = name or address
if name in self.aliases:
return 'name taken, %s' % name
if coerce_address:
if 'ports' not in self.host_info[host]:
self.host_info[host].update({'ports': set(), 'cores': 0})
self.host_info[host]['ports'].add(port)
self.host_info[host]['cores'] += ncores
self.ncores[address] = ncores
self.aliases[name] = address
self.worker_info[address]['name'] = name
if address not in self.processing:
self.has_what[address] = set()
self.worker_bytes[address] = 0
self.processing[address] = dict()
self.occupancy[address] = 0
self.stacks[address] = deque()
self.stack_durations[address] = deque()
self.stack_duration[address] = 0
if nbytes:
self.nbytes.update(nbytes)
# for key in keys: # TODO
# self.mark_key_in_memory(key, [address])
self.worker_streams[address] = BatchedSend(interval=2, loop=self.loop)
self._worker_coroutines.append(self.worker_stream(address))
if self.ncores[address] > len(self.processing[address]):
self.idle.add(address)
for key in list(self.unrunnable):
r = self.restrictions.get(key, [])
if address in r or host in r or name in r:
self.transitions({key: 'released'})
self.maybe_idle.add(address)
self.ensure_occupied()
logger.info("Register %s", str(address))
return 'OK'
[docs] def update_graph(self, client=None, tasks=None, keys=None,
dependencies=None, restrictions=None, priority=None,
loose_restrictions=None):
"""
Add new computations to the internal dask graph
This happens whenever the Client calls submit, map, get, or compute.
"""
for k in list(tasks):
if tasks[k] is k:
del tasks[k]
if k in self.tasks:
del tasks[k]
original_keys = keys
keys = set(keys)
for k in keys:
self.who_wants[k].add(client)
self.wants_what[client].add(k)
n = 0
while len(tasks) != n: # walk thorough new tasks, cancel any bad deps
n = len(tasks)
for k, deps in list(dependencies.items()):
if any(dep not in self.dependencies and dep not in tasks
for dep in deps): # bad key
logger.info('User asked for computation on lost data, %s', k)
del tasks[k]
del dependencies[k]
if k in keys:
keys.remove(k)
self.report({'op': 'cancelled-key', 'key': k})
self.client_releases_keys(keys=[k], client=client)
stack = list(keys)
touched = set()
while stack:
k = stack.pop()
if k in self.dependencies:
continue
touched.add(k)
if k not in self.tasks and k in tasks:
self.tasks[k] = tasks[k]
self.dependencies[k] = set(dependencies.get(k, ()))
self.released.add(k)
self.task_state[k] = 'released'
for dep in self.dependencies[k]:
if dep not in self.dependents:
self.dependents[dep] = set()
self.dependents[dep].add(k)
if k not in self.dependents:
self.dependents[k] = set()
stack.extend(self.dependencies[k])
recommendations = OrderedDict()
new_priority = priority or order(tasks) # TODO: define order wrt old graph
self.generation += 1 # older graph generations take precedence
for key in set(new_priority) & touched:
if key not in self.priority:
self.priority[key] = (self.generation, new_priority[key]) # prefer old
if restrictions:
restrictions = {k: set(map(self.coerce_address, v))
for k, v in restrictions.items()}
self.restrictions.update(restrictions)
if loose_restrictions:
self.loose_restrictions |= set(loose_restrictions)
for key in sorted(touched | keys, key=self.priority.get):
if self.task_state[key] == 'released':
recommendations[key] = 'waiting'
for key in touched | keys:
for dep in self.dependencies[key]:
if dep in self.exceptions_blame:
self.exceptions_blame[key] = self.exceptions_blame[dep]
recommendations[key] = 'erred'
break
self.transitions(recommendations)
for plugin in self.plugins[:]:
try:
plugin.update_graph(self, client=client, tasks=tasks,
keys=keys, restrictions=restrictions or {},
dependencies=dependencies,
loose_restrictions=loose_restrictions)
except Exception as e:
logger.exception(e)
for key in keys:
if self.task_state[key] in ('memory', 'erred'):
self.report_on_key(key)
self.ensure_occupied()
[docs] def stimulus_task_finished(self, key=None, worker=None, **kwargs):
""" Mark that a task has finished execution on a particular worker """
# logger.debug("Stimulus task finished %s, %s", key, worker)
self.maybe_idle.add(worker)
if key not in self.task_state:
return {}
if self.task_state[key] == 'processing':
recommendations = self.transition(key, 'memory', worker=worker,
**kwargs)
else:
recommendations = {}
if self.task_state[key] == 'memory':
self.who_has[key].add(worker)
if key not in self.has_what[worker]:
self.worker_bytes[worker] += self.nbytes.get(key, 1000)
self.has_what[worker].add(key)
return recommendations
[docs] def stimulus_task_erred(self, key=None, worker=None,
exception=None, traceback=None, **kwargs):
""" Mark that a task has erred on a particular worker """
logger.debug("Stimulus task erred %s, %s", key, worker)
self.maybe_idle.add(worker)
if key not in self.task_state:
return {}
if self.task_state[key] == 'processing':
recommendations = self.transition(key, 'erred', cause=key,
exception=exception, traceback=traceback)
else:
recommendations = {}
return recommendations
[docs] def stimulus_missing_data(self, keys=None, key=None, worker=None,
ensure=True, **kwargs):
""" Mark that certain keys have gone missing. Recover. """
logger.debug("Stimulus missing data %s, %s", key, worker)
if worker:
self.maybe_idle.add(worker)
recommendations = OrderedDict()
for k in set(keys):
if self.task_state.get(k) == 'memory':
for w in set(self.who_has[k]):
self.has_what[w].remove(k)
self.who_has[k].remove(w)
self.worker_bytes[w] -= self.nbytes.get(k, 1000)
recommendations[k] = 'released'
if key:
recommendations[key] = 'released'
self.transitions(recommendations)
if ensure:
self.ensure_occupied()
return {}
[docs] def remove_worker(self, stream=None, address=None, safe=False):
"""
Remove worker from cluster
We do this when a worker reports that it plans to leave or when it
appears to be unresponsive. This may send its tasks back to a released
state.
"""
with log_errors(pdb=LOG_PDB):
address = self.coerce_address(address)
logger.info("Remove worker %s", address)
if address not in self.processing:
return 'already-removed'
with ignoring(AttributeError):
stream = self.worker_streams[address].stream
if not stream.closed():
stream.close()
host, port = address.split(':')
self.host_info[host]['cores'] -= self.ncores[address]
self.host_info[host]['ports'].remove(port)
if not self.host_info[host]['ports']:
del self.host_info[host]
del self.worker_streams[address]
del self.ncores[address]
del self.aliases[self.worker_info[address]['name']]
del self.worker_info[address]
if address in self.maybe_idle:
self.maybe_idle.remove(address)
if address in self.idle:
self.idle.remove(address)
if address in self.saturated:
self.saturated.remove(address)
recommendations = OrderedDict()
in_flight = set(self.processing.pop(address))
for k in list(in_flight):
self.rprocessing[k].remove(address)
if not safe:
self.suspicious_tasks[k] += 1
if not safe and self.suspicious_tasks[k] > self.allowed_failures:
e = pickle.dumps(KilledWorker(k, address))
r = self.transition(k, 'erred', exception=e, cause=k)
recommendations.update(r)
in_flight.remove(k)
elif not self.rprocessing[k]:
recommendations[k] = 'released'
for k in self.stacks.pop(address):
if k in self.tasks:
recommendations[k] = 'waiting'
del self.stack_durations[address]
del self.stack_duration[address]
del self.occupancy[address]
del self.worker_bytes[address]
for key in self.has_what.pop(address):
self.who_has[key].remove(address)
if not self.who_has[key]:
if key in self.tasks:
recommendations[key] = 'released'
else:
recommendations[key] = 'forgotten'
self.transitions(recommendations)
if not self.stacks:
logger.info("Lost all workers")
self.ensure_occupied()
return 'OK'
[docs] def stimulus_cancel(self, stream, keys=None, client=None):
""" Stop execution on a list of keys """
logger.info("Client %s requests to cancel %d keys", client, len(keys))
for key in keys:
self.cancel_key(key, client)
[docs] def cancel_key(self, key, client, retries=5):
""" Cancel a particular key and all dependents """
# TODO: this should be converted to use the transition mechanism
if key not in self.who_wants: # no key yet, lets try again in 500ms
if retries:
self.loop.add_future(gen.sleep(0.2),
lambda _: self.cancel_key(key, client, retries - 1))
return
if self.who_wants[key] == {client}: # no one else wants this key
for dep in list(self.dependents[key]):
self.cancel_key(dep, client)
logger.debug("Scheduler cancels key %s", key)
self.report({'op': 'cancelled-key', 'key': key})
self.client_releases_keys(keys=[key], client=client)
[docs] def client_releases_keys(self, keys=None, client=None):
""" Remove keys from client desired list """
for key in list(keys):
if key in self.wants_what[client]:
self.wants_what[client].remove(key)
s = self.who_wants[key]
s.remove(client)
if not s:
del self.who_wants[key]
if key in self.waiting_data and not self.waiting_data[key]:
r = self.transition(key, 'released')
self.transitions(r)
if key in self.dependents and not self.dependents[key]:
r = self.transition(key, 'forgotten')
self.transitions(r)
def client_wants_keys(self, keys=None, client=None):
for k in keys:
self.who_wants[k].add(client)
self.wants_what[client].add(k)
######################################
# Task Validation (currently unused) #
######################################
def validate_released(self, key):
assert key in self.dependencies
assert self.task_state[key] == 'released'
assert key not in self.waiting_data
assert key not in self.who_has
assert key not in self.rprocessing
# assert key not in self.ready
assert key not in self.waiting
assert not any(key in self.waiting_data.get(dep, ())
for dep in self.dependencies[key])
assert key in self.released
def validate_waiting(self, key):
assert key in self.waiting
assert key in self.waiting_data
assert key not in self.who_has
assert key not in self.rprocessing
assert key not in self.released
for dep in self.dependencies[key]:
assert (dep in self.who_has) + (dep in self.waiting[key]) == 1
def validate_processing(self, key):
assert key not in self.waiting
assert key in self.waiting_data
assert key in self.rprocessing
for w in self.rprocessing[key]:
assert key in self.processing[w]
assert key not in self.who_has
for dep in self.dependencies[key]:
assert dep in self.who_has
def validate_memory(self, key):
assert key in self.who_has
assert key not in self.rprocessing
assert key not in self.waiting
assert key not in self.released
for dep in self.dependents[key]:
assert (dep in self.who_has) + (dep in self.waiting_data[key]) == 1
def validate_queue(self, key):
# assert key in self.ready
assert key not in self.released
assert key not in self.rprocessing
assert key not in self.who_has
assert key not in self.waiting
for dep in self.dependencies[key]:
assert dep in self.who_has
def validate_stacks(self, key):
# assert any(key in stack for stack in self.stacks.values())
assert key not in self.released
assert key not in self.rprocessing
assert key not in self.who_has
assert key not in self.waiting
for dep in self.dependencies[key]:
assert dep in self.who_has
def validate_key(self, key):
try:
try:
func = getattr(self, 'validate_' + self.task_state[key])
except KeyError:
logger.debug("Key lost: %s", key)
except AttributeError:
logger.info("self.validate_%s not found", self.task_state[key])
else:
func(key)
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def validate_state(self, allow_overlap=False, allow_bad_stacks=True):
validate_state(self.dependencies, self.dependents, self.waiting,
self.waiting_data, self.ready, self.who_has, self.stacks,
self.processing, None, self.released, self.who_wants,
self.wants_what, tasks=self.tasks, erred=self.exceptions_blame,
allow_overlap=allow_overlap, allow_bad_stacks=allow_bad_stacks)
if not (set(self.ncores) == \
set(self.has_what) == \
set(self.stacks) == \
set(self.processing) == \
set(self.worker_info) == \
set(self.worker_streams)):
raise ValueError("Workers not the same in all collections")
assert self.worker_bytes == {w: sum(self.nbytes[k] for k in keys)
for w, keys in self.has_what.items()}
for w in self.stacks:
assert abs(sum(self.stack_durations[w]) - self.stack_duration[w]) < 1e-8
assert len(self.stack_durations[w]) == len(self.stacks[w])
###################
# Manage Messages #
###################
[docs] def report(self, msg):
"""
Publish updates to all listening Queues and Streams
If the message contains a key then we only send the message to those
streams that care about the key.
"""
for q in self.report_queues:
q.put_nowait(msg)
if 'key' in msg:
streams = [self.streams[c]
for c in self.who_wants.get(msg['key'], ())
if c in self.streams]
else:
streams = self.streams.values()
for s in streams:
try:
s.send(msg)
# logger.debug("Scheduler sends message to client %s", msg)
except StreamClosedError:
logger.critical("Tried writing to closed stream: %s", msg)
@gen.coroutine
[docs] def add_client(self, stream, client=None):
""" Add client to network
We listen to all future messages from this IOStream.
"""
logger.info("Receive client connection: %s", client)
bstream = BatchedSend(interval=2, loop=self.loop)
bstream.start(stream)
self.streams[client] = bstream
try:
yield self.handle_messages(stream, bstream, client=client)
finally:
if not stream.closed():
bstream.send({'op': 'stream-closed'})
yield bstream.close(ignore_closed=True)
del self.streams[client]
logger.info("Close client connection: %s", client)
[docs] def remove_client(self, client=None):
""" Remove client from network """
logger.info("Remove client %s", client)
self.client_releases_keys(self.wants_what.get(client, ()), client)
with ignoring(KeyError):
del self.wants_what[client]
@gen.coroutine
[docs] def handle_messages(self, in_queue, report, client=None):
"""
The master client coroutine. Handles all inbound messages from clients.
This runs once per Client IOStream or Queue.
See Also
--------
Scheduler.worker_stream: The equivalent function for workers
"""
with log_errors(pdb=LOG_PDB):
if isinstance(in_queue, Queue):
next_message = in_queue.get
elif isinstance(in_queue, IOStream):
next_message = lambda: read(in_queue, deserialize=self.deserialize)
else:
raise NotImplementedError()
if isinstance(report, Queue):
put = report.put_nowait
elif isinstance(report, IOStream):
put = lambda msg: write(report, msg)
elif isinstance(report, BatchedSend):
put = report.send
else:
put = lambda msg: None
put({'op': 'stream-start'})
breakout = False
while True:
try:
msgs = yield next_message()
except (StreamClosedError, AssertionError, GeneratorExit):
break
except Exception as e:
logger.exception(e)
put(error_message(e, status='scheduler-error'))
continue
if not isinstance(msgs, list):
msgs = [msgs]
for msg in msgs:
# logger.debug("scheduler receives message %s", msg)
try:
op = msg.pop('op')
except Exception as e:
logger.exception(e)
put(error_message(e, status='scheduler-error'))
if op == 'close-stream':
breakout = True
break
elif op == 'close':
breakout = True
self.close()
break
elif op in self.compute_handlers:
try:
result = self.compute_handlers[op](**msg)
if isinstance(result, gen.Future):
yield result
except Exception as e:
logger.exception(e)
raise
else:
logger.warn("Bad message: op=%s, %s", op, msg, exc_info=True)
if op == 'close':
breakout = True
break
if breakout:
break
self.remove_client(client=client)
logger.debug('Finished handle_messages coroutine')
[docs] def handle_queues(self, scheduler_queue, report_queue):
"""
Register new control and report queues to the Scheduler
Queues are not in common use. This may be deprecated in the future.
"""
self.scheduler_queues.append(scheduler_queue)
if report_queue:
self.report_queues.append(report_queue)
future = self.handle_messages(scheduler_queue, report_queue)
self.coroutines.append(future)
return future
[docs] def send_task_to_worker(self, worker, key):
""" Send a single computational task to a worker """
msg = {'op': 'compute-task',
'key': key}
deps = self.dependencies[key]
if deps:
msg['who_has'] = {dep: tuple(self.who_has.get(dep, ()))
for dep in deps}
task = self.tasks[key]
if type(task) is dict:
msg.update(task)
else:
msg['task'] = task
self.worker_streams[worker].send(msg)
@gen.coroutine
[docs] def worker_stream(self, worker):
"""
Listen to responses from a single worker
This is the main loop for scheduler-worker interaction
See Also
--------
Scheduler.handle_messages: Equivalent coroutine for clients
"""
yield gen.sleep(0)
ip, port = coerce_to_address(worker, out=tuple)
stream = yield connect(ip, port)
yield write(stream, {'op': 'compute-stream'})
self.worker_streams[worker].start(stream)
logger.info("Starting worker compute stream, %s", worker)
try:
while True:
msgs = yield read(stream)
if not isinstance(msgs, list):
msgs = [msgs]
if worker in self.worker_info:
recommendations = OrderedDict()
for msg in msgs:
# logger.debug("Compute response from worker %s, %s",
# worker, msg)
if msg == 'OK': # from close
break
self.correct_time_delay(worker, msg)
key = msg['key']
if msg['status'] == 'OK':
r = self.stimulus_task_finished(worker=worker, **msg)
recommendations.update(r)
elif msg['status'] == 'error':
r = self.stimulus_task_erred(worker=worker, **msg)
recommendations.update(r)
elif msg['status'] == 'missing-data':
r = self.stimulus_missing_data(worker=worker,
ensure=False, **msg)
recommendations.update(r)
else:
logger.warn("Unknown message type, %s, %s",
msg['status'], msg)
self.transitions(recommendations)
if self.validate:
logger.info("Messages: %s\nRecommendations: %s",
msgs, recommendations)
self.ensure_occupied()
except (StreamClosedError, IOError, OSError):
logger.info("Worker failed from closed stream: %s", worker)
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
finally:
if not stream.closed():
stream.close()
self.remove_worker(address=worker)
[docs] def correct_time_delay(self, worker, msg):
"""
Apply offset time delay in message times.
Clocks on different workers differ. We keep track of a relative "now"
through periodic heartbeats. We use this known delay to align message
times to Scheduler local time. In particular this helps with
diagnostics.
Operates in place
"""
if 'time-delay' in self.worker_info[worker]:
delay = self.worker_info[worker]['time-delay']
for key in ['transfer_start', 'transfer_stop', 'time',
'compute_start', 'compute_stop', 'disk_load_start',
'disk_load_stop']:
if key in msg:
msg[key] += delay
@gen.coroutine
[docs] def clear_data_from_workers(self):
""" Send delete signals to clear unused data from workers
This watches the ``.deleted_keys`` attribute, which stores a set of
keys to be deleted from each worker. This function is run periodically
by the ``._delete_periodic_callback`` to actually remove the data.
This runs every ``self.delete_interval`` milliseconds.
"""
if self.deleted_keys:
d = self.deleted_keys.copy()
self.deleted_keys.clear()
coroutines = [self.rpc(addr=worker).delete_data(
keys=list(keys - self.has_what.get(worker,
set())),
report=False)
for worker, keys in d.items()
if keys]
for worker, keys in d.items():
logger.debug("Remove %d keys from worker %s", len(keys), worker)
yield ignore_exceptions(coroutines, socket.error, StreamClosedError)
raise Return('OK')
[docs] def add_plugin(self, plugin):
"""
Add external plugin to scheduler
See https://distributed.readthedocs.io/en/latest/plugins.html
"""
self.plugins.append(plugin)
[docs] def remove_plugin(self, plugin):
""" Remove external plugin from scheduler """
self.plugins.remove(plugin)
############################
# Less common interactions #
############################
@gen.coroutine
[docs] def scatter(self, stream=None, data=None, workers=None, client=None,
broadcast=False, timeout=2):
""" Send data out to workers
See also
--------
Scheduler.broadcast:
"""
start = time()
while not self.ncores:
yield gen.sleep(0.2)
if time() > start + timeout:
raise gen.TimeoutError("No workers found")
if workers is not None:
workers = [self.coerce_address(w) for w in workers]
ncores = workers if workers is not None else self.ncores
keys, who_has, nbytes = yield scatter_to_workers(ncores, data,
report=False,
serialize=False)
self.update_data(who_has=who_has, nbytes=nbytes, client=client)
if broadcast:
if broadcast == True:
n = len(ncores)
else:
n = broadcast
yield self.replicate(keys=keys, workers=workers, n=n)
raise gen.Return(keys)
@gen.coroutine
[docs] def gather(self, stream=None, keys=None):
""" Collect data in from workers """
keys = list(keys)
who_has = {key: self.who_has.get(key, ()) for key in keys}
try:
data = yield gather_from_workers(who_has, rpc=self.rpc, close=False)
result = {'status': 'OK', 'data': data}
except KeyError as e:
logger.debug("Couldn't gather keys %s", e)
result = {'status': 'error', 'keys': e.args}
raise gen.Return(result)
@gen.coroutine
[docs] def restart(self, environment=None):
""" Restart all workers. Reset local state. """
n = len(self.ncores)
with log_errors():
logger.debug("Send shutdown signal to workers")
for q in self.scheduler_queues + self.report_queues:
clear_queue(q)
nannies = {addr: d['services']['nanny']
for addr, d in self.worker_info.items()}
for addr in nannies:
self.remove_worker(address=addr)
for client, keys in self.wants_what.items():
self.client_releases_keys(keys=keys, client=client)
logger.debug("Send kill signal to nannies: %s", nannies)
nannies = [rpc(ip=worker_address.split(':')[0], port=n_port)
for worker_address, n_port in nannies.items()]
try:
yield All([nanny.kill() for nanny in nannies])
logger.debug("Received done signal from nannies")
while self.ncores:
yield gen.sleep(0.01)
logger.debug("Workers all removed. Sending startup signal")
# All quiet
resps = yield All([nanny.instantiate(close=True,
environment=environment) for nanny in nannies])
assert all(resp == 'OK' for resp in resps)
finally:
for nanny in nannies:
nanny.close_rpc()
self.start()
logger.debug("All workers reporting in")
self.report({'op': 'restart'})
for plugin in self.plugins[:]:
try:
plugin.restart(self)
except Exception as e:
logger.exception(e)
@gen.coroutine
[docs] def broadcast(self, stream=None, msg=None, workers=None, hosts=None,
nanny=False):
""" Broadcast message to workers, return all results """
if workers is None:
if hosts is None:
workers = list(self.ncores)
else:
workers = []
if hosts is not None:
for host in hosts:
if host in self.host_info:
workers.extend([host + ':' + port
for port in self.host_info[host]['ports']])
# TODO replace with worker_list
if nanny:
addresses = []
for addr in workers:
ip = addr.split(':')[0]
port = self.worker_info[addr]['services']['nanny']
addresses.append('%s:%d' % (ip, port))
else:
addresses = workers
results = yield All([send_recv(arg=address, close=True, **msg)
for address in addresses])
raise Return(dict(zip(workers, results)))
@gen.coroutine
[docs] def rebalance(self, stream=None, keys=None, workers=None):
""" Rebalance keys so that each worker stores roughly equal bytes
**Policy**
This orders the workers by what fraction of bytes of the existing keys
they have. It walks down this list from most-to-least. At each worker
it sends the largest results it can find and sends them to the least
occupied worker until either the sender or the recipient are at the
average expected load.
"""
with log_errors():
keys = set(keys or self.who_has)
workers = set(workers or self.ncores)
if not keys.issubset(self.who_has):
raise Return({'status': 'missing-data',
'keys': list(keys - set(self.who_has))})
workers_by_key = {k: self.who_has.get(k, set()) & workers for k in keys}
keys_by_worker = {w: set() for w in workers}
for k, v in workers_by_key.items():
for vv in v:
keys_by_worker[vv].add(k)
worker_bytes = {w: sum(self.nbytes.get(k, 1000) for k in v)
for w, v in keys_by_worker.items()}
avg = sum(worker_bytes.values()) / len(worker_bytes)
sorted_workers = list(map(first, sorted(worker_bytes.items(),
key=second, reverse=True)))
recipients = iter(reversed(sorted_workers))
recipient = next(recipients)
msgs = [] # (sender, recipient, key)
for sender in sorted_workers[:len(workers) // 2]:
sender_keys = {k: self.nbytes.get(k, 1000)
for k in keys_by_worker[sender]}
sender_keys = iter(sorted(sender_keys.items(),
key=second, reverse=True))
try:
while worker_bytes[sender] > avg:
while (worker_bytes[recipient] < avg and
worker_bytes[sender] > avg):
k, nb = next(sender_keys)
if k not in keys_by_worker[recipient]:
keys_by_worker[recipient].add(k)
# keys_by_worker[sender].remove(k)
msgs.append((sender, recipient, k))
worker_bytes[sender] -= nb
worker_bytes[recipient] += nb
if worker_bytes[sender] > avg:
recipient = next(recipients)
except StopIteration:
break
to_recipients = defaultdict(lambda: defaultdict(list))
to_senders = defaultdict(list)
for sender, recipient, key in msgs:
to_recipients[recipient][key].append(sender)
to_senders[sender].append(key)
result = yield {r: self.rpc(addr=r).gather(who_has=v)
for r, v in to_recipients.items()}
if not all(r['status'] == 'OK' for r in result.values()):
raise Return({'status': 'missing-data',
'keys': sum([r['keys'] for r in result
if 'keys' in r], [])})
for sender, recipient, key in msgs:
self.who_has[key].add(recipient)
self.has_what[recipient].add(key)
self.worker_bytes[recipient] += self.nbytes.get(key, 1000)
result = yield {r: self.rpc(addr=r).delete_data(keys=v, report=False)
for r, v in to_senders.items()}
for sender, recipient, key in msgs:
self.who_has[key].remove(sender)
self.has_what[sender].remove(key)
self.worker_bytes[sender] -= self.nbytes.get(key, 1000)
raise Return({'status': 'OK'})
@gen.coroutine
[docs] def replicate(self, stream=None, keys=None, n=None, workers=None,
branching_factor=2, delete=True):
""" Replicate data throughout cluster
This performs a tree copy of the data throughout the network
individually on each piece of data.
Parameters
----------
keys: Iterable
list of keys to replicate
n: int
Number of replications we expect to see within the cluster
branching_factor: int, optional
The number of workers that can copy data in each generation
See also
--------
Scheduler.rebalance
"""
workers = set(self.workers_list(workers))
if n is None:
n = len(workers)
n = min(n, len(workers))
keys = set(keys)
if n == 0:
raise ValueError("Can not use replicate to delete data")
if not keys.issubset(self.who_has):
raise Return({'status': 'missing-data',
'keys': list(keys - set(self.who_has))})
# Delete extraneous data
if delete:
del_keys = {k: random.sample(self.who_has[k] & workers,
len(self.who_has[k] & workers) - n)
for k in keys
if len(self.who_has[k] & workers) > n}
del_workers = {k: v for k, v in reverse_dict(del_keys).items() if v}
yield [self.rpc(addr=worker).delete_data(keys=list(keys),
report=False)
for worker, keys in del_workers.items()]
for worker, keys in del_workers.items():
self.has_what[worker] -= keys
for key in keys:
self.who_has[key].remove(worker)
self.worker_bytes[worker] -= self.nbytes.get(key, 1000)
keys = {k for k in keys if len(self.who_has[k] & workers) < n}
# Copy not-yet-filled data
while keys:
gathers = defaultdict(dict)
for k in list(keys):
missing = workers - self.who_has[k]
count = min(max(n - len(self.who_has[k] & workers), 0),
branching_factor * len(self.who_has[k]))
if not count:
keys.remove(k)
else:
sample = random.sample(missing, count)
for w in sample:
gathers[w][k] = list(self.who_has[k])
results = yield {w: self.rpc(addr=w).gather(who_has=who_has)
for w, who_has in gathers.items()}
for w, v in results.items():
if v['status'] == 'OK':
self.add_keys(address=w, keys=list(gathers[w]))
def workers_to_close(self, memory_ratio=2):
if not self.idle or self.ready:
return []
limit_bytes = {w: self.worker_info[w]['memory_limit']
for w in self.worker_info}
worker_bytes = self.worker_bytes
limit = sum(limit_bytes.values())
total = sum(worker_bytes.values())
idle = sorted(self.idle, key=worker_bytes.get, reverse=True)
to_close = []
while idle:
w = idle.pop()
limit -= limit_bytes[w]
if limit >= memory_ratio * total: # still plenty of space
to_close.append(w)
else:
break
return to_close
@gen.coroutine
def retire_workers(self, stream=None, workers=None, remove=True):
if workers is None:
while True:
try:
workers = self.workers_to_close()
if workers:
yield self.retire_workers(workers=workers, remove=remove)
raise gen.Return(list(workers))
except KeyError: # keys left during replicate
pass
workers = set(workers)
keys = set.union(*[self.has_what[w] for w in workers])
keys = {k for k in keys if self.who_has[k].issubset(workers)}
other_workers = set(self.worker_info) - workers
if keys:
if other_workers:
yield self.replicate(keys=keys, workers=other_workers, n=1,
delete=False)
else:
raise gen.Return([])
if remove:
for w in workers:
self.remove_worker(address=w, safe=True)
raise gen.Return(list(workers))
@gen.coroutine
def synchronize_worker_data(self, stream=None, worker=None):
if worker is None:
result = yield {w: self.synchronize_worker_data(worker=w)
for w in self.worker_info}
result = {k: v for k, v in result.items() if any(v.values())}
if result:
logger.info("Excess keys found on workers: %s", result)
raise Return(result or None)
else:
keys = yield self.rpc(addr=worker).keys()
keys = set(keys)
missing = self.has_what[worker] - keys
if missing:
logger.info("Expected data missing from worker: %s, %s",
worker, missing)
extra = keys - self.has_what[worker] - self.deleted_keys[worker]
if extra:
yield gen.sleep(self.synchronize_worker_interval / 1000) # delay
keys = yield self.rpc(addr=worker).keys() # check again
extra &= set(keys) # make sure the keys are still on worker
extra -= self.has_what[worker] # and still unknown to scheduler
if extra: # still around? delete them
yield self.rpc(addr=worker).delete_data(keys=list(extra),
report=False)
raise Return({'extra': list(extra), 'missing': list(missing)})
[docs] def add_keys(self, stream=None, address=None, keys=()):
"""
Learn that a worker has certain keys
This should not be used in practice and is mostly here for legacy
reasons.
"""
address = coerce_to_address(address)
if address not in self.worker_info:
return 'not found'
for key in keys:
if key in self.who_has:
if key not in self.has_what[address]:
self.worker_bytes[address] += self.nbytes.get(key, 1000)
self.has_what[address].add(key)
self.who_has[key].add(address)
# else:
# TODO: delete key from worker
return 'OK'
[docs] def update_data(self, stream=None, who_has=None, nbytes=None, client=None):
"""
Learn that new data has entered the network from an external source
See Also
--------
Scheduler.mark_key_in_memory
"""
with log_errors():
who_has = {k: [self.coerce_address(vv) for vv in v]
for k, v in who_has.items()}
logger.debug("Update data %s", who_has)
if client:
self.client_wants_keys(keys=list(who_has), client=client)
# for key, workers in who_has.items(): # TODO
# self.mark_key_in_memory(key, workers)
self.nbytes.update(nbytes)
for key, workers in who_has.items():
if key not in self.dependents:
self.dependents[key] = set()
if key not in self.dependencies:
self.dependencies[key] = set()
self.task_state[key] = 'memory'
self.who_has[key] = set(workers)
for w in workers:
if key not in self.has_what[w]:
self.worker_bytes[w] += self.nbytes.get(key, 1000)
self.has_what[w].add(key)
self.waiting_data[key] = set()
self.report({'op': 'key-in-memory',
'key': key,
'workers': list(workers)})
def report_on_key(self, key):
if key not in self.task_state:
self.report({'op': 'cancelled-key',
'key': key})
elif self.task_state[key] == 'memory':
self.report({'op': 'key-in-memory',
'key': key})
elif self.task_state[key] == 'erred':
failing_key = self.exceptions_blame[key]
self.report({'op': 'task-erred',
'key': key,
'exception': self.exceptions[failing_key],
'traceback': self.tracebacks.get(failing_key, None)})
@gen.coroutine
[docs] def feed(self, stream, function=None, setup=None, teardown=None, interval=1, **kwargs):
"""
Provides a data stream to external requester
Caution: this runs arbitrary Python code on the scheduler. This should
eventually be phased out. It is mostly used by diagnostics.
"""
import pickle
with log_errors():
if function:
function = pickle.loads(function)
if setup:
setup = pickle.loads(setup)
if teardown:
teardown = pickle.loads(teardown)
state = setup(self) if setup else None
if isinstance(state, gen.Future):
state = yield state
try:
while True:
if state is None:
response = function(self)
else:
response = function(self, state)
yield write(stream, response)
yield gen.sleep(interval)
except (OSError, IOError, StreamClosedError):
if teardown:
teardown(self, state)
def get_stacks(self, stream=None, workers=None):
if workers is not None:
workers = set(map(self.coerce_address, workers))
return {w: list(self.stacks[w]) for w in workers}
else:
return valmap(list, self.stacks)
def get_processing(self, stream=None, workers=None):
if workers is not None:
workers = set(map(self.coerce_address, workers))
return {w: list(self.processing[w]) for w in workers}
else:
return valmap(list, self.processing)
def get_who_has(self, stream=None, keys=None):
if keys is not None:
return {k: list(self.who_has.get(k, [])) for k in keys}
else:
return valmap(list, self.who_has)
def get_has_what(self, stream=None, workers=None):
if workers is not None:
workers = map(self.coerce_address, workers)
return {w: list(self.has_what.get(w, ())) for w in workers}
else:
return valmap(list, self.has_what)
def get_ncores(self, stream=None, workers=None):
if workers is not None:
workers = map(self.coerce_address, workers)
return {w: self.ncores.get(w, None) for w in workers}
else:
return self.ncores
def get_nbytes(self, stream=None, keys=None, summary=True):
with log_errors():
if keys is not None:
result = {k: self.nbytes[k] for k in keys}
else:
result = self.nbytes
if summary:
out = defaultdict(lambda: 0)
for k, v in result.items():
out[key_split(k)] += v
result = out
return result
def publish_dataset(self, stream=None, keys=None, data=None, name=None,
client=None):
if name in self.datasets:
raise KeyError("Dataset %s already exists" % name)
self.client_wants_keys(keys, 'published-%s' % name)
self.datasets[name] = {'data': data, 'keys': keys}
return {'status': 'OK', 'name': name}
def unpublish_dataset(self, stream=None, name=None):
out = self.datasets.pop(name, {'keys': []})
self.client_releases_keys(out['keys'], 'published-%s' % name)
def list_datasets(self, *args):
return list(sorted(self.datasets.keys()))
def get_dataset(self, stream, name=None, client=None):
if name in self.datasets:
return self.datasets[name]
else:
raise KeyError("Dataset '%s' not found" % name)
[docs] def change_worker_cores(self, stream=None, worker=None, diff=0):
""" Add or remove cores from a worker
This is used when a worker wants to spin off a long-running task
"""
self.ncores[worker] += diff
self.maybe_idle.add(worker)
self.ensure_occupied()
#####################
# State Transitions #
#####################
def transition_released_waiting(self, key):
try:
if self.validate:
assert key in self.tasks
assert key in self.dependencies
assert key in self.dependents
assert key not in self.waiting
# assert key not in self.readyset
# assert key not in self.rstacks
assert key not in self.who_has
assert key not in self.rprocessing
# assert all(dep in self.task_state
# for dep in self.dependencies[key])
if not all(dep in self.task_state for dep in
self.dependencies[key]):
return {key: 'forgotten'}
self.waiting[key] = set()
recommendations = OrderedDict()
for dep in self.dependencies[key]:
if dep in self.exceptions_blame:
self.exceptions_blame[key] = self.exceptions_blame[dep]
recommendations[key] = 'erred'
return recommendations
for dep in self.dependencies[key]:
if dep not in self.who_has:
self.waiting[key].add(dep)
if dep in self.released:
recommendations[dep] = 'waiting'
else:
self.waiting_data[dep].add(key)
if not self.waiting[key]:
recommendations[key] = 'ready'
self.waiting_data[key] = {dep for dep in self.dependents[key]
if dep not in self.who_has
and dep not in self.released
and dep not in self.exceptions_blame}
self.task_state[key] = 'waiting'
self.released.remove(key)
if self.validate:
assert key in self.waiting
assert key in self.waiting_data
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_waiting_ready(self, key):
try:
if self.validate:
assert key in self.waiting
assert not self.waiting[key]
assert key not in self.who_has
assert key not in self.exceptions_blame
assert key not in self.rprocessing
# assert key not in self.readyset
assert key not in self.unrunnable
assert all(dep in self.who_has
for dep in self.dependencies[key])
del self.waiting[key]
if self.dependencies.get(key, None) or key in self.restrictions:
new_worker = decide_worker(self.dependencies, self.stacks,
self.stack_duration, self.processing, self.who_has,
self.has_what, self.restrictions,
self.loose_restrictions, self.nbytes, self.ncores, key)
if not new_worker:
self.unrunnable.add(key)
self.task_state[key] = 'no-worker'
else:
self.stacks[new_worker].append(key)
duration = self.task_duration.get(key_split(key), 0.5)
self.stack_durations[new_worker].append(duration)
self.stack_duration[new_worker] += duration
self.maybe_idle.add(new_worker)
self.put_key_in_stealable(key)
self.task_state[key] = 'stacks'
else:
self.ready.appendleft(key)
self.task_state[key] = 'queue'
return {}
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_ready_processing(self, key, worker=None, latency=5e-3):
try:
if self.validate:
assert key not in self.waiting
assert key not in self.who_has
assert key not in self.exceptions_blame
assert self.task_state[key] in ('queue', 'stacks')
if self.task_state[key] == 'no-worker':
raise ValueError()
assert worker
duration = self.task_duration.get(key_split(key), latency*100)
self.processing[worker][key] = duration
self.rprocessing[key].add(worker)
self.occupancy[worker] += duration
self.task_state[key] = 'processing'
self.remove_key_from_stealable(key)
# logger.debug("Send job to worker: %s, %s", worker, key)
try:
self.send_task_to_worker(worker, key)
except StreamClosedError:
self.remove_worker(worker)
return {}
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_processing_memory(self, key, nbytes=None, type=None,
worker=None, compute_start=None, compute_stop=None,
transfer_start=None, transfer_stop=None, **kwargs):
try:
if self.validate:
assert key in self.rprocessing
assert all(key in self.processing[w] for w in self.rprocessing[key])
assert key not in self.waiting
assert key not in self.who_has
assert key not in self.exceptions_blame
# assert all(dep in self.waiting_data[key ] for dep in
# self.dependents[key] if self.task_state[dep] in
# ['waiting', 'queue', 'stacks'])
# assert key not in self.nbytes
assert self.task_state[key] == 'processing'
if worker not in self.processing:
return {key: 'released'}
#############################
# Update Timing Information #
#############################
if compute_start:
# Update average task duration for worker
info = self.worker_info[worker]
ks = key_split(key)
gap = (transfer_start or compute_start) - info.get('last-task', 0)
old_duration = self.task_duration.get(ks, 0)
new_duration = compute_stop - compute_start
if (not old_duration or
gap > max(10e-3, info.get('latency', 0), old_duration)):
avg_duration = new_duration
else:
avg_duration = (0.5 * old_duration
+ 0.5 * new_duration)
self.task_duration[ks] = avg_duration
if ks in self.stealable_unknown_durations:
for k in self.stealable_unknown_durations.pop(ks, ()):
if self.task_state.get(k) == 'stacks':
self.put_key_in_stealable(k)
info['last-task'] = compute_stop
############################
# Update State Information #
############################
if nbytes:
self.nbytes[key] = nbytes
self.who_has[key] = set()
if worker:
self.who_has[key].add(worker)
self.has_what[worker].add(key)
self.worker_bytes[worker] += self.nbytes.get(key, 1000)
if nbytes:
self.nbytes[key] = nbytes
workers = self.rprocessing.pop(key)
for worker in workers:
self.occupancy[worker] -= self.processing[worker].pop(key)
recommendations = OrderedDict()
deps = self.dependents.get(key, [])
if len(deps) > 1:
deps = sorted(deps, key=self.priority.get, reverse=True)
for dep in deps:
if dep in self.waiting:
s = self.waiting[dep]
s.remove(key)
if not s: # new task ready to run
recommendations[dep] = 'ready'
for dep in self.dependencies.get(key, []):
if dep in self.waiting_data:
s = self.waiting_data[dep]
s.remove(key)
if (not s and dep and
dep not in self.who_wants and
not self.waiting_data.get(dep)):
recommendations[dep] = 'released'
if (not self.waiting_data.get(key) and
key not in self.who_wants):
recommendations[key] = 'released'
else:
msg = {'op': 'key-in-memory',
'key': key}
if type is not None:
msg['type'] = type
self.report(msg)
self.task_state[key] = 'memory'
if self.validate:
assert key not in self.rprocessing
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_memory_released(self, key, safe=False):
try:
if self.validate:
assert key in self.who_has
assert key not in self.released
# assert key not in self.readyset
assert key not in self.waiting
assert key not in self.rprocessing
if safe:
assert not self.waiting_data.get(key)
# assert key not in self.who_wants
recommendations = OrderedDict()
for dep in self.waiting_data.get(key, ()): # lost dependency
if self.task_state[dep] == 'waiting':
self.waiting[dep].add(key)
else:
recommendations[dep] = 'waiting'
workers = self.who_has.pop(key)
for w in workers:
if w in self.worker_info: # in case worker has died
self.has_what[w].remove(key)
self.worker_bytes[w] -= self.nbytes.get(key, 1000)
self.deleted_keys[w].add(key)
self.released.add(key)
self.task_state[key] = 'released'
self.report({'op': 'lost-data', 'key': key})
if key not in self.tasks: # pure data
recommendations[key] = 'forgotten'
elif not all(dep in self.task_state
for dep in self.dependencies[key]):
recommendations[key] = 'forgotten'
elif key in self.who_wants or self.waiting_data.get(key):
recommendations[key] = 'waiting'
if key in self.waiting_data:
del self.waiting_data[key]
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_released_erred(self, key):
try:
if self.validate:
with log_errors(pdb=LOG_PDB):
assert key in self.exceptions_blame
assert key not in self.who_has
assert key not in self.waiting
assert key not in self.waiting_data
recommendations = {}
failing_key = self.exceptions_blame[key]
for dep in self.dependents[key]:
self.exceptions_blame[dep] = failing_key
if dep not in self.who_has:
recommendations[dep] = 'erred'
self.report({'op': 'task-erred',
'key': key,
'exception': self.exceptions[failing_key],
'traceback': self.tracebacks.get(failing_key, None)})
self.task_state[key] = 'erred'
self.released.remove(key)
# TODO: waiting data?
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_waiting_released(self, key):
try:
if self.validate:
assert key in self.waiting
assert key in self.waiting_data
assert key not in self.who_has
assert key not in self.rprocessing
recommendations = {}
del self.waiting[key]
for dep in self.dependencies[key]:
if dep in self.waiting_data:
if key in self.waiting_data[dep]:
self.waiting_data[dep].remove(key)
if not self.waiting_data[dep] and dep not in self.who_wants:
recommendations[dep] = 'released'
assert self.task_state[dep] != 'erred'
self.task_state[key] = 'released'
self.released.add(key)
if self.validate:
assert not any(key in self.waiting_data.get(dep, ())
for dep in self.dependencies[key])
if any(dep not in self.task_state for dep in
self.dependencies[key]):
recommendations[key] = 'forgotten'
elif (key not in self.exceptions_blame and
(key in self.who_wants or self.waiting_data.get(key))):
recommendations[key] = 'waiting'
del self.waiting_data[key]
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_processing_released(self, key):
try:
if self.validate:
assert key in self.rprocessing
assert key not in self.who_has
assert self.task_state[key] == 'processing'
for w in self.rprocessing.pop(key):
self.occupancy[w] -= self.processing[w].pop(key)
self.released.add(key)
self.task_state[key] = 'released'
recommendations = OrderedDict()
if any(dep not in self.task_state
for dep in self.dependencies[key]):
recommendations[key] = 'forgotten'
elif self.waiting_data[key] or key in self.who_wants:
recommendations[key] = 'waiting'
else:
for dep in self.dependencies[key]:
if dep not in self.released:
assert key in self.waiting_data[dep]
self.waiting_data[dep].remove(key)
if not self.waiting_data[dep] and dep not in self.who_wants:
recommendations[dep] = 'released'
del self.waiting_data[key]
if self.validate:
assert key not in self.rprocessing
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_ready_released(self, key):
try:
if self.validate:
assert key not in self.who_has
assert self.task_state[key] in ('stacks', 'queue', 'no-worker')
if self.task_state[key] == 'no-worker':
self.unrunnable.remove(key)
if self.task_state[key] == 'stacks': # TODO: non-linear
for w in self.stacks:
if key in self.stacks[w]:
for i, k in enumerate(self.stacks[w]):
if k == key:
del self.stacks[w][i]
duration = self.stack_durations[w][i]
del self.stack_durations[w][i]
self.stack_duration[w] -= duration
break
self.released.add(key)
self.task_state[key] = 'released'
for dep in self.dependencies[key]:
try:
self.waiting_data[dep].remove(key)
except KeyError: # dep may also be released
pass
# TODO: maybe release dep if not about to wait?
if self.waiting_data[key] or key in self.who_wants:
recommendations = {key: 'waiting'}
else:
recommendations = {}
del self.waiting_data[key]
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_processing_erred(self, key, cause=None, exception=None,
traceback=None):
try:
if self.validate:
assert cause or key in self.exceptions_blame
assert key in self.rprocessing
assert key not in self.who_has
assert key not in self.waiting
# assert key not in self.rstacks
# assert key not in self.readyset
if exception:
self.exceptions[key] = exception
if traceback:
self.tracebacks[key] = traceback
if cause:
self.exceptions_blame[key] = cause
failing_key = self.exceptions_blame[key]
recommendations = {}
for dep in self.dependents[key]:
self.exceptions_blame[dep] = key
recommendations[dep] = 'erred'
for dep in self.dependencies.get(key, []):
if dep in self.waiting_data:
s = self.waiting_data[dep]
if key in s:
s.remove(key)
if (not s and dep and
dep not in self.who_wants and
not self.waiting_data.get(dep)):
recommendations[dep] = 'released'
for w in self.rprocessing.pop(key):
self.occupancy[w] -= self.processing[w].pop(key)
del self.waiting_data[key] # do anything with this?
self.task_state[key] = 'erred'
self.report({'op': 'task-erred',
'key': key,
'exception': self.exceptions[failing_key],
'traceback': self.tracebacks.get(failing_key)})
if self.validate:
assert key not in self.rprocessing
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def remove_key(self, key):
if key in self.tasks:
del self.tasks[key]
del self.task_state[key]
if key in self.dependencies:
del self.dependencies[key]
del self.dependents[key]
if key in self.restrictions:
del self.restrictions[key]
if key in self.loose_restrictions:
self.loose_restrictions.remove(key)
if key in self.priority:
del self.priority[key]
if key in self.exceptions:
del self.exceptions[key]
if key in self.exceptions_blame:
del self.exceptions_blame[key]
if key in self.released:
self.released.remove(key)
if key in self.waiting_data:
del self.waiting_data[key]
if key in self.suspicious_tasks:
del self.suspicious_tasks[key]
if key in self.nbytes:
del self.nbytes[key]
def transition_memory_forgotten(self, key):
try:
if self.validate:
assert key in self.dependents
assert self.task_state[key] == 'memory'
assert key in self.waiting_data
assert key in self.who_has
assert key not in self.rprocessing
# assert key not in self.ready
assert key not in self.waiting
recommendations = {}
for dep in self.waiting_data[key]:
recommendations[dep] = 'forgotten'
for dep in self.dependents[key]:
if self.task_state[dep] == 'released':
recommendations[dep] = 'forgotten'
for dep in self.dependencies.get(key, ()):
try:
s = self.dependents[dep]
s.remove(key)
if not s and dep not in self.who_wants:
assert dep is not key
recommendations[dep] = 'forgotten'
except KeyError:
pass
workers = self.who_has.pop(key)
for w in workers:
if w in self.worker_info: # in case worker has died
self.has_what[w].remove(key)
self.worker_bytes[w] -= self.nbytes.get(key, 1000)
self.deleted_keys[w].add(key)
if self.validate:
assert all(key not in self.dependents[dep]
for dep in self.dependencies[key]
if dep in self.task_state)
assert all(key not in self.waiting_data.get(dep, ())
for dep in self.dependencies[key]
if dep in self.task_state)
self.remove_key(key)
self.report_on_key(key)
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
def transition_released_forgotten(self, key):
try:
if self.validate:
assert key in self.dependencies
assert self.task_state[key] in ('released', 'erred')
# assert not self.waiting_data[key]
if key in self.tasks and self.dependencies[key].issubset(self.task_state):
assert key not in self.who_wants
assert not self.dependents[key]
assert not any(key in self.waiting_data.get(dep, ())
for dep in self.dependencies[key])
assert key not in self.who_has
assert key not in self.rprocessing
# assert key not in self.ready
assert key not in self.waiting
recommendations = {}
for dep in self.dependencies[key]:
try:
s = self.dependents[dep]
s.remove(key)
if not s and dep not in self.who_wants:
assert dep is not key
recommendations[dep] = 'forgotten'
except KeyError:
pass
for dep in self.dependents[key]:
if self.task_state[dep] not in ('memory', 'error'):
recommendations[dep] = 'forgotten'
for dep in self.dependents[key]:
if self.task_state[dep] == 'released':
recommendations[dep] = 'forgotten'
for dep in self.dependencies[key]:
try:
self.waiting_data[dep].remove(key)
except KeyError:
pass
if self.validate:
assert all(key not in self.dependents[dep]
for dep in self.dependencies[key]
if dep in self.task_state)
assert all(key not in self.waiting_data.get(dep, ())
for dep in self.dependencies[key]
if dep in self.task_state)
self.remove_key(key)
self.report_on_key(key)
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
[docs] def transition(self, key, finish, *args, **kwargs):
""" Transition a key from its current state to the finish state
Examples
--------
>>> self.transition('x', 'waiting')
{'x': 'ready'}
Returns
-------
Dictionary of recommendations for future transitions
See Also
--------
Scheduler.transitions: transitive version of this function
"""
try:
try:
start = self.task_state[key]
except KeyError:
return {}
if start == finish:
return {}
if (start, finish) in self._transitions:
func = self._transitions[start, finish]
recommendations = func(key, *args, **kwargs)
else:
func = self._transitions['released', finish]
assert not args and not kwargs
a = self.transition(key, 'released')
if key in a:
func = self._transitions['released', a[key]]
b = func(key)
a = a.copy()
a.update(b)
recommendations = a
start = 'released'
finish2 = self.task_state.get(key, 'forgotten')
self.transition_log.append((key, start, finish2, recommendations))
if self.validate:
logger.info("Transition %s->%s: %s New: %s",
start, finish2, key, recommendations)
for plugin in self.plugins:
try:
plugin.transition(key, start, finish2, *args, **kwargs)
except Exception:
logger.info("Plugin failed with exception", exc_info=True)
return recommendations
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb; pdb.set_trace()
raise
[docs] def transitions(self, recommendations):
""" Process transitions until none are left
This includes feedback from previous transitions and continues until we
reach a steady state
"""
keys = set()
recommendations = recommendations.copy()
while recommendations:
key, finish = recommendations.popitem()
keys.add(key)
new = self.transition(key, finish)
recommendations.update(new)
if self.validate:
for key in keys:
self.validate_key(key)
[docs] def transition_story(self, *keys):
""" Get all transitions that touch one of the input keys """
keys = set(keys)
return [t for t in self.transition_log
if t[0] in keys or keys.intersection(t[3])]
##############################
# Assigning Tasks to Workers #
##############################
[docs] def ensure_occupied(self):
""" Run ready tasks on idle workers
**Work stealing policy**
If some workers are idle but not others, if there are no globally ready
tasks, and if there are tasks in worker stacks, then we start to pull
preferred tasks from overburdened workers and deploy them back into the
global pool in the following manner.
We determine the number of tasks to reclaim as the number of all tasks
in all stacks times the fraction of idle workers to all workers.
We sort the stacks by size and walk through them, reclaiming half of
each stack until we have enough task to fill the global pool.
We are careful not to reclaim tasks that are restricted to run on
certain workers.
See also
--------
Scheduler.ensure_occupied_queue
Scheduler.ensure_occupied_stacks
Scheduler.work_steal
"""
with log_errors(pdb=LOG_PDB):
for worker in self.maybe_idle:
self.ensure_occupied_stacks(worker)
self.maybe_idle.clear()
if self.idle and self.ready:
if len(self.ready) < len(self.idle):
def keyfunc(w):
return (-len(self.stacks[w]) - len(self.processing[w]),
-len(self.has_what.get(w, ())))
for worker in topk(len(self.ready), self.idle, key=keyfunc):
self.ensure_occupied_queue(worker, count=1)
else:
# Fill up empty cores
workers = list(self.idle)
free_cores = [self.ncores[w] - len(self.processing[w])
for w in workers]
workers2 = [] # Clean out workers that *are* actually full
free_cores2 = []
for w, fs in zip(workers, free_cores):
if fs > 0:
workers2.append(w)
free_cores2.append(fs)
if workers2:
n = min(sum(free_cores2), len(self.ready))
counts = divide_n_among_bins(n, free_cores2)
for worker, count in zip(workers2, counts):
self.ensure_occupied_queue(worker, count=count)
# Fill up unsaturated cores by time
workers = list(self.idle)
latency = 5e-3
free_time = [latency * self.ncores[w] - self.occupancy[w]
for w in workers]
workers2 = [] # Clean out workers that *are* actually full
free_time2 = []
for w, fs in zip(workers, free_time):
if fs > 0:
workers2.append(w)
free_time2.append(fs)
total_free_time = sum(free_time2)
if workers2 and total_free_time > 0:
tasks = []
while self.ready and total_free_time > 0:
task = self.ready.pop()
if self.task_state.get(task) != 'queue':
continue
total_free_time -= self.task_duration.get(key_split(task), 1)
tasks.append(task)
self.ready.extend(tasks[::-1])
counts = divide_n_among_bins(len(tasks), free_time2)
for worker, count in zip(workers2, counts):
self.ensure_occupied_queue(worker, count=count)
if self.idle and any(self.stealable):
thieves = self.work_steal()
for worker in thieves:
self.ensure_occupied_stacks(worker)
[docs] def ensure_occupied_stacks(self, worker):
""" Send tasks to worker while it has tasks and free cores
These tasks may come from the worker's own stacks or from the global
ready deque.
We update the idle workers set appropriately.
See Also
--------
Scheduler.ensure_occupied
Scheduler.ensure_occupied_queue
"""
stack = self.stacks[worker]
latency = 5e-3
while (stack and
(self.ncores[worker] > len(self.processing[worker]) or
self.occupancy[worker] < latency * self.ncores[worker])):
key = stack.pop()
duration = self.stack_durations[worker].pop()
self.stack_duration[worker] -= duration
if self.task_state.get(key) == 'stacks':
r = self.transition(key, 'processing',
worker=worker, latency=latency)
if stack:
self.saturated.add(worker)
if worker in self.idle:
self.idle.remove(worker)
else:
if worker in self.saturated:
self.saturated.remove(worker)
self._check_idle(worker)
def put_key_in_stealable(self, key):
ratio, loc = self.steal_time_ratio(key)
if ratio is not None:
self.stealable[loc].add(key)
self.key_stealable[key] = loc
def remove_key_from_stealable(self, key):
loc = self.key_stealable.pop(key, None)
if loc is not None:
try:
self.stealable[loc].remove(key)
except:
pass
[docs] def ensure_occupied_queue(self, worker, count):
"""
Send at most count tasks from the ready queue to the specified worker
See also
--------
Scheduler.ensure_occupied
Scheduler.ensure_occupied_stacks
"""
for i in range(count):
try:
key = self.ready.pop()
while self.task_state.get(key) != 'queue':
key = self.ready.pop()
except (IndexError, KeyError):
break
if self.task_state[key] == 'queue':
r = self.transition(key, 'processing', worker=worker)
self._check_idle(worker)
[docs] def work_steal(self):
""" Steal tasks from saturated workers to idle workers
This moves tasks from the bottom of the stacks of over-occupied workers
to the stacks of idling workers.
See also
--------
Scheduler.ensure_occupied
"""
if not self.steal:
return []
with log_errors():
thieves = set()
for level, stealable in enumerate(self.stealable[:-1]):
if not stealable:
continue
if len(self.idle) == len(self.ncores): # no stacks
stealable.clear()
continue
# Enough idleness to continue?
ratio = 2 ** (level - 3)
n_saturated = len(self.ncores) - len(self.idle)
duration_if_hold = len(stealable) / n_saturated
duration_if_steal = ratio
if level > 1 and duration_if_hold < duration_if_steal:
break
while stealable and self.idle:
for w in list(self.idle):
try:
key = stealable.pop()
except:
break
else:
if self.task_state.get(key, 'stacks'):
self.stacks[w].append(key)
duration = self.task_duration.get(key_split(key), 0.5)
self.stack_durations[w].append(duration)
self.stack_duration[w] += duration
thieves.add(w)
if (self.ncores[w] <=
len(self.processing[w]) + len(self.stacks[w])):
self.idle.remove(w)
if stealable:
break
logger.debug('Stolen tasks for %d workers', len(thieves))
return thieves
[docs] def steal_time_ratio(self, key, bandwidth=BANDWIDTH):
""" The compute to communication time ratio of a key
Returns
-------
ratio: The compute/communication time ratio of the task
loc: The self.stealable bin into which this key should go
"""
if key in self.restrictions and key not in self.loose_restrictions:
return None, None # don't steal
nbytes = sum(self.nbytes.get(k, 1000) for k in self.dependencies[key])
transfer_time = nbytes / bandwidth
split = key_split(key)
if split in fast_tasks:
return None, None
try:
compute_time = self.task_duration[split]
except KeyError:
self.stealable_unknown_durations[split].add(key)
return None, None
else:
try:
ratio = compute_time / transfer_time
except ZeroDivisionError:
ratio = 10000
if ratio > 8:
loc = 0
elif ratio < 2**-8:
loc = -1
else:
loc = int(-round(log(ratio) / log(2), 0) + 3)
return ratio, loc
[docs] def issaturated(self, worker, latency=5e-3):
"""
Determine if a worker has enough work to avoid being idle
A worker is saturated if the following criteria are met
1. It is working on at least as many tasks as it has cores
2. The expected time it will take to complete all of its currently
assigned tasks is at least a full round-trip time. This is
relevant when it has many small tasks
"""
return (len(self.stacks[worker]) + len(self.processing[worker])
> self.ncores[worker] and
self.occupancy[worker] > latency * self.ncores[worker])
def _check_idle(self, worker, latency=5e-3):
if not self.issaturated(worker, latency=latency):
self.idle.add(worker)
elif worker in self.idle:
self.idle.remove(worker)
#####################
# Utility functions #
#####################
[docs] def coerce_address(self, addr):
"""
Coerce possible input addresses to canonical form
Handles lists, strings, bytes, tuples, or aliases
"""
if isinstance(addr, list):
addr = tuple(addr)
if addr in self.aliases:
addr = self.aliases[addr]
if isinstance(addr, bytes):
addr = addr.decode()
if addr in self.aliases:
addr = self.aliases[addr]
if isinstance(addr, unicode):
if ':' in addr:
addr = tuple(addr.rsplit(':', 1))
else:
addr = ensure_ip(addr)
if isinstance(addr, tuple):
ip, port = addr
if PY3 and isinstance(ip, bytes):
ip = ip.decode()
ip = ensure_ip(ip)
port = int(port)
addr = '%s:%d' % (ip, port)
return addr
[docs] def workers_list(self, workers):
"""
List of qualifying workers
Takes a list of worker addresses or hostnames.
Returns a list of all worker addresses that match
"""
if workers is None:
return list(self.ncores)
out = set()
for w in workers:
if ':' in w:
out.add(w)
else:
out.update({ww for ww in self.ncores if w in ww}) # TODO: quadratic
return list(out)
[docs] def start_ipython(self, stream=None):
"""Start an IPython kernel
Returns Jupyter connection info dictionary.
"""
from ._ipython_utils import start_ipython
if self._ipython_kernel is None:
self._ipython_kernel = start_ipython(
ip=self.ip,
ns={'scheduler': self},
log=logger,
)
return self._ipython_kernel.get_connection_info()
[docs]def decide_worker(dependencies, stacks, stack_duration, processing, who_has,
has_what, restrictions, loose_restrictions, nbytes, ncores, key):
""" Decide which worker should take task
>>> dependencies = {'c': {'b'}, 'b': {'a'}}
>>> stacks = {'alice:8000': ['z'], 'bob:8000': []}
>>> processing = {'alice:8000': set(), 'bob:8000': set()}
>>> who_has = {'a': {'alice:8000'}}
>>> has_what = {'alice:8000': {'a'}}
>>> nbytes = {'a': 100}
>>> ncores = {'alice:8000': 1, 'bob:8000': 1}
>>> restrictions = {}
>>> loose_restrictions = set()
We choose the worker that has the data on which 'b' depends (alice has 'a')
>>> decide_worker(dependencies, stacks, processing, who_has, has_what,
... restrictions, loose_restrictions, nbytes, ncores, 'b')
'alice:8000'
If both Alice and Bob have dependencies then we choose the less-busy worker
>>> who_has = {'a': {'alice:8000', 'bob:8000'}}
>>> has_what = {'alice:8000': {'a'}, 'bob:8000': {'a'}}
>>> decide_worker(dependencies, stacks, processing, who_has, has_what,
... restrictions, loose_restrictions, nbytes, ncores, 'b')
'bob:8000'
Optionally provide restrictions of where jobs are allowed to occur
>>> restrictions = {'b': {'alice', 'charlie'}}
>>> decide_worker(dependencies, stacks, processing, who_has, has_what,
... restrictions, loose_restrictions, nbytes, ncores, 'b')
'alice:8000'
If the task requires data communication, then we choose to minimize the
number of bytes sent between workers. This takes precedence over worker
occupancy.
>>> dependencies = {'c': {'a', 'b'}}
>>> who_has = {'a': {'alice:8000'}, 'b': {'bob:8000'}}
>>> has_what = {'alice:8000': {'a'}, 'bob:8000': {'b'}}
>>> nbytes = {'a': 1, 'b': 1000}
>>> stacks = {'alice:8000': [], 'bob:8000': []}
>>> decide_worker(dependencies, stacks, processing, who_has, has_what,
... {}, set(), nbytes, ncores, 'c')
'bob:8000'
"""
deps = dependencies[key]
assert all(d in who_has for d in deps)
workers = frequencies([w for dep in deps
for w in who_has[dep]])
if not workers:
workers = stacks
if key in restrictions:
r = restrictions[key]
workers = {w for w in workers if w in r or w.split(':')[0] in r} # TODO: nonlinear
if not workers:
workers = {w for w in stacks if w in r or w.split(':')[0] in r}
if not workers:
if key in loose_restrictions:
return decide_worker(dependencies, stacks, stack_duration,
processing, who_has, has_what, {}, set(), nbytes,
ncores, key)
else:
return None
if not workers or not stacks:
return None
if len(workers) == 1:
return first(workers)
# Select worker that will finish task first
def objective(w):
comm_bytes = sum([nbytes.get(k, 1000) for k in dependencies[key]
if w not in who_has[k]])
stack_time = stack_duration[w] / ncores[w]
start_time = comm_bytes / BANDWIDTH + stack_time
return start_time
return min(workers, key=objective)
def validate_state(dependencies, dependents, waiting, waiting_data, ready,
who_has, stacks, processing, finished_results, released,
who_wants, wants_what, tasks=None, allow_overlap=False, allow_bad_stacks=False,
erred=None, **kwargs):
"""
Validate a current runtime state
This performs a sequence of checks on the entire graph, running in about
linear time. This raises assert errors if anything doesn't check out.
"""
in_stacks = {k for v in stacks.values() for k in v}
in_processing = {k for v in processing.values() for k in v}
keys = {key for key in dependents if not dependents[key]}
ready_set = set(ready)
assert set(waiting).issubset(dependencies), "waiting not subset of deps"
assert set(waiting_data).issubset(dependents), "waiting_data not subset"
if tasks is not None:
assert ready_set.issubset(tasks), "All ready tasks are tasks"
assert set(dependents).issubset(set(tasks) | set(who_has)), "all dependents tasks"
assert set(dependencies).issubset(set(tasks) | set(who_has)), "all dependencies tasks"
for k, v in waiting.items():
assert v, "waiting on empty set"
assert v.issubset(dependencies[k]), "waiting set not dependencies"
for vv in v:
assert vv not in who_has, ("waiting dependency in memory", k, vv)
assert vv not in released, ("dependency released", k, vv)
for dep in dependencies[k]:
assert dep in v or who_has.get(dep), ("dep missing", k, dep)
for k, v in waiting_data.items():
for vv in v:
if vv in released:
raise ValueError('dependent not in play', k, vv)
if not (vv in ready_set or
vv in waiting or
vv in in_stacks or
vv in in_processing):
raise ValueError('dependent not in play2', k, vv)
for v in concat(processing.values()):
assert v in dependencies, "all processing keys in dependencies"
for key in who_has:
assert key in waiting_data or key in who_wants
@memoize
def check_key(key):
""" Validate a single key, recurse downwards """
vals = ([key in waiting,
key in ready,
key in in_stacks,
key in in_processing,
not not who_has.get(key),
key in released,
key in erred])
if ((allow_overlap and sum(vals) < 1) or
(not allow_overlap and sum(vals) != 1)):
if not (in_stacks and waiting): # known ok state
raise ValueError("Key exists in wrong number of places", key, vals)
for dep in dependencies[key]:
if dep in dependents:
check_key(dep) # Recursive case
if who_has.get(key):
assert not any(key in waiting.get(dep, ())
for dep in dependents.get(key, ()))
assert not waiting.get(key)
if not allow_bad_stacks and (key in in_stacks or key in in_processing):
if not all(who_has.get(dep) for dep in dependencies[key]):
raise ValueError("Key in stacks/processing without all deps",
key)
assert not waiting.get(key)
assert key not in ready
if finished_results is not None:
if key in finished_results:
assert who_has.get(key)
assert key in keys
if key in keys and who_has.get(key):
assert key in finished_results
for key, s in who_wants.items():
assert s, "empty who_wants"
for client in s:
assert key in wants_what[client]
if key in waiting:
assert waiting[key], 'waiting empty'
if key in ready:
assert key not in waiting
return True
assert all(map(check_key, keys))
_round_robin = [0]
fast_tasks = {'rechunk-split', 'shuffle-split'}
class KilledWorker(Exception):
pass