Merge remote-tracking branch 'origin/main'

This commit is contained in:
Weblate
2026-05-02 19:30:37 +00:00
9 changed files with 229 additions and 11 deletions

View File

@@ -90,10 +90,10 @@ class AccountBalanceAPITests(TestCase):
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_get_balance_unauthenticated(self):
"""Test unauthenticated request returns 403"""
"""Test unauthenticated request returns 401"""
unauthenticated_client = APIClient()
response = unauthenticated_client.get(
f"/api/accounts/{self.account.id}/balance/"
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

View File

@@ -159,7 +159,7 @@ column_mapping:
self.assertIn("import_run_id", response.data)
def test_unauthenticated_request(self):
"""Test unauthenticated request returns 403"""
"""Test unauthenticated request returns 401"""
unauthenticated_client = APIClient()
csv_content = b"date,description,amount\n2025-01-01,Test,100"
@@ -173,7 +173,7 @@ column_mapping:
format="multipart",
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@override_settings(
@@ -266,11 +266,11 @@ column_mapping:
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_profiles_unauthenticated(self):
"""Test unauthenticated request returns 403"""
"""Test unauthenticated request returns 401"""
unauthenticated_client = APIClient()
response = unauthenticated_client.get("/api/import/profiles/")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@override_settings(
@@ -397,8 +397,8 @@ column_mapping:
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_runs_unauthenticated(self):
"""Test unauthenticated request returns 403"""
"""Test unauthenticated request returns 401"""
unauthenticated_client = APIClient()
response = unauthenticated_client.get("/api/import/runs/")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

View File

@@ -1,6 +1,47 @@
import functools
import inspect
import procrastinate
from django.db import close_old_connections
_CONNECTION_CLEANUP_WRAPPED = "_wygiwyh_connection_cleanup_wrapped"
def _wrap_task_with_django_connection_cleanup(task):
if getattr(task.func, _CONNECTION_CLEANUP_WRAPPED, False):
return
func = task.func
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapped(*args, **kwargs):
close_old_connections()
try:
return await func(*args, **kwargs)
finally:
close_old_connections()
wrapped = async_wrapped
else:
@functools.wraps(func)
def sync_wrapped(*args, **kwargs):
close_old_connections()
try:
return func(*args, **kwargs)
finally:
close_old_connections()
wrapped = sync_wrapped
setattr(wrapped, _CONNECTION_CLEANUP_WRAPPED, True)
task.func = wrapped
def on_app_ready(app: procrastinate.App):
"""This function is ran upon procrastinate initialization."""
...
for task in set(app.tasks.values()):
_wrap_task_with_django_connection_cleanup(task)

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,89 @@
from unittest.mock import patch
import procrastinate
from django.db import connection
from django.test import SimpleTestCase, TransactionTestCase
from procrastinate.testing import InMemoryConnector
from apps.common.procrastinate import on_app_ready
def make_app_with_task(func):
app = procrastinate.App(connector=InMemoryConnector())
task = app.task(name="sample_task")(func)
return app, task
class ProcrastinateConnectionCleanupTests(SimpleTestCase):
def test_app_ready_closes_old_connections_around_sync_tasks(self):
calls = []
def sample_task(value):
calls.append(("task", value))
return value * 2
app, task = make_app_with_task(sample_task)
with patch(
"apps.common.procrastinate.close_old_connections",
create=True,
side_effect=lambda: calls.append(("cleanup", None)),
):
on_app_ready(app)
result = task.func(3)
self.assertEqual(result, 6)
self.assertEqual(
calls,
[
("cleanup", None),
("task", 3),
("cleanup", None),
],
)
def test_app_ready_closes_old_connections_when_sync_task_raises(self):
calls = []
def sample_task():
calls.append(("task", None))
raise RuntimeError("boom")
app, task = make_app_with_task(sample_task)
with patch(
"apps.common.procrastinate.close_old_connections",
create=True,
side_effect=lambda: calls.append(("cleanup", None)),
):
on_app_ready(app)
with self.assertRaises(RuntimeError):
task.func()
self.assertEqual(
calls,
[
("cleanup", None),
("task", None),
("cleanup", None),
],
)
class ProcrastinateConnectionRecoveryTests(TransactionTestCase):
def test_wrapped_task_recovers_from_closed_django_connection(self):
def sample_task():
with connection.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0]
app, task = make_app_with_task(sample_task)
on_app_ready(app)
connection.ensure_connection()
connection.connection.close()
self.assertEqual(task.func(), 1)

View File

@@ -365,7 +365,9 @@ def check_for_transaction_rules(
if processed_action.set_category:
value = simple.eval(processed_action.set_category)
if isinstance(value, int):
if value is None:
transaction.category = None
elif isinstance(value, int):
transaction.category = TransactionCategory.objects.get(id=value)
else:
transaction.category = TransactionCategory.objects.get(name=value)
@@ -458,7 +460,9 @@ def check_for_transaction_rules(
transaction.account = account
elif field == TransactionRuleAction.Field.category:
if isinstance(new_value, int):
if new_value is None:
transaction.category = None
elif isinstance(new_value, int):
category = TransactionCategory.objects.get(id=new_value)
transaction.category = category
elif isinstance(new_value, str):

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,82 @@
from datetime import date
from decimal import Decimal
from unittest.mock import patch
from django.contrib.auth import get_user_model
from django.test import TransactionTestCase
from apps.accounts.models import Account
from apps.currencies.models import Currency
from apps.rules.models import TransactionRule, UpdateOrCreateTransactionRuleAction
from apps.rules.tasks import check_for_transaction_rules
from apps.transactions.models import Transaction
def run_check_for_transaction_rules_without_worker_wrapper(**kwargs):
task_func = check_for_transaction_rules.func
task_func = getattr(task_func, "__wrapped__", task_func)
return task_func(**kwargs)
class CheckForTransactionRulesTests(TransactionTestCase):
def setUp(self):
User = get_user_model()
self.user = User.objects.create_user(
email="rules@example.com",
password="testpass123",
)
self.currency = Currency.objects.create(
code="USD",
name="US Dollar",
decimal_places=2,
)
self.account = Account.objects.create(
name="Main Account",
currency=self.currency,
owner=self.user,
)
@patch("apps.rules.signals.check_for_transaction_rules.defer")
def test_update_or_create_action_can_clear_category_from_none_expression(
self, mock_defer
):
source_transaction = Transaction.objects.create(
account=self.account,
type=Transaction.Type.EXPENSE,
amount=Decimal("10.00"),
date=date(2026, 5, 4),
reference_date=date(2026, 5, 1),
description="Source without category",
category=None,
owner=self.user,
)
rule = TransactionRule.objects.create(
active=True,
on_create=False,
on_update=True,
name="Copy transaction",
trigger="True",
owner=self.user,
)
UpdateOrCreateTransactionRuleAction.objects.create(
rule=rule,
set_account="account_id",
set_type="'EX'",
set_date="date",
set_reference_date="reference_date",
set_amount="amount",
set_description="'Generated transaction'",
set_category="category_name",
)
run_check_for_transaction_rules_without_worker_wrapper(
instance_id=source_transaction.id,
user_id=self.user.id,
signal="transaction_updated",
)
generated_transaction = Transaction.objects.get(
description="Generated transaction"
)
self.assertIsNone(generated_transaction.category)