Source code for redis_tasks.worker

import logging
from operator import attrgetter

from .conf import connection, construct_redis_key, settings
from .exceptions import WorkerDoesNotExist
from .queue import Queue
from .registries import worker_registry
from .task import Task
from .utils import atomic_pipeline, enum, utcformat, utcnow, utcparse

logger = logging.getLogger(__name__)

WorkerState = enum(
    'WorkerState',
    IDLE='idle',
    BUSY='busy',
    DEAD='dead',
)


[docs]class Worker:
[docs] @classmethod def all(cls): return list(sorted((cls.fetch(id) for id in worker_registry.get_worker_ids()), key=attrgetter('description')))
[docs] @classmethod def fetch(cls, id): return cls(fetch_id=id)
def __init__(self, id=None, *, description=None, queues=None, fetch_id=None): self.id = id or fetch_id self.key = construct_redis_key('worker:' + self.id) self.task_key = construct_redis_key('worker_task:' + self.id) if fetch_id: self.refresh() return self.description = description or f"Worker {id}" self.state = None self.queues = queues if not queues: raise ValueError("Worker needs queues") self.current_task_id = None self.started_at = None self.shutdown_at = None
[docs] def refresh(self): with connection.pipeline() as pipeline: pipeline.hgetall(self.key) pipeline.lrange(self.task_key, 0, -1) obj, task_id = pipeline.execute() if not obj: raise WorkerDoesNotExist(self.id) assert len(task_id) < 2 self.current_task_id = task_id[0].decode() if task_id else None obj = {k.decode(): v.decode() for k, v in obj.items()} self.state = obj['state'] self.description = obj['description'] self.queues = [Queue(q) for q in obj['queues'].split(',')] for k in ['started_at', 'shutdown_at']: setattr(self, k, utcparse(obj[k]) if obj.get(k) else None)
@atomic_pipeline def _save(self, fields=None, *, pipeline): string_fields = ['description', 'state'] date_fields = ['started_at', 'shutdown_at'] special_fields = ['queues', 'current_task_id'] if fields is None: fields = string_fields + date_fields + special_fields deletes = [] store = {} for field in fields: value = getattr(self, field) if field == 'queues': store['queues'] = ','.join(q.name for q in self.queues) elif field == 'current_task_id': pipeline.delete(self.task_key) if value: pipeline.lpush(self.task_key, value) elif value is None: deletes.append(field) elif field in string_fields: store[field] = value elif field in date_fields: store[field] = utcformat(value) else: raise AttributeError(f'{field} is not a valid attribute') if deletes: pipeline.hdel(self.key, *deletes) if store: pipeline.hset(self.key, mapping=store)
[docs] def heartbeat(self): """Send a heartbeat. Raises WorkerDoesNotExist if the registry considers this worker as dead""" worker_registry.heartbeat(self)
[docs] @atomic_pipeline def startup(self, *, pipeline): logger.info(f'Worker {self.description} [{self.id}] started') self.state = WorkerState.IDLE self.started_at = utcnow() worker_registry.add(self, pipeline=pipeline) self._save(pipeline=pipeline)
[docs] @atomic_pipeline def start_task(self, task, *, pipeline): assert self.state == WorkerState.IDLE # The task should be assigned by the queue assert self.current_task_id == task.id self.state = WorkerState.BUSY self._save(['state'], pipeline=pipeline) task.set_running(self, pipeline=pipeline)
[docs] @atomic_pipeline def end_task(self, task, outcome, *, pipeline): assert self.state == WorkerState.BUSY assert self.current_task_id == task.id self.state = WorkerState.IDLE self.current_task_id = None self._save(['state', 'current_task_id'], pipeline=pipeline) task.handle_outcome(outcome, pipeline=pipeline)
[docs] @atomic_pipeline def shutdown(self, *, pipeline): logger.info(f'Worker {self.description} [{self.id}] shut down') assert self.state == WorkerState.IDLE worker_registry.remove(self, pipeline=pipeline) self.state = WorkerState.DEAD self.shutdown_at = utcnow() self._save(['state', 'shutdown_at'], pipeline=pipeline) pipeline.expire(self.key, settings.DEAD_WORKER_TTL)
[docs] @atomic_pipeline def died(self, *, pipeline): logger.warning(f'Worker {self.description} [{self.id}] died') worker_registry.remove(self, pipeline=pipeline) self.shutdown_at = utcnow() if self.current_task_id: task = Task.fetch(self.current_task_id) task.handle_worker_death(pipeline=pipeline) self.current_task_id = None self.state = WorkerState.DEAD self._save(['state', 'current_task_id', 'shutdown_at'], pipeline=pipeline) pipeline.expire(self.key, settings.DEAD_WORKER_TTL)
[docs] def fetch_current_task(self): """Returns the currently executing task.""" if self.current_task_id: return Task.fetch(self.current_task_id)