import datetime
import logging
import random
import signal
import threading
import time
import uuid
from contextlib import suppress
import croniter
from .conf import connection, construct_redis_key, settings
from .exceptions import TaskDoesNotExist
from .queue import Queue
from .smear_dst import DstSmearingTz
from .task import Task, TaskStatus
from .utils import (
LazyObject, atomic_pipeline, decode_dict, utcformat, utcnow, utcparse)
logger = logging.getLogger(__name__)
local_tz = LazyObject(lambda: DstSmearingTz(settings.TIMEZONE))
[docs]class CrontabSchedule:
def __init__(self, crontab):
self.crontab = crontab
[docs] def get_next(self, after):
after = local_tz.from_utc(after)
iter = croniter.croniter(self.crontab, after, ret_type=datetime.datetime)
return local_tz.to_utc(iter.get_next())
crontab = CrontabSchedule
[docs]def once_per_day(time_str):
hour, minute = time_str.split(':')
return CrontabSchedule(f'{minute} {hour} * * *')
[docs]class PeriodicSchedule:
def __init__(self, *, hours=0, minutes=0, seconds=0, start_at=None):
self.interval = hours * 60 * 60 + minutes * 60 + seconds
if start_at is None:
start_at = random.uniform(0, self.interval)
elif isinstance(start_at, str):
hours, minutes = start_at.split(':')
start_at = int(hours) * 60 * 60 + int(minutes) * 60
self.start_at = start_at
assert self.interval < 12 * 60 * 60
assert self.start_at < 24 * 60 * 60
[docs] def get_next(self, after):
after = after.astimezone(local_tz.tz)
midnight = local_tz.tz.localize(
after.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=None),
is_dst=True)
after_time = (after - midnight).total_seconds()
if self.start_at > after_time:
next_time = self.start_at
else:
since_start = after_time - self.start_at
remaining = self.interval - since_start % self.interval
next_time = after_time + remaining
next = local_tz.tz.normalize(midnight + datetime.timedelta(seconds=next_time))
if next.day != after.day:
midnight = local_tz.tz.localize(
after.replace(hour=23, minute=59, second=59, microsecond=0, tzinfo=None),
is_dst=False) + datetime.timedelta(seconds=1)
next = midnight + datetime.timedelta(seconds=self.start_at)
return next.astimezone(datetime.timezone.utc)
run_every = PeriodicSchedule
[docs]class SchedulerEntry:
def __init__(self, id, config):
self.id = id
self.key = construct_redis_key(f"schedule_entry:{self.id}")
self.singleton = config.get('singleton', True)
self.task_template = [config['task'],
config.get('args', ()), config.get('kwargs', {})]
self.queue = Queue(config.get('queue', settings.SCHEDULER_QUEUE))
self.schedule = config['schedule']
self.last_save = None
stored = decode_dict(connection.hgetall(self.key))
prev_run = stored.get("prev_run")
self.prev_run = utcparse(prev_run) if prev_run else utcnow()
self.prev_task_id = stored.get("prev_task_id")
self.next_run = self.schedule.get_next(self.prev_run)
# Make sure task config is valid
Task(*self.task_template)
[docs] @atomic_pipeline
def save(self, *, pipeline):
pipeline.hset(self.key, "prev_run", utcformat(self.prev_run))
if self.prev_task_id:
pipeline.hset(self.key, "prev_task_id", self.prev_task_id)
else:
pipeline.hdel(self.key, "prev_task_id")
ttl = max(24 * 60 * 60, settings.SCHEDULER_MAX_CATCHUP * 5)
pipeline.expire(self.key, ttl)
self.last_save = utcnow()
[docs] @atomic_pipeline
def process(self, now, *, pipeline):
max_catchup = now - datetime.timedelta(seconds=settings.SCHEDULER_MAX_CATCHUP)
self.next_run = self.schedule.get_next(max(max_catchup, self.prev_run))
if self.next_run > now:
if (not self.last_save or
(now - self.last_save).total_seconds() >= settings.SCHEDULER_MAX_CATCHUP):
self.save(pipeline=pipeline)
return
self.prev_run = now
if self.singleton:
self.next_run = self.schedule.get_next(now)
if self.is_enqueued():
logger.info(f'Schedule entry "{self.id}" already enqueued or running, skipping')
else:
self.enqueue(pipeline=pipeline)
else:
while self.next_run <= now:
self.enqueue(pipeline=pipeline)
self.next_run = self.schedule.get_next(self.next_run)
self.save(pipeline=pipeline)
[docs] def is_enqueued(self):
if self.prev_task_id:
with suppress(TaskDoesNotExist):
prev_task = Task.fetch(self.prev_task_id)
if prev_task.status not in [TaskStatus.FINISHED, TaskStatus.FAILED]:
return True
return False
[docs] @atomic_pipeline
def enqueue(self, *, pipeline):
task = self.queue.enqueue_call(*self.task_template, pipeline=pipeline)
self.prev_task_id = task.id
return task
[docs]class Scheduler:
def __init__(self):
self.schedule = [SchedulerEntry(k, v) for k, v in settings.SCHEDULE.items()]
self.shutdown_requested = threading.Event()
[docs] def setup_signal_handler(self):
def stop(signum, frame):
logger.info('Initiating redis_tasks scheduler shutdown')
self.shutdown_requested.set()
signal.signal(signal.SIGINT, stop)
signal.signal(signal.SIGTERM, stop)
[docs] def run(self):
if not self.schedule:
logger.error("No schedule configured, nothing to do")
return
self.setup_signal_handler()
HEARTBEAT_FREQ = 10
with Mutex(timeout=HEARTBEAT_FREQ + 2) as mutex:
logger.info('redis_tasks scheduler started')
while not self.shutdown_requested.is_set():
mutex.extend()
now = utcnow()
for entry in self.schedule:
entry.process(now)
next_run = min(x.next_run for x in self.schedule)
next_heartbeat = now + datetime.timedelta(seconds=HEARTBEAT_FREQ)
wait_for = (min(next_run, next_heartbeat) - utcnow()).total_seconds()
self.shutdown_requested.wait(wait_for)
logger.info('redis_tasks scheduler shut down')
[docs]def scheduler_main():
Scheduler().run()
[docs]class Mutex(object):
expire_script = None
def __init__(self, *, timeout):
self.key = construct_redis_key('scheduler')
self.timeout = timeout
self.token = None
# KEYS[1]: lock key, ARGS[1]: token, ARGS[2]: milliseconds
# return 1 if the lock was held and the expire executed, otherwise 0
self.expire_script = connection.register_script("""
if redis.call('get', KEYS[1]) == ARGV[1] then
return redis.call('pexpire', KEYS[1], ARGV[2])
else
return 0
end""")
def __enter__(self):
if not self.acquire(wait=False):
wait_for = self.timeout + 1
logger.warning("Found signs of an already running scheduler instance, "
f"waiting {wait_for} seconds for it to disappear")
try:
self.acquire(wait=wait_for)
except TimeoutError:
raise RuntimeError("redis_tasks scheduler already running")
return self
def __exit__(self, *exc):
if self.token:
self.expire_script(keys=[self.key], args=[self.token, 0])
self.token = None
[docs] def acquire(self, wait=None):
token = str(uuid.uuid1()).encode()
stop_trying_at = time.time() + wait
while True:
acquired = connection.set(self.key, token, nx=True, px=int(self.timeout * 1000))
if acquired:
self.token = token
return True
elif not wait:
return False
elif time.time() > stop_trying_at:
raise TimeoutError
time.sleep(0.1)
[docs] def extend(self):
if not self.expire_script(keys=[self.key], args=[self.token, int(self.timeout * 1000)]):
raise RuntimeError("Cannot refresh a lock that's no longer owned")