diff --git a/app/apps/common/procrastinate.py b/app/apps/common/procrastinate.py index c6f3c3d..8e27e28 100644 --- a/app/apps/common/procrastinate.py +++ b/app/apps/common/procrastinate.py @@ -1,6 +1,47 @@ +import functools +import inspect + import procrastinate +from django.db import close_old_connections + + +_CONNECTION_CLEANUP_WRAPPED = "_wygiwyh_connection_cleanup_wrapped" + + +def _wrap_task_with_django_connection_cleanup(task): + if getattr(task.func, _CONNECTION_CLEANUP_WRAPPED, False): + return + + func = task.func + + if inspect.iscoroutinefunction(func): + + @functools.wraps(func) + async def async_wrapped(*args, **kwargs): + close_old_connections() + try: + return await func(*args, **kwargs) + finally: + close_old_connections() + + wrapped = async_wrapped + else: + + @functools.wraps(func) + def sync_wrapped(*args, **kwargs): + close_old_connections() + try: + return func(*args, **kwargs) + finally: + close_old_connections() + + wrapped = sync_wrapped + + setattr(wrapped, _CONNECTION_CLEANUP_WRAPPED, True) + task.func = wrapped def on_app_ready(app: procrastinate.App): """This function is ran upon procrastinate initialization.""" - ... + for task in set(app.tasks.values()): + _wrap_task_with_django_connection_cleanup(task) diff --git a/app/apps/common/tests/__init__.py b/app/apps/common/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/app/apps/common/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/app/apps/common/tests/test_procrastinate.py b/app/apps/common/tests/test_procrastinate.py new file mode 100644 index 0000000..be175bf --- /dev/null +++ b/app/apps/common/tests/test_procrastinate.py @@ -0,0 +1,89 @@ +from unittest.mock import patch + +import procrastinate +from django.db import connection +from django.test import SimpleTestCase, TransactionTestCase +from procrastinate.testing import InMemoryConnector + +from apps.common.procrastinate import on_app_ready + + +def make_app_with_task(func): + app = procrastinate.App(connector=InMemoryConnector()) + task = app.task(name="sample_task")(func) + + return app, task + + +class ProcrastinateConnectionCleanupTests(SimpleTestCase): + def test_app_ready_closes_old_connections_around_sync_tasks(self): + calls = [] + + def sample_task(value): + calls.append(("task", value)) + return value * 2 + + app, task = make_app_with_task(sample_task) + + with patch( + "apps.common.procrastinate.close_old_connections", + create=True, + side_effect=lambda: calls.append(("cleanup", None)), + ): + on_app_ready(app) + + result = task.func(3) + + self.assertEqual(result, 6) + self.assertEqual( + calls, + [ + ("cleanup", None), + ("task", 3), + ("cleanup", None), + ], + ) + + def test_app_ready_closes_old_connections_when_sync_task_raises(self): + calls = [] + + def sample_task(): + calls.append(("task", None)) + raise RuntimeError("boom") + + app, task = make_app_with_task(sample_task) + + with patch( + "apps.common.procrastinate.close_old_connections", + create=True, + side_effect=lambda: calls.append(("cleanup", None)), + ): + on_app_ready(app) + + with self.assertRaises(RuntimeError): + task.func() + + self.assertEqual( + calls, + [ + ("cleanup", None), + ("task", None), + ("cleanup", None), + ], + ) + + +class ProcrastinateConnectionRecoveryTests(TransactionTestCase): + def test_wrapped_task_recovers_from_closed_django_connection(self): + def sample_task(): + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + return cursor.fetchone()[0] + + app, task = make_app_with_task(sample_task) + on_app_ready(app) + + connection.ensure_connection() + connection.connection.close() + + self.assertEqual(task.func(), 1)