From 34a2b6bfd4f6b72e767f3ec2c3a4285c3e37657d Mon Sep 17 00:00:00 2001 From: Herculino Trotta Date: Sat, 2 May 2026 16:15:26 -0300 Subject: [PATCH 1/4] fix(procrastinate): close Django connections around jobs --- app/apps/common/procrastinate.py | 43 +++++++++- app/apps/common/tests/__init__.py | 1 + app/apps/common/tests/test_procrastinate.py | 89 +++++++++++++++++++++ 3 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 app/apps/common/tests/__init__.py create mode 100644 app/apps/common/tests/test_procrastinate.py diff --git a/app/apps/common/procrastinate.py b/app/apps/common/procrastinate.py index c6f3c3d..8e27e28 100644 --- a/app/apps/common/procrastinate.py +++ b/app/apps/common/procrastinate.py @@ -1,6 +1,47 @@ +import functools +import inspect + import procrastinate +from django.db import close_old_connections + + +_CONNECTION_CLEANUP_WRAPPED = "_wygiwyh_connection_cleanup_wrapped" + + +def _wrap_task_with_django_connection_cleanup(task): + if getattr(task.func, _CONNECTION_CLEANUP_WRAPPED, False): + return + + func = task.func + + if inspect.iscoroutinefunction(func): + + @functools.wraps(func) + async def async_wrapped(*args, **kwargs): + close_old_connections() + try: + return await func(*args, **kwargs) + finally: + close_old_connections() + + wrapped = async_wrapped + else: + + @functools.wraps(func) + def sync_wrapped(*args, **kwargs): + close_old_connections() + try: + return func(*args, **kwargs) + finally: + close_old_connections() + + wrapped = sync_wrapped + + setattr(wrapped, _CONNECTION_CLEANUP_WRAPPED, True) + task.func = wrapped def on_app_ready(app: procrastinate.App): """This function is ran upon procrastinate initialization.""" - ... + for task in set(app.tasks.values()): + _wrap_task_with_django_connection_cleanup(task) diff --git a/app/apps/common/tests/__init__.py b/app/apps/common/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/app/apps/common/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/app/apps/common/tests/test_procrastinate.py b/app/apps/common/tests/test_procrastinate.py new file mode 100644 index 0000000..be175bf --- /dev/null +++ b/app/apps/common/tests/test_procrastinate.py @@ -0,0 +1,89 @@ +from unittest.mock import patch + +import procrastinate +from django.db import connection +from django.test import SimpleTestCase, TransactionTestCase +from procrastinate.testing import InMemoryConnector + +from apps.common.procrastinate import on_app_ready + + +def make_app_with_task(func): + app = procrastinate.App(connector=InMemoryConnector()) + task = app.task(name="sample_task")(func) + + return app, task + + +class ProcrastinateConnectionCleanupTests(SimpleTestCase): + def test_app_ready_closes_old_connections_around_sync_tasks(self): + calls = [] + + def sample_task(value): + calls.append(("task", value)) + return value * 2 + + app, task = make_app_with_task(sample_task) + + with patch( + "apps.common.procrastinate.close_old_connections", + create=True, + side_effect=lambda: calls.append(("cleanup", None)), + ): + on_app_ready(app) + + result = task.func(3) + + self.assertEqual(result, 6) + self.assertEqual( + calls, + [ + ("cleanup", None), + ("task", 3), + ("cleanup", None), + ], + ) + + def test_app_ready_closes_old_connections_when_sync_task_raises(self): + calls = [] + + def sample_task(): + calls.append(("task", None)) + raise RuntimeError("boom") + + app, task = make_app_with_task(sample_task) + + with patch( + "apps.common.procrastinate.close_old_connections", + create=True, + side_effect=lambda: calls.append(("cleanup", None)), + ): + on_app_ready(app) + + with self.assertRaises(RuntimeError): + task.func() + + self.assertEqual( + calls, + [ + ("cleanup", None), + ("task", None), + ("cleanup", None), + ], + ) + + +class ProcrastinateConnectionRecoveryTests(TransactionTestCase): + def test_wrapped_task_recovers_from_closed_django_connection(self): + def sample_task(): + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + return cursor.fetchone()[0] + + app, task = make_app_with_task(sample_task) + on_app_ready(app) + + connection.ensure_connection() + connection.connection.close() + + self.assertEqual(task.func(), 1) From 78171183ccbd8c0a357960e3a1893f116bdd369e Mon Sep 17 00:00:00 2001 From: Herculino Trotta Date: Sat, 2 May 2026 16:15:48 -0300 Subject: [PATCH 2/4] test(currencies): avoid test discovery collision --- app/apps/currencies/{tests.py => tests/test_models.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename app/apps/currencies/{tests.py => tests/test_models.py} (100%) diff --git a/app/apps/currencies/tests.py b/app/apps/currencies/tests/test_models.py similarity index 100% rename from app/apps/currencies/tests.py rename to app/apps/currencies/tests/test_models.py From 63c69e5c6acd640e7fc26838bd5915c81455c253 Mon Sep 17 00:00:00 2001 From: Herculino Trotta Date: Sat, 2 May 2026 16:16:08 -0300 Subject: [PATCH 3/4] test(api): expect unauthorized for anonymous requests --- app/apps/api/tests/test_accounts.py | 4 ++-- app/apps/api/tests/test_imports.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/app/apps/api/tests/test_accounts.py b/app/apps/api/tests/test_accounts.py index 50e7a94..9da60c9 100644 --- a/app/apps/api/tests/test_accounts.py +++ b/app/apps/api/tests/test_accounts.py @@ -90,10 +90,10 @@ class AccountBalanceAPITests(TestCase): self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_get_balance_unauthenticated(self): - """Test unauthenticated request returns 403""" + """Test unauthenticated request returns 401""" unauthenticated_client = APIClient() response = unauthenticated_client.get( f"/api/accounts/{self.account.id}/balance/" ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) diff --git a/app/apps/api/tests/test_imports.py b/app/apps/api/tests/test_imports.py index 9509e4d..4b41cd1 100644 --- a/app/apps/api/tests/test_imports.py +++ b/app/apps/api/tests/test_imports.py @@ -159,7 +159,7 @@ column_mapping: self.assertIn("import_run_id", response.data) def test_unauthenticated_request(self): - """Test unauthenticated request returns 403""" + """Test unauthenticated request returns 401""" unauthenticated_client = APIClient() csv_content = b"date,description,amount\n2025-01-01,Test,100" @@ -173,7 +173,7 @@ column_mapping: format="multipart", ) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) @override_settings( @@ -266,11 +266,11 @@ column_mapping: self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_profiles_unauthenticated(self): - """Test unauthenticated request returns 403""" + """Test unauthenticated request returns 401""" unauthenticated_client = APIClient() response = unauthenticated_client.get("/api/import/profiles/") - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) @override_settings( @@ -397,8 +397,8 @@ column_mapping: self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_runs_unauthenticated(self): - """Test unauthenticated request returns 403""" + """Test unauthenticated request returns 401""" unauthenticated_client = APIClient() response = unauthenticated_client.get("/api/import/runs/") - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) From d72ff3cdf58012ef48d7148189375cbe29d070ac Mon Sep 17 00:00:00 2001 From: Herculino Trotta Date: Sat, 2 May 2026 16:16:27 -0300 Subject: [PATCH 4/4] fix(rules): allow category expressions to clear categories --- app/apps/rules/tasks.py | 8 ++- app/apps/rules/tests/__init__.py | 1 + app/apps/rules/tests/test_tasks.py | 82 ++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 app/apps/rules/tests/__init__.py create mode 100644 app/apps/rules/tests/test_tasks.py diff --git a/app/apps/rules/tasks.py b/app/apps/rules/tasks.py index 79e2478..5d3ee4a 100644 --- a/app/apps/rules/tasks.py +++ b/app/apps/rules/tasks.py @@ -365,7 +365,9 @@ def check_for_transaction_rules( if processed_action.set_category: value = simple.eval(processed_action.set_category) - if isinstance(value, int): + if value is None: + transaction.category = None + elif isinstance(value, int): transaction.category = TransactionCategory.objects.get(id=value) else: transaction.category = TransactionCategory.objects.get(name=value) @@ -458,7 +460,9 @@ def check_for_transaction_rules( transaction.account = account elif field == TransactionRuleAction.Field.category: - if isinstance(new_value, int): + if new_value is None: + transaction.category = None + elif isinstance(new_value, int): category = TransactionCategory.objects.get(id=new_value) transaction.category = category elif isinstance(new_value, str): diff --git a/app/apps/rules/tests/__init__.py b/app/apps/rules/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/app/apps/rules/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/app/apps/rules/tests/test_tasks.py b/app/apps/rules/tests/test_tasks.py new file mode 100644 index 0000000..7516426 --- /dev/null +++ b/app/apps/rules/tests/test_tasks.py @@ -0,0 +1,82 @@ +from datetime import date +from decimal import Decimal +from unittest.mock import patch + +from django.contrib.auth import get_user_model +from django.test import TransactionTestCase + +from apps.accounts.models import Account +from apps.currencies.models import Currency +from apps.rules.models import TransactionRule, UpdateOrCreateTransactionRuleAction +from apps.rules.tasks import check_for_transaction_rules +from apps.transactions.models import Transaction + + +def run_check_for_transaction_rules_without_worker_wrapper(**kwargs): + task_func = check_for_transaction_rules.func + task_func = getattr(task_func, "__wrapped__", task_func) + + return task_func(**kwargs) + + +class CheckForTransactionRulesTests(TransactionTestCase): + def setUp(self): + User = get_user_model() + self.user = User.objects.create_user( + email="rules@example.com", + password="testpass123", + ) + self.currency = Currency.objects.create( + code="USD", + name="US Dollar", + decimal_places=2, + ) + self.account = Account.objects.create( + name="Main Account", + currency=self.currency, + owner=self.user, + ) + + @patch("apps.rules.signals.check_for_transaction_rules.defer") + def test_update_or_create_action_can_clear_category_from_none_expression( + self, mock_defer + ): + source_transaction = Transaction.objects.create( + account=self.account, + type=Transaction.Type.EXPENSE, + amount=Decimal("10.00"), + date=date(2026, 5, 4), + reference_date=date(2026, 5, 1), + description="Source without category", + category=None, + owner=self.user, + ) + rule = TransactionRule.objects.create( + active=True, + on_create=False, + on_update=True, + name="Copy transaction", + trigger="True", + owner=self.user, + ) + UpdateOrCreateTransactionRuleAction.objects.create( + rule=rule, + set_account="account_id", + set_type="'EX'", + set_date="date", + set_reference_date="reference_date", + set_amount="amount", + set_description="'Generated transaction'", + set_category="category_name", + ) + + run_check_for_transaction_rules_without_worker_wrapper( + instance_id=source_transaction.id, + user_id=self.user.id, + signal="transaction_updated", + ) + + generated_transaction = Transaction.objects.get( + description="Generated transaction" + ) + self.assertIsNone(generated_transaction.category)