mirror of
https://github.com/eitchtee/WYGIWYH.git
synced 2026-05-03 14:04:27 +02:00
fix(procrastinate): close Django connections around jobs
This commit is contained in:
@@ -1,6 +1,47 @@
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
import procrastinate
|
||||
from django.db import close_old_connections
|
||||
|
||||
|
||||
_CONNECTION_CLEANUP_WRAPPED = "_wygiwyh_connection_cleanup_wrapped"
|
||||
|
||||
|
||||
def _wrap_task_with_django_connection_cleanup(task):
|
||||
if getattr(task.func, _CONNECTION_CLEANUP_WRAPPED, False):
|
||||
return
|
||||
|
||||
func = task.func
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapped(*args, **kwargs):
|
||||
close_old_connections()
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
finally:
|
||||
close_old_connections()
|
||||
|
||||
wrapped = async_wrapped
|
||||
else:
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapped(*args, **kwargs):
|
||||
close_old_connections()
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
close_old_connections()
|
||||
|
||||
wrapped = sync_wrapped
|
||||
|
||||
setattr(wrapped, _CONNECTION_CLEANUP_WRAPPED, True)
|
||||
task.func = wrapped
|
||||
|
||||
|
||||
def on_app_ready(app: procrastinate.App):
|
||||
"""This function is ran upon procrastinate initialization."""
|
||||
...
|
||||
for task in set(app.tasks.values()):
|
||||
_wrap_task_with_django_connection_cleanup(task)
|
||||
|
||||
1
app/apps/common/tests/__init__.py
Normal file
1
app/apps/common/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
89
app/apps/common/tests/test_procrastinate.py
Normal file
89
app/apps/common/tests/test_procrastinate.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import procrastinate
|
||||
from django.db import connection
|
||||
from django.test import SimpleTestCase, TransactionTestCase
|
||||
from procrastinate.testing import InMemoryConnector
|
||||
|
||||
from apps.common.procrastinate import on_app_ready
|
||||
|
||||
|
||||
def make_app_with_task(func):
|
||||
app = procrastinate.App(connector=InMemoryConnector())
|
||||
task = app.task(name="sample_task")(func)
|
||||
|
||||
return app, task
|
||||
|
||||
|
||||
class ProcrastinateConnectionCleanupTests(SimpleTestCase):
|
||||
def test_app_ready_closes_old_connections_around_sync_tasks(self):
|
||||
calls = []
|
||||
|
||||
def sample_task(value):
|
||||
calls.append(("task", value))
|
||||
return value * 2
|
||||
|
||||
app, task = make_app_with_task(sample_task)
|
||||
|
||||
with patch(
|
||||
"apps.common.procrastinate.close_old_connections",
|
||||
create=True,
|
||||
side_effect=lambda: calls.append(("cleanup", None)),
|
||||
):
|
||||
on_app_ready(app)
|
||||
|
||||
result = task.func(3)
|
||||
|
||||
self.assertEqual(result, 6)
|
||||
self.assertEqual(
|
||||
calls,
|
||||
[
|
||||
("cleanup", None),
|
||||
("task", 3),
|
||||
("cleanup", None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_app_ready_closes_old_connections_when_sync_task_raises(self):
|
||||
calls = []
|
||||
|
||||
def sample_task():
|
||||
calls.append(("task", None))
|
||||
raise RuntimeError("boom")
|
||||
|
||||
app, task = make_app_with_task(sample_task)
|
||||
|
||||
with patch(
|
||||
"apps.common.procrastinate.close_old_connections",
|
||||
create=True,
|
||||
side_effect=lambda: calls.append(("cleanup", None)),
|
||||
):
|
||||
on_app_ready(app)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
task.func()
|
||||
|
||||
self.assertEqual(
|
||||
calls,
|
||||
[
|
||||
("cleanup", None),
|
||||
("task", None),
|
||||
("cleanup", None),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ProcrastinateConnectionRecoveryTests(TransactionTestCase):
|
||||
def test_wrapped_task_recovers_from_closed_django_connection(self):
|
||||
def sample_task():
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
app, task = make_app_with_task(sample_task)
|
||||
on_app_ready(app)
|
||||
|
||||
connection.ensure_connection()
|
||||
connection.connection.close()
|
||||
|
||||
self.assertEqual(task.func(), 1)
|
||||
Reference in New Issue
Block a user