diff --git a/app/apps/api/tests/__init__.py b/app/apps/api/tests/__init__.py index 3c860ef..9ca8b36 100644 --- a/app/apps/api/tests/__init__.py +++ b/app/apps/api/tests/__init__.py @@ -1,4 +1,5 @@ # Import all test classes for Django test discovery from .test_imports import * from .test_accounts import * - +from .test_data_isolation import * +from .test_shared_access import * diff --git a/app/apps/api/tests/test_data_isolation.py b/app/apps/api/tests/test_data_isolation.py new file mode 100644 index 0000000..5f13c5c --- /dev/null +++ b/app/apps/api/tests/test_data_isolation.py @@ -0,0 +1,719 @@ +from datetime import date +from decimal import Decimal + +from django.contrib.auth import get_user_model +from django.test import TestCase, override_settings +from rest_framework import status +from rest_framework.test import APIClient + +from apps.accounts.models import Account, AccountGroup +from apps.currencies.models import Currency +from apps.dca.models import DCAStrategy, DCAEntry +from apps.transactions.models import ( + Transaction, + TransactionCategory, + TransactionTag, + TransactionEntity, + InstallmentPlan, + RecurringTransaction, +) + + +ACCESS_DENIED_CODES = [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND] + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class AccountDataIsolationTests(TestCase): + """Tests to ensure users cannot access other users' accounts.""" + + def setUp(self): + """Set up test data with two distinct users.""" + User = get_user_model() + + # User 1 - the requester + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + # User 2 - owner of data that user1 should NOT access + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + self.client2 = APIClient() + self.client2.force_authenticate(user=self.user2) + + # Shared currency + self.currency = Currency.objects.create( + code="USD", name="US Dollar", decimal_places=2, prefix="$ " + ) + + # User 1's account + self.user1_account_group = AccountGroup.all_objects.create( + name="User1 Group", owner=self.user1 + ) + self.user1_account = Account.all_objects.create( + name="User1 Account", + group=self.user1_account_group, + currency=self.currency, + owner=self.user1, + ) + + # User 2's account (private, should be invisible to user1) + self.user2_account_group = AccountGroup.all_objects.create( + name="User2 Group", owner=self.user2 + ) + self.user2_account = Account.all_objects.create( + name="User2 Account", + group=self.user2_account_group, + currency=self.currency, + owner=self.user2, + ) + + def test_user_cannot_see_other_users_accounts_in_list(self): + """GET /api/accounts/ should only return user's own accounts.""" + response = self.client1.get("/api/accounts/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # User1 should only see their own account + account_ids = [acc["id"] for acc in response.data["results"]] + self.assertIn(self.user1_account.id, account_ids) + self.assertNotIn(self.user2_account.id, account_ids) + + def test_user_cannot_access_other_users_account_detail(self): + """GET /api/accounts/{id}/ should deny access to other user's account.""" + response = self.client1.get(f"/api/accounts/{self.user2_account.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_modify_other_users_account(self): + """PATCH on other user's account should deny access.""" + response = self.client1.patch( + f"/api/accounts/{self.user2_account.id}/", + {"name": "Hacked Account"}, + ) + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + # Verify account name wasn't changed + self.user2_account.refresh_from_db() + self.assertEqual(self.user2_account.name, "User2 Account") + + def test_user_cannot_delete_other_users_account(self): + """DELETE on other user's account should deny access.""" + response = self.client1.delete(f"/api/accounts/{self.user2_account.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + # Verify account still exists + self.assertTrue(Account.all_objects.filter(id=self.user2_account.id).exists()) + + def test_user_cannot_get_balance_of_other_users_account(self): + """Balance action on other user's account should deny access.""" + response = self.client1.get(f"/api/accounts/{self.user2_account.id}/balance/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_can_access_own_account(self): + """User can access their own account normally.""" + response = self.client1.get(f"/api/accounts/{self.user1_account.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["name"], "User1 Account") + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class AccountGroupDataIsolationTests(TestCase): + """Tests to ensure users cannot access other users' account groups.""" + + def setUp(self): + """Set up test data with two distinct users.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + + # User 1's account group + self.user1_group = AccountGroup.all_objects.create( + name="User1 Group", owner=self.user1 + ) + + # User 2's account group + self.user2_group = AccountGroup.all_objects.create( + name="User2 Group", owner=self.user2 + ) + + def test_user_cannot_see_other_users_account_groups(self): + """GET /api/account-groups/ should only return user's own groups.""" + response = self.client1.get("/api/account-groups/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + group_ids = [grp["id"] for grp in response.data["results"]] + self.assertIn(self.user1_group.id, group_ids) + self.assertNotIn(self.user2_group.id, group_ids) + + def test_user_cannot_access_other_users_account_group_detail(self): + """GET /api/account-groups/{id}/ should deny access to other user's group.""" + response = self.client1.get(f"/api/account-groups/{self.user2_group.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_modify_other_users_account_group(self): + """PATCH on other user's account group should deny access.""" + response = self.client1.patch( + f"/api/account-groups/{self.user2_group.id}/", + {"name": "Hacked Group"}, + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + self.user2_group.refresh_from_db() + self.assertEqual(self.user2_group.name, "User2 Group") + + def test_user_cannot_delete_other_users_account_group(self): + """DELETE on other user's account group should deny access.""" + response = self.client1.delete(f"/api/account-groups/{self.user2_group.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + self.assertTrue( + AccountGroup.all_objects.filter(id=self.user2_group.id).exists() + ) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class TransactionDataIsolationTests(TestCase): + """Tests to ensure users cannot access other users' transactions.""" + + def setUp(self): + """Set up test data with transactions for two distinct users.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + + self.currency = Currency.objects.create( + code="USD", name="US Dollar", decimal_places=2, prefix="$ " + ) + + # User 1's account and transaction + self.user1_account = Account.all_objects.create( + name="User1 Account", currency=self.currency, owner=self.user1 + ) + self.user1_transaction = Transaction.userless_all_objects.create( + account=self.user1_account, + type=Transaction.Type.INCOME, + amount=Decimal("100.00"), + is_paid=True, + date=date(2025, 1, 1), + description="User1 Income", + owner=self.user1, + ) + + # User 2's account and transaction + self.user2_account = Account.all_objects.create( + name="User2 Account", currency=self.currency, owner=self.user2 + ) + self.user2_transaction = Transaction.userless_all_objects.create( + account=self.user2_account, + type=Transaction.Type.EXPENSE, + amount=Decimal("50.00"), + is_paid=True, + date=date(2025, 1, 1), + description="User2 Expense", + owner=self.user2, + ) + + def test_user_cannot_see_other_users_transactions_in_list(self): + """GET /api/transactions/ should only return user's own transactions.""" + response = self.client1.get("/api/transactions/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + transaction_ids = [t["id"] for t in response.data["results"]] + self.assertIn(self.user1_transaction.id, transaction_ids) + self.assertNotIn(self.user2_transaction.id, transaction_ids) + + def test_user_cannot_access_other_users_transaction_detail(self): + """GET /api/transactions/{id}/ should deny access to other user's transaction.""" + response = self.client1.get(f"/api/transactions/{self.user2_transaction.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_modify_other_users_transaction(self): + """PATCH on other user's transaction should deny access.""" + response = self.client1.patch( + f"/api/transactions/{self.user2_transaction.id}/", + {"description": "Hacked Transaction"}, + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + self.user2_transaction.refresh_from_db() + self.assertEqual(self.user2_transaction.description, "User2 Expense") + + def test_user_cannot_delete_other_users_transaction(self): + """DELETE on other user's transaction should deny access.""" + response = self.client1.delete( + f"/api/transactions/{self.user2_transaction.id}/" + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + self.assertTrue( + Transaction.userless_all_objects.filter( + id=self.user2_transaction.id + ).exists() + ) + + def test_user_cannot_create_transaction_in_other_users_account(self): + """POST /api/transactions/ with other user's account should fail.""" + response = self.client1.post( + "/api/transactions/", + { + "account": self.user2_account.id, + "type": "IN", + "amount": "100.00", + "date": "2025-01-15", + "description": "Sneaky transaction", + }, + format="json", + ) + + # Should deny access - 400 (validation error), 403, or 404 + self.assertIn( + response.status_code, + ACCESS_DENIED_CODES + [status.HTTP_400_BAD_REQUEST], + ) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class CategoryTagEntityIsolationTests(TestCase): + """Tests for isolation of categories, tags, and entities between users.""" + + def setUp(self): + """Set up test data.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + + # User 1's categories, tags, entities + self.user1_category = TransactionCategory.all_objects.create( + name="User1 Category", owner=self.user1 + ) + self.user1_tag = TransactionTag.all_objects.create( + name="User1 Tag", owner=self.user1 + ) + self.user1_entity = TransactionEntity.all_objects.create( + name="User1 Entity", owner=self.user1 + ) + + # User 2's categories, tags, entities + self.user2_category = TransactionCategory.all_objects.create( + name="User2 Category", owner=self.user2 + ) + self.user2_tag = TransactionTag.all_objects.create( + name="User2 Tag", owner=self.user2 + ) + self.user2_entity = TransactionEntity.all_objects.create( + name="User2 Entity", owner=self.user2 + ) + + def test_user_cannot_see_other_users_categories(self): + """GET /api/categories/ should only return user's own categories.""" + response = self.client1.get("/api/categories/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + category_ids = [c["id"] for c in response.data["results"]] + self.assertIn(self.user1_category.id, category_ids) + self.assertNotIn(self.user2_category.id, category_ids) + + def test_user_cannot_access_other_users_category_detail(self): + """GET /api/categories/{id}/ should deny access to other user's category.""" + response = self.client1.get(f"/api/categories/{self.user2_category.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_see_other_users_tags(self): + """GET /api/tags/ should only return user's own tags.""" + response = self.client1.get("/api/tags/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + tag_ids = [t["id"] for t in response.data["results"]] + self.assertIn(self.user1_tag.id, tag_ids) + self.assertNotIn(self.user2_tag.id, tag_ids) + + def test_user_cannot_access_other_users_tag_detail(self): + """GET /api/tags/{id}/ should deny access to other user's tag.""" + response = self.client1.get(f"/api/tags/{self.user2_tag.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_see_other_users_entities(self): + """GET /api/entities/ should only return user's own entities.""" + response = self.client1.get("/api/entities/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + entity_ids = [e["id"] for e in response.data["results"]] + self.assertIn(self.user1_entity.id, entity_ids) + self.assertNotIn(self.user2_entity.id, entity_ids) + + def test_user_cannot_access_other_users_entity_detail(self): + """GET /api/entities/{id}/ should deny access to other user's entity.""" + response = self.client1.get(f"/api/entities/{self.user2_entity.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_modify_other_users_category(self): + """PATCH on other user's category should deny access.""" + response = self.client1.patch( + f"/api/categories/{self.user2_category.id}/", + {"name": "Hacked Category"}, + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_delete_other_users_tag(self): + """DELETE on other user's tag should deny access.""" + response = self.client1.delete(f"/api/tags/{self.user2_tag.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + self.assertTrue( + TransactionTag.all_objects.filter(id=self.user2_tag.id).exists() + ) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class DCADataIsolationTests(TestCase): + """Tests to ensure users cannot access other users' DCA strategies and entries.""" + + def setUp(self): + """Set up test data.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + + self.currency1 = Currency.objects.create( + code="BTC", name="Bitcoin", decimal_places=8, prefix="" + ) + self.currency2 = Currency.objects.create( + code="USD", name="US Dollar", decimal_places=2, prefix="$ " + ) + + # User 1's DCA strategy and entry + self.user1_strategy = DCAStrategy.all_objects.create( + name="User1 BTC Strategy", + target_currency=self.currency1, + payment_currency=self.currency2, + owner=self.user1, + ) + self.user1_entry = DCAEntry.objects.create( + strategy=self.user1_strategy, + date=date(2025, 1, 1), + amount_paid=Decimal("100.00"), + amount_received=Decimal("0.001"), + ) + + # User 2's DCA strategy and entry + self.user2_strategy = DCAStrategy.all_objects.create( + name="User2 BTC Strategy", + target_currency=self.currency1, + payment_currency=self.currency2, + owner=self.user2, + ) + self.user2_entry = DCAEntry.objects.create( + strategy=self.user2_strategy, + date=date(2025, 1, 1), + amount_paid=Decimal("200.00"), + amount_received=Decimal("0.002"), + ) + + def test_user_cannot_see_other_users_dca_strategies(self): + """GET /api/dca/strategies/ should only return user's own strategies.""" + response = self.client1.get("/api/dca/strategies/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + strategy_ids = [s["id"] for s in response.data["results"]] + self.assertIn(self.user1_strategy.id, strategy_ids) + self.assertNotIn(self.user2_strategy.id, strategy_ids) + + def test_user_cannot_access_other_users_dca_strategy_detail(self): + """GET /api/dca/strategies/{id}/ should deny access to other user's strategy.""" + response = self.client1.get(f"/api/dca/strategies/{self.user2_strategy.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_access_other_users_dca_entries(self): + """GET /api/dca/entries/ filtered by other user's strategy should return empty.""" + response = self.client1.get( + f"/api/dca/entries/?strategy={self.user2_strategy.id}" + ) + + # Either OK with empty results or error + if response.status_code == status.HTTP_200_OK: + entry_ids = [e["id"] for e in response.data["results"]] + self.assertNotIn(self.user2_entry.id, entry_ids) + + def test_user_cannot_access_other_users_dca_entry_detail(self): + """GET /api/dca/entries/{id}/ should deny access to other user's entry.""" + response = self.client1.get(f"/api/dca/entries/{self.user2_entry.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_access_other_users_strategy_investment_frequency(self): + """investment_frequency action on other user's strategy should deny access.""" + response = self.client1.get( + f"/api/dca/strategies/{self.user2_strategy.id}/investment_frequency/" + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_access_other_users_strategy_price_comparison(self): + """price_comparison action on other user's strategy should deny access.""" + response = self.client1.get( + f"/api/dca/strategies/{self.user2_strategy.id}/price_comparison/" + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_access_other_users_strategy_current_price(self): + """current_price action on other user's strategy should deny access.""" + response = self.client1.get( + f"/api/dca/strategies/{self.user2_strategy.id}/current_price/" + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_modify_other_users_dca_strategy(self): + """PATCH on other user's DCA strategy should deny access.""" + response = self.client1.patch( + f"/api/dca/strategies/{self.user2_strategy.id}/", + {"name": "Hacked Strategy"}, + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_delete_other_users_dca_entry(self): + """DELETE on other user's DCA entry should deny access.""" + response = self.client1.delete(f"/api/dca/entries/{self.user2_entry.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + self.assertTrue(DCAEntry.objects.filter(id=self.user2_entry.id).exists()) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class InstallmentRecurringIsolationTests(TestCase): + """Tests for isolation of installment plans and recurring transactions.""" + + def setUp(self): + """Set up test data.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + + self.currency = Currency.objects.create( + code="USD", name="US Dollar", decimal_places=2, prefix="$ " + ) + + # User 1's account + self.user1_account = Account.all_objects.create( + name="User1 Account", currency=self.currency, owner=self.user1 + ) + + # User 2's account + self.user2_account = Account.all_objects.create( + name="User2 Account", currency=self.currency, owner=self.user2 + ) + + # User 1's installment plan + self.user1_installment = InstallmentPlan.all_objects.create( + account=self.user1_account, + type=Transaction.Type.EXPENSE, + description="User1 Installment", + number_of_installments=12, + start_date=date(2025, 1, 1), + installment_amount=Decimal("100.00"), + ) + + # User 2's installment plan + self.user2_installment = InstallmentPlan.all_objects.create( + account=self.user2_account, + type=Transaction.Type.EXPENSE, + description="User2 Installment", + number_of_installments=6, + start_date=date(2025, 1, 1), + installment_amount=Decimal("200.00"), + ) + + # User 1's recurring transaction + self.user1_recurring = RecurringTransaction.all_objects.create( + account=self.user1_account, + type=Transaction.Type.EXPENSE, + amount=Decimal("50.00"), + description="User1 Recurring", + start_date=date(2025, 1, 1), + recurrence_type=RecurringTransaction.RecurrenceType.MONTH, + recurrence_interval=1, + ) + + # User 2's recurring transaction + self.user2_recurring = RecurringTransaction.all_objects.create( + account=self.user2_account, + type=Transaction.Type.INCOME, + amount=Decimal("1000.00"), + description="User2 Recurring", + start_date=date(2025, 1, 1), + recurrence_type=RecurringTransaction.RecurrenceType.MONTH, + recurrence_interval=1, + ) + + def test_user_cannot_see_other_users_installment_plans(self): + """GET /api/installment-plans/ should only return user's own plans.""" + response = self.client1.get("/api/installment-plans/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + plan_ids = [p["id"] for p in response.data["results"]] + self.assertIn(self.user1_installment.id, plan_ids) + self.assertNotIn(self.user2_installment.id, plan_ids) + + def test_user_cannot_access_other_users_installment_plan_detail(self): + """GET /api/installment-plans/{id}/ should deny access to other user's plan.""" + response = self.client1.get( + f"/api/installment-plans/{self.user2_installment.id}/" + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_see_other_users_recurring_transactions(self): + """GET /api/recurring-transactions/ should only return user's own recurring.""" + response = self.client1.get("/api/recurring-transactions/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + recurring_ids = [r["id"] for r in response.data["results"]] + self.assertIn(self.user1_recurring.id, recurring_ids) + self.assertNotIn(self.user2_recurring.id, recurring_ids) + + def test_user_cannot_access_other_users_recurring_transaction_detail(self): + """GET /api/recurring-transactions/{id}/ should deny access to other user's recurring.""" + response = self.client1.get( + f"/api/recurring-transactions/{self.user2_recurring.id}/" + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_modify_other_users_installment_plan(self): + """PATCH on other user's installment plan should deny access.""" + response = self.client1.patch( + f"/api/installment-plans/{self.user2_installment.id}/", + {"description": "Hacked Installment"}, + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_cannot_delete_other_users_recurring_transaction(self): + """DELETE on other user's recurring transaction should deny access.""" + response = self.client1.delete( + f"/api/recurring-transactions/{self.user2_recurring.id}/" + ) + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + self.assertTrue( + RecurringTransaction.all_objects.filter(id=self.user2_recurring.id).exists() + ) diff --git a/app/apps/api/tests/test_shared_access.py b/app/apps/api/tests/test_shared_access.py new file mode 100644 index 0000000..529cdce --- /dev/null +++ b/app/apps/api/tests/test_shared_access.py @@ -0,0 +1,587 @@ +from datetime import date +from decimal import Decimal + +from django.contrib.auth import get_user_model +from django.test import TestCase, override_settings +from rest_framework import status +from rest_framework.test import APIClient + +from apps.accounts.models import Account, AccountGroup +from apps.currencies.models import Currency +from apps.dca.models import DCAStrategy, DCAEntry +from apps.transactions.models import ( + Transaction, + TransactionCategory, + TransactionTag, + TransactionEntity, +) + + +ACCESS_DENIED_CODES = [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND] + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class SharedAccountAccessTests(TestCase): + """Tests for shared account access via shared_with field.""" + + def setUp(self): + """Set up test data with shared accounts.""" + User = get_user_model() + + # User 1 - owner + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + # User 2 - will have shared access + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + self.client2 = APIClient() + self.client2.force_authenticate(user=self.user2) + + # User 3 - no shared access + self.user3 = User.objects.create_user( + email="user3@test.com", password="testpass123" + ) + self.client3 = APIClient() + self.client3.force_authenticate(user=self.user3) + + self.currency = Currency.objects.create( + code="USD", name="US Dollar", decimal_places=2, prefix="$ " + ) + + # User 1's account shared with user 2 + self.shared_account = Account.all_objects.create( + name="Shared Account", + currency=self.currency, + owner=self.user1, + visibility="private", + ) + self.shared_account.shared_with.add(self.user2) + + # User 1's private account (not shared) + self.private_account = Account.all_objects.create( + name="Private Account", + currency=self.currency, + owner=self.user1, + visibility="private", + ) + + # Transaction in shared account + self.shared_transaction = Transaction.userless_all_objects.create( + account=self.shared_account, + type=Transaction.Type.INCOME, + amount=Decimal("100.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Shared Transaction", + owner=self.user1, + ) + + # Transaction in private account + self.private_transaction = Transaction.userless_all_objects.create( + account=self.private_account, + type=Transaction.Type.EXPENSE, + amount=Decimal("50.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Private Transaction", + owner=self.user1, + ) + + def test_user_can_see_accounts_shared_with_them(self): + """User2 should see the account shared with them.""" + response = self.client2.get("/api/accounts/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + account_ids = [acc["id"] for acc in response.data["results"]] + self.assertIn(self.shared_account.id, account_ids) + + def test_user_cannot_see_accounts_not_shared_with_them(self): + """User2 should NOT see user1's private (non-shared) account.""" + response = self.client2.get("/api/accounts/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + account_ids = [acc["id"] for acc in response.data["results"]] + self.assertNotIn(self.private_account.id, account_ids) + + def test_user_can_access_shared_account_detail(self): + """User2 should be able to access shared account details.""" + response = self.client2.get(f"/api/accounts/{self.shared_account.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["name"], "Shared Account") + + def test_user_without_share_cannot_access_shared_account(self): + """User3 should NOT be able to access the shared account.""" + response = self.client3.get(f"/api/accounts/{self.shared_account.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_can_see_transactions_in_shared_account(self): + """User2 should see transactions in the shared account.""" + response = self.client2.get("/api/transactions/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + transaction_ids = [t["id"] for t in response.data["results"]] + self.assertIn(self.shared_transaction.id, transaction_ids) + self.assertNotIn(self.private_transaction.id, transaction_ids) + + def test_user_can_access_transaction_in_shared_account(self): + """User2 should be able to access transaction details in shared account.""" + response = self.client2.get(f"/api/transactions/{self.shared_transaction.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["description"], "Shared Transaction") + + def test_user_cannot_access_transaction_in_non_shared_account(self): + """User2 should NOT access transactions in user1's private account.""" + response = self.client2.get(f"/api/transactions/{self.private_transaction.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) + + def test_user_can_get_balance_of_shared_account(self): + """User2 should be able to get balance of shared account.""" + response = self.client2.get(f"/api/accounts/{self.shared_account.id}/balance/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("current_balance", response.data) + + def test_sharing_works_with_multiple_users(self): + """Account shared with multiple users should be accessible by all.""" + # Add user3 to shared_with + self.shared_account.shared_with.add(self.user3) + + # User2 still has access + response2 = self.client2.get(f"/api/accounts/{self.shared_account.id}/") + self.assertEqual(response2.status_code, status.HTTP_200_OK) + + # User3 now has access + response3 = self.client3.get(f"/api/accounts/{self.shared_account.id}/") + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class PublicVisibilityTests(TestCase): + """Tests for public visibility access.""" + + def setUp(self): + """Set up test data with public accounts.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + self.client2 = APIClient() + self.client2.force_authenticate(user=self.user2) + + self.currency = Currency.objects.create( + code="USD", name="US Dollar", decimal_places=2, prefix="$ " + ) + + # User 1's public account + self.public_account = Account.all_objects.create( + name="Public Account", + currency=self.currency, + owner=self.user1, + visibility="public", + ) + + # User 1's private account + self.private_account = Account.all_objects.create( + name="Private Account", + currency=self.currency, + owner=self.user1, + visibility="private", + ) + + # Transaction in public account + self.public_transaction = Transaction.userless_all_objects.create( + account=self.public_account, + type=Transaction.Type.INCOME, + amount=Decimal("100.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Public Transaction", + owner=self.user1, + ) + + def test_user_can_see_public_accounts(self): + """User2 should see user1's public account.""" + response = self.client2.get("/api/accounts/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + account_ids = [acc["id"] for acc in response.data["results"]] + self.assertIn(self.public_account.id, account_ids) + self.assertNotIn(self.private_account.id, account_ids) + + def test_user_can_access_public_account_detail(self): + """User2 should be able to access public account details.""" + response = self.client2.get(f"/api/accounts/{self.public_account.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["name"], "Public Account") + + def test_user_can_see_transactions_in_public_accounts(self): + """User2 should see transactions in public accounts.""" + response = self.client2.get("/api/transactions/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + transaction_ids = [t["id"] for t in response.data["results"]] + self.assertIn(self.public_transaction.id, transaction_ids) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class SharedCategoryTagEntityTests(TestCase): + """Tests for shared categories, tags, and entities.""" + + def setUp(self): + """Set up test data with shared categories/tags/entities.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + self.client2 = APIClient() + self.client2.force_authenticate(user=self.user2) + + self.user3 = User.objects.create_user( + email="user3@test.com", password="testpass123" + ) + self.client3 = APIClient() + self.client3.force_authenticate(user=self.user3) + + # User 1's category shared with user 2 + self.shared_category = TransactionCategory.all_objects.create( + name="Shared Category", owner=self.user1 + ) + self.shared_category.shared_with.add(self.user2) + + # User 1's private category + self.private_category = TransactionCategory.all_objects.create( + name="Private Category", owner=self.user1 + ) + + # User 1's public category + self.public_category = TransactionCategory.all_objects.create( + name="Public Category", owner=self.user1, visibility="public" + ) + + # User 1's tag shared with user 2 + self.shared_tag = TransactionTag.all_objects.create( + name="Shared Tag", owner=self.user1 + ) + self.shared_tag.shared_with.add(self.user2) + + # User 1's entity shared with user 2 + self.shared_entity = TransactionEntity.all_objects.create( + name="Shared Entity", owner=self.user1 + ) + self.shared_entity.shared_with.add(self.user2) + + def test_user_can_see_shared_categories(self): + """User2 should see categories shared with them.""" + response = self.client2.get("/api/categories/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + category_ids = [c["id"] for c in response.data["results"]] + self.assertIn(self.shared_category.id, category_ids) + self.assertNotIn(self.private_category.id, category_ids) + + def test_user_can_access_shared_category_detail(self): + """User2 should be able to access shared category details.""" + response = self.client2.get(f"/api/categories/{self.shared_category.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["name"], "Shared Category") + + def test_user_can_see_public_categories(self): + """User3 should see public categories.""" + response = self.client3.get("/api/categories/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + category_ids = [c["id"] for c in response.data["results"]] + self.assertIn(self.public_category.id, category_ids) + + def test_user_without_share_cannot_see_shared_category(self): + """User3 should NOT see category shared only with user2.""" + response = self.client3.get("/api/categories/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + category_ids = [c["id"] for c in response.data["results"]] + self.assertNotIn(self.shared_category.id, category_ids) + + def test_user_can_see_shared_tags(self): + """User2 should see tags shared with them.""" + response = self.client2.get("/api/tags/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + tag_ids = [t["id"] for t in response.data["results"]] + self.assertIn(self.shared_tag.id, tag_ids) + + def test_user_can_access_shared_tag_detail(self): + """User2 should be able to access shared tag details.""" + response = self.client2.get(f"/api/tags/{self.shared_tag.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["name"], "Shared Tag") + + def test_user_can_see_shared_entities(self): + """User2 should see entities shared with them.""" + response = self.client2.get("/api/entities/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + entity_ids = [e["id"] for e in response.data["results"]] + self.assertIn(self.shared_entity.id, entity_ids) + + def test_user_can_access_shared_entity_detail(self): + """User2 should be able to access shared entity details.""" + response = self.client2.get(f"/api/entities/{self.shared_entity.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["name"], "Shared Entity") + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class SharedDCAAccessTests(TestCase): + """Tests for shared DCA strategy access.""" + + def setUp(self): + """Set up test data with shared DCA strategies.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + self.client2 = APIClient() + self.client2.force_authenticate(user=self.user2) + + self.user3 = User.objects.create_user( + email="user3@test.com", password="testpass123" + ) + self.client3 = APIClient() + self.client3.force_authenticate(user=self.user3) + + self.currency1 = Currency.objects.create( + code="BTC", name="Bitcoin", decimal_places=8, prefix="" + ) + self.currency2 = Currency.objects.create( + code="USD", name="US Dollar", decimal_places=2, prefix="$ " + ) + + # User 1's DCA strategy shared with user 2 + self.shared_strategy = DCAStrategy.all_objects.create( + name="Shared BTC Strategy", + target_currency=self.currency1, + payment_currency=self.currency2, + owner=self.user1, + ) + self.shared_strategy.shared_with.add(self.user2) + + # Entry in shared strategy + self.shared_entry = DCAEntry.objects.create( + strategy=self.shared_strategy, + date=date(2025, 1, 1), + amount_paid=Decimal("100.00"), + amount_received=Decimal("0.001"), + ) + + # User 1's private strategy + self.private_strategy = DCAStrategy.all_objects.create( + name="Private BTC Strategy", + target_currency=self.currency1, + payment_currency=self.currency2, + owner=self.user1, + ) + + def test_user_can_see_shared_dca_strategies(self): + """User2 should see DCA strategies shared with them.""" + response = self.client2.get("/api/dca/strategies/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + strategy_ids = [s["id"] for s in response.data["results"]] + self.assertIn(self.shared_strategy.id, strategy_ids) + self.assertNotIn(self.private_strategy.id, strategy_ids) + + def test_user_can_access_shared_dca_strategy_detail(self): + """User2 should be able to access shared strategy details.""" + response = self.client2.get(f"/api/dca/strategies/{self.shared_strategy.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["name"], "Shared BTC Strategy") + + def test_user_without_share_cannot_see_shared_strategy(self): + """User3 should NOT see strategy shared only with user2.""" + response = self.client3.get("/api/dca/strategies/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + strategy_ids = [s["id"] for s in response.data["results"]] + self.assertNotIn(self.shared_strategy.id, strategy_ids) + + def test_user_can_access_shared_strategy_actions(self): + """User2 should be able to access actions on shared strategy.""" + # investment_frequency + response1 = self.client2.get( + f"/api/dca/strategies/{self.shared_strategy.id}/investment_frequency/" + ) + self.assertEqual(response1.status_code, status.HTTP_200_OK) + + # price_comparison + response2 = self.client2.get( + f"/api/dca/strategies/{self.shared_strategy.id}/price_comparison/" + ) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + + # current_price + response3 = self.client2.get( + f"/api/dca/strategies/{self.shared_strategy.id}/current_price/" + ) + self.assertEqual(response3.status_code, status.HTTP_200_OK) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class SharedAccountGroupTests(TestCase): + """Tests for shared account group access.""" + + def setUp(self): + """Set up test data with shared account groups.""" + User = get_user_model() + + self.user1 = User.objects.create_user( + email="user1@test.com", password="testpass123" + ) + self.client1 = APIClient() + self.client1.force_authenticate(user=self.user1) + + self.user2 = User.objects.create_user( + email="user2@test.com", password="testpass123" + ) + self.client2 = APIClient() + self.client2.force_authenticate(user=self.user2) + + self.user3 = User.objects.create_user( + email="user3@test.com", password="testpass123" + ) + self.client3 = APIClient() + self.client3.force_authenticate(user=self.user3) + + # User 1's account group shared with user 2 + self.shared_group = AccountGroup.all_objects.create( + name="Shared Group", owner=self.user1 + ) + self.shared_group.shared_with.add(self.user2) + + # User 1's private account group + self.private_group = AccountGroup.all_objects.create( + name="Private Group", owner=self.user1 + ) + + # User 1's public account group + self.public_group = AccountGroup.all_objects.create( + name="Public Group", owner=self.user1, visibility="public" + ) + + def test_user_can_see_shared_account_groups(self): + """User2 should see account groups shared with them.""" + response = self.client2.get("/api/account-groups/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + group_ids = [g["id"] for g in response.data["results"]] + self.assertIn(self.shared_group.id, group_ids) + self.assertNotIn(self.private_group.id, group_ids) + + def test_user_can_access_shared_account_group_detail(self): + """User2 should be able to access shared account group details.""" + response = self.client2.get(f"/api/account-groups/{self.shared_group.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["name"], "Shared Group") + + def test_user_can_see_public_account_groups(self): + """User3 should see public account groups.""" + response = self.client3.get("/api/account-groups/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + group_ids = [g["id"] for g in response.data["results"]] + self.assertIn(self.public_group.id, group_ids) + + def test_user_without_share_cannot_access_shared_group(self): + """User3 should NOT be able to access shared account group.""" + response = self.client3.get(f"/api/account-groups/{self.shared_group.id}/") + + self.assertIn(response.status_code, ACCESS_DENIED_CODES) diff --git a/app/apps/api/views/accounts.py b/app/apps/api/views/accounts.py index 46d8e20..8220a5b 100644 --- a/app/apps/api/views/accounts.py +++ b/app/apps/api/views/accounts.py @@ -7,7 +7,11 @@ from rest_framework.response import Response from apps.accounts.models import AccountGroup, Account from apps.accounts.services import get_account_balance from apps.api.custom.pagination import CustomPageNumberPagination -from apps.api.serializers import AccountGroupSerializer, AccountSerializer, AccountBalanceSerializer +from apps.api.serializers import ( + AccountGroupSerializer, + AccountSerializer, + AccountBalanceSerializer, +) class AccountGroupViewSet(viewsets.ModelViewSet): @@ -17,13 +21,15 @@ class AccountGroupViewSet(viewsets.ModelViewSet): serializer_class = AccountGroupSerializer pagination_class = CustomPageNumberPagination filterset_fields = { - 'name': ['exact', 'icontains'], - 'owner': ['exact'], + "name": ["exact", "icontains"], + "owner": ["exact"], } - search_fields = ['name'] - ordering_fields = '__all__' - ordering = ['id'] + search_fields = ["name"] + ordering_fields = "__all__" + ordering = ["id"] + def get_queryset(self): + return AccountGroup.objects.all() @extend_schema_view( @@ -40,37 +46,37 @@ class AccountViewSet(viewsets.ModelViewSet): serializer_class = AccountSerializer pagination_class = CustomPageNumberPagination filterset_fields = { - 'name': ['exact', 'icontains'], - 'group': ['exact', 'isnull'], - 'currency': ['exact'], - 'exchange_currency': ['exact', 'isnull'], - 'is_asset': ['exact'], - 'is_archived': ['exact'], - 'owner': ['exact'], + "name": ["exact", "icontains"], + "group": ["exact", "isnull"], + "currency": ["exact"], + "exchange_currency": ["exact", "isnull"], + "is_asset": ["exact"], + "is_archived": ["exact"], + "owner": ["exact"], } - search_fields = ['name'] - ordering_fields = '__all__' - ordering = ['id'] + search_fields = ["name"] + ordering_fields = "__all__" + ordering = ["id"] def get_queryset(self): - return ( - Account.objects.all() - .select_related("group", "currency", "exchange_currency") + return Account.objects.all().select_related( + "group", "currency", "exchange_currency" ) @action(detail=True, methods=["get"], permission_classes=[IsAuthenticated]) def balance(self, request, pk=None): """Get current and projected balance for an account.""" account = self.get_object() - + current_balance = get_account_balance(account, paid_only=True) projected_balance = get_account_balance(account, paid_only=False) - - serializer = AccountBalanceSerializer({ - "current_balance": current_balance, - "projected_balance": projected_balance, - "currency": account.currency, - }) - - return Response(serializer.data) + serializer = AccountBalanceSerializer( + { + "current_balance": current_balance, + "projected_balance": projected_balance, + "currency": account.currency, + } + ) + + return Response(serializer.data) diff --git a/app/apps/api/views/dca.py b/app/apps/api/views/dca.py index 10165cb..9360682 100644 --- a/app/apps/api/views/dca.py +++ b/app/apps/api/views/dca.py @@ -6,18 +6,21 @@ from apps.api.serializers import DCAStrategySerializer, DCAEntrySerializer class DCAStrategyViewSet(viewsets.ModelViewSet): - queryset = DCAStrategy.all_objects.all() + queryset = DCAStrategy.objects.all() serializer_class = DCAStrategySerializer filterset_fields = { - 'name': ['exact', 'icontains'], - 'target_currency': ['exact'], - 'payment_currency': ['exact'], - 'notes': ['exact', 'icontains'], - 'created_at': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'updated_at': ['exact', 'gte', 'lte', 'gt', 'lt'], + "name": ["exact", "icontains"], + "target_currency": ["exact"], + "payment_currency": ["exact"], + "notes": ["exact", "icontains"], + "created_at": ["exact", "gte", "lte", "gt", "lt"], + "updated_at": ["exact", "gte", "lte", "gt", "lt"], } - search_fields = ['name', 'notes'] - ordering_fields = '__all__' + search_fields = ["name", "notes"] + ordering_fields = "__all__" + + def get_queryset(self): + return DCAStrategy.objects.all() @action(detail=True, methods=["get"]) def investment_frequency(self, request, pk=None): @@ -43,16 +46,21 @@ class DCAEntryViewSet(viewsets.ModelViewSet): queryset = DCAEntry.objects.all() serializer_class = DCAEntrySerializer filterset_fields = { - 'strategy': ['exact'], - 'date': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'amount_paid': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'amount_received': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'expense_transaction': ['exact', 'isnull'], - 'income_transaction': ['exact', 'isnull'], - 'notes': ['exact', 'icontains'], - 'created_at': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'updated_at': ['exact', 'gte', 'lte', 'gt', 'lt'], + "strategy": ["exact"], + "date": ["exact", "gte", "lte", "gt", "lt"], + "amount_paid": ["exact", "gte", "lte", "gt", "lt"], + "amount_received": ["exact", "gte", "lte", "gt", "lt"], + "expense_transaction": ["exact", "isnull"], + "income_transaction": ["exact", "isnull"], + "notes": ["exact", "icontains"], + "created_at": ["exact", "gte", "lte", "gt", "lt"], + "updated_at": ["exact", "gte", "lte", "gt", "lt"], } - search_fields = ['notes'] - ordering_fields = '__all__' - ordering = ['-date'] + search_fields = ["notes"] + ordering_fields = "__all__" + ordering = ["-date"] + + def get_queryset(self): + # Filter entries by strategies the user has access to + accessible_strategies = DCAStrategy.objects.all() + return DCAEntry.objects.filter(strategy__in=accessible_strategies) diff --git a/app/apps/api/views/transactions.py b/app/apps/api/views/transactions.py index f04b556..6e9165b 100644 --- a/app/apps/api/views/transactions.py +++ b/app/apps/api/views/transactions.py @@ -27,30 +27,33 @@ class TransactionViewSet(viewsets.ModelViewSet): serializer_class = TransactionSerializer pagination_class = CustomPageNumberPagination filterset_fields = { - 'account': ['exact'], - 'type': ['exact'], - 'is_paid': ['exact'], - 'date': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'reference_date': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'mute': ['exact'], - 'amount': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'description': ['exact', 'icontains'], - 'notes': ['exact', 'icontains'], - 'category': ['exact', 'isnull'], - 'installment_plan': ['exact', 'isnull'], - 'installment_id': ['exact', 'gte', 'lte'], - 'recurring_transaction': ['exact', 'isnull'], - 'internal_note': ['exact', 'icontains'], - 'internal_id': ['exact'], - 'deleted': ['exact'], - 'created_at': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'updated_at': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'deleted_at': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'], - 'owner': ['exact'], + "account": ["exact"], + "type": ["exact"], + "is_paid": ["exact"], + "date": ["exact", "gte", "lte", "gt", "lt"], + "reference_date": ["exact", "gte", "lte", "gt", "lt"], + "mute": ["exact"], + "amount": ["exact", "gte", "lte", "gt", "lt"], + "description": ["exact", "icontains"], + "notes": ["exact", "icontains"], + "category": ["exact", "isnull"], + "installment_plan": ["exact", "isnull"], + "installment_id": ["exact", "gte", "lte"], + "recurring_transaction": ["exact", "isnull"], + "internal_note": ["exact", "icontains"], + "internal_id": ["exact"], + "deleted": ["exact"], + "created_at": ["exact", "gte", "lte", "gt", "lt"], + "updated_at": ["exact", "gte", "lte", "gt", "lt"], + "deleted_at": ["exact", "gte", "lte", "gt", "lt", "isnull"], + "owner": ["exact"], } - search_fields = ['description', 'notes', 'internal_note'] - ordering_fields = '__all__' - ordering = ['-id'] + search_fields = ["description", "notes", "internal_note"] + ordering_fields = "__all__" + ordering = ["-id"] + + def get_queryset(self): + return Transaction.objects.all() def perform_create(self, serializer): instance = serializer.save() @@ -71,14 +74,17 @@ class TransactionCategoryViewSet(viewsets.ModelViewSet): serializer_class = TransactionCategorySerializer pagination_class = CustomPageNumberPagination filterset_fields = { - 'name': ['exact', 'icontains'], - 'mute': ['exact'], - 'active': ['exact'], - 'owner': ['exact'], + "name": ["exact", "icontains"], + "mute": ["exact"], + "active": ["exact"], + "owner": ["exact"], } - search_fields = ['name'] - ordering_fields = '__all__' - ordering = ['id'] + search_fields = ["name"] + ordering_fields = "__all__" + ordering = ["id"] + + def get_queryset(self): + return TransactionCategory.objects.all() class TransactionTagViewSet(viewsets.ModelViewSet): @@ -86,13 +92,16 @@ class TransactionTagViewSet(viewsets.ModelViewSet): serializer_class = TransactionTagSerializer pagination_class = CustomPageNumberPagination filterset_fields = { - 'name': ['exact', 'icontains'], - 'active': ['exact'], - 'owner': ['exact'], + "name": ["exact", "icontains"], + "active": ["exact"], + "owner": ["exact"], } - search_fields = ['name'] - ordering_fields = '__all__' - ordering = ['id'] + search_fields = ["name"] + ordering_fields = "__all__" + ordering = ["id"] + + def get_queryset(self): + return TransactionTag.objects.all() class TransactionEntityViewSet(viewsets.ModelViewSet): @@ -100,13 +109,16 @@ class TransactionEntityViewSet(viewsets.ModelViewSet): serializer_class = TransactionEntitySerializer pagination_class = CustomPageNumberPagination filterset_fields = { - 'name': ['exact', 'icontains'], - 'active': ['exact'], - 'owner': ['exact'], + "name": ["exact", "icontains"], + "active": ["exact"], + "owner": ["exact"], } - search_fields = ['name'] - ordering_fields = '__all__' - ordering = ['id'] + search_fields = ["name"] + ordering_fields = "__all__" + ordering = ["id"] + + def get_queryset(self): + return TransactionEntity.objects.all() class InstallmentPlanViewSet(viewsets.ModelViewSet): @@ -114,25 +126,28 @@ class InstallmentPlanViewSet(viewsets.ModelViewSet): serializer_class = InstallmentPlanSerializer pagination_class = CustomPageNumberPagination filterset_fields = { - 'account': ['exact'], - 'type': ['exact'], - 'description': ['exact', 'icontains'], - 'number_of_installments': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'installment_start': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'installment_total_number': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'start_date': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'reference_date': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'], - 'end_date': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'], - 'recurrence': ['exact'], - 'installment_amount': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'category': ['exact', 'isnull'], - 'notes': ['exact', 'icontains'], - 'add_description_to_transaction': ['exact'], - 'add_notes_to_transaction': ['exact'], + "account": ["exact"], + "type": ["exact"], + "description": ["exact", "icontains"], + "number_of_installments": ["exact", "gte", "lte", "gt", "lt"], + "installment_start": ["exact", "gte", "lte", "gt", "lt"], + "installment_total_number": ["exact", "gte", "lte", "gt", "lt"], + "start_date": ["exact", "gte", "lte", "gt", "lt"], + "reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"], + "end_date": ["exact", "gte", "lte", "gt", "lt", "isnull"], + "recurrence": ["exact"], + "installment_amount": ["exact", "gte", "lte", "gt", "lt"], + "category": ["exact", "isnull"], + "notes": ["exact", "icontains"], + "add_description_to_transaction": ["exact"], + "add_notes_to_transaction": ["exact"], } - search_fields = ['description', 'notes'] - ordering_fields = '__all__' - ordering = ['-id'] + search_fields = ["description", "notes"] + ordering_fields = "__all__" + ordering = ["-id"] + + def get_queryset(self): + return InstallmentPlan.objects.all() class RecurringTransactionViewSet(viewsets.ModelViewSet): @@ -140,25 +155,27 @@ class RecurringTransactionViewSet(viewsets.ModelViewSet): serializer_class = RecurringTransactionSerializer pagination_class = CustomPageNumberPagination filterset_fields = { - 'is_paused': ['exact'], - 'account': ['exact'], - 'type': ['exact'], - 'amount': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'description': ['exact', 'icontains'], - 'category': ['exact', 'isnull'], - 'notes': ['exact', 'icontains'], - 'reference_date': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'], - 'start_date': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'end_date': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'], - 'recurrence_type': ['exact'], - 'recurrence_interval': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'keep_at_most': ['exact', 'gte', 'lte', 'gt', 'lt'], - 'last_generated_date': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'], - 'last_generated_reference_date': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'], - 'add_description_to_transaction': ['exact'], - 'add_notes_to_transaction': ['exact'], + "is_paused": ["exact"], + "account": ["exact"], + "type": ["exact"], + "amount": ["exact", "gte", "lte", "gt", "lt"], + "description": ["exact", "icontains"], + "category": ["exact", "isnull"], + "notes": ["exact", "icontains"], + "reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"], + "start_date": ["exact", "gte", "lte", "gt", "lt"], + "end_date": ["exact", "gte", "lte", "gt", "lt", "isnull"], + "recurrence_type": ["exact"], + "recurrence_interval": ["exact", "gte", "lte", "gt", "lt"], + "keep_at_most": ["exact", "gte", "lte", "gt", "lt"], + "last_generated_date": ["exact", "gte", "lte", "gt", "lt", "isnull"], + "last_generated_reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"], + "add_description_to_transaction": ["exact"], + "add_notes_to_transaction": ["exact"], } - search_fields = ['description', 'notes'] - ordering_fields = '__all__' - ordering = ['-id'] + search_fields = ["description", "notes"] + ordering_fields = "__all__" + ordering = ["-id"] + def get_queryset(self): + return RecurringTransaction.objects.all()