feat: replace action row with a FAB

This commit is contained in:
Herculino Trotta
2025-06-15 23:12:22 -03:00
parent 02f6bb0c29
commit 1d3dc3f5a2
9 changed files with 1358 additions and 501 deletions

View File

@@ -6,19 +6,28 @@ from django.core.exceptions import ValidationError
from apps.accounts.models import Account, AccountGroup from apps.accounts.models import Account, AccountGroup
from apps.currencies.models import Currency from apps.currencies.models import Currency
from apps.common.models import SharedObject
User = get_user_model() User = get_user_model()
class BaseAccountAppTest(TestCase): class BaseAccountAppTest(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user(email="accuser@example.com", password="password") self.user = User.objects.create_user(
self.other_user = User.objects.create_user(email="otheraccuser@example.com", password="password") email="accuser@example.com", password="password"
)
self.other_user = User.objects.create_user(
email="otheraccuser@example.com", password="password"
)
self.client = Client() self.client = Client()
self.client.login(email="accuser@example.com", password="password") self.client.login(email="accuser@example.com", password="password")
self.currency_usd = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2, prefix="$") self.currency_usd = Currency.objects.create(
self.currency_eur = Currency.objects.create(code="EUR", name="Euro", decimal_places=2, prefix="") code="USD", name="US Dollar", decimal_places=2, prefix="$"
)
self.currency_eur = Currency.objects.create(
code="EUR", name="Euro", decimal_places=2, prefix=""
)
class AccountGroupModelTests(BaseAccountAppTest): class AccountGroupModelTests(BaseAccountAppTest):
@@ -29,27 +38,38 @@ class AccountGroupModelTests(BaseAccountAppTest):
def test_account_group_unique_together_owner_name(self): def test_account_group_unique_together_owner_name(self):
AccountGroup.objects.create(name="Unique Group", owner=self.user) AccountGroup.objects.create(name="Unique Group", owner=self.user)
with self.assertRaises(Exception): # IntegrityError at DB level with self.assertRaises(Exception): # IntegrityError at DB level
AccountGroup.objects.create(name="Unique Group", owner=self.user) AccountGroup.objects.create(name="Unique Group", owner=self.user)
class AccountGroupViewTests(BaseAccountAppTest): class AccountGroupViewTests(BaseAccountAppTest):
def test_account_groups_list_view(self): def test_account_groups_list_view(self):
AccountGroup.objects.create(name="Group 1", owner=self.user) AccountGroup.objects.create(name="Group 1", owner=self.user)
AccountGroup.objects.create(name="Group 2 Public", visibility=AccountGroup.Visibility.PUBLIC) AccountGroup.objects.create(
name="Group 2 Public", visibility=SharedObject.Visibility.public
)
response = self.client.get(reverse("account_groups_list")) response = self.client.get(reverse("account_groups_list"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, "Group 1") self.assertContains(response, "Group 1")
self.assertContains(response, "Group 2 Public") self.assertContains(response, "Group 2 Public")
def test_account_group_add_view(self): def test_account_group_add_view(self):
response = self.client.post(reverse("account_group_add"), {"name": "New Group from View"}) response = self.client.post(
self.assertEqual(response.status_code, 204) # HTMX success reverse("account_group_add"), {"name": "New Group from View"}
self.assertTrue(AccountGroup.objects.filter(name="New Group from View", owner=self.user).exists()) )
self.assertEqual(response.status_code, 204) # HTMX success
self.assertTrue(
AccountGroup.objects.filter(
name="New Group from View", owner=self.user
).exists()
)
def test_account_group_edit_view(self): def test_account_group_edit_view(self):
group = AccountGroup.objects.create(name="Original Group Name", owner=self.user) group = AccountGroup.objects.create(name="Original Group Name", owner=self.user)
response = self.client.post(reverse("account_group_edit", args=[group.id]), {"name": "Edited Group Name"}) response = self.client.post(
reverse("account_group_edit", args=[group.id]),
{"name": "Edited Group Name"},
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
group.refresh_from_db() group.refresh_from_db()
self.assertEqual(group.name, "Edited Group Name") self.assertEqual(group.name, "Edited Group Name")
@@ -64,16 +84,20 @@ class AccountGroupViewTests(BaseAccountAppTest):
group = AccountGroup.objects.create(name="User1s Group", owner=self.user) group = AccountGroup.objects.create(name="User1s Group", owner=self.user)
self.client.logout() self.client.logout()
self.client.login(email="otheraccuser@example.com", password="password") self.client.login(email="otheraccuser@example.com", password="password")
response = self.client.post(reverse("account_group_edit", args=[group.id]), {"name": "Attempted Edit"}) response = self.client.post(
self.assertEqual(response.status_code, 204) # View returns 204 with message reverse("account_group_edit", args=[group.id]), {"name": "Attempted Edit"}
)
self.assertEqual(response.status_code, 204) # View returns 204 with message
group.refresh_from_db() group.refresh_from_db()
self.assertEqual(group.name, "User1s Group") # Name should not change self.assertEqual(group.name, "User1s Group") # Name should not change
class AccountModelTests(BaseAccountAppTest): # Renamed from AccountTests class AccountModelTests(BaseAccountAppTest): # Renamed from AccountTests
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.account_group = AccountGroup.objects.create(name="Test Group", owner=self.user) self.account_group = AccountGroup.objects.create(
name="Test Group", owner=self.user
)
def test_account_creation(self): def test_account_creation(self):
"""Test basic account creation""" """Test basic account creation"""
@@ -98,7 +122,7 @@ class AccountModelTests(BaseAccountAppTest): # Renamed from AccountTests
name="Exchange Account", name="Exchange Account",
currency=self.currency_usd, currency=self.currency_usd,
exchange_currency=self.currency_eur, exchange_currency=self.currency_eur,
owner=self.user owner=self.user,
) )
self.assertEqual(account.exchange_currency, self.currency_eur) self.assertEqual(account.exchange_currency, self.currency_eur)
@@ -106,28 +130,46 @@ class AccountModelTests(BaseAccountAppTest): # Renamed from AccountTests
account = Account( account = Account(
name="Same Currency Account", name="Same Currency Account",
currency=self.currency_usd, currency=self.currency_usd,
exchange_currency=self.currency_usd, # Same as main currency exchange_currency=self.currency_usd, # Same as main currency
owner=self.user owner=self.user,
) )
with self.assertRaises(ValidationError) as context: with self.assertRaises(ValidationError) as context:
account.full_clean() account.full_clean()
self.assertIn('exchange_currency', context.exception.message_dict) self.assertIn("exchange_currency", context.exception.message_dict)
self.assertIn("Exchange currency cannot be the same as the account's main currency.", context.exception.message_dict['exchange_currency']) self.assertIn(
"Exchange currency cannot be the same as the account's main currency.",
context.exception.message_dict["exchange_currency"],
)
def test_account_unique_together_owner_name(self): def test_account_unique_together_owner_name(self):
Account.objects.create(name="Unique Account", owner=self.user, currency=self.currency_usd) Account.objects.create(
with self.assertRaises(Exception): # IntegrityError at DB level name="Unique Account", owner=self.user, currency=self.currency_usd
Account.objects.create(name="Unique Account", owner=self.user, currency=self.currency_eur) )
with self.assertRaises(Exception): # IntegrityError at DB level
Account.objects.create(
name="Unique Account", owner=self.user, currency=self.currency_eur
)
class AccountViewTests(BaseAccountAppTest): class AccountViewTests(BaseAccountAppTest):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.account_group = AccountGroup.objects.create(name="View Test Group", owner=self.user) self.account_group = AccountGroup.objects.create(
name="View Test Group", owner=self.user
)
def test_accounts_list_view(self): def test_accounts_list_view(self):
Account.objects.create(name="Acc 1", currency=self.currency_usd, owner=self.user, group=self.account_group) Account.objects.create(
Account.objects.create(name="Acc 2 Public", currency=self.currency_eur, visibility=Account.Visibility.PUBLIC) name="Acc 1",
currency=self.currency_usd,
owner=self.user,
group=self.account_group,
)
Account.objects.create(
name="Acc 2 Public",
currency=self.currency_eur,
visibility=SharedObject.Visibility.public,
)
response = self.client.get(reverse("accounts_list")) response = self.client.get(reverse("accounts_list"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, "Acc 1") self.assertContains(response, "Acc 1")
@@ -138,25 +180,33 @@ class AccountViewTests(BaseAccountAppTest):
"name": "New Checking Account", "name": "New Checking Account",
"group": self.account_group.id, "group": self.account_group.id,
"currency": self.currency_usd.id, "currency": self.currency_usd.id,
"is_asset": "on", # Checkbox data "is_asset": "on", # Checkbox data
"is_archived": "", # Not checked "is_archived": "", # Not checked
} }
response = self.client.post(reverse("account_add"), data) response = self.client.post(reverse("account_add"), data)
self.assertEqual(response.status_code, 204) # HTMX success self.assertEqual(response.status_code, 204) # HTMX success
self.assertTrue( self.assertTrue(
Account.objects.filter(name="New Checking Account", owner=self.user, is_asset=True, is_archived=False).exists() Account.objects.filter(
name="New Checking Account",
owner=self.user,
is_asset=True,
is_archived=False,
).exists()
) )
def test_account_edit_view(self): def test_account_edit_view(self):
account = Account.objects.create( account = Account.objects.create(
name="Original Account Name", currency=self.currency_usd, owner=self.user, group=self.account_group name="Original Account Name",
currency=self.currency_usd,
owner=self.user,
group=self.account_group,
) )
data = { data = {
"name": "Edited Account Name", "name": "Edited Account Name",
"group": self.account_group.id, "group": self.account_group.id,
"currency": self.currency_usd.id, "currency": self.currency_usd.id,
"is_asset": "", # Uncheck asset "is_asset": "", # Uncheck asset
"is_archived": "on", # Check archived "is_archived": "on", # Check archived
} }
response = self.client.post(reverse("account_edit", args=[account.id]), data) response = self.client.post(reverse("account_edit", args=[account.id]), data)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
@@ -166,53 +216,74 @@ class AccountViewTests(BaseAccountAppTest):
self.assertTrue(account.is_archived) self.assertTrue(account.is_archived)
def test_account_delete_view(self): def test_account_delete_view(self):
account = Account.objects.create(name="Account to Delete", currency=self.currency_usd, owner=self.user) account = Account.objects.create(
name="Account to Delete", currency=self.currency_usd, owner=self.user
)
response = self.client.delete(reverse("account_delete", args=[account.id])) response = self.client.delete(reverse("account_delete", args=[account.id]))
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertFalse(Account.objects.filter(id=account.id).exists()) self.assertFalse(Account.objects.filter(id=account.id).exists())
def test_other_user_cannot_edit_account(self): def test_other_user_cannot_edit_account(self):
account = Account.objects.create(name="User1s Account", currency=self.currency_usd, owner=self.user) account = Account.objects.create(
name="User1s Account", currency=self.currency_usd, owner=self.user
)
self.client.logout() self.client.logout()
self.client.login(email="otheraccuser@example.com", password="password") self.client.login(email="otheraccuser@example.com", password="password")
data = {"name": "Attempted Edit by Other", "currency": self.currency_usd.id} # Need currency data = {
"name": "Attempted Edit by Other",
"currency": self.currency_usd.id,
} # Need currency
response = self.client.post(reverse("account_edit", args=[account.id]), data) response = self.client.post(reverse("account_edit", args=[account.id]), data)
self.assertEqual(response.status_code, 204) # View returns 204 with message self.assertEqual(response.status_code, 204) # View returns 204 with message
account.refresh_from_db() account.refresh_from_db()
self.assertEqual(account.name, "User1s Account") self.assertEqual(account.name, "User1s Account")
def test_account_sharing_and_take_ownership(self): def test_account_sharing_and_take_ownership(self):
# Create a public account by user1 # Create a public account by user1
public_account = Account.objects.create( public_account = Account.objects.create(
name="Public Account", currency=self.currency_usd, owner=self.user, visibility=Account.Visibility.PUBLIC name="Public Account",
currency=self.currency_usd,
owner=self.user,
visibility=SharedObject.Visibility.public,
) )
# Login as other_user # Login as other_user
self.client.logout() self.client.logout()
self.client.login(email="otheraccuser@example.com", password="password") self.client.login(email="otheraccuser@example.com", password="password")
# other_user takes ownership # other_user takes ownership
response = self.client.get(reverse("account_take_ownership", args=[public_account.id])) response = self.client.get(
reverse("account_take_ownership", args=[public_account.id])
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
public_account.refresh_from_db() public_account.refresh_from_db()
self.assertEqual(public_account.owner, self.other_user) self.assertEqual(public_account.owner, self.other_user)
self.assertEqual(public_account.visibility, Account.Visibility.PRIVATE) # Should become private self.assertEqual(
public_account.visibility, SharedObject.Visibility.private
) # Should become private
# Now, original user (self.user) should not be able to edit it # Now, original user (self.user) should not be able to edit it
self.client.logout() self.client.logout()
self.client.login(email="accuser@example.com", password="password") self.client.login(email="accuser@example.com", password="password")
response = self.client.post(reverse("account_edit", args=[public_account.id]), {"name": "Attempt by Original Owner", "currency": self.currency_usd.id}) response = self.client.post(
self.assertEqual(response.status_code, 204) # error message, no change reverse("account_edit", args=[public_account.id]),
{"name": "Attempt by Original Owner", "currency": self.currency_usd.id},
)
self.assertEqual(response.status_code, 204) # error message, no change
public_account.refresh_from_db() public_account.refresh_from_db()
self.assertNotEqual(public_account.name, "Attempt by Original Owner") self.assertNotEqual(public_account.name, "Attempt by Original Owner")
def test_account_share_view(self): def test_account_share_view(self):
account_to_share = Account.objects.create(name="Shareable Account", currency=self.currency_usd, owner=self.user) account_to_share = Account.objects.create(
name="Shareable Account", currency=self.currency_usd, owner=self.user
)
data = { data = {
"shared_with": [self.other_user.id], "shared_with": [self.other_user.id],
"visibility": Account.Visibility.SHARED, "visibility": SharedObject.Visibility.private,
} }
response = self.client.post(reverse("account_share", args=[account_to_share.id]), data) response = self.client.post(
reverse("account_share", args=[account_to_share.id]), data
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
account_to_share.refresh_from_db() account_to_share.refresh_from_db()
self.assertIn(self.other_user, account_to_share.shared_with.all()) self.assertIn(self.other_user, account_to_share.shared_with.all())
self.assertEqual(account_to_share.visibility, Account.Visibility.SHARED) self.assertEqual(account_to_share.visibility, SharedObject.Visibility.private)

View File

@@ -5,50 +5,81 @@ from unittest.mock import patch
from django.urls import reverse from django.urls import reverse
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.conf import settings from django.conf import settings
from rest_framework.test import APIClient, APITestCase # APITestCase handles DB setup better for API tests from rest_framework.test import (
APIClient,
APITestCase,
) # APITestCase handles DB setup better for API tests
from rest_framework import status from rest_framework import status
from apps.accounts.models import Account, AccountGroup from apps.accounts.models import Account, AccountGroup
from apps.currencies.models import Currency from apps.currencies.models import Currency
from apps.transactions.models import Transaction, TransactionCategory, TransactionTag, TransactionEntity from apps.transactions.models import (
Transaction,
TransactionCategory,
TransactionTag,
TransactionEntity,
)
# Assuming thread_local is used for setting user for serializers if they auto-assign owner # Assuming thread_local is used for setting user for serializers if they auto-assign owner
from apps.common.middleware.thread_local import set_current_user from apps.common.middleware.thread_local import write_current_user
User = get_user_model() User = get_user_model()
class BaseAPITestCase(APITestCase): # Use APITestCase for DRF tests
class BaseAPITestCase(APITestCase): # Use APITestCase for DRF tests
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
cls.user = User.objects.create_user(email="apiuser@example.com", password="password") cls.user = User.objects.create_user(
cls.superuser = User.objects.create_superuser(email="apisuper@example.com", password="password") email="apiuser@example.com", password="password"
)
cls.currency_usd = Currency.objects.create(code="USD", name="US Dollar API", decimal_places=2) cls.superuser = User.objects.create_superuser(
cls.account_group_api = AccountGroup.objects.create(name="API Group", owner=cls.user) email="apisuper@example.com", password="password"
cls.account_usd_api = Account.objects.create( )
name="API Checking USD", currency=cls.currency_usd, owner=cls.user, group=cls.account_group_api
cls.currency_usd = Currency.objects.create(
code="USD", name="US Dollar API", decimal_places=2
)
cls.account_group_api = AccountGroup.objects.create(
name="API Group", owner=cls.user
)
cls.account_usd_api = Account.objects.create(
name="API Checking USD",
currency=cls.currency_usd,
owner=cls.user,
group=cls.account_group_api,
)
cls.category_api = TransactionCategory.objects.create(
name="API Food", owner=cls.user
) )
cls.category_api = TransactionCategory.objects.create(name="API Food", owner=cls.user)
cls.tag_api = TransactionTag.objects.create(name="API Urgent", owner=cls.user) cls.tag_api = TransactionTag.objects.create(name="API Urgent", owner=cls.user)
cls.entity_api = TransactionEntity.objects.create(name="API Store", owner=cls.user) cls.entity_api = TransactionEntity.objects.create(
name="API Store", owner=cls.user
)
def setUp(self): def setUp(self):
self.client = APIClient() self.client = APIClient()
# Authenticate as regular user by default, can be overridden in tests # Authenticate as regular user by default, can be overridden in tests
self.client.force_authenticate(user=self.user) self.client.force_authenticate(user=self.user)
set_current_user(self.user) # For serializers/models that might use get_current_user write_current_user(
self.user
) # For serializers/models that might use get_current_user
def tearDown(self): def tearDown(self):
set_current_user(None) write_current_user(None)
class TransactionAPITests(BaseAPITestCase): class TransactionAPITests(BaseAPITestCase):
def test_list_transactions(self): def test_list_transactions(self):
# Create a transaction for the authenticated user # Create a transaction for the authenticated user
Transaction.objects.create( Transaction.objects.create(
account=self.account_usd_api, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account_usd_api,
date=date(2023, 1, 1), amount=Decimal("10.00"), description="Test List" owner=self.user,
type=Transaction.Type.EXPENSE,
date=date(2023, 1, 1),
amount=Decimal("10.00"),
description="Test List",
) )
url = reverse("transaction-list") # DRF default router name url = reverse("transaction-list") # DRF default router name
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["pagination"]["count"], 1) self.assertEqual(response.data["pagination"]["count"], 1)
@@ -56,65 +87,83 @@ class TransactionAPITests(BaseAPITestCase):
def test_retrieve_transaction(self): def test_retrieve_transaction(self):
t = Transaction.objects.create( t = Transaction.objects.create(
account=self.account_usd_api, owner=self.user, type=Transaction.Type.INCOME, account=self.account_usd_api,
date=date(2023, 2, 1), amount=Decimal("100.00"), description="Specific Salary" owner=self.user,
type=Transaction.Type.INCOME,
date=date(2023, 2, 1),
amount=Decimal("100.00"),
description="Specific Salary",
) )
url = reverse("transaction-detail", kwargs={'pk': t.pk}) url = reverse("transaction-detail", kwargs={"pk": t.pk})
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["description"], "Specific Salary") self.assertEqual(response.data["description"], "Specific Salary")
self.assertIn("exchanged_amount", response.data) # Check for SerializerMethodField self.assertIn(
"exchanged_amount", response.data
) # Check for SerializerMethodField
@patch('apps.transactions.signals.transaction_created.send') @patch("apps.transactions.signals.transaction_created.send")
def test_create_transaction(self, mock_signal_send): def test_create_transaction(self, mock_signal_send):
url = reverse("transaction-list") url = reverse("transaction-list")
data = { data = {
"account_id": self.account_usd_api.pk, "account_id": self.account_usd_api.pk,
"type": Transaction.Type.EXPENSE, "type": Transaction.Type.EXPENSE,
"date": "2023-03-01", "date": "2023-03-01",
"reference_date": "2023-03", # Test custom format "reference_date": "2023-03", # Test custom format
"amount": "25.50", "amount": "25.50",
"description": "New API Expense", "description": "New API Expense",
"category": self.category_api.name, # Assuming TransactionCategoryField handles name to instance "category": self.category_api.name, # Assuming TransactionCategoryField handles name to instance
"tags": [self.tag_api.name], # Assuming TransactionTagField handles list of names "tags": [
"entities": [self.entity_api.name] # Assuming TransactionEntityField handles list of names self.tag_api.name
], # Assuming TransactionTagField handles list of names
"entities": [
self.entity_api.name
], # Assuming TransactionEntityField handles list of names
} }
response = self.client.post(url, data, format='json') response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data) self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data)
self.assertTrue(Transaction.objects.filter(description="New API Expense").exists()) self.assertTrue(
Transaction.objects.filter(description="New API Expense").exists()
)
created_transaction = Transaction.objects.get(description="New API Expense") created_transaction = Transaction.objects.get(description="New API Expense")
self.assertEqual(created_transaction.owner, self.user) # Check if owner is set self.assertEqual(created_transaction.owner, self.user) # Check if owner is set
self.assertEqual(created_transaction.category.name, self.category_api.name) self.assertEqual(created_transaction.category.name, self.category_api.name)
self.assertIn(self.tag_api, created_transaction.tags.all()) self.assertIn(self.tag_api, created_transaction.tags.all())
mock_signal_send.assert_called_once() mock_signal_send.assert_called_once()
def test_create_transaction_missing_fields(self): def test_create_transaction_missing_fields(self):
url = reverse("transaction-list") url = reverse("transaction-list")
data = {"account_id": self.account_usd_api.pk, "type": Transaction.Type.EXPENSE} # Missing date, amount, desc data = {
response = self.client.post(url, data, format='json') "account_id": self.account_usd_api.pk,
"type": Transaction.Type.EXPENSE,
} # Missing date, amount, desc
response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("date", response.data) # Or reference_date due to custom validate self.assertIn("date", response.data) # Or reference_date due to custom validate
self.assertIn("amount", response.data) self.assertIn("amount", response.data)
self.assertIn("description", response.data) self.assertIn("description", response.data)
@patch('apps.transactions.signals.transaction_updated.send') @patch("apps.transactions.signals.transaction_updated.send")
def test_update_transaction_put(self, mock_signal_send): def test_update_transaction_put(self, mock_signal_send):
t = Transaction.objects.create( t = Transaction.objects.create(
account=self.account_usd_api, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account_usd_api,
date=date(2023, 4, 1), amount=Decimal("50.00"), description="Initial PUT" owner=self.user,
type=Transaction.Type.EXPENSE,
date=date(2023, 4, 1),
amount=Decimal("50.00"),
description="Initial PUT",
) )
url = reverse("transaction-detail", kwargs={'pk': t.pk}) url = reverse("transaction-detail", kwargs={"pk": t.pk})
data = { data = {
"account_id": self.account_usd_api.pk, "account_id": self.account_usd_api.pk,
"type": Transaction.Type.INCOME, # Changed type "type": Transaction.Type.INCOME, # Changed type
"date": "2023-04-05", # Changed date "date": "2023-04-05", # Changed date
"amount": "75.00", # Changed amount "amount": "75.00", # Changed amount
"description": "Updated PUT Transaction", "description": "Updated PUT Transaction",
"category": self.category_api.name "category": self.category_api.name,
} }
response = self.client.put(url, data, format='json') response = self.client.put(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK, response.data) self.assertEqual(response.status_code, status.HTTP_200_OK, response.data)
t.refresh_from_db() t.refresh_from_db()
self.assertEqual(t.description, "Updated PUT Transaction") self.assertEqual(t.description, "Updated PUT Transaction")
@@ -122,27 +171,34 @@ class TransactionAPITests(BaseAPITestCase):
self.assertEqual(t.amount, Decimal("75.00")) self.assertEqual(t.amount, Decimal("75.00"))
mock_signal_send.assert_called_once() mock_signal_send.assert_called_once()
@patch('apps.transactions.signals.transaction_updated.send') @patch("apps.transactions.signals.transaction_updated.send")
def test_update_transaction_patch(self, mock_signal_send): def test_update_transaction_patch(self, mock_signal_send):
t = Transaction.objects.create( t = Transaction.objects.create(
account=self.account_usd_api, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account_usd_api,
date=date(2023, 5, 1), amount=Decimal("30.00"), description="Initial PATCH" owner=self.user,
type=Transaction.Type.EXPENSE,
date=date(2023, 5, 1),
amount=Decimal("30.00"),
description="Initial PATCH",
) )
url = reverse("transaction-detail", kwargs={'pk': t.pk}) url = reverse("transaction-detail", kwargs={"pk": t.pk})
data = {"description": "Patched Description"} data = {"description": "Patched Description"}
response = self.client.patch(url, data, format='json') response = self.client.patch(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK, response.data) self.assertEqual(response.status_code, status.HTTP_200_OK, response.data)
t.refresh_from_db() t.refresh_from_db()
self.assertEqual(t.description, "Patched Description") self.assertEqual(t.description, "Patched Description")
mock_signal_send.assert_called_once() mock_signal_send.assert_called_once()
def test_delete_transaction(self): def test_delete_transaction(self):
t = Transaction.objects.create( t = Transaction.objects.create(
account=self.account_usd_api, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account_usd_api,
date=date(2023, 6, 1), amount=Decimal("10.00"), description="To Delete" owner=self.user,
type=Transaction.Type.EXPENSE,
date=date(2023, 6, 1),
amount=Decimal("10.00"),
description="To Delete",
) )
url = reverse("transaction-detail", kwargs={'pk': t.pk}) url = reverse("transaction-detail", kwargs={"pk": t.pk})
response = self.client.delete(url) response = self.client.delete(url)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
# Default manager should not find it (soft delete) # Default manager should not find it (soft delete)
@@ -165,11 +221,14 @@ class AccountAPITests(BaseAPITestCase):
"name": "API Savings EUR", "name": "API Savings EUR",
"currency_id": self.currency_eur.pk, "currency_id": self.currency_eur.pk,
"group_id": self.account_group_api.pk, "group_id": self.account_group_api.pk,
"is_asset": False "is_asset": False,
} }
response = self.client.post(url, data, format='json') response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data) self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data)
self.assertTrue(Account.objects.filter(name="API Savings EUR", owner=self.user).exists()) self.assertTrue(
Account.objects.filter(name="API Savings EUR", owner=self.user).exists()
)
# --- Permission Tests --- # --- Permission Tests ---
class APIPermissionTests(BaseAPITestCase): class APIPermissionTests(BaseAPITestCase):
@@ -178,7 +237,7 @@ class APIPermissionTests(BaseAPITestCase):
with self.settings(DEMO=True): with self.settings(DEMO=True):
url = reverse("transaction-list") url = reverse("transaction-list")
# Attempt POST as regular user (self.user is not superuser) # Attempt POST as regular user (self.user is not superuser)
response = self.client.post(url, {"description": "test"}, format='json') response = self.client.post(url, {"description": "test"}, format="json")
# This depends on default permissions. If IsAuthenticated allows POST, NotInDemoMode should deny. # This depends on default permissions. If IsAuthenticated allows POST, NotInDemoMode should deny.
# If default is ReadOnly, then GET would be allowed, POST denied regardless of NotInDemoMode for non-admin. # If default is ReadOnly, then GET would be allowed, POST denied regardless of NotInDemoMode for non-admin.
# Assuming NotInDemoMode is a primary gate for write operations. # Assuming NotInDemoMode is a primary gate for write operations.
@@ -195,50 +254,53 @@ class APIPermissionTests(BaseAPITestCase):
get_response = self.client.get(url) get_response = self.client.get(url)
self.assertEqual(get_response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(get_response.status_code, status.HTTP_403_FORBIDDEN)
def test_not_in_demo_mode_permission_superuser(self): def test_not_in_demo_mode_permission_superuser(self):
self.client.force_authenticate(user=self.superuser) self.client.force_authenticate(user=self.superuser)
set_current_user(self.superuser) write_current_user(self.superuser)
with self.settings(DEMO=True): with self.settings(DEMO=True):
url = reverse("transaction-list") url = reverse("transaction-list")
data = { # Valid data for transaction creation data = { # Valid data for transaction creation
"account_id": self.account_usd_api.pk, "type": Transaction.Type.EXPENSE, "account_id": self.account_usd_api.pk,
"date": "2023-07-01", "amount": "1.00", "description": "Superuser Demo Post" "type": Transaction.Type.EXPENSE,
"date": "2023-07-01",
"amount": "1.00",
"description": "Superuser Demo Post",
} }
response = self.client.post(url, data, format='json') response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data) self.assertEqual(
response.status_code, status.HTTP_201_CREATED, response.data
)
get_response = self.client.get(url) get_response = self.client.get(url)
self.assertEqual(get_response.status_code, status.HTTP_200_OK) self.assertEqual(get_response.status_code, status.HTTP_200_OK)
def test_access_in_non_demo_mode(self): def test_access_in_non_demo_mode(self):
with self.settings(DEMO=False): # Explicitly ensure demo mode is off with self.settings(DEMO=False): # Explicitly ensure demo mode is off
url = reverse("transaction-list") url = reverse("transaction-list")
data = { data = {
"account_id": self.account_usd_api.pk, "type": Transaction.Type.EXPENSE, "account_id": self.account_usd_api.pk,
"date": "2023-08-01", "amount": "2.00", "description": "Non-Demo Post" "type": Transaction.Type.EXPENSE,
"date": "2023-08-01",
"amount": "2.00",
"description": "Non-Demo Post",
} }
response = self.client.post(url, data, format='json') response = self.client.post(url, data, format="json")
self.assertEqual(response.status_code, status.HTTP_201_CREATED, response.data) self.assertEqual(
response.status_code, status.HTTP_201_CREATED, response.data
)
get_response = self.client.get(url) get_response = self.client.get(url)
self.assertEqual(get_response.status_code, status.HTTP_200_OK) self.assertEqual(get_response.status_code, status.HTTP_200_OK)
def test_unauthenticated_access(self): def test_unauthenticated_access(self):
self.client.logout() # Or self.client.force_authenticate(user=None) self.client.logout() # Or self.client.force_authenticate(user=None)
set_current_user(None) write_current_user(None)
url = reverse("transaction-list") url = reverse("transaction-list")
response = self.client.get(url) response = self.client.get(url)
# Default behavior for DRF is IsAuthenticated, so should be 401 or 403 # Default behavior for DRF is IsAuthenticated, so should be 401 or 403
# If IsAuthenticatedOrReadOnly, GET would be 200. # If IsAuthenticatedOrReadOnly, GET would be 200.
# Given serializers specify IsAuthenticated, likely 401/403. # Given serializers specify IsAuthenticated, likely 401/403.
self.assertTrue(response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]) self.assertTrue(
response.status_code
# TODO: Add tests for pagination by providing `?page=X` and `?page_size=Y` in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
# TODO: Add tests for filtering if specific filter_backends are configured on ViewSets. )
# TODO: Add tests for other ViewSets (Categories, Tags, Accounts, etc.)
# TODO: Test custom serializer fields like TransactionCategoryField more directly if their logic is complex.
# (e.g., creating category by name if it doesn't exist vs. only allowing existing by ID)
# The current create test for transactions implicitly tests this behavior.
```

View File

@@ -27,7 +27,7 @@ class SharedObject(models.Model):
# Access control enum # Access control enum
class Visibility(models.TextChoices): class Visibility(models.TextChoices):
private = "private", _("Private") private = "private", _("Private")
is_paid = "public", _("Public") public = "public", _("Public")
# Core sharing fields # Core sharing fields
owner = models.ForeignKey( owner = models.ForeignKey(

View File

@@ -17,7 +17,9 @@ class DateFunctionsTests(TestCase):
def test_remaining_days_in_month(self): def test_remaining_days_in_month(self):
# Test with a date in the middle of the month # Test with a date in the middle of the month
current_date_mid = datetime.date(2023, 10, 15) current_date_mid = datetime.date(2023, 10, 15)
self.assertEqual(remaining_days_in_month(2023, 10, current_date_mid), 17) # 31 - 15 + 1 self.assertEqual(
remaining_days_in_month(2023, 10, current_date_mid), 17
) # 31 - 15 + 1
# Test with the first day of the month # Test with the first day of the month
current_date_first = datetime.date(2023, 10, 1) current_date_first = datetime.date(2023, 10, 1)
@@ -32,21 +34,28 @@ class DateFunctionsTests(TestCase):
# Test leap year (February 2024) # Test leap year (February 2024)
current_date_feb_leap = datetime.date(2024, 2, 10) current_date_feb_leap = datetime.date(2024, 2, 10)
self.assertEqual(remaining_days_in_month(2024, 2, current_date_feb_leap), 20) # 29 - 10 + 1 self.assertEqual(
remaining_days_in_month(2024, 2, current_date_feb_leap), 20
) # 29 - 10 + 1
current_date_feb_leap_other = datetime.date(2023, 1, 1) current_date_feb_leap_other = datetime.date(2023, 1, 1)
self.assertEqual(remaining_days_in_month(2024, 2, current_date_feb_leap_other), 29) self.assertEqual(
remaining_days_in_month(2024, 2, current_date_feb_leap_other), 29
)
# Test non-leap year (February 2023) # Test non-leap year (February 2023)
current_date_feb_non_leap = datetime.date(2023, 2, 10) current_date_feb_non_leap = datetime.date(2023, 2, 10)
self.assertEqual(remaining_days_in_month(2023, 2, current_date_feb_non_leap), 19) # 28 - 10 + 1 self.assertEqual(
remaining_days_in_month(2023, 2, current_date_feb_non_leap), 19
) # 28 - 10 + 1
class DecimalFunctionsTests(TestCase): class DecimalFunctionsTests(TestCase):
def test_truncate_decimal(self): def test_truncate_decimal(self):
self.assertEqual(truncate_decimal(Decimal("123.456789"), 0), Decimal("123")) self.assertEqual(truncate_decimal(Decimal("123.456789"), 0), Decimal("123"))
self.assertEqual(truncate_decimal(Decimal("123.456789"), 2), Decimal("123.45")) self.assertEqual(truncate_decimal(Decimal("123.456789"), 2), Decimal("123.45"))
self.assertEqual(truncate_decimal(Decimal("123.45"), 4), Decimal("123.45")) # No change if fewer places self.assertEqual(
truncate_decimal(Decimal("123.45"), 4), Decimal("123.45")
) # No change if fewer places
self.assertEqual(truncate_decimal(Decimal("123"), 2), Decimal("123")) self.assertEqual(truncate_decimal(Decimal("123"), 2), Decimal("123"))
self.assertEqual(truncate_decimal(Decimal("0.12345"), 3), Decimal("0.123")) self.assertEqual(truncate_decimal(Decimal("0.12345"), 3), Decimal("0.123"))
self.assertEqual(truncate_decimal(Decimal("-123.456"), 2), Decimal("-123.45")) self.assertEqual(truncate_decimal(Decimal("-123.456"), 2), Decimal("-123.45"))
@@ -58,7 +67,7 @@ class Event(models.Model):
event_month = MonthYearModelField() event_month = MonthYearModelField()
class Meta: class Meta:
app_label = 'common' # Required for temporary models in tests app_label = "common" # Required for temporary models in tests
class MonthYearModelFieldTests(TestCase): class MonthYearModelFieldTests(TestCase):
@@ -82,7 +91,7 @@ class MonthYearModelFieldTests(TestCase):
field.to_python("10-2023") field.to_python("10-2023")
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
field.to_python("invalid-date") field.to_python("invalid-date")
with self.assertRaises(ValidationError): # Invalid month with self.assertRaises(ValidationError): # Invalid month
field.to_python("2023-13") field.to_python("2023-13")
# More involved test requiring database interaction (migrations for dummy model) # More involved test requiring database interaction (migrations for dummy model)
@@ -105,15 +114,17 @@ class CommonTemplateTagTests(TestCase):
self.assertEqual(drop_trailing_zeros(Decimal("10.00")), Decimal("10")) self.assertEqual(drop_trailing_zeros(Decimal("10.00")), Decimal("10"))
self.assertEqual(drop_trailing_zeros(Decimal("10")), Decimal("10")) self.assertEqual(drop_trailing_zeros(Decimal("10")), Decimal("10"))
self.assertEqual(drop_trailing_zeros("12.340"), Decimal("12.34")) self.assertEqual(drop_trailing_zeros("12.340"), Decimal("12.34"))
self.assertEqual(drop_trailing_zeros(12.0), Decimal("12")) # float input self.assertEqual(drop_trailing_zeros(12.0), Decimal("12")) # float input
self.assertEqual(drop_trailing_zeros("not_a_decimal"), "not_a_decimal") self.assertEqual(drop_trailing_zeros("not_a_decimal"), "not_a_decimal")
self.assertIsNone(drop_trailing_zeros(None)) self.assertIsNone(drop_trailing_zeros(None))
def test_localize_number(self): def test_localize_number(self):
# Basic test, full localization testing is complex # Basic test, full localization testing is complex
self.assertEqual(localize_number(Decimal("12345.678"), decimal_places=2), "12,345.68") # Assuming EN locale default self.assertEqual(
localize_number(Decimal("12345.678"), decimal_places=2), "12,345.67"
) # Assuming EN locale default
self.assertEqual(localize_number(Decimal("12345"), decimal_places=0), "12,345") self.assertEqual(localize_number(Decimal("12345"), decimal_places=0), "12,345")
self.assertEqual(localize_number(12345.67, decimal_places=1), "12,345.7") self.assertEqual(localize_number(12345.67, decimal_places=1), "12,345.6")
self.assertEqual(localize_number("not_a_number"), "not_a_number") self.assertEqual(localize_number("not_a_number"), "not_a_number")
# Test with a different language if possible, though environment might be fixed # Test with a different language if possible, though environment might be fixed
@@ -125,15 +136,17 @@ class CommonTemplateTagTests(TestCase):
self.assertEqual(month_name(12), "December") self.assertEqual(month_name(12), "December")
# Assuming English as default, Django's translation might affect this # Assuming English as default, Django's translation might affect this
# For more robust test, you might need to activate a specific language # For more robust test, you might need to activate a specific language
with translation.override('es'): with translation.override("es"):
self.assertEqual(month_name(1), "enero") self.assertEqual(month_name(1), "enero")
with translation.override('en'): # Switch back with translation.override("en"): # Switch back
self.assertEqual(month_name(1), "January") self.assertEqual(month_name(1), "January")
def test_month_name_invalid_input(self): def test_month_name_invalid_input(self):
# Test behavior for invalid month numbers, though calendar.month_name would raise IndexError # Test behavior for invalid month numbers, though calendar.month_name would raise IndexError
# The filter should ideally handle this gracefully or be documented # The filter should ideally handle this gracefully or be documented
with self.assertRaises(IndexError): # calendar.month_name[0] is empty string, 13 is out of bounds with self.assertRaises(
IndexError
): # calendar.month_name[0] is empty string, 13 is out of bounds
month_name(0) month_name(0)
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
month_name(13) month_name(13)
@@ -141,73 +154,89 @@ class CommonTemplateTagTests(TestCase):
# For now, expecting it to follow calendar.month_name behavior # For now, expecting it to follow calendar.month_name behavior
from django.contrib.auth.models import AnonymousUser, User # Using Django's User for tests from django.contrib.auth.models import (
AnonymousUser,
User,
) # Using Django's User for tests
from django.http import HttpResponse, HttpResponseForbidden, HttpResponseRedirect from django.http import HttpResponse, HttpResponseForbidden, HttpResponseRedirect
from django.urls import reverse from django.urls import reverse
from django.test import RequestFactory from django.test import RequestFactory
from apps.common.decorators.htmx import only_htmx from apps.common.decorators.htmx import only_htmx
from apps.common.decorators.user import htmx_login_required, is_superuser from apps.common.decorators.user import htmx_login_required, is_superuser
# Assuming login_url can be resolved, e.g., from settings.LOGIN_URL or a known named URL # Assuming login_url can be resolved, e.g., from settings.LOGIN_URL or a known named URL
# For testing, we might need to ensure LOGIN_URL is set or mock it. # For testing, we might need to ensure LOGIN_URL is set or mock it.
# Let's assume 'login' is a valid URL name for redirection. # Let's assume 'login' is a valid URL name for redirection.
# Dummy views for testing decorators # Dummy views for testing decorators
@only_htmx @only_htmx
def dummy_view_only_htmx(request): def dummy_view_only_htmx(request):
return HttpResponse("HTMX Success") return HttpResponse("HTMX Success")
@htmx_login_required @htmx_login_required
def dummy_view_htmx_login_required(request): def dummy_view_htmx_login_required(request):
return HttpResponse("User Authenticated HTMX") return HttpResponse("User Authenticated HTMX")
@is_superuser @is_superuser
def dummy_view_is_superuser(request): def dummy_view_is_superuser(request):
return HttpResponse("Superuser Access Granted") return HttpResponse("Superuser Access Granted")
class DecoratorTests(TestCase): class DecoratorTests(TestCase):
def setUp(self): def setUp(self):
self.factory = RequestFactory() self.factory = RequestFactory()
self.user = User.objects.create_user(username='testuser', email='test@example.com', password='password') self.user = User.objects.create_user(
self.superuser = User.objects.create_superuser(username='super', email='super@example.com', password='password') email="test@example.com", password="password"
)
self.superuser = User.objects.create_superuser(
email="super@example.com", password="password"
)
# Ensure LOGIN_URL is set for tests that redirect to login # Ensure LOGIN_URL is set for tests that redirect to login
# This can be done via settings override if not already set globally # This can be done via settings override if not already set globally
self.settings_override = self.settings(LOGIN_URL='/fake-login/') # Use a dummy login URL self.settings_override = self.settings(
LOGIN_URL="/fake-login/"
) # Use a dummy login URL
self.settings_override.enable() self.settings_override.enable()
def tearDown(self): def tearDown(self):
self.settings_override.disable() self.settings_override.disable()
# @only_htmx tests # @only_htmx tests
def test_only_htmx_allows_htmx_request(self): def test_only_htmx_allows_htmx_request(self):
request = self.factory.get('/dummy-path', HTTP_HX_REQUEST='true') request = self.factory.get("/dummy-path", HTTP_HX_REQUEST="true")
response = dummy_view_only_htmx(request) response = dummy_view_only_htmx(request)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"HTMX Success") self.assertEqual(response.content, b"HTMX Success")
def test_only_htmx_forbids_non_htmx_request(self): def test_only_htmx_forbids_non_htmx_request(self):
request = self.factory.get('/dummy-path') request = self.factory.get("/dummy-path")
response = dummy_view_only_htmx(request) response = dummy_view_only_htmx(request)
self.assertEqual(response.status_code, 403) # Or whatever HttpResponseForbidden returns by default self.assertEqual(
response.status_code, 403
) # Or whatever HttpResponseForbidden returns by default
# @htmx_login_required tests # @htmx_login_required tests
def test_htmx_login_required_allows_authenticated_user(self): def test_htmx_login_required_allows_authenticated_user(self):
request = self.factory.get('/dummy-path', HTTP_HX_REQUEST='true') request = self.factory.get("/dummy-path", HTTP_HX_REQUEST="true")
request.user = self.user request.user = self.user
response = dummy_view_htmx_login_required(request) response = dummy_view_htmx_login_required(request)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"User Authenticated HTMX") self.assertEqual(response.content, b"User Authenticated HTMX")
def test_htmx_login_required_redirects_anonymous_user_for_htmx(self): def test_htmx_login_required_redirects_anonymous_user_for_htmx(self):
request = self.factory.get('/dummy-path', HTTP_HX_REQUEST='true') request = self.factory.get("/dummy-path", HTTP_HX_REQUEST="true")
request.user = AnonymousUser() request.user = AnonymousUser()
response = dummy_view_htmx_login_required(request) response = dummy_view_htmx_login_required(request)
self.assertEqual(response.status_code, 302) # Redirect self.assertEqual(response.status_code, 302) # Redirect
# Check for HX-Redirect header for HTMX redirects to login # Check for HX-Redirect header for HTMX redirects to login
self.assertIn('HX-Redirect', response.headers) self.assertIn("HX-Redirect", response.headers)
self.assertEqual(response.headers['HX-Redirect'], '/fake-login/?next=/dummy-path') self.assertEqual(
response.headers["HX-Redirect"], "/fake-login/?next=/dummy-path"
)
def test_htmx_login_required_redirects_anonymous_user_for_non_htmx(self): def test_htmx_login_required_redirects_anonymous_user_for_non_htmx(self):
# This decorator specifically checks for HX-Request and returns 403 if not present *before* auth check. # This decorator specifically checks for HX-Request and returns 403 if not present *before* auth check.
@@ -218,45 +247,49 @@ class DecoratorTests(TestCase):
# Let's assume it's strictly for HTMX and would deny non-HTMX, or that the login_required part # Let's assume it's strictly for HTMX and would deny non-HTMX, or that the login_required part
# would kick in. # would kick in.
# Given the decorator might be composed or simple, let's test the redirect path. # Given the decorator might be composed or simple, let's test the redirect path.
request = self.factory.get('/dummy-path') # Non-HTMX request = self.factory.get("/dummy-path") # Non-HTMX
request.user = AnonymousUser() request.user = AnonymousUser()
response = dummy_view_htmx_login_required(request) response = dummy_view_htmx_login_required(request)
# If it's a standard @login_required behavior for non-HTMX part: # If it's a standard @login_required behavior for non-HTMX part:
self.assertTrue(response.status_code == 302 or response.status_code == 403) self.assertTrue(response.status_code == 302 or response.status_code == 403)
if response.status_code == 302: if response.status_code == 302:
self.assertTrue(response.url.startswith('/fake-login/')) self.assertTrue(response.url.startswith("/fake-login/"))
# @is_superuser tests # @is_superuser tests
def test_is_superuser_allows_superuser(self): def test_is_superuser_allows_superuser(self):
request = self.factory.get('/dummy-path') request = self.factory.get("/dummy-path")
request.user = self.superuser request.user = self.superuser
response = dummy_view_is_superuser(request) response = dummy_view_is_superuser(request)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"Superuser Access Granted") self.assertEqual(response.content, b"Superuser Access Granted")
def test_is_superuser_forbids_regular_user(self): def test_is_superuser_forbids_regular_user(self):
request = self.factory.get('/dummy-path') request = self.factory.get("/dummy-path")
request.user = self.user request.user = self.user
response = dummy_view_is_superuser(request) response = dummy_view_is_superuser(request)
self.assertEqual(response.status_code, 403) # Or redirects to login if @login_required is also part of it self.assertEqual(
response.status_code, 403
) # Or redirects to login if @login_required is also part of it
def test_is_superuser_forbids_anonymous_user(self): def test_is_superuser_forbids_anonymous_user(self):
request = self.factory.get('/dummy-path') request = self.factory.get("/dummy-path")
request.user = AnonymousUser() request.user = AnonymousUser()
response = dummy_view_is_superuser(request) response = dummy_view_is_superuser(request)
# This typically redirects to login if @login_required is implicitly part of such checks, # This typically redirects to login if @login_required is implicitly part of such checks,
# or returns 403 if it's purely a superuser check after authentication. # or returns 403 if it's purely a superuser check after authentication.
self.assertTrue(response.status_code == 302 or response.status_code == 403) self.assertTrue(response.status_code == 302 or response.status_code == 403)
if response.status_code == 302: # Standard redirect to login if response.status_code == 302: # Standard redirect to login
self.assertTrue(response.url.startswith('/fake-login/')) self.assertTrue(response.url.startswith("/fake-login/"))
from io import StringIO from io import StringIO
from django.core.management import call_command from django.core.management import call_command
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
# Ensure User is available for management command test # Ensure User is available for management command test
User = get_user_model() User = get_user_model()
class ManagementCommandTests(TestCase): class ManagementCommandTests(TestCase):
def test_setup_users_command(self): def test_setup_users_command(self):
# Capture output # Capture output
@@ -272,8 +305,10 @@ class ManagementCommandTests(TestCase):
test_admin_email = "admin@command.com" test_admin_email = "admin@command.com"
test_admin_pass = "CommandPass123" test_admin_pass = "CommandPass123"
with self.settings(ADMIN_EMAIL=test_admin_email, ADMIN_PASSWORD=test_admin_pass): with self.settings(
call_command('setup_users', stdout=out) ADMIN_EMAIL=test_admin_email, ADMIN_PASSWORD=test_admin_pass
):
call_command("setup_users", stdout=out)
# Check if the admin user was created (if the command is supposed to create one) # Check if the admin user was created (if the command is supposed to create one)
self.assertTrue(User.objects.filter(email=test_admin_email).exists()) self.assertTrue(User.objects.filter(email=test_admin_email).exists())
@@ -282,7 +317,7 @@ class ManagementCommandTests(TestCase):
self.assertTrue(admin_user.check_password(test_admin_pass)) self.assertTrue(admin_user.check_password(test_admin_pass))
# The command also creates a 'user@example.com' # The command also creates a 'user@example.com'
self.assertTrue(User.objects.filter(email='user@example.com').exists()) self.assertTrue(User.objects.filter(email="user@example.com").exists())
# Check output for success messages (optional, depends on command's verbosity) # Check output for success messages (optional, depends on command's verbosity)
# self.assertIn("Superuser admin@command.com created.", out.getvalue()) # self.assertIn("Superuser admin@command.com created.", out.getvalue())

View File

@@ -9,22 +9,28 @@ from django.urls import reverse
from django.utils import timezone from django.utils import timezone
from apps.currencies.models import Currency, ExchangeRate, ExchangeRateService from apps.currencies.models import Currency, ExchangeRate, ExchangeRateService
from apps.accounts.models import Account # For ExchangeRateService target_accounts from apps.accounts.models import Account # For ExchangeRateService target_accounts
User = get_user_model() User = get_user_model()
class BaseCurrencyAppTest(TestCase): class BaseCurrencyAppTest(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user(email="curtestuser@example.com", password="password") self.user = User.objects.create_user(
email="curtestuser@example.com", password="password"
)
self.client = Client() self.client = Client()
self.client.login(email="curtestuser@example.com", password="password") self.client.login(email="curtestuser@example.com", password="password")
self.usd = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2, prefix="$") self.usd = Currency.objects.create(
self.eur = Currency.objects.create(code="EUR", name="Euro", decimal_places=2, prefix="") code="USD", name="US Dollar", decimal_places=2, prefix="$"
)
self.eur = Currency.objects.create(
code="EUR", name="Euro", decimal_places=2, prefix=""
)
class CurrencyModelTests(BaseCurrencyAppTest): # Changed from CurrencyTests class CurrencyModelTests(BaseCurrencyAppTest): # Changed from CurrencyTests
def test_currency_creation(self): def test_currency_creation(self):
"""Test basic currency creation""" """Test basic currency creation"""
# self.usd is already created in BaseCurrencyAppTest # self.usd is already created in BaseCurrencyAppTest
@@ -33,10 +39,11 @@ class CurrencyModelTests(BaseCurrencyAppTest): # Changed from CurrencyTests
self.assertEqual(self.usd.decimal_places, 2) self.assertEqual(self.usd.decimal_places, 2)
self.assertEqual(self.usd.prefix, "$") self.assertEqual(self.usd.prefix, "$")
# Test creation with suffix # Test creation with suffix
jpy = Currency.objects.create(code="JPY", name="Japanese Yen", decimal_places=0, suffix="") jpy = Currency.objects.create(
code="JPY", name="Japanese Yen", decimal_places=0, suffix=""
)
self.assertEqual(jpy.suffix, "") self.assertEqual(jpy.suffix, "")
def test_currency_decimal_places_validation(self): def test_currency_decimal_places_validation(self):
"""Test decimal places validation for maximum value""" """Test decimal places validation for maximum value"""
currency = Currency(code="TESTMAX", name="Test Currency Max", decimal_places=31) currency = Currency(code="TESTMAX", name="Test Currency Max", decimal_places=31)
@@ -58,11 +65,14 @@ class CurrencyModelTests(BaseCurrencyAppTest): # Changed from CurrencyTests
self.usd.exchange_currency = self.usd self.usd.exchange_currency = self.usd
with self.assertRaises(ValidationError) as context: with self.assertRaises(ValidationError) as context:
self.usd.full_clean() self.usd.full_clean()
self.assertIn('exchange_currency', context.exception.message_dict) self.assertIn("exchange_currency", context.exception.message_dict)
self.assertIn("Currency cannot have itself as exchange currency.", context.exception.message_dict['exchange_currency']) self.assertIn(
"Currency cannot have itself as exchange currency.",
context.exception.message_dict["exchange_currency"],
)
class ExchangeRateModelTests(BaseCurrencyAppTest): # Changed from ExchangeRateTests class ExchangeRateModelTests(BaseCurrencyAppTest): # Changed from ExchangeRateTests
def test_exchange_rate_creation(self): def test_exchange_rate_creation(self):
"""Test basic exchange rate creation""" """Test basic exchange rate creation"""
rate = ExchangeRate.objects.create( rate = ExchangeRate.objects.create(
@@ -83,11 +93,11 @@ class ExchangeRateModelTests(BaseCurrencyAppTest): # Changed from ExchangeRateTe
rate=Decimal("0.85"), rate=Decimal("0.85"),
date=date, date=date,
) )
with self.assertRaises(IntegrityError): # Specifically expect IntegrityError with self.assertRaises(IntegrityError): # Specifically expect IntegrityError
ExchangeRate.objects.create( ExchangeRate.objects.create(
from_currency=self.usd, from_currency=self.usd,
to_currency=self.eur, to_currency=self.eur,
rate=Decimal("0.86"), # Different rate, same pair and date rate=Decimal("0.86"), # Different rate, same pair and date
date=date, date=date,
) )
@@ -95,14 +105,17 @@ class ExchangeRateModelTests(BaseCurrencyAppTest): # Changed from ExchangeRateTe
"""Test that from_currency and to_currency cannot be the same.""" """Test that from_currency and to_currency cannot be the same."""
rate = ExchangeRate( rate = ExchangeRate(
from_currency=self.usd, from_currency=self.usd,
to_currency=self.usd, # Same currency to_currency=self.usd, # Same currency
rate=Decimal("1.00"), rate=Decimal("1.00"),
date=timezone.now() date=timezone.now(),
) )
with self.assertRaises(ValidationError) as context: with self.assertRaises(ValidationError) as context:
rate.full_clean() rate.full_clean()
self.assertIn('to_currency', context.exception.message_dict) self.assertIn("to_currency", context.exception.message_dict)
self.assertIn("From and To currencies cannot be the same.", context.exception.message_dict['to_currency']) self.assertIn(
"From and To currencies cannot be the same.",
context.exception.message_dict["to_currency"],
)
class ExchangeRateServiceModelTests(BaseCurrencyAppTest): class ExchangeRateServiceModelTests(BaseCurrencyAppTest):
@@ -111,7 +124,7 @@ class ExchangeRateServiceModelTests(BaseCurrencyAppTest):
name="Test Coingecko Free", name="Test Coingecko Free",
service_type=ExchangeRateService.ServiceType.COINGECKO_FREE, service_type=ExchangeRateService.ServiceType.COINGECKO_FREE,
interval_type=ExchangeRateService.IntervalType.EVERY, interval_type=ExchangeRateService.IntervalType.EVERY,
fetch_interval="12" # Every 12 hours fetch_interval="12", # Every 12 hours
) )
self.assertEqual(str(service), "Test Coingecko Free") self.assertEqual(str(service), "Test Coingecko Free")
self.assertTrue(service.is_active) self.assertTrue(service.is_active)
@@ -119,17 +132,22 @@ class ExchangeRateServiceModelTests(BaseCurrencyAppTest):
def test_fetch_interval_validation_every_x_hours(self): def test_fetch_interval_validation_every_x_hours(self):
# Valid # Valid
service = ExchangeRateService( service = ExchangeRateService(
name="Valid Every", service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE, name="Valid Every",
interval_type=ExchangeRateService.IntervalType.EVERY, fetch_interval="6" service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE,
interval_type=ExchangeRateService.IntervalType.EVERY,
fetch_interval="6",
) )
service.full_clean() # Should not raise service.full_clean() # Should not raise
# Invalid - not a digit # Invalid - not a digit
service.fetch_interval = "abc" service.fetch_interval = "abc"
with self.assertRaises(ValidationError) as context: with self.assertRaises(ValidationError) as context:
service.full_clean() service.full_clean()
self.assertIn("fetch_interval", context.exception.message_dict) self.assertIn("fetch_interval", context.exception.message_dict)
self.assertIn("'Every X hours' interval type requires a positive integer.", context.exception.message_dict['fetch_interval'][0]) self.assertIn(
"'Every X hours' interval type requires a positive integer.",
context.exception.message_dict["fetch_interval"][0],
)
# Invalid - out of range # Invalid - out of range
service.fetch_interval = "0" service.fetch_interval = "0"
@@ -144,49 +162,66 @@ class ExchangeRateServiceModelTests(BaseCurrencyAppTest):
valid_intervals = ["1", "0,12", "1-5", "1-5,8,10-12", "0,1,2,3,22,23"] valid_intervals = ["1", "0,12", "1-5", "1-5,8,10-12", "0,1,2,3,22,23"]
for interval in valid_intervals: for interval in valid_intervals:
service = ExchangeRateService( service = ExchangeRateService(
name=f"Test On {interval}", service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE, name=f"Test On {interval}",
interval_type=ExchangeRateService.IntervalType.ON, fetch_interval=interval service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE,
interval_type=ExchangeRateService.IntervalType.ON,
fetch_interval=interval,
) )
service.full_clean() # Should not raise service.full_clean() # Should not raise
# Check normalized form (optional, but good if model does it) # Check normalized form (optional, but good if model does it)
# self.assertEqual(service.fetch_interval, ",".join(str(h) for h in sorted(service._parse_hour_ranges(interval)))) # self.assertEqual(service.fetch_interval, ",".join(str(h) for h in sorted(service._parse_hour_ranges(interval))))
invalid_intervals = [ invalid_intervals = [
"abc", "1-", "-5", "24", "-1", "1-24", "1,2,25", "5-1", # Invalid hour, range, or format "abc",
"1.5", "1, 2, 3," # decimal, trailing comma "1-",
"-5",
"24",
"-1",
"1-24",
"1,2,25",
"5-1", # Invalid hour, range, or format
"1.5",
"1, 2, 3,", # decimal, trailing comma
] ]
for interval in invalid_intervals: for interval in invalid_intervals:
service = ExchangeRateService( service = ExchangeRateService(
name=f"Test On Invalid {interval}", service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE, name=f"Test On Invalid {interval}",
interval_type=ExchangeRateService.IntervalType.NOT_ON, fetch_interval=interval service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE,
interval_type=ExchangeRateService.IntervalType.NOT_ON,
fetch_interval=interval,
) )
with self.assertRaises(ValidationError) as context: with self.assertRaises(ValidationError) as context:
service.full_clean() service.full_clean()
self.assertIn("fetch_interval", context.exception.message_dict) self.assertIn("fetch_interval", context.exception.message_dict)
self.assertTrue("Invalid hour format" in context.exception.message_dict['fetch_interval'][0] or \ self.assertTrue(
"Hours must be between 0 and 23" in context.exception.message_dict['fetch_interval'][0] or \ "Invalid hour format"
"Invalid range" in context.exception.message_dict['fetch_interval'][0] in context.exception.message_dict["fetch_interval"][0]
or "Hours must be between 0 and 23"
in context.exception.message_dict["fetch_interval"][0]
or "Invalid range"
in context.exception.message_dict["fetch_interval"][0]
) )
@patch("apps.currencies.exchange_rates.fetcher.PROVIDER_MAPPING")
@patch('apps.currencies.exchange_rates.fetcher.PROVIDER_MAPPING')
def test_get_provider(self, mock_provider_mapping): def test_get_provider(self, mock_provider_mapping):
# Mock a provider class # Mock a provider class
class MockProvider: class MockProvider:
def __init__(self, api_key=None): def __init__(self, api_key=None):
self.api_key = api_key self.api_key = api_key
mock_provider_mapping.__getitem__.return_value = MockProvider mock_provider_mapping.__getitem__.return_value = MockProvider
service = ExchangeRateService( service = ExchangeRateService(
name="Test Get Provider", name="Test Get Provider",
service_type=ExchangeRateService.ServiceType.COINGECKO_FREE, # Any valid choice service_type=ExchangeRateService.ServiceType.COINGECKO_FREE, # Any valid choice
api_key="testkey" api_key="testkey",
) )
provider_instance = service.get_provider() provider_instance = service.get_provider()
self.assertIsInstance(provider_instance, MockProvider) self.assertIsInstance(provider_instance, MockProvider)
self.assertEqual(provider_instance.api_key, "testkey") self.assertEqual(provider_instance.api_key, "testkey")
mock_provider_mapping.__getitem__.assert_called_with(ExchangeRateService.ServiceType.COINGECKO_FREE) mock_provider_mapping.__getitem__.assert_called_with(
ExchangeRateService.ServiceType.COINGECKO_FREE
)
class CurrencyViewTests(BaseCurrencyAppTest): class CurrencyViewTests(BaseCurrencyAppTest):
@@ -197,21 +232,35 @@ class CurrencyViewTests(BaseCurrencyAppTest):
self.assertContains(response, self.eur.name) self.assertContains(response, self.eur.name)
def test_currency_add_view(self): def test_currency_add_view(self):
data = {"code": "GBP", "name": "British Pound", "decimal_places": 2, "prefix": "£"} data = {
"code": "GBP",
"name": "British Pound",
"decimal_places": 2,
"prefix": "£",
}
response = self.client.post(reverse("currency_add"), data) response = self.client.post(reverse("currency_add"), data)
self.assertEqual(response.status_code, 204) # HTMX success self.assertEqual(response.status_code, 204) # HTMX success
self.assertTrue(Currency.objects.filter(code="GBP").exists()) self.assertTrue(Currency.objects.filter(code="GBP").exists())
def test_currency_edit_view(self): def test_currency_edit_view(self):
gbp = Currency.objects.create(code="GBP", name="Pound Sterling", decimal_places=2) gbp = Currency.objects.create(
data = {"code": "GBP", "name": "British Pound Sterling", "decimal_places": 2, "prefix": "£"} code="GBP", name="Pound Sterling", decimal_places=2
)
data = {
"code": "GBP",
"name": "British Pound Sterling",
"decimal_places": 2,
"prefix": "£",
}
response = self.client.post(reverse("currency_edit", args=[gbp.id]), data) response = self.client.post(reverse("currency_edit", args=[gbp.id]), data)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
gbp.refresh_from_db() gbp.refresh_from_db()
self.assertEqual(gbp.name, "British Pound Sterling") self.assertEqual(gbp.name, "British Pound Sterling")
def test_currency_delete_view(self): def test_currency_delete_view(self):
cad = Currency.objects.create(code="CAD", name="Canadian Dollar", decimal_places=2) cad = Currency.objects.create(
code="CAD", name="Canadian Dollar", decimal_places=2
)
response = self.client.delete(reverse("currency_delete", args=[cad.id])) response = self.client.delete(reverse("currency_delete", args=[cad.id]))
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertFalse(Currency.objects.filter(code="CAD").exists()) self.assertFalse(Currency.objects.filter(code="CAD").exists())
@@ -220,38 +269,72 @@ class CurrencyViewTests(BaseCurrencyAppTest):
class ExchangeRateViewTests(BaseCurrencyAppTest): class ExchangeRateViewTests(BaseCurrencyAppTest):
def test_exchange_rate_list_view_main(self): def test_exchange_rate_list_view_main(self):
# This view lists pairs, not individual rates directly in the main list # This view lists pairs, not individual rates directly in the main list
ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.9"), date=timezone.now()) ExchangeRate.objects.create(
from_currency=self.usd,
to_currency=self.eur,
rate=Decimal("0.9"),
date=timezone.now(),
)
response = self.client.get(reverse("exchange_rates_list")) response = self.client.get(reverse("exchange_rates_list"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, self.usd.name) # Check if pair components are mentioned self.assertContains(
response, self.usd.name
) # Check if pair components are mentioned
self.assertContains(response, self.eur.name) self.assertContains(response, self.eur.name)
def test_exchange_rate_list_pair_view(self): def test_exchange_rate_list_pair_view(self):
rate_date = timezone.now() rate_date = timezone.now()
ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.9"), date=rate_date) ExchangeRate.objects.create(
url = reverse("exchange_rates_list_pair") + f"?from={self.usd.name}&to={self.eur.name}" from_currency=self.usd,
to_currency=self.eur,
rate=Decimal("0.9"),
date=rate_date,
)
url = (
reverse("exchange_rates_list_pair")
+ f"?from={self.usd.name}&to={self.eur.name}"
)
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, "0.9") # Check if the rate is displayed self.assertContains(response, "0.9") # Check if the rate is displayed
def test_exchange_rate_add_view(self): def test_exchange_rate_add_view(self):
data = { data = {
"from_currency": self.usd.id, "from_currency": self.usd.id,
"to_currency": self.eur.id, "to_currency": self.eur.id,
"rate": "0.88", "rate": "0.88",
"date": timezone.now().strftime('%Y-%m-%d %H:%M:%S') # Match form field format "date": timezone.now().strftime(
"%Y-%m-%d %H:%M:%S"
), # Match form field format
} }
response = self.client.post(reverse("exchange_rate_add"), data) response = self.client.post(reverse("exchange_rate_add"), data)
self.assertEqual(response.status_code, 204, response.content.decode() if response.content and response.status_code != 204 else "No content on 204") self.assertEqual(
self.assertTrue(ExchangeRate.objects.filter(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.88")).exists()) response.status_code,
204,
(
response.content.decode()
if response.content and response.status_code != 204
else "No content on 204"
),
)
self.assertTrue(
ExchangeRate.objects.filter(
from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.88")
).exists()
)
def test_exchange_rate_edit_view(self): def test_exchange_rate_edit_view(self):
rate = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.91"), date=timezone.now()) rate = ExchangeRate.objects.create(
from_currency=self.usd,
to_currency=self.eur,
rate=Decimal("0.91"),
date=timezone.now(),
)
data = { data = {
"from_currency": self.usd.id, "from_currency": self.usd.id,
"to_currency": self.eur.id, "to_currency": self.eur.id,
"rate": "0.92", "rate": "0.92",
"date": rate.date.strftime('%Y-%m-%d %H:%M:%S') "date": rate.date.strftime("%Y-%m-%d %H:%M:%S"),
} }
response = self.client.post(reverse("exchange_rate_edit", args=[rate.id]), data) response = self.client.post(reverse("exchange_rate_edit", args=[rate.id]), data)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
@@ -259,7 +342,12 @@ class ExchangeRateViewTests(BaseCurrencyAppTest):
self.assertEqual(rate.rate, Decimal("0.92")) self.assertEqual(rate.rate, Decimal("0.92"))
def test_exchange_rate_delete_view(self): def test_exchange_rate_delete_view(self):
rate = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.93"), date=timezone.now()) rate = ExchangeRate.objects.create(
from_currency=self.usd,
to_currency=self.eur,
rate=Decimal("0.93"),
date=timezone.now(),
)
response = self.client.delete(reverse("exchange_rate_delete", args=[rate.id])) response = self.client.delete(reverse("exchange_rate_delete", args=[rate.id]))
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertFalse(ExchangeRate.objects.filter(id=rate.id).exists()) self.assertFalse(ExchangeRate.objects.filter(id=rate.id).exists())
@@ -267,8 +355,12 @@ class ExchangeRateViewTests(BaseCurrencyAppTest):
class ExchangeRateServiceViewTests(BaseCurrencyAppTest): class ExchangeRateServiceViewTests(BaseCurrencyAppTest):
def test_exchange_rate_service_list_view(self): def test_exchange_rate_service_list_view(self):
service = ExchangeRateService.objects.create(name="My Test Service", service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE, fetch_interval="1") service = ExchangeRateService.objects.create(
response = self.client.get(reverse("exchange_rates_services_list")) name="My Test Service",
service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE,
fetch_interval="1",
)
response = self.client.get(reverse("automatic_exchange_rates_list"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, service.name) self.assertContains(response, service.name)
@@ -281,33 +373,47 @@ class ExchangeRateServiceViewTests(BaseCurrencyAppTest):
"fetch_interval": "24", "fetch_interval": "24",
# target_currencies and target_accounts are M2M, handled differently or optional # target_currencies and target_accounts are M2M, handled differently or optional
} }
response = self.client.post(reverse("exchange_rate_service_add"), data) response = self.client.post(reverse("automatic_exchange_rate_add"), data)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertTrue(ExchangeRateService.objects.filter(name="New Fetcher Service").exists()) self.assertTrue(
ExchangeRateService.objects.filter(name="New Fetcher Service").exists()
)
def test_exchange_rate_service_edit_view(self): def test_exchange_rate_service_edit_view(self):
service = ExchangeRateService.objects.create(name="Editable Service", service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE, fetch_interval="1") service = ExchangeRateService.objects.create(
name="Editable Service",
service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE,
fetch_interval="1",
)
data = { data = {
"name": "Edited Fetcher Service", "name": "Edited Fetcher Service",
"service_type": service.service_type, "service_type": service.service_type,
"is_active": "on", "is_active": "on",
"interval_type": service.interval_type, "interval_type": service.interval_type,
"fetch_interval": "6", # Changed interval "fetch_interval": "6", # Changed interval
} }
response = self.client.post(reverse("exchange_rate_service_edit", args=[service.id]), data) response = self.client.post(
reverse("automatic_exchange_rate_edit", args=[service.id]), data
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
service.refresh_from_db() service.refresh_from_db()
self.assertEqual(service.name, "Edited Fetcher Service") self.assertEqual(service.name, "Edited Fetcher Service")
self.assertEqual(service.fetch_interval, "6") self.assertEqual(service.fetch_interval, "6")
def test_exchange_rate_service_delete_view(self): def test_exchange_rate_service_delete_view(self):
service = ExchangeRateService.objects.create(name="Deletable Service", service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE, fetch_interval="1") service = ExchangeRateService.objects.create(
response = self.client.delete(reverse("exchange_rate_service_delete", args=[service.id])) name="Deletable Service",
service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE,
fetch_interval="1",
)
response = self.client.delete(
reverse("automatic_exchange_rate_delete", args=[service.id])
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertFalse(ExchangeRateService.objects.filter(id=service.id).exists()) self.assertFalse(ExchangeRateService.objects.filter(id=service.id).exists())
@patch('apps.currencies.tasks.manual_fetch_exchange_rates.defer') @patch("apps.currencies.tasks.manual_fetch_exchange_rates.defer")
def test_exchange_rate_service_force_fetch_view(self, mock_defer): def test_exchange_rate_service_force_fetch_view(self, mock_defer):
response = self.client.get(reverse("exchange_rate_service_force_fetch")) response = self.client.get(reverse("automatic_exchange_rate_force_fetch"))
self.assertEqual(response.status_code, 204) # Triggers toast self.assertEqual(response.status_code, 204) # Triggers toast
mock_defer.assert_called_once() mock_defer.assert_called_once()

View File

@@ -11,33 +11,60 @@ from django.utils import timezone
from apps.accounts.models import Account, AccountGroup from apps.accounts.models import Account, AccountGroup
from apps.currencies.models import Currency from apps.currencies.models import Currency
from apps.transactions.models import Transaction, TransactionCategory, TransactionTag, TransactionEntity from apps.transactions.models import (
from apps.export_app.resources.transactions import TransactionResource, TransactionTagResource Transaction,
TransactionCategory,
TransactionTag,
TransactionEntity,
)
from apps.export_app.resources.transactions import (
TransactionResource,
TransactionTagResource,
)
from apps.export_app.resources.accounts import AccountResource from apps.export_app.resources.accounts import AccountResource
from apps.export_app.forms import ExportForm, RestoreForm # Added RestoreForm from apps.export_app.forms import ExportForm, RestoreForm # Added RestoreForm
User = get_user_model() User = get_user_model()
class BaseExportAppTest(TestCase): class BaseExportAppTest(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
cls.superuser = User.objects.create_superuser( cls.superuser = User.objects.create_superuser(
email="exportadmin@example.com", password="password" email="exportadmin@example.com", password="password"
) )
cls.currency_usd = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2) cls.currency_usd = Currency.objects.create(
cls.currency_eur = Currency.objects.create(code="EUR", name="Euro", decimal_places=2) code="USD", name="US Dollar", decimal_places=2
)
cls.currency_eur = Currency.objects.create(
code="EUR", name="Euro", decimal_places=2
)
cls.user_group = AccountGroup.objects.create(name="User Group", owner=cls.superuser) cls.user_group = AccountGroup.objects.create(
name="User Group", owner=cls.superuser
)
cls.account_usd = Account.objects.create( cls.account_usd = Account.objects.create(
name="Checking USD", currency=cls.currency_usd, owner=cls.superuser, group=cls.user_group name="Checking USD",
currency=cls.currency_usd,
owner=cls.superuser,
group=cls.user_group,
) )
cls.account_eur = Account.objects.create( cls.account_eur = Account.objects.create(
name="Savings EUR", currency=cls.currency_eur, owner=cls.superuser, group=cls.user_group name="Savings EUR",
currency=cls.currency_eur,
owner=cls.superuser,
group=cls.user_group,
) )
cls.category_food = TransactionCategory.objects.create(name="Food", owner=cls.superuser) cls.category_food = TransactionCategory.objects.create(
cls.tag_urgent = TransactionTag.objects.create(name="Urgent", owner=cls.superuser) name="Food", owner=cls.superuser
cls.entity_store = TransactionEntity.objects.create(name="SuperStore", owner=cls.superuser) )
cls.tag_urgent = TransactionTag.objects.create(
name="Urgent", owner=cls.superuser
)
cls.entity_store = TransactionEntity.objects.create(
name="SuperStore", owner=cls.superuser
)
cls.transaction1 = Transaction.objects.create( cls.transaction1 = Transaction.objects.create(
account=cls.account_usd, account=cls.account_usd,
@@ -48,7 +75,7 @@ class BaseExportAppTest(TestCase):
amount=Decimal("50.00"), amount=Decimal("50.00"),
description="Groceries", description="Groceries",
category=cls.category_food, category=cls.category_food,
is_paid=True is_paid=True,
) )
cls.transaction1.tags.add(cls.tag_urgent) cls.transaction1.tags.add(cls.tag_urgent)
cls.transaction1.entities.add(cls.entity_store) cls.transaction1.entities.add(cls.entity_store)
@@ -61,7 +88,7 @@ class BaseExportAppTest(TestCase):
reference_date=date(2023, 1, 1), reference_date=date(2023, 1, 1),
amount=Decimal("1200.00"), amount=Decimal("1200.00"),
description="Salary", description="Salary",
is_paid=True is_paid=True,
) )
def setUp(self): def setUp(self):
@@ -72,7 +99,9 @@ class BaseExportAppTest(TestCase):
class ResourceExportTests(BaseExportAppTest): class ResourceExportTests(BaseExportAppTest):
def test_transaction_resource_export(self): def test_transaction_resource_export(self):
resource = TransactionResource() resource = TransactionResource()
queryset = Transaction.objects.filter(owner=self.superuser).order_by('pk') # Ensure consistent order queryset = Transaction.objects.filter(owner=self.superuser).order_by(
"pk"
) # Ensure consistent order
dataset = resource.export(queryset=queryset) dataset = resource.export(queryset=queryset)
self.assertEqual(len(dataset), 2) self.assertEqual(len(dataset), 2)
@@ -90,14 +119,17 @@ class ResourceExportTests(BaseExportAppTest):
self.assertEqual(exported_row1_dict["description"], "Groceries") self.assertEqual(exported_row1_dict["description"], "Groceries")
self.assertEqual(exported_row1_dict["category"], self.category_food.name) self.assertEqual(exported_row1_dict["category"], self.category_food.name)
# M2M fields order might vary, so check for presence # M2M fields order might vary, so check for presence
self.assertIn(self.tag_urgent.name, exported_row1_dict["tags"].split(',')) self.assertIn(self.tag_urgent.name, exported_row1_dict["tags"].split(","))
self.assertIn(self.entity_store.name, exported_row1_dict["entities"].split(',')) self.assertIn(self.entity_store.name, exported_row1_dict["entities"].split(","))
self.assertEqual(Decimal(exported_row1_dict["amount"]), self.transaction1.amount) self.assertEqual(
Decimal(exported_row1_dict["amount"]), self.transaction1.amount
)
def test_account_resource_export(self): def test_account_resource_export(self):
resource = AccountResource() resource = AccountResource()
queryset = Account.objects.filter(owner=self.superuser).order_by('name') # Ensure consistent order queryset = Account.objects.filter(owner=self.superuser).order_by(
"name"
) # Ensure consistent order
dataset = resource.export(queryset=queryset) dataset = resource.export(queryset=queryset)
self.assertEqual(len(dataset), 2) self.assertEqual(len(dataset), 2)
@@ -125,9 +157,13 @@ class ExportViewTests(BaseExportAppTest):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "text/csv") self.assertEqual(response["Content-Type"], "text/csv")
self.assertTrue(response["Content-Disposition"].endswith("_WYGIWYH_export_transactions.csv\"")) self.assertTrue(
response["Content-Disposition"].endswith(
'_WYGIWYH_export_transactions.csv"'
)
)
content = response.content.decode('utf-8') content = response.content.decode("utf-8")
reader = csv.reader(io.StringIO(content)) reader = csv.reader(io.StringIO(content))
headers = next(reader) headers = next(reader)
self.assertIn("id", headers) self.assertIn("id", headers)
@@ -136,7 +172,6 @@ class ExportViewTests(BaseExportAppTest):
self.assertIn(self.transaction1.description, content) self.assertIn(self.transaction1.description, content)
self.assertIn(self.transaction2.description, content) self.assertIn(self.transaction2.description, content)
def test_export_multiple_to_zip(self): def test_export_multiple_to_zip(self):
data = { data = {
"transactions": "on", "transactions": "on",
@@ -146,7 +181,9 @@ class ExportViewTests(BaseExportAppTest):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "application/zip") self.assertEqual(response["Content-Type"], "application/zip")
self.assertTrue(response["Content-Disposition"].endswith("_WYGIWYH_export.zip\"")) self.assertTrue(
response["Content-Disposition"].endswith('_WYGIWYH_export.zip"')
)
zip_buffer = io.BytesIO(response.content) zip_buffer = io.BytesIO(response.content)
with zipfile.ZipFile(zip_buffer, "r") as zf: with zipfile.ZipFile(zip_buffer, "r") as zf:
@@ -155,19 +192,22 @@ class ExportViewTests(BaseExportAppTest):
self.assertIn("accounts.csv", filenames) self.assertIn("accounts.csv", filenames)
with zf.open("transactions.csv") as csv_file: with zf.open("transactions.csv") as csv_file:
content = csv_file.read().decode('utf-8') content = csv_file.read().decode("utf-8")
self.assertIn("id,type,date", content) self.assertIn("id,type,date", content)
self.assertIn(self.transaction1.description, content) self.assertIn(self.transaction1.description, content)
def test_export_no_selection(self): def test_export_no_selection(self):
data = {} data = {}
response = self.client.post(reverse("export_form"), data) response = self.client.post(reverse("export_form"), data)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertIn("You have to select at least one export", response.content.decode()) self.assertIn(
"You have to select at least one export", response.content.decode()
)
def test_export_access_non_superuser(self): def test_export_access_non_superuser(self):
normal_user = User.objects.create_user(email="normal@example.com", password="password") normal_user = User.objects.create_user(
email="normal@example.com", password="password"
)
self.client.logout() self.client.logout()
self.client.login(email="normal@example.com", password="password") self.client.login(email="normal@example.com", password="password")
@@ -201,4 +241,3 @@ class RestoreViewTests(BaseExportAppTest):
# mock_process_imports.assert_called_once() # mock_process_imports.assert_called_once()
# # Further checks on how mock_process_imports was called could be added here. # # Further checks on how mock_process_imports was called could be added here.
pass pass
```

View File

@@ -13,54 +13,88 @@ from django.urls import reverse
from apps.import_app.models import ImportProfile, ImportRun from apps.import_app.models import ImportProfile, ImportRun
from apps.import_app.services.v1 import ImportService from apps.import_app.services.v1 import ImportService
from apps.import_app.schemas.v1 import ImportProfileSchema, CSVImportSettings, ColumnMapping, TransactionDateMapping, TransactionAmountMapping, TransactionDescriptionMapping, TransactionAccountMapping from apps.import_app.schemas.v1 import (
ImportProfileSchema,
CSVImportSettings,
ColumnMapping,
TransactionDateMapping,
TransactionAmountMapping,
TransactionDescriptionMapping,
TransactionAccountMapping,
)
from apps.accounts.models import Account from apps.accounts.models import Account
from apps.currencies.models import Currency from apps.currencies.models import Currency
from apps.transactions.models import Transaction, TransactionCategory, TransactionTag, TransactionEntity from apps.transactions.models import (
Transaction,
TransactionCategory,
TransactionTag,
TransactionEntity,
)
# Mocking get_current_user from thread_local # Mocking get_current_user from thread_local
from apps.common.middleware.thread_local import get_current_user, set_current_user from apps.common.middleware.thread_local import get_current_user, write_current_user
User = get_user_model() User = get_user_model()
# --- Base Test Case --- # --- Base Test Case ---
class BaseImportAppTest(TestCase): class BaseImportAppTest(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user(email="importer@example.com", password="password") self.user = User.objects.create_user(
set_current_user(self.user) # For services that rely on get_current_user email="importer@example.com", password="password"
)
write_current_user(self.user) # For services that rely on get_current_user
self.client = Client() self.client = Client()
self.client.login(email="importer@example.com", password="password") self.client.login(email="importer@example.com", password="password")
self.currency_usd = Currency.objects.create(code="USD", name="US Dollar") self.currency_usd = Currency.objects.create(code="USD", name="US Dollar")
self.account_usd = Account.objects.create(name="Checking USD", currency=self.currency_usd, owner=self.user) self.account_usd = Account.objects.create(
name="Checking USD", currency=self.currency_usd, owner=self.user
)
def tearDown(self): def tearDown(self):
set_current_user(None) write_current_user(None)
def _create_valid_transaction_import_profile_yaml(self, extra_settings=None, extra_mappings=None): def _create_valid_transaction_import_profile_yaml(
self, extra_settings=None, extra_mappings=None
):
settings_dict = { settings_dict = {
"file_type": "csv", "file_type": "csv",
"delimiter": ",", "delimiter": ",",
"skip_lines": 0, "skip_lines": 0,
"importing": "transactions", "importing": "transactions",
"trigger_transaction_rules": False, "trigger_transaction_rules": False,
**(extra_settings or {}) **(extra_settings or {}),
} }
mappings_dict = { mappings_dict = {
"col_date": {"target": "date", "source": "DateColumn", "format": "%Y-%m-%d"}, "col_date": {
"target": "date",
"source": "DateColumn",
"format": "%Y-%m-%d",
},
"col_amount": {"target": "amount", "source": "AmountColumn"}, "col_amount": {"target": "amount", "source": "AmountColumn"},
"col_desc": {"target": "description", "source": "DescriptionColumn"}, "col_desc": {"target": "description", "source": "DescriptionColumn"},
"col_acc": {"target": "account", "source": "AccountNameColumn", "type": "name"}, "col_acc": {
**(extra_mappings or {}) "target": "account",
"source": "AccountNameColumn",
"type": "name",
},
**(extra_mappings or {}),
} }
return yaml.dump({"settings": settings_dict, "mapping": mappings_dict}) return yaml.dump({"settings": settings_dict, "mapping": mappings_dict})
# --- Model Tests --- # --- Model Tests ---
class ImportProfileModelTests(BaseImportAppTest): class ImportProfileModelTests(BaseImportAppTest):
def test_import_profile_valid_yaml_clean(self): def test_import_profile_valid_yaml_clean(self):
valid_yaml = self._create_valid_transaction_import_profile_yaml() valid_yaml = self._create_valid_transaction_import_profile_yaml()
profile = ImportProfile(name="Test Valid Profile", yaml_config=valid_yaml, version=ImportProfile.Versions.VERSION_1) profile = ImportProfile(
name="Test Valid Profile",
yaml_config=valid_yaml,
version=ImportProfile.Versions.VERSION_1,
)
try: try:
profile.full_clean() # Should not raise ValidationError profile.full_clean() # Should not raise ValidationError
except ValidationError as e: except ValidationError as e:
@@ -77,13 +111,20 @@ settings:
mapping: mapping:
col_date: {target: date, source: Date, format: "%Y-%m-%d"} col_date: {target: date, source: Date, format: "%Y-%m-%d"}
""" """
profile = ImportProfile(name="Test Invalid Profile", yaml_config=invalid_yaml, version=ImportProfile.Versions.VERSION_1) profile = ImportProfile(
name="Test Invalid Profile",
yaml_config=invalid_yaml,
version=ImportProfile.Versions.VERSION_1,
)
with self.assertRaises(ValidationError) as context: with self.assertRaises(ValidationError) as context:
profile.full_clean() profile.full_clean()
self.assertIn("yaml_config", context.exception.message_dict) self.assertIn("yaml_config", context.exception.message_dict)
self.assertTrue("Input should be a valid string" in str(context.exception.message_dict["yaml_config"]) or \ self.assertTrue(
"Input should be a valid integer" in str(context.exception.message_dict["yaml_config"])) "Input should be a valid string"
in str(context.exception.message_dict["yaml_config"])
or "Input should be a valid integer"
in str(context.exception.message_dict["yaml_config"])
)
def test_import_profile_invalid_mapping_for_import_type(self): def test_import_profile_invalid_mapping_for_import_type(self):
invalid_yaml = """ invalid_yaml = """
@@ -93,11 +134,18 @@ settings:
mapping: mapping:
some_col: {target: account_name, source: SomeColumn} some_col: {target: account_name, source: SomeColumn}
""" """
profile = ImportProfile(name="Invalid Mapping Type", yaml_config=invalid_yaml, version=ImportProfile.Versions.VERSION_1) profile = ImportProfile(
name="Invalid Mapping Type",
yaml_config=invalid_yaml,
version=ImportProfile.Versions.VERSION_1,
)
with self.assertRaises(ValidationError) as context: with self.assertRaises(ValidationError) as context:
profile.full_clean() profile.full_clean()
self.assertIn("yaml_config", context.exception.message_dict) self.assertIn("yaml_config", context.exception.message_dict)
self.assertIn("Mapping type 'AccountNameMapping' is not allowed when importing tags", str(context.exception.message_dict["yaml_config"])) self.assertIn(
"Mapping type 'AccountNameMapping' is not allowed when importing tags",
str(context.exception.message_dict["yaml_config"]),
)
# --- Service Tests (Focus on ImportService v1) --- # --- Service Tests (Focus on ImportService v1) ---
@@ -105,8 +153,12 @@ class ImportServiceV1LogicTests(BaseImportAppTest):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.basic_yaml_config = self._create_valid_transaction_import_profile_yaml() self.basic_yaml_config = self._create_valid_transaction_import_profile_yaml()
self.profile = ImportProfile.objects.create(name="Service Test Profile", yaml_config=self.basic_yaml_config) self.profile = ImportProfile.objects.create(
self.import_run = ImportRun.objects.create(profile=self.profile, file_name="test.csv") name="Service Test Profile", yaml_config=self.basic_yaml_config
)
self.import_run = ImportRun.objects.create(
profile=self.profile, file_name="test.csv"
)
def get_service(self): def get_service(self):
self.import_run.logs = "" self.import_run.logs = ""
@@ -116,41 +168,77 @@ class ImportServiceV1LogicTests(BaseImportAppTest):
def test_transform_value_replace(self): def test_transform_value_replace(self):
service = self.get_service() service = self.get_service()
mapping_def = {"type": "replace", "pattern": "USD", "replacement": "EUR"} mapping_def = {"type": "replace", "pattern": "USD", "replacement": "EUR"}
mapping = ColumnMapping(source="col", target="field", transformations=[mapping_def]) mapping = ColumnMapping(
self.assertEqual(service._transform_value("Amount USD", mapping, row={"col":"Amount USD"}), "Amount EUR") source="col", target="field", transformations=[mapping_def]
)
self.assertEqual(
service._transform_value("Amount USD", mapping, row={"col": "Amount USD"}),
"Amount EUR",
)
def test_transform_value_regex(self): def test_transform_value_regex(self):
service = self.get_service() service = self.get_service()
mapping_def = {"type": "regex", "pattern": r"\d+", "replacement": "NUM"} mapping_def = {"type": "regex", "pattern": r"\d+", "replacement": "NUM"}
mapping = ColumnMapping(source="col", target="field", transformations=[mapping_def]) mapping = ColumnMapping(
self.assertEqual(service._transform_value("abc123xyz", mapping, row={"col":"abc123xyz"}), "abcNUMxyz") source="col", target="field", transformations=[mapping_def]
)
self.assertEqual(
service._transform_value("abc123xyz", mapping, row={"col": "abc123xyz"}),
"abcNUMxyz",
)
def test_transform_value_date_format(self): def test_transform_value_date_format(self):
service = self.get_service() service = self.get_service()
mapping_def = {"type": "date_format", "original_format": "%d/%m/%Y", "new_format": "%Y-%m-%d"} mapping_def = {
mapping = ColumnMapping(source="col", target="field", transformations=[mapping_def]) "type": "date_format",
self.assertEqual(service._transform_value("15/10/2023", mapping, row={"col":"15/10/2023"}), "2023-10-15") "original_format": "%d/%m/%Y",
"new_format": "%Y-%m-%d",
}
mapping = ColumnMapping(
source="col", target="field", transformations=[mapping_def]
)
self.assertEqual(
service._transform_value("15/10/2023", mapping, row={"col": "15/10/2023"}),
"2023-10-15",
)
def test_transform_value_merge(self): def test_transform_value_merge(self):
service = self.get_service() service = self.get_service()
mapping_def = {"type": "merge", "fields": ["colA", "colB"], "separator": "-"} mapping_def = {"type": "merge", "fields": ["colA", "colB"], "separator": "-"}
mapping = ColumnMapping(source="colA", target="field", transformations=[mapping_def]) mapping = ColumnMapping(
source="colA", target="field", transformations=[mapping_def]
)
row_data = {"colA": "ValA", "colB": "ValB"} row_data = {"colA": "ValA", "colB": "ValB"}
self.assertEqual(service._transform_value(row_data["colA"], mapping, row_data), "ValA-ValB") self.assertEqual(
service._transform_value(row_data["colA"], mapping, row_data), "ValA-ValB"
)
def test_transform_value_split(self): def test_transform_value_split(self):
service = self.get_service() service = self.get_service()
mapping_def = {"type": "split", "separator": "|", "index": 1} mapping_def = {"type": "split", "separator": "|", "index": 1}
mapping = ColumnMapping(source="col", target="field", transformations=[mapping_def]) mapping = ColumnMapping(
self.assertEqual(service._transform_value("partA|partB|partC", mapping, row={"col":"partA|partB|partC"}), "partB") source="col", target="field", transformations=[mapping_def]
)
self.assertEqual(
service._transform_value(
"partA|partB|partC", mapping, row={"col": "partA|partB|partC"}
),
"partB",
)
def test_coerce_type_date(self): def test_coerce_type_date(self):
service = self.get_service() service = self.get_service()
mapping = TransactionDateMapping(source="col", target="date", format="%Y-%m-%d") mapping = TransactionDateMapping(source="col", target="date", format="%Y-%m-%d")
self.assertEqual(service._coerce_type("2023-11-21", mapping), date(2023, 11, 21)) self.assertEqual(
service._coerce_type("2023-11-21", mapping), date(2023, 11, 21)
)
mapping_multi_format = TransactionDateMapping(source="col", target="date", format=["%d/%m/%Y", "%Y-%m-%d"]) mapping_multi_format = TransactionDateMapping(
self.assertEqual(service._coerce_type("21/11/2023", mapping_multi_format), date(2023, 11, 21)) source="col", target="date", format=["%d/%m/%Y", "%Y-%m-%d"]
)
self.assertEqual(
service._coerce_type("21/11/2023", mapping_multi_format), date(2023, 11, 21)
)
def test_coerce_type_decimal(self): def test_coerce_type_decimal(self):
service = self.get_service() service = self.get_service()
@@ -168,8 +256,13 @@ class ImportServiceV1LogicTests(BaseImportAppTest):
def test_map_row_simple(self): def test_map_row_simple(self):
service = self.get_service() service = self.get_service()
row = {"DateColumn": "2023-01-15", "AmountColumn": "100.50", "DescriptionColumn": "Lunch", "AccountNameColumn": "Checking USD"} row = {
with patch.object(Account.objects, 'filter') as mock_filter: "DateColumn": "2023-01-15",
"AmountColumn": "100.50",
"DescriptionColumn": "Lunch",
"AccountNameColumn": "Checking USD",
}
with patch.object(Account.objects, "filter") as mock_filter:
mock_filter.return_value.first.return_value = self.account_usd mock_filter.return_value.first.return_value = self.account_usd
mapped = service._map_row(row) mapped = service._map_row(row)
self.assertEqual(mapped["date"], date(2023, 1, 15)) self.assertEqual(mapped["date"], date(2023, 1, 15))
@@ -178,46 +271,82 @@ class ImportServiceV1LogicTests(BaseImportAppTest):
self.assertEqual(mapped["account"], self.account_usd) self.assertEqual(mapped["account"], self.account_usd)
def test_check_duplicate_transaction_strict(self): def test_check_duplicate_transaction_strict(self):
dedup_yaml = yaml.dump({ dedup_yaml = yaml.dump(
"settings": {"file_type": "csv", "importing": "transactions"}, {
"mapping": { "settings": {"file_type": "csv", "importing": "transactions"},
"col_date": {"target": "date", "source": "Date", "format": "%Y-%m-%d"}, "mapping": {
"col_amount": {"target": "amount", "source": "Amount"}, "col_date": {
"col_desc": {"target": "description", "source": "Desc"}, "target": "date",
"col_acc": {"target": "account", "source": "Acc", "type": "name"}, "source": "Date",
}, "format": "%Y-%m-%d",
"deduplication": [{"type": "compare", "fields": ["date", "amount", "description", "account"], "match_type": "strict"}] },
}) "col_amount": {"target": "amount", "source": "Amount"},
profile = ImportProfile.objects.create(name="Dedupe Profile Strict", yaml_config=dedup_yaml) "col_desc": {"target": "description", "source": "Desc"},
"col_acc": {"target": "account", "source": "Acc", "type": "name"},
},
"deduplication": [
{
"type": "compare",
"fields": ["date", "amount", "description", "account"],
"match_type": "strict",
}
],
}
)
profile = ImportProfile.objects.create(
name="Dedupe Profile Strict", yaml_config=dedup_yaml
)
import_run = ImportRun.objects.create(profile=profile, file_name="dedupe.csv") import_run = ImportRun.objects.create(profile=profile, file_name="dedupe.csv")
service = ImportService(import_run) service = ImportService(import_run)
Transaction.objects.create( Transaction.objects.create(
owner=self.user, account=self.account_usd, date=date(2023,1,1), amount=Decimal("10.00"), description="Coffee" owner=self.user,
account=self.account_usd,
date=date(2023, 1, 1),
amount=Decimal("10.00"),
description="Coffee",
) )
dup_data = {"owner": self.user, "account": self.account_usd, "date": date(2023,1,1), "amount": Decimal("10.00"), "description": "Coffee"} dup_data = {
"owner": self.user,
"account": self.account_usd,
"date": date(2023, 1, 1),
"amount": Decimal("10.00"),
"description": "Coffee",
}
self.assertTrue(service._check_duplicate_transaction(dup_data)) self.assertTrue(service._check_duplicate_transaction(dup_data))
not_dup_data = {"owner": self.user, "account": self.account_usd, "date": date(2023,1,1), "amount": Decimal("10.00"), "description": "Tea"} not_dup_data = {
"owner": self.user,
"account": self.account_usd,
"date": date(2023, 1, 1),
"amount": Decimal("10.00"),
"description": "Tea",
}
self.assertFalse(service._check_duplicate_transaction(not_dup_data)) self.assertFalse(service._check_duplicate_transaction(not_dup_data))
class ImportServiceFileProcessingTests(BaseImportAppTest): class ImportServiceFileProcessingTests(BaseImportAppTest):
@patch('apps.import_app.tasks.process_import.defer') @patch("apps.import_app.tasks.process_import.defer")
def test_process_csv_file_basic_transaction_import(self, mock_defer): def test_process_csv_file_basic_transaction_import(self, mock_defer):
csv_content = "DateColumn,AmountColumn,DescriptionColumn,AccountNameColumn\n2023-03-10,123.45,Test CSV Import 1,Checking USD\n2023-03-11,67.89,Test CSV Import 2,Checking USD" csv_content = "DateColumn,AmountColumn,DescriptionColumn,AccountNameColumn\n2023-03-10,123.45,Test CSV Import 1,Checking USD\n2023-03-11,67.89,Test CSV Import 2,Checking USD"
profile_yaml = self._create_valid_transaction_import_profile_yaml() profile_yaml = self._create_valid_transaction_import_profile_yaml()
profile = ImportProfile.objects.create(name="CSV Test Profile", yaml_config=profile_yaml) profile = ImportProfile.objects.create(
name="CSV Test Profile", yaml_config=profile_yaml
)
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv", dir=ImportService.TEMP_DIR) as tmp_file: with tempfile.NamedTemporaryFile(
mode="w+", delete=False, suffix=".csv", dir=ImportService.TEMP_DIR
) as tmp_file:
tmp_file.write(csv_content) tmp_file.write(csv_content)
tmp_file_path = tmp_file.name tmp_file_path = tmp_file.name
import_run = ImportRun.objects.create(profile=profile, file_name=os.path.basename(tmp_file_path)) import_run = ImportRun.objects.create(
profile=profile, file_name=os.path.basename(tmp_file_path)
)
service = ImportService(import_run) service = ImportService(import_run)
with patch.object(Account.objects, 'filter') as mock_account_filter: with patch.object(Account.objects, "filter") as mock_account_filter:
mock_account_filter.return_value.first.return_value = self.account_usd mock_account_filter.return_value.first.return_value = self.account_usd
service.process_file(tmp_file_path) service.process_file(tmp_file_path)
@@ -234,9 +363,13 @@ class ImportServiceFileProcessingTests(BaseImportAppTest):
if os.path.exists(tmp_file_path): if os.path.exists(tmp_file_path):
os.remove(tmp_file_path) os.remove(tmp_file_path)
class ImportViewTests(BaseImportAppTest): class ImportViewTests(BaseImportAppTest):
def test_import_profile_list_view(self): def test_import_profile_list_view(self):
ImportProfile.objects.create(name="Profile 1", yaml_config=self._create_valid_transaction_import_profile_yaml()) ImportProfile.objects.create(
name="Profile 1",
yaml_config=self._create_valid_transaction_import_profile_yaml(),
)
response = self.client.get(reverse("import_profile_list")) response = self.client.get(reverse("import_profile_list"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, "Profile 1") self.assertContains(response, "Profile 1")
@@ -244,18 +377,29 @@ class ImportViewTests(BaseImportAppTest):
def test_import_profile_add_view_get(self): def test_import_profile_add_view_get(self):
response = self.client.get(reverse("import_profile_add")) response = self.client.get(reverse("import_profile_add"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertIsInstance(response.context['form'], ImportProfileForm) self.assertIsInstance(response.context["form"], ImportProfileForm)
@patch('apps.import_app.tasks.process_import.defer') @patch("apps.import_app.tasks.process_import.defer")
def test_import_run_add_view_post_valid_file(self, mock_defer): def test_import_run_add_view_post_valid_file(self, mock_defer):
profile = ImportProfile.objects.create(name="Upload Profile", yaml_config=self._create_valid_transaction_import_profile_yaml()) profile = ImportProfile.objects.create(
name="Upload Profile",
yaml_config=self._create_valid_transaction_import_profile_yaml(),
)
csv_content = "DateColumn,AmountColumn,DescriptionColumn,AccountNameColumn\n2023-01-01,10.00,Test Upload,Checking USD" csv_content = "DateColumn,AmountColumn,DescriptionColumn,AccountNameColumn\n2023-01-01,10.00,Test Upload,Checking USD"
uploaded_file = SimpleUploadedFile("test_upload.csv", csv_content.encode('utf-8'), content_type="text/csv") uploaded_file = SimpleUploadedFile(
"test_upload.csv", csv_content.encode("utf-8"), content_type="text/csv"
)
response = self.client.post(reverse("import_run_add", args=[profile.id]), {"file": uploaded_file}) response = self.client.post(
reverse("import_run_add", args=[profile.id]), {"file": uploaded_file}
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertTrue(ImportRun.objects.filter(profile=profile, file_name__contains="test_upload.csv").exists()) self.assertTrue(
ImportRun.objects.filter(
profile=profile, file_name__contains="test_upload.csv"
).exists()
)
mock_defer.assert_called_once() mock_defer.assert_called_once()
args_list = mock_defer.call_args_list[0] args_list = mock_defer.call_args_list[0]
kwargs_passed = args_list.kwargs kwargs_passed = args_list.kwargs
@@ -263,9 +407,17 @@ class ImportViewTests(BaseImportAppTest):
self.assertIn("file_path", kwargs_passed) self.assertIn("file_path", kwargs_passed)
self.assertEqual(kwargs_passed["user_id"], self.user.id) self.assertEqual(kwargs_passed["user_id"], self.user.id)
run = ImportRun.objects.get(profile=profile, file_name__contains="test_upload.csv") run = ImportRun.objects.get(
temp_file_path_in_storage = os.path.join(ImportService.TEMP_DIR, run.file_name) # Ensure correct path construction profile=profile, file_name__contains="test_upload.csv"
if os.path.exists(temp_file_path_in_storage): # Check existence before removing )
os.remove(temp_file_path_in_storage) temp_file_path_in_storage = os.path.join(
elif os.path.exists(os.path.join(ImportService.TEMP_DIR, os.path.basename(run.file_name))): # Fallback for just basename ImportService.TEMP_DIR, run.file_name
os.remove(os.path.join(ImportService.TEMP_DIR, os.path.basename(run.file_name))) ) # Ensure correct path construction
if os.path.exists(temp_file_path_in_storage): # Check existence before removing
os.remove(temp_file_path_in_storage)
elif os.path.exists(
os.path.join(ImportService.TEMP_DIR, os.path.basename(run.file_name))
): # Fallback for just basename
os.remove(
os.path.join(ImportService.TEMP_DIR, os.path.basename(run.file_name))
)

View File

@@ -1,14 +1,15 @@
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from collections import OrderedDict from collections import OrderedDict
import json # Added for view tests import json # Added for view tests
from django.db.models import Q
from django.test import TestCase, Client from django.test import TestCase, Client
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.utils import timezone from django.utils import timezone
from django.template.defaultfilters import date as date_filter from django.template.defaultfilters import date as date_filter
from django.urls import reverse # Added for view tests from django.urls import reverse # Added for view tests
from dateutil.relativedelta import relativedelta # Added for date calculations from dateutil.relativedelta import relativedelta # Added for date calculations
from apps.currencies.models import Currency from apps.currencies.models import Currency
from apps.accounts.models import Account, AccountGroup from apps.accounts.models import Account, AccountGroup
@@ -17,44 +18,68 @@ from apps.net_worth.utils.calculate_net_worth import (
calculate_historical_currency_net_worth, calculate_historical_currency_net_worth,
calculate_historical_account_balance, calculate_historical_account_balance,
) )
# Mocking get_current_user from thread_local # Mocking get_current_user from thread_local
from apps.common.middleware.thread_local import get_current_user, set_current_user from apps.common.middleware.thread_local import get_current_user, write_current_user
from apps.common.models import SharedObject
User = get_user_model() User = get_user_model()
class BaseNetWorthTest(TestCase): class BaseNetWorthTest(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user(email="networthuser@example.com", password="password") self.user = User.objects.create_user(
self.other_user = User.objects.create_user(email="othernetworth@example.com", password="password") email="networthuser@example.com", password="password"
)
self.other_user = User.objects.create_user(
email="othernetworth@example.com", password="password"
)
# Set current user for thread_local middleware # Set current user for thread_local middleware
set_current_user(self.user) write_current_user(self.user)
self.client = Client() self.client = Client()
self.client.login(email="networthuser@example.com", password="password") self.client.login(email="networthuser@example.com", password="password")
self.currency_usd = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2) self.currency_usd = Currency.objects.create(
self.currency_eur = Currency.objects.create(code="EUR", name="Euro", decimal_places=2) code="USD", name="US Dollar", decimal_places=2
)
self.currency_eur = Currency.objects.create(
code="EUR", name="Euro", decimal_places=2
)
self.account_group_main = AccountGroup.objects.create(name="Main Group", owner=self.user) self.account_group_main = AccountGroup.objects.create(
name="Main Group", owner=self.user
)
self.account_usd_1 = Account.objects.create( self.account_usd_1 = Account.objects.create(
name="USD Account 1", currency=self.currency_usd, owner=self.user, group=self.account_group_main name="USD Account 1",
currency=self.currency_usd,
owner=self.user,
group=self.account_group_main,
) )
self.account_usd_2 = Account.objects.create( self.account_usd_2 = Account.objects.create(
name="USD Account 2", currency=self.currency_usd, owner=self.user, group=self.account_group_main name="USD Account 2",
currency=self.currency_usd,
owner=self.user,
group=self.account_group_main,
) )
self.account_eur_1 = Account.objects.create( self.account_eur_1 = Account.objects.create(
name="EUR Account 1", currency=self.currency_eur, owner=self.user, group=self.account_group_main name="EUR Account 1",
currency=self.currency_eur,
owner=self.user,
group=self.account_group_main,
) )
# Public account for visibility tests # Public account for visibility tests
self.account_public_usd = Account.objects.create( self.account_public_usd = Account.objects.create(
name="Public USD Account", currency=self.currency_usd, visibility=Account.Visibility.PUBLIC name="Public USD Account",
currency=self.currency_usd,
visibility=SharedObject.Visibility.public,
) )
def tearDown(self): def tearDown(self):
# Clear current user # Clear current user
set_current_user(None) write_current_user(None)
class CalculateNetWorthUtilsTests(BaseNetWorthTest): class CalculateNetWorthUtilsTests(BaseNetWorthTest):
@@ -63,32 +88,66 @@ class CalculateNetWorthUtilsTests(BaseNetWorthTest):
result = calculate_historical_currency_net_worth(qs) result = calculate_historical_currency_net_worth(qs)
current_month_str = date_filter(timezone.localdate(timezone.now()), "b Y") current_month_str = date_filter(timezone.localdate(timezone.now()), "b Y")
next_month_str = date_filter(timezone.localdate(timezone.now()) + relativedelta(months=1), "b Y") next_month_str = date_filter(
timezone.localdate(timezone.now()) + relativedelta(months=1), "b Y"
)
self.assertIn(current_month_str, result) self.assertIn(current_month_str, result)
self.assertIn(next_month_str, result) self.assertIn(next_month_str, result)
expected_currencies_present = {"US Dollar", "Euro"} # Based on created accounts for self.user expected_currencies_present = {
"US Dollar",
"Euro",
} # Based on created accounts for self.user
actual_currencies_in_result = set() actual_currencies_in_result = set()
if result and result[current_month_str]: # Check if current_month_str key exists and has data if (
result and result[current_month_str]
): # Check if current_month_str key exists and has data
actual_currencies_in_result = set(result[current_month_str].keys()) actual_currencies_in_result = set(result[current_month_str].keys())
self.assertTrue(expected_currencies_present.issubset(actual_currencies_in_result) or not result[current_month_str]) self.assertTrue(
expected_currencies_present.issubset(actual_currencies_in_result)
or not result[current_month_str]
)
def test_calculate_historical_currency_net_worth_single_currency(self): def test_calculate_historical_currency_net_worth_single_currency(self):
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("1000"), date=datetime.date(2023, 10, 5), reference_date=datetime.date(2023,10,1), is_paid=True) Transaction.objects.create(
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.EXPENSE, amount=Decimal("200"), date=datetime.date(2023, 10, 15), reference_date=datetime.date(2023,10,1), is_paid=True) account=self.account_usd_1,
Transaction.objects.create(account=self.account_usd_2, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("300"), date=datetime.date(2023, 11, 5), reference_date=datetime.date(2023,11,1), is_paid=True) owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("1000"),
date=datetime.date(2023, 10, 5),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.EXPENSE,
amount=Decimal("200"),
date=datetime.date(2023, 10, 15),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_usd_2,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("300"),
date=datetime.date(2023, 11, 5),
reference_date=datetime.date(2023, 11, 1),
is_paid=True,
)
qs = Transaction.objects.filter(owner=self.user, account__currency=self.currency_usd) qs = Transaction.objects.filter(
owner=self.user, account__currency=self.currency_usd
)
result = calculate_historical_currency_net_worth(qs) result = calculate_historical_currency_net_worth(qs)
oct_str = date_filter(datetime.date(2023, 10, 1), "b Y") oct_str = date_filter(datetime.date(2023, 10, 1), "b Y")
nov_str = date_filter(datetime.date(2023, 11, 1), "b Y") nov_str = date_filter(datetime.date(2023, 11, 1), "b Y")
dec_str = date_filter(datetime.date(2023, 12, 1), "b Y") dec_str = date_filter(datetime.date(2023, 12, 1), "b Y")
self.assertIn(oct_str, result) self.assertIn(oct_str, result)
self.assertEqual(result[oct_str]["US Dollar"], Decimal("800.00")) self.assertEqual(result[oct_str]["US Dollar"], Decimal("800.00"))
@@ -98,12 +157,43 @@ class CalculateNetWorthUtilsTests(BaseNetWorthTest):
self.assertIn(dec_str, result) self.assertIn(dec_str, result)
self.assertEqual(result[dec_str]["US Dollar"], Decimal("1100.00")) self.assertEqual(result[dec_str]["US Dollar"], Decimal("1100.00"))
def test_calculate_historical_currency_net_worth_multi_currency(self): def test_calculate_historical_currency_net_worth_multi_currency(self):
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("1000"), date=datetime.date(2023, 10, 5), reference_date=datetime.date(2023,10,1), is_paid=True) Transaction.objects.create(
Transaction.objects.create(account=self.account_eur_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("500"), date=datetime.date(2023, 10, 10), reference_date=datetime.date(2023,10,1), is_paid=True) account=self.account_usd_1,
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.EXPENSE, amount=Decimal("100"), date=datetime.date(2023, 11, 5), reference_date=datetime.date(2023,11,1), is_paid=True) owner=self.user,
Transaction.objects.create(account=self.account_eur_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("50"), date=datetime.date(2023, 11, 15), reference_date=datetime.date(2023,11,1), is_paid=True) type=Transaction.Type.INCOME,
amount=Decimal("1000"),
date=datetime.date(2023, 10, 5),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_eur_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("500"),
date=datetime.date(2023, 10, 10),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.EXPENSE,
amount=Decimal("100"),
date=datetime.date(2023, 11, 5),
reference_date=datetime.date(2023, 11, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_eur_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("50"),
date=datetime.date(2023, 11, 15),
reference_date=datetime.date(2023, 11, 1),
is_paid=True,
)
qs = Transaction.objects.filter(owner=self.user) qs = Transaction.objects.filter(owner=self.user)
result = calculate_historical_currency_net_worth(qs) result = calculate_historical_currency_net_worth(qs)
@@ -117,33 +207,81 @@ class CalculateNetWorthUtilsTests(BaseNetWorthTest):
self.assertEqual(result[nov_str]["Euro"], Decimal("550.00")) self.assertEqual(result[nov_str]["Euro"], Decimal("550.00"))
def test_calculate_historical_currency_net_worth_public_account_visibility(self): def test_calculate_historical_currency_net_worth_public_account_visibility(self):
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("100"), date=datetime.date(2023,10,1), reference_date=datetime.date(2023,10,1), is_paid=True) Transaction.objects.create(
Transaction.objects.create(account=self.account_public_usd, type=Transaction.Type.INCOME, amount=Decimal("200"), date=datetime.date(2023,10,1), reference_date=datetime.date(2023,10,1), is_paid=True) account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("100"),
date=datetime.date(2023, 10, 1),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_public_usd,
type=Transaction.Type.INCOME,
amount=Decimal("200"),
date=datetime.date(2023, 10, 1),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
qs = Transaction.objects.filter(Q(owner=self.user) | Q(account__visibility=Account.Visibility.PUBLIC)) qs = Transaction.objects.filter(
Q(owner=self.user) | Q(account__visibility=SharedObject.Visibility.public)
)
result = calculate_historical_currency_net_worth(qs) result = calculate_historical_currency_net_worth(qs)
oct_str = date_filter(datetime.date(2023, 10, 1), "b Y") oct_str = date_filter(datetime.date(2023, 10, 1), "b Y")
self.assertEqual(result[oct_str]["US Dollar"], Decimal("300.00")) self.assertEqual(result[oct_str]["US Dollar"], Decimal("300.00"))
def test_calculate_historical_account_balance_no_transactions(self): def test_calculate_historical_account_balance_no_transactions(self):
qs = Transaction.objects.none() qs = Transaction.objects.none()
result = calculate_historical_account_balance(qs) result = calculate_historical_account_balance(qs)
current_month_str = date_filter(timezone.localdate(timezone.now()), "b Y") current_month_str = date_filter(timezone.localdate(timezone.now()), "b Y")
next_month_str = date_filter(timezone.localdate(timezone.now()) + relativedelta(months=1), "b Y") next_month_str = date_filter(
timezone.localdate(timezone.now()) + relativedelta(months=1), "b Y"
)
self.assertIn(current_month_str, result) self.assertIn(current_month_str, result)
self.assertIn(next_month_str, result) self.assertIn(next_month_str, result)
if result and result[current_month_str]: if result and result[current_month_str]:
for account_name in [self.account_usd_1.name, self.account_eur_1.name, self.account_public_usd.name]: for account_name in [
self.assertEqual(result[current_month_str].get(account_name, Decimal(0)), Decimal("0.00")) self.account_usd_1.name,
self.account_eur_1.name,
self.account_public_usd.name,
]:
self.assertEqual(
result[current_month_str].get(account_name, Decimal(0)),
Decimal("0.00"),
)
def test_calculate_historical_account_balance_single_account(self): def test_calculate_historical_account_balance_single_account(self):
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("1000"), date=datetime.date(2023, 10, 5), reference_date=datetime.date(2023,10,1), is_paid=True) Transaction.objects.create(
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.EXPENSE, amount=Decimal("200"), date=datetime.date(2023, 10, 15), reference_date=datetime.date(2023,10,1), is_paid=True) account=self.account_usd_1,
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("50"), date=datetime.date(2023, 11, 5), reference_date=datetime.date(2023,11,1), is_paid=True) owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("1000"),
date=datetime.date(2023, 10, 5),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.EXPENSE,
amount=Decimal("200"),
date=datetime.date(2023, 10, 15),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("50"),
date=datetime.date(2023, 11, 5),
reference_date=datetime.date(2023, 11, 1),
is_paid=True,
)
qs = Transaction.objects.filter(account=self.account_usd_1) qs = Transaction.objects.filter(account=self.account_usd_1)
result = calculate_historical_account_balance(qs) result = calculate_historical_account_balance(qs)
@@ -155,9 +293,33 @@ class CalculateNetWorthUtilsTests(BaseNetWorthTest):
self.assertEqual(result[nov_str][self.account_usd_1.name], Decimal("850.00")) self.assertEqual(result[nov_str][self.account_usd_1.name], Decimal("850.00"))
def test_calculate_historical_account_balance_multiple_accounts(self): def test_calculate_historical_account_balance_multiple_accounts(self):
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("100"), date=datetime.date(2023,10,1), reference_date=datetime.date(2023,10,1), is_paid=True) Transaction.objects.create(
Transaction.objects.create(account=self.account_eur_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("200"), date=datetime.date(2023,10,1), reference_date=datetime.date(2023,10,1), is_paid=True) account=self.account_usd_1,
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.EXPENSE, amount=Decimal("30"), date=datetime.date(2023,11,1), reference_date=datetime.date(2023,11,1), is_paid=True) owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("100"),
date=datetime.date(2023, 10, 1),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_eur_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("200"),
date=datetime.date(2023, 10, 1),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.EXPENSE,
amount=Decimal("30"),
date=datetime.date(2023, 11, 1),
reference_date=datetime.date(2023, 11, 1),
is_paid=True,
)
qs = Transaction.objects.filter(owner=self.user) qs = Transaction.objects.filter(owner=self.user)
result = calculate_historical_account_balance(qs) result = calculate_historical_account_balance(qs)
@@ -169,12 +331,13 @@ class CalculateNetWorthUtilsTests(BaseNetWorthTest):
self.assertEqual(result[nov_str][self.account_usd_1.name], Decimal("70.00")) self.assertEqual(result[nov_str][self.account_usd_1.name], Decimal("70.00"))
self.assertEqual(result[nov_str][self.account_eur_1.name], Decimal("200.00")) self.assertEqual(result[nov_str][self.account_eur_1.name], Decimal("200.00"))
def test_date_range_handling_in_utils(self): def test_date_range_handling_in_utils(self):
qs_empty = Transaction.objects.none() qs_empty = Transaction.objects.none()
today = timezone.localdate(timezone.now()) today = timezone.localdate(timezone.now())
start_of_this_month_str = date_filter(today.replace(day=1), "b Y") start_of_this_month_str = date_filter(today.replace(day=1), "b Y")
start_of_next_month_str = date_filter((today.replace(day=1) + relativedelta(months=1)), "b Y") start_of_next_month_str = date_filter(
(today.replace(day=1) + relativedelta(months=1)), "b Y"
)
currency_result = calculate_historical_currency_net_worth(qs_empty) currency_result = calculate_historical_currency_net_worth(qs_empty)
self.assertIn(start_of_this_month_str, currency_result) self.assertIn(start_of_this_month_str, currency_result)
@@ -186,27 +349,66 @@ class CalculateNetWorthUtilsTests(BaseNetWorthTest):
def test_archived_account_exclusion_in_currency_net_worth(self): def test_archived_account_exclusion_in_currency_net_worth(self):
archived_usd_acc = Account.objects.create( archived_usd_acc = Account.objects.create(
name="Archived USD", currency=self.currency_usd, owner=self.user, is_archived=True name="Archived USD",
currency=self.currency_usd,
owner=self.user,
is_archived=True,
)
Transaction.objects.create(
account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("100"),
date=datetime.date(2023, 10, 1),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=archived_usd_acc,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("500"),
date=datetime.date(2023, 10, 1),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
) )
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("100"), date=datetime.date(2023,10,1), reference_date=datetime.date(2023,10,1), is_paid=True)
Transaction.objects.create(account=archived_usd_acc, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("500"), date=datetime.date(2023,10,1), reference_date=datetime.date(2023,10,1), is_paid=True)
qs = Transaction.objects.filter(owner=self.user, account__is_archived=False) qs = Transaction.objects.filter(owner=self.user, account__is_archived=False)
result = calculate_historical_currency_net_worth(qs) result = calculate_historical_currency_net_worth(qs)
oct_str = date_filter(datetime.date(2023, 10, 1), "b Y") oct_str = date_filter(datetime.date(2023, 10, 1), "b Y")
if oct_str in result: if oct_str in result:
self.assertEqual(result[oct_str].get("US Dollar", Decimal(0)), Decimal("100.00")) self.assertEqual(
result[oct_str].get("US Dollar", Decimal(0)), Decimal("100.00")
)
elif result: elif result:
self.fail(f"{oct_str} not found in result, but other data exists.") self.fail(f"{oct_str} not found in result, but other data exists.")
def test_archived_account_exclusion_in_account_balance(self): def test_archived_account_exclusion_in_account_balance(self):
archived_usd_acc = Account.objects.create( archived_usd_acc = Account.objects.create(
name="Archived USD Acct Bal", currency=self.currency_usd, owner=self.user, is_archived=True name="Archived USD Acct Bal",
currency=self.currency_usd,
owner=self.user,
is_archived=True,
)
Transaction.objects.create(
account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("100"),
date=datetime.date(2023, 10, 1),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=archived_usd_acc,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("500"),
date=datetime.date(2023, 10, 1),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
) )
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("100"), date=datetime.date(2023,10,1), reference_date=datetime.date(2023,10,1), is_paid=True)
Transaction.objects.create(account=archived_usd_acc, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("500"), date=datetime.date(2023,10,1), reference_date=datetime.date(2023,10,1), is_paid=True)
qs = Transaction.objects.filter(owner=self.user) qs = Transaction.objects.filter(owner=self.user)
result = calculate_historical_account_balance(qs) result = calculate_historical_account_balance(qs)
@@ -214,18 +416,45 @@ class CalculateNetWorthUtilsTests(BaseNetWorthTest):
if oct_str in result: if oct_str in result:
self.assertIn(self.account_usd_1.name, result[oct_str]) self.assertIn(self.account_usd_1.name, result[oct_str])
self.assertEqual(result[oct_str][self.account_usd_1.name], Decimal("100.00")) self.assertEqual(
result[oct_str][self.account_usd_1.name], Decimal("100.00")
)
self.assertNotIn(archived_usd_acc.name, result[oct_str]) self.assertNotIn(archived_usd_acc.name, result[oct_str])
elif result: elif result:
self.fail(f"{oct_str} not found in result for account balance, but other data exists.") self.fail(
f"{oct_str} not found in result for account balance, but other data exists."
)
class NetWorthViewTests(BaseNetWorthTest): class NetWorthViewTests(BaseNetWorthTest):
def test_net_worth_current_view(self): def test_net_worth_current_view(self):
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("1200.50"), date=datetime.date(2023, 10, 5), reference_date=datetime.date(2023,10,1), is_paid=True) Transaction.objects.create(
Transaction.objects.create(account=self.account_eur_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("800.75"), date=datetime.date(2023, 10, 10), reference_date=datetime.date(2023,10,1), is_paid=True) account=self.account_usd_1,
Transaction.objects.create(account=self.account_usd_2, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("300.00"), date=datetime.date(2023, 9, 1), reference_date=datetime.date(2023,9,1), is_paid=False) # This is unpaid owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("1200.50"),
date=datetime.date(2023, 10, 5),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_eur_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("800.75"),
date=datetime.date(2023, 10, 10),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_usd_2,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("300.00"),
date=datetime.date(2023, 9, 1),
reference_date=datetime.date(2023, 9, 1),
is_paid=False,
) # This is unpaid
response = self.client.get(reverse("net_worth_current")) response = self.client.get(reverse("net_worth_current"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@@ -246,16 +475,38 @@ class NetWorthViewTests(BaseNetWorthTest):
# Historical chart data in net_worth_current view uses a queryset that is NOT filtered by is_paid. # Historical chart data in net_worth_current view uses a queryset that is NOT filtered by is_paid.
sep_str = date_filter(datetime.date(2023, 9, 1), "b Y") sep_str = date_filter(datetime.date(2023, 9, 1), "b Y")
if sep_str in chart_data_currency["labels"]: if sep_str in chart_data_currency["labels"]:
usd_dataset = next((ds for ds in chart_data_currency["datasets"] if ds["label"] == "US Dollar"), None) usd_dataset = next(
(
ds
for ds in chart_data_currency["datasets"]
if ds["label"] == "US Dollar"
),
None,
)
self.assertIsNotNone(usd_dataset) self.assertIsNotNone(usd_dataset)
sep_idx = chart_data_currency["labels"].index(sep_str) sep_idx = chart_data_currency["labels"].index(sep_str)
# The $300 from Sep (account_usd_2) should be part of the historical calculation for the chart # The $300 from Sep (account_usd_2) should be part of the historical calculation for the chart
self.assertEqual(usd_dataset["data"][sep_idx], 300.00) self.assertEqual(usd_dataset["data"][sep_idx], 300.00)
def test_net_worth_projected_view(self): def test_net_worth_projected_view(self):
Transaction.objects.create(account=self.account_usd_1, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("1000"), date=datetime.date(2023, 10, 5), reference_date=datetime.date(2023,10,1), is_paid=True) Transaction.objects.create(
Transaction.objects.create(account=self.account_usd_2, owner=self.user, type=Transaction.Type.INCOME, amount=Decimal("500"), date=datetime.date(2023, 11, 1), reference_date=datetime.date(2023,11,1), is_paid=False) # Unpaid account=self.account_usd_1,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("1000"),
date=datetime.date(2023, 10, 5),
reference_date=datetime.date(2023, 10, 1),
is_paid=True,
)
Transaction.objects.create(
account=self.account_usd_2,
owner=self.user,
type=Transaction.Type.INCOME,
amount=Decimal("500"),
date=datetime.date(2023, 11, 1),
reference_date=datetime.date(2023, 11, 1),
is_paid=False,
) # Unpaid
response = self.client.get(reverse("net_worth_projected")) response = self.client.get(reverse("net_worth_projected"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@@ -263,7 +514,7 @@ class NetWorthViewTests(BaseNetWorthTest):
# `currency_net_worth` in projected view also uses a queryset NOT filtered by is_paid when calling `calculate_currency_totals`. # `currency_net_worth` in projected view also uses a queryset NOT filtered by is_paid when calling `calculate_currency_totals`.
self.assertContains(response, "US Dollar") self.assertContains(response, "US Dollar")
self.assertContains(response, "1,500.00") # 1000 (paid) + 500 (unpaid) self.assertContains(response, "1,500.00") # 1000 (paid) + 500 (unpaid)
chart_data_currency_json = response.context.get("chart_data_currency_json") chart_data_currency_json = response.context.get("chart_data_currency_json")
self.assertIsNotNone(chart_data_currency_json) self.assertIsNotNone(chart_data_currency_json)
@@ -275,7 +526,14 @@ class NetWorthViewTests(BaseNetWorthTest):
oct_str = date_filter(datetime.date(2023, 10, 1), "b Y") oct_str = date_filter(datetime.date(2023, 10, 1), "b Y")
if nov_str in chart_data_currency["labels"]: if nov_str in chart_data_currency["labels"]:
usd_dataset = next((ds for ds in chart_data_currency["datasets"] if ds["label"] == "US Dollar"), None) usd_dataset = next(
(
ds
for ds in chart_data_currency["datasets"]
if ds["label"] == "US Dollar"
),
None,
)
if usd_dataset: if usd_dataset:
nov_idx = chart_data_currency["labels"].index(nov_str) nov_idx = chart_data_currency["labels"].index(nov_str)
# Value in Nov should be cumulative: 1000 (from Oct) + 500 (from Nov unpaid) # Value in Nov should be cumulative: 1000 (from Oct) + 500 (from Nov unpaid)

View File

@@ -8,35 +8,45 @@ from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.utils import timezone from django.utils import timezone
from decimal import Decimal from decimal import Decimal
import datetime # Import was missing import datetime # Import was missing
from apps.transactions.models import ( from apps.transactions.models import (
TransactionCategory, TransactionCategory,
TransactionTag, TransactionTag,
TransactionEntity, # Added TransactionEntity, # Added
Transaction, Transaction,
InstallmentPlan, InstallmentPlan,
RecurringTransaction, RecurringTransaction,
) )
from apps.accounts.models import Account, AccountGroup from apps.accounts.models import Account, AccountGroup
from apps.currencies.models import Currency, ExchangeRate from apps.currencies.models import Currency, ExchangeRate
from apps.common.models import SharedObject
User = get_user_model() User = get_user_model()
class BaseTransactionAppTest(TestCase): class BaseTransactionAppTest(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create_user(email="testuser@example.com", password="password") self.user = User.objects.create_user(
self.other_user = User.objects.create_user(email="otheruser@example.com", password="password") email="testuser@example.com", password="password"
)
self.other_user = User.objects.create_user(
email="otheruser@example.com", password="password"
)
self.client = Client() self.client = Client()
self.client.login(email="testuser@example.com", password="password") self.client.login(email="testuser@example.com", password="password")
self.currency = Currency.objects.create( self.currency = Currency.objects.create(
code="USD", name="US Dollar", decimal_places=2, prefix="$ " code="USD", name="US Dollar", decimal_places=2, prefix="$ "
) )
self.account_group = AccountGroup.objects.create(name="Test Group", owner=self.user) self.account_group = AccountGroup.objects.create(
name="Test Group", owner=self.user
)
self.account = Account.objects.create( self.account = Account.objects.create(
name="Test Account", group=self.account_group, currency=self.currency, owner=self.user name="Test Account",
group=self.account_group,
currency=self.currency,
owner=self.user,
) )
@@ -50,13 +60,24 @@ class TransactionCategoryTests(BaseTransactionAppTest):
self.assertEqual(category.owner, self.user) self.assertEqual(category.owner, self.user)
def test_category_creation_view(self): def test_category_creation_view(self):
response = self.client.post(reverse("category_add"), {"name": "Utilities", "active": "on"}) response = self.client.post(
self.assertEqual(response.status_code, 204) # HTMX success, no content reverse("category_add"), {"name": "Utilities", "active": "on"}
self.assertTrue(TransactionCategory.objects.filter(name="Utilities", owner=self.user).exists()) )
self.assertEqual(response.status_code, 204) # HTMX success, no content
self.assertTrue(
TransactionCategory.objects.filter(
name="Utilities", owner=self.user
).exists()
)
def test_category_edit_view(self): def test_category_edit_view(self):
category = TransactionCategory.objects.create(name="Initial Name", owner=self.user) category = TransactionCategory.objects.create(
response = self.client.post(reverse("category_edit", args=[category.id]), {"name": "Updated Name", "mute": "on", "active": "on"}) name="Initial Name", owner=self.user
)
response = self.client.post(
reverse("category_edit", args=[category.id]),
{"name": "Updated Name", "mute": "on", "active": "on"},
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
category.refresh_from_db() category.refresh_from_db()
self.assertEqual(category.name, "Updated Name") self.assertEqual(category.name, "Updated Name")
@@ -66,25 +87,38 @@ class TransactionCategoryTests(BaseTransactionAppTest):
category = TransactionCategory.objects.create(name="To Delete", owner=self.user) category = TransactionCategory.objects.create(name="To Delete", owner=self.user)
response = self.client.delete(reverse("category_delete", args=[category.id])) response = self.client.delete(reverse("category_delete", args=[category.id]))
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertFalse(TransactionCategory.all_objects.filter(id=category.id).exists()) # all_objects to check even if soft deleted by mistake self.assertFalse(
TransactionCategory.all_objects.filter(id=category.id).exists()
) # all_objects to check even if soft deleted by mistake
def test_other_user_cannot_edit_category(self): def test_other_user_cannot_edit_category(self):
category = TransactionCategory.objects.create(name="User1s Category", owner=self.user) category = TransactionCategory.objects.create(
name="User1s Category", owner=self.user
)
self.client.logout() self.client.logout()
self.client.login(email="otheruser@example.com", password="password") self.client.login(email="otheruser@example.com", password="password")
response = self.client.post(reverse("category_edit", args=[category.id]), {"name": "Attempted Update"}) response = self.client.post(
reverse("category_edit", args=[category.id]), {"name": "Attempted Update"}
)
# This should return a 204 with a message, not a 403, as per view logic for owned objects # This should return a 204 with a message, not a 403, as per view logic for owned objects
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
category.refresh_from_db() category.refresh_from_db()
self.assertEqual(category.name, "User1s Category") # Name should not change self.assertEqual(category.name, "User1s Category") # Name should not change
def test_category_sharing_and_visibility(self): def test_category_sharing_and_visibility(self):
category = TransactionCategory.objects.create(name="Shared Cat", owner=self.user, visibility=TransactionCategory.Visibility.SHARED) category = TransactionCategory.objects.create(
name="Shared Cat",
owner=self.user,
visibility=SharedObject.Visibility.private,
)
category.shared_with.add(self.other_user) category.shared_with.add(self.other_user)
# Other user should be able to see it (though not directly tested here, view logic would permit) # Other user should be able to see it (though not directly tested here, view logic would permit)
# Test that owner can still edit # Test that owner can still edit
response = self.client.post(reverse("category_edit", args=[category.id]), {"name": "Owner Edited Shared Cat", "active":"on"}) response = self.client.post(
reverse("category_edit", args=[category.id]),
{"name": "Owner Edited Shared Cat", "active": "on"},
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
category.refresh_from_db() category.refresh_from_db()
self.assertEqual(category.name, "Owner Edited Shared Cat") self.assertEqual(category.name, "Owner Edited Shared Cat")
@@ -92,7 +126,9 @@ class TransactionCategoryTests(BaseTransactionAppTest):
# Test other user cannot delete if not owner # Test other user cannot delete if not owner
self.client.logout() self.client.logout()
self.client.login(email="otheruser@example.com", password="password") self.client.login(email="otheruser@example.com", password="password")
response = self.client.delete(reverse("category_delete", args=[category.id])) # This removes user from shared_with response = self.client.delete(
reverse("category_delete", args=[category.id])
) # This removes user from shared_with
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
category.refresh_from_db() category.refresh_from_db()
self.assertTrue(TransactionCategory.all_objects.filter(id=category.id).exists()) self.assertTrue(TransactionCategory.all_objects.filter(id=category.id).exists())
@@ -108,13 +144,19 @@ class TransactionTagTests(BaseTransactionAppTest):
self.assertEqual(tag.owner, self.user) self.assertEqual(tag.owner, self.user)
def test_tag_creation_view(self): def test_tag_creation_view(self):
response = self.client.post(reverse("tag_add"), {"name": "Vacation", "active": "on"}) response = self.client.post(
reverse("tag_add"), {"name": "Vacation", "active": "on"}
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertTrue(TransactionTag.objects.filter(name="Vacation", owner=self.user).exists()) self.assertTrue(
TransactionTag.objects.filter(name="Vacation", owner=self.user).exists()
)
def test_tag_edit_view(self): def test_tag_edit_view(self):
tag = TransactionTag.objects.create(name="Old Tag", owner=self.user) tag = TransactionTag.objects.create(name="Old Tag", owner=self.user)
response = self.client.post(reverse("tag_edit", args=[tag.id]), {"name": "New Tag", "active": "on"}) response = self.client.post(
reverse("tag_edit", args=[tag.id]), {"name": "New Tag", "active": "on"}
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
tag.refresh_from_db() tag.refresh_from_db()
self.assertEqual(tag.name, "New Tag") self.assertEqual(tag.name, "New Tag")
@@ -135,39 +177,54 @@ class TransactionEntityTests(BaseTransactionAppTest):
self.assertEqual(entity.owner, self.user) self.assertEqual(entity.owner, self.user)
def test_entity_creation_view(self): def test_entity_creation_view(self):
response = self.client.post(reverse("entity_add"), {"name": "Online Store", "active": "on"}) response = self.client.post(
reverse("entity_add"), {"name": "Online Store", "active": "on"}
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertTrue(TransactionEntity.objects.filter(name="Online Store", owner=self.user).exists()) self.assertTrue(
TransactionEntity.objects.filter(
name="Online Store", owner=self.user
).exists()
)
def test_entity_edit_view(self): def test_entity_edit_view(self):
entity = TransactionEntity.objects.create(name="Local Shop", owner=self.user) entity = TransactionEntity.objects.create(name="Local Shop", owner=self.user)
response = self.client.post(reverse("entity_edit", args=[entity.id]), {"name": "Local Shop Inc.", "active": "on"}) response = self.client.post(
reverse("entity_edit", args=[entity.id]),
{"name": "Local Shop Inc.", "active": "on"},
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
entity.refresh_from_db() entity.refresh_from_db()
self.assertEqual(entity.name, "Local Shop Inc.") self.assertEqual(entity.name, "Local Shop Inc.")
def test_entity_delete_view(self): def test_entity_delete_view(self):
entity = TransactionEntity.objects.create(name="To Be Removed Entity", owner=self.user) entity = TransactionEntity.objects.create(
name="To Be Removed Entity", owner=self.user
)
response = self.client.delete(reverse("entity_delete", args=[entity.id])) response = self.client.delete(reverse("entity_delete", args=[entity.id]))
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertFalse(TransactionEntity.all_objects.filter(id=entity.id).exists()) self.assertFalse(TransactionEntity.all_objects.filter(id=entity.id).exists())
class TransactionTests(BaseTransactionAppTest): # Inherit from BaseTransactionAppTest class TransactionTests(BaseTransactionAppTest): # Inherit from BaseTransactionAppTest
def setUp(self): def setUp(self):
super().setUp() # Call BaseTransactionAppTest's setUp super().setUp() # Call BaseTransactionAppTest's setUp
"""Set up test data""" """Set up test data"""
# self.category is already created in BaseTransactionAppTest if needed, # self.category is already created in BaseTransactionAppTest if needed,
# or create specific ones here. # or create specific ones here.
self.category = TransactionCategory.objects.create(name="Test Category", owner=self.user) self.category = TransactionCategory.objects.create(
name="Test Category", owner=self.user
)
self.tag = TransactionTag.objects.create(name="Test Tag", owner=self.user) self.tag = TransactionTag.objects.create(name="Test Tag", owner=self.user)
self.entity = TransactionEntity.objects.create(name="Test Entity", owner=self.user) self.entity = TransactionEntity.objects.create(
name="Test Entity", owner=self.user
)
def test_transaction_creation(self): def test_transaction_creation(self):
"""Test basic transaction creation with required fields""" """Test basic transaction creation with required fields"""
transaction = Transaction.objects.create( transaction = Transaction.objects.create(
account=self.account, account=self.account,
owner=self.user, # Assign owner owner=self.user, # Assign owner
type=Transaction.Type.EXPENSE, type=Transaction.Type.EXPENSE,
date=timezone.now().date(), date=timezone.now().date(),
amount=Decimal("100.00"), amount=Decimal("100.00"),
@@ -184,7 +241,6 @@ class TransactionTests(BaseTransactionAppTest): # Inherit from BaseTransactionAp
self.assertIn(self.tag, transaction.tags.all()) self.assertIn(self.tag, transaction.tags.all())
self.assertIn(self.entity, transaction.entities.all()) self.assertIn(self.entity, transaction.entities.all())
def test_transaction_creation_view(self): def test_transaction_creation_view(self):
data = { data = {
"account": self.account.id, "account": self.account.id,
@@ -194,90 +250,122 @@ class TransactionTests(BaseTransactionAppTest): # Inherit from BaseTransactionAp
"amount": "250.75", "amount": "250.75",
"description": "Freelance Gig", "description": "Freelance Gig",
"category": self.category.id, "category": self.category.id,
"tags": [self.tag.name], # Dynamic fields expect names for creation/selection "tags": [
"entities": [self.entity.name] self.tag.name
], # Dynamic fields expect names for creation/selection
"entities": [self.entity.name],
} }
response = self.client.post(reverse("transaction_add"), data) response = self.client.post(reverse("transaction_add"), data)
self.assertEqual(response.status_code, 204, response.content.decode() if response.content else "No content") self.assertEqual(
response.status_code,
204,
response.content.decode() if response.content else "No content",
)
self.assertTrue( self.assertTrue(
Transaction.objects.filter(description="Freelance Gig", owner=self.user, amount=Decimal("250.75")).exists() Transaction.objects.filter(
description="Freelance Gig", owner=self.user, amount=Decimal("250.75")
).exists()
) )
# Check that tag and entity were associated (or created if DynamicModel...Field handled it) # Check that tag and entity were associated (or created if DynamicModel...Field handled it)
created_transaction = Transaction.objects.get(description="Freelance Gig") created_transaction = Transaction.objects.get(description="Freelance Gig")
self.assertIn(self.tag, created_transaction.tags.all()) self.assertIn(self.tag, created_transaction.tags.all())
self.assertIn(self.entity, created_transaction.entities.all()) self.assertIn(self.entity, created_transaction.entities.all())
def test_transaction_edit_view(self): def test_transaction_edit_view(self):
transaction = Transaction.objects.create( transaction = Transaction.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
date=timezone.now().date(), amount=Decimal("50.00"), description="Initial" owner=self.user,
type=Transaction.Type.EXPENSE,
date=timezone.now().date(),
amount=Decimal("50.00"),
description="Initial",
) )
updated_description = "Updated Description" updated_description = "Updated Description"
updated_amount = "75.25" updated_amount = "75.25"
response = self.client.post( response = self.client.post(
reverse("transaction_edit", args=[transaction.id]), reverse("transaction_edit", args=[transaction.id]),
{ {
"account": self.account.id, "type": Transaction.Type.EXPENSE, "is_paid": "on", "account": self.account.id,
"date": transaction.date.isoformat(), "amount": updated_amount, "type": Transaction.Type.EXPENSE,
"description": updated_description, "category": self.category.id "is_paid": "on",
} "date": transaction.date.isoformat(),
"amount": updated_amount,
"description": updated_description,
"category": self.category.id,
},
) )
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
transaction.refresh_from_db() transaction.refresh_from_db()
self.assertEqual(transaction.description, updated_description) self.assertEqual(transaction.description, updated_description)
self.assertEqual(transaction.amount, Decimal(updated_amount)) self.assertEqual(transaction.amount, Decimal(updated_amount))
def test_transaction_soft_delete_view(self): def test_transaction_soft_delete_view(self):
transaction = Transaction.objects.create( transaction = Transaction.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
date=timezone.now().date(), amount=Decimal("10.00"), description="To Soft Delete" owner=self.user,
type=Transaction.Type.EXPENSE,
date=timezone.now().date(),
amount=Decimal("10.00"),
description="To Soft Delete",
)
response = self.client.delete(
reverse("transaction_delete", args=[transaction.id])
) )
response = self.client.delete(reverse("transaction_delete", args=[transaction.id]))
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
transaction.refresh_from_db() transaction.refresh_from_db()
self.assertTrue(transaction.deleted) self.assertTrue(transaction.deleted)
self.assertIsNotNone(transaction.deleted_at) self.assertIsNotNone(transaction.deleted_at)
self.assertTrue(Transaction.deleted_objects.filter(id=transaction.id).exists()) self.assertTrue(Transaction.deleted_objects.filter(id=transaction.id).exists())
self.assertFalse(Transaction.objects.filter(id=transaction.id).exists()) # Default manager should not find it self.assertFalse(
Transaction.objects.filter(id=transaction.id).exists()
) # Default manager should not find it
def test_transaction_hard_delete_after_soft_delete(self): def test_transaction_hard_delete_after_soft_delete(self):
# First soft delete # First soft delete
transaction = Transaction.objects.create( transaction = Transaction.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
date=timezone.now().date(), amount=Decimal("15.00"), description="To Hard Delete" owner=self.user,
type=Transaction.Type.EXPENSE,
date=timezone.now().date(),
amount=Decimal("15.00"),
description="To Hard Delete",
) )
transaction.delete() # Soft delete via model method transaction.delete() # Soft delete via model method
self.assertTrue(Transaction.deleted_objects.filter(id=transaction.id).exists()) self.assertTrue(Transaction.deleted_objects.filter(id=transaction.id).exists())
# Then hard delete via view (which calls model's delete again on an already soft-deleted item) # Then hard delete via view (which calls model's delete again on an already soft-deleted item)
response = self.client.delete(reverse("transaction_delete", args=[transaction.id])) response = self.client.delete(
reverse("transaction_delete", args=[transaction.id])
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertFalse(Transaction.all_objects.filter(id=transaction.id).exists()) self.assertFalse(Transaction.all_objects.filter(id=transaction.id).exists())
def test_transaction_undelete_view(self): def test_transaction_undelete_view(self):
transaction = Transaction.objects.create( transaction = Transaction.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
date=timezone.now().date(), amount=Decimal("20.00"), description="To Undelete" owner=self.user,
type=Transaction.Type.EXPENSE,
date=timezone.now().date(),
amount=Decimal("20.00"),
description="To Undelete",
) )
transaction.delete() # Soft delete transaction.delete() # Soft delete
transaction.refresh_from_db() transaction.refresh_from_db()
self.assertTrue(transaction.deleted) self.assertTrue(transaction.deleted)
response = self.client.get(reverse("transaction_undelete", args=[transaction.id])) response = self.client.get(
reverse("transaction_undelete", args=[transaction.id])
)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
transaction.refresh_from_db() transaction.refresh_from_db()
self.assertFalse(transaction.deleted) self.assertFalse(transaction.deleted)
self.assertIsNone(transaction.deleted_at) self.assertIsNone(transaction.deleted_at)
self.assertTrue(Transaction.objects.filter(id=transaction.id).exists()) self.assertTrue(Transaction.objects.filter(id=transaction.id).exists())
def test_transaction_with_exchange_currency(self): def test_transaction_with_exchange_currency(self):
"""Test transaction with exchange currency""" """Test transaction with exchange currency"""
eur = Currency.objects.create( eur = Currency.objects.create(
code="EUR", name="Euro", decimal_places=2, prefix="", owner=self.user code="EUR", name="Euro", decimal_places=2, prefix=""
) )
self.account.exchange_currency = eur self.account.exchange_currency = eur
self.account.save() self.account.save()
@@ -287,8 +375,8 @@ class TransactionTests(BaseTransactionAppTest): # Inherit from BaseTransactionAp
from_currency=self.currency, from_currency=self.currency,
to_currency=eur, to_currency=eur,
rate=Decimal("0.85"), rate=Decimal("0.85"),
date=timezone.now().date(), # Ensure date matches transaction or is general date=timezone.now().date(), # Ensure date matches transaction or is general
owner=self.user owner=self.user,
) )
transaction = Transaction.objects.create( transaction = Transaction.objects.create(
@@ -352,39 +440,56 @@ class TransactionTests(BaseTransactionAppTest): # Inherit from BaseTransactionAp
def test_transaction_transfer_view(self): def test_transaction_transfer_view(self):
other_account = Account.objects.create( other_account = Account.objects.create(
name="Other Account", group=self.account_group, currency=self.currency, owner=self.user name="Other Account",
group=self.account_group,
currency=self.currency,
owner=self.user,
) )
data = { data = {
"from_account": self.account.id, "from_account": self.account.id,
"to_account": other_account.id, "to_account": other_account.id,
"from_amount": "100.00", "from_amount": "100.00",
"to_amount": "100.00", # Assuming same currency for simplicity "to_amount": "100.00", # Assuming same currency for simplicity
"date": timezone.now().date().isoformat(), "date": timezone.now().date().isoformat(),
"description": "Test Transfer", "description": "Test Transfer",
} }
response = self.client.post(reverse("transactions_transfer"), data) response = self.client.post(reverse("transactions_transfer"), data)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
self.assertTrue( self.assertTrue(
Transaction.objects.filter(account=self.account, type=Transaction.Type.EXPENSE, amount="100.00").exists() Transaction.objects.filter(
account=self.account, type=Transaction.Type.EXPENSE, amount="100.00"
).exists()
) )
self.assertTrue( self.assertTrue(
Transaction.objects.filter(account=other_account, type=Transaction.Type.INCOME, amount="100.00").exists() Transaction.objects.filter(
account=other_account, type=Transaction.Type.INCOME, amount="100.00"
).exists()
) )
def test_transaction_bulk_edit_view(self): def test_transaction_bulk_edit_view(self):
t1 = Transaction.objects.create( t1 = Transaction.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
date=timezone.now().date(), amount=Decimal("10.00"), description="Bulk 1" owner=self.user,
type=Transaction.Type.EXPENSE,
date=timezone.now().date(),
amount=Decimal("10.00"),
description="Bulk 1",
) )
t2 = Transaction.objects.create( t2 = Transaction.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
date=timezone.now().date(), amount=Decimal("20.00"), description="Bulk 2" owner=self.user,
type=Transaction.Type.EXPENSE,
date=timezone.now().date(),
amount=Decimal("20.00"),
description="Bulk 2",
)
new_category = TransactionCategory.objects.create(
name="Bulk Category", owner=self.user
) )
new_category = TransactionCategory.objects.create(name="Bulk Category", owner=self.user)
data = { data = {
"transactions": [t1.id, t2.id], "transactions": [t1.id, t2.id],
"category": new_category.id, "category": new_category.id,
"is_paid": "true", # NullBoolean can be 'true', 'false', or empty for no change "is_paid": "true", # NullBoolean can be 'true', 'false', or empty for no change
} }
response = self.client.post(reverse("transactions_bulk_edit"), data) response = self.client.post(reverse("transactions_bulk_edit"), data)
self.assertEqual(response.status_code, 204) self.assertEqual(response.status_code, 204)
@@ -396,18 +501,21 @@ class TransactionTests(BaseTransactionAppTest): # Inherit from BaseTransactionAp
self.assertTrue(t2.is_paid) self.assertTrue(t2.is_paid)
class InstallmentPlanTests(BaseTransactionAppTest): # Inherit from BaseTransactionAppTest class InstallmentPlanTests(
BaseTransactionAppTest
): # Inherit from BaseTransactionAppTest
def setUp(self): def setUp(self):
super().setUp() # Call BaseTransactionAppTest's setUp super().setUp() # Call BaseTransactionAppTest's setUp
# self.currency and self.account are available from base # self.currency and self.account are available from base
self.category = TransactionCategory.objects.create(name="Installments", owner=self.user) self.category = TransactionCategory.objects.create(
name="Installments", owner=self.user
)
def test_installment_plan_creation_and_transaction_generation(self): def test_installment_plan_creation_and_transaction_generation(self):
"""Test basic installment plan creation and its transaction generation.""" """Test basic installment plan creation and its transaction generation."""
start_date = timezone.now().date() start_date = timezone.now().date()
plan = InstallmentPlan.objects.create( plan = InstallmentPlan.objects.create(
account=self.account, account=self.account,
owner=self.user,
type=Transaction.Type.EXPENSE, type=Transaction.Type.EXPENSE,
description="Test Plan", description="Test Plan",
number_of_installments=3, number_of_installments=3,
@@ -416,10 +524,10 @@ class InstallmentPlanTests(BaseTransactionAppTest): # Inherit from BaseTransacti
recurrence=InstallmentPlan.Recurrence.MONTHLY, recurrence=InstallmentPlan.Recurrence.MONTHLY,
category=self.category, category=self.category,
) )
plan.create_transactions() # Manually call as it's not in save in the form plan.create_transactions() # Manually call as it's not in save in the form
self.assertEqual(plan.transactions.count(), 3) self.assertEqual(plan.transactions.count(), 3)
first_transaction = plan.transactions.order_by('date').first() first_transaction = plan.transactions.order_by("date").first()
self.assertEqual(first_transaction.amount, Decimal("100.00")) self.assertEqual(first_transaction.amount, Decimal("100.00"))
self.assertEqual(first_transaction.date, start_date) self.assertEqual(first_transaction.date, start_date)
self.assertEqual(first_transaction.category, self.category) self.assertEqual(first_transaction.category, self.category)
@@ -427,52 +535,64 @@ class InstallmentPlanTests(BaseTransactionAppTest): # Inherit from BaseTransacti
def test_installment_plan_update_transactions(self): def test_installment_plan_update_transactions(self):
start_date = timezone.now().date() start_date = timezone.now().date()
plan = InstallmentPlan.objects.create( plan = InstallmentPlan.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
description="Initial Plan", number_of_installments=2, start_date=start_date, type=Transaction.Type.EXPENSE,
installment_amount=Decimal("50.00"), recurrence=InstallmentPlan.Recurrence.MONTHLY, description="Initial Plan",
number_of_installments=2,
start_date=start_date,
installment_amount=Decimal("50.00"),
recurrence=InstallmentPlan.Recurrence.MONTHLY,
) )
plan.create_transactions() plan.create_transactions()
self.assertEqual(plan.transactions.count(), 2) self.assertEqual(plan.transactions.count(), 2)
plan.description = "Updated Plan Description" plan.description = "Updated Plan Description"
plan.installment_amount = Decimal("60.00") plan.installment_amount = Decimal("60.00")
plan.number_of_installments = 3 # Increase installments plan.number_of_installments = 3 # Increase installments
plan.save() # This should trigger _calculate_end_date and _calculate_installment_total_number plan.save() # This should trigger _calculate_end_date and _calculate_installment_total_number
plan.update_transactions() # Manually call as it's not in save in the form plan.update_transactions() # Manually call as it's not in save in the form
self.assertEqual(plan.transactions.count(), 3) self.assertEqual(plan.transactions.count(), 3)
updated_transaction = plan.transactions.order_by('date').first() updated_transaction = plan.transactions.order_by("date").first()
self.assertEqual(updated_transaction.description, "Updated Plan Description") self.assertEqual(updated_transaction.description, "Updated Plan Description")
# Amount should not change if already paid, but these are created as unpaid # Amount should not change if already paid, but these are created as unpaid
self.assertEqual(updated_transaction.amount, Decimal("60.00")) self.assertEqual(updated_transaction.amount, Decimal("60.00"))
def test_installment_plan_delete_with_transactions(self): def test_installment_plan_delete_with_transactions(self):
plan = InstallmentPlan.objects.create( plan = InstallmentPlan.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
description="Plan to Delete", number_of_installments=2, start_date=timezone.now().date(), type=Transaction.Type.EXPENSE,
installment_amount=Decimal("25.00"), recurrence=InstallmentPlan.Recurrence.MONTHLY, description="Plan to Delete",
number_of_installments=2,
start_date=timezone.now().date(),
installment_amount=Decimal("25.00"),
recurrence=InstallmentPlan.Recurrence.MONTHLY,
) )
plan.create_transactions() plan.create_transactions()
plan_id = plan.id plan_id = plan.id
self.assertTrue(Transaction.objects.filter(installment_plan_id=plan_id).exists()) self.assertTrue(
Transaction.objects.filter(installment_plan_id=plan_id).exists()
)
plan.delete() # This should also delete related transactions as per model's delete plan.delete() # This should also delete related transactions as per model's delete
self.assertFalse(InstallmentPlan.all_objects.filter(id=plan_id).exists()) self.assertFalse(InstallmentPlan.all_objects.filter(id=plan_id).exists())
self.assertFalse(Transaction.all_objects.filter(installment_plan_id=plan_id).exists()) self.assertFalse(
Transaction.all_objects.filter(installment_plan_id=plan_id).exists()
)
class RecurringTransactionTests(BaseTransactionAppTest): # Inherit class RecurringTransactionTests(BaseTransactionAppTest): # Inherit
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.category = TransactionCategory.objects.create(name="Recurring Category", owner=self.user) self.category = TransactionCategory.objects.create(
name="Recurring Category", owner=self.user
)
def test_recurring_transaction_creation_and_upcoming_generation(self): def test_recurring_transaction_creation_and_upcoming_generation(self):
"""Test basic recurring transaction creation and initial upcoming transaction generation.""" """Test basic recurring transaction creation and initial upcoming transaction generation."""
start_date = timezone.now().date() start_date = timezone.now().date()
recurring = RecurringTransaction.objects.create( recurring = RecurringTransaction.objects.create(
account=self.account, account=self.account,
owner=self.user,
type=Transaction.Type.INCOME, type=Transaction.Type.INCOME,
amount=Decimal("200.00"), amount=Decimal("200.00"),
description="Monthly Salary", description="Monthly Salary",
@@ -481,20 +601,26 @@ class RecurringTransactionTests(BaseTransactionAppTest): # Inherit
recurrence_interval=1, recurrence_interval=1,
category=self.category, category=self.category,
) )
recurring.create_upcoming_transactions() # Manually call recurring.create_upcoming_transactions() # Manually call
# It should create a few transactions (e.g., for next 5 occurrences or up to end_date) # It should create a few transactions (e.g., for next 5 occurrences or up to end_date)
self.assertTrue(recurring.transactions.count() > 0) self.assertTrue(recurring.transactions.count() > 0)
first_upcoming = recurring.transactions.order_by('date').first() first_upcoming = recurring.transactions.order_by("date").first()
self.assertEqual(first_upcoming.amount, Decimal("200.00")) self.assertEqual(first_upcoming.amount, Decimal("200.00"))
self.assertEqual(first_upcoming.date, start_date) # First one should be on start_date self.assertEqual(
first_upcoming.date, start_date
) # First one should be on start_date
self.assertFalse(first_upcoming.is_paid) self.assertFalse(first_upcoming.is_paid)
def test_recurring_transaction_update_unpaid(self): def test_recurring_transaction_update_unpaid(self):
recurring = RecurringTransaction.objects.create( recurring = RecurringTransaction.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
amount=Decimal("30.00"), description="Subscription", start_date=timezone.now().date(), type=Transaction.Type.EXPENSE,
recurrence_type=RecurringTransaction.RecurrenceType.MONTH, recurrence_interval=1 amount=Decimal("30.00"),
description="Subscription",
start_date=timezone.now().date(),
recurrence_type=RecurringTransaction.RecurrenceType.MONTH,
recurrence_interval=1,
) )
recurring.create_upcoming_transactions() recurring.create_upcoming_transactions()
unpaid_transaction = recurring.transactions.filter(is_paid=False).first() unpaid_transaction = recurring.transactions.filter(is_paid=False).first()
@@ -503,7 +629,7 @@ class RecurringTransactionTests(BaseTransactionAppTest): # Inherit
recurring.amount = Decimal("35.00") recurring.amount = Decimal("35.00")
recurring.description = "Updated Subscription" recurring.description = "Updated Subscription"
recurring.save() recurring.save()
recurring.update_unpaid_transactions() # Manually call recurring.update_unpaid_transactions() # Manually call
unpaid_transaction.refresh_from_db() unpaid_transaction.refresh_from_db()
self.assertEqual(unpaid_transaction.amount, Decimal("35.00")) self.assertEqual(unpaid_transaction.amount, Decimal("35.00"))
@@ -511,13 +637,21 @@ class RecurringTransactionTests(BaseTransactionAppTest): # Inherit
def test_recurring_transaction_delete_unpaid(self): def test_recurring_transaction_delete_unpaid(self):
recurring = RecurringTransaction.objects.create( recurring = RecurringTransaction.objects.create(
account=self.account, owner=self.user, type=Transaction.Type.EXPENSE, account=self.account,
amount=Decimal("40.00"), description="Service Fee", start_date=timezone.now().date() + timedelta(days=5), # future start type=Transaction.Type.EXPENSE,
recurrence_type=RecurringTransaction.RecurrenceType.MONTH, recurrence_interval=1 amount=Decimal("40.00"),
description="Service Fee",
start_date=timezone.now().date() + timedelta(days=5), # future start
recurrence_type=RecurringTransaction.RecurrenceType.MONTH,
recurrence_interval=1,
) )
recurring.create_upcoming_transactions() recurring.create_upcoming_transactions()
self.assertTrue(recurring.transactions.filter(is_paid=False).exists()) self.assertTrue(recurring.transactions.filter(is_paid=False).exists())
recurring.delete_unpaid_transactions() # Manually call recurring.delete_unpaid_transactions() # Manually call
# This method in the model deletes transactions with date > today # This method in the model deletes transactions with date > today
self.assertFalse(recurring.transactions.filter(is_paid=False, date__gt=timezone.now().date()).exists()) self.assertFalse(
recurring.transactions.filter(
is_paid=False, date__gt=timezone.now().date()
).exists()
)