diff --git a/flamenco_worker/may_i_run.py b/flamenco_worker/may_i_run.py index 5eec5ffd2d096779157cd3182ed642f4d89ccb27..2540474580a67313cc2a0190f3698dc775ba422c 100644 --- a/flamenco_worker/may_i_run.py +++ b/flamenco_worker/may_i_run.py @@ -38,11 +38,11 @@ class MayIRun: return if await self.may_i_run(task_id): - self._log.debug('Current task may run') + self._log.debug('Current task %s may run', task_id) return self._log.warning('We have to stop task %s', task_id) - await self.worker.stop_current_task() + await self.worker.stop_current_task(task_id) async def may_i_run(self, task_id: str) -> bool: """Asks the Manager whether we are still allowed to run the given task.""" diff --git a/flamenco_worker/worker.py b/flamenco_worker/worker.py index 245f7bae91ca2ceb409bc9194264b8ea3b2e5529..971d9e11975390b93f000663eb5e158c5af5ed94 100644 --- a/flamenco_worker/worker.py +++ b/flamenco_worker/worker.py @@ -284,17 +284,26 @@ class FlamencoWorker: self.single_iteration_fut = asyncio.ensure_future(self.single_iteration(delay), loop=self.loop) - async def stop_current_task(self): + async def stop_current_task(self, task_id: str): """Stops the current task by canceling the AsyncIO task. This causes a CancelledError in the self.single_iteration() function, which then takes care of the task status change and subsequent activity push. + + :param task_id: the task ID to stop. Will only perform a stop if it + matches the currently executing task. This is to avoid race + conditions. """ if not self.asyncio_execution_fut or self.asyncio_execution_fut.done(): self._log.warning('stop_current_task() called but no task is running.') return + if self.task_id != task_id: + self._log.warning('stop_current_task(%r) called, but current task is %r, not stopping', + task_id, self.task_id) + return + self._log.warning('Stopping task %s', self.task_id) self.task_is_silently_aborting = True diff --git a/tests/test_worker.py b/tests/test_worker.py index 17c53de711f78f18c9a74a22f02b1d49f12bcaca..974ca2bebe0ec692b7818558731b0cf9615d59b1 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -252,7 +252,7 @@ class TestWorkerTaskExecution(AbstractFWorkerTest): stop_called = True await asyncio.sleep(0.2) - await self.worker.stop_current_task() + await self.worker.stop_current_task(self.worker.task_id) asyncio.ensure_future(stop(), loop=self.loop) self.loop.run_until_complete(self.worker.single_iteration_fut) @@ -277,6 +277,61 @@ class TestWorkerTaskExecution(AbstractFWorkerTest): self.assertEqual(self.tuqueue.queue.call_count, 2) + def test_stop_current_task_mismatch(self): + + from tests.mock_responses import JsonResponse, CoroMock + + self.manager.post = CoroMock() + # response when fetching a task + self.manager.post.coro.return_value = JsonResponse({ + '_id': '58514d1e9837734f2e71b479', + 'job': '58514d1e9837734f2e71b477', + 'manager': '585a795698377345814d2f68', + 'project': '', + 'user': '580f8c66983773759afdb20e', + 'name': 'sleep-14-26', + 'status': 'processing', + 'priority': 50, + 'job_type': 'unittest', + 'task_type': 'sleep', + 'commands': [ + {'name': 'sleep', 'settings': {'time_in_seconds': 3}} + ] + }) + + self.worker.schedule_fetch_task() + + stop_called = False + + async def stop(): + nonlocal stop_called + stop_called = True + + await asyncio.sleep(0.2) + await self.worker.stop_current_task('other-task-id') + + asyncio.ensure_future(stop(), loop=self.loop) + self.loop.run_until_complete(self.worker.single_iteration_fut) + + self.assertTrue(stop_called) + + self.manager.post.assert_called_once_with('/task', loop=self.asyncio_loop) + self.tuqueue.queue.assert_any_call( + '/tasks/58514d1e9837734f2e71b479/update', + {'task_progress_percentage': 0, 'activity': '', + 'command_progress_percentage': 0, 'task_status': 'active', + 'current_command_idx': 0}, + ) + + # The task shouldn't be stopped, because the wrong task ID was requested to stop. + last_args, last_kwargs = self.tuqueue.queue.call_args + self.assertEqual(last_args[0], '/tasks/58514d1e9837734f2e71b479/update') + self.assertEqual(last_kwargs, {}) + self.assertIn('activity', last_args[1]) + self.assertEqual(last_args[1]['activity'], 'Task completed') + + self.assertEqual(self.tuqueue.queue.call_count, 2) + class WorkerPushToMasterTest(AbstractFWorkerTest): def test_one_activity(self):