fix(procrastinate): close Django connections around jobs

This commit is contained in:
Herculino Trotta
2026-05-02 16:15:26 -03:00
parent 8fc11b0acf
commit 34a2b6bfd4
3 changed files with 132 additions and 1 deletions

View File

@@ -1,6 +1,47 @@
import functools
import inspect
import procrastinate
from django.db import close_old_connections
_CONNECTION_CLEANUP_WRAPPED = "_wygiwyh_connection_cleanup_wrapped"
def _wrap_task_with_django_connection_cleanup(task):
if getattr(task.func, _CONNECTION_CLEANUP_WRAPPED, False):
return
func = task.func
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapped(*args, **kwargs):
close_old_connections()
try:
return await func(*args, **kwargs)
finally:
close_old_connections()
wrapped = async_wrapped
else:
@functools.wraps(func)
def sync_wrapped(*args, **kwargs):
close_old_connections()
try:
return func(*args, **kwargs)
finally:
close_old_connections()
wrapped = sync_wrapped
setattr(wrapped, _CONNECTION_CLEANUP_WRAPPED, True)
task.func = wrapped
def on_app_ready(app: procrastinate.App):
"""This function is ran upon procrastinate initialization."""
...
for task in set(app.tasks.values()):
_wrap_task_with_django_connection_cleanup(task)

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,89 @@
from unittest.mock import patch
import procrastinate
from django.db import connection
from django.test import SimpleTestCase, TransactionTestCase
from procrastinate.testing import InMemoryConnector
from apps.common.procrastinate import on_app_ready
def make_app_with_task(func):
app = procrastinate.App(connector=InMemoryConnector())
task = app.task(name="sample_task")(func)
return app, task
class ProcrastinateConnectionCleanupTests(SimpleTestCase):
def test_app_ready_closes_old_connections_around_sync_tasks(self):
calls = []
def sample_task(value):
calls.append(("task", value))
return value * 2
app, task = make_app_with_task(sample_task)
with patch(
"apps.common.procrastinate.close_old_connections",
create=True,
side_effect=lambda: calls.append(("cleanup", None)),
):
on_app_ready(app)
result = task.func(3)
self.assertEqual(result, 6)
self.assertEqual(
calls,
[
("cleanup", None),
("task", 3),
("cleanup", None),
],
)
def test_app_ready_closes_old_connections_when_sync_task_raises(self):
calls = []
def sample_task():
calls.append(("task", None))
raise RuntimeError("boom")
app, task = make_app_with_task(sample_task)
with patch(
"apps.common.procrastinate.close_old_connections",
create=True,
side_effect=lambda: calls.append(("cleanup", None)),
):
on_app_ready(app)
with self.assertRaises(RuntimeError):
task.func()
self.assertEqual(
calls,
[
("cleanup", None),
("task", None),
("cleanup", None),
],
)
class ProcrastinateConnectionRecoveryTests(TransactionTestCase):
def test_wrapped_task_recovers_from_closed_django_connection(self):
def sample_task():
with connection.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0]
app, task = make_app_with_task(sample_task)
on_app_ready(app)
connection.ensure_connection()
connection.connection.close()
self.assertEqual(task.func(), 1)