From d3a816d91bbd7ff0aaed1ca6a2a2421bcc276c69 Mon Sep 17 00:00:00 2001 From: Herculino Trotta Date: Sun, 7 Dec 2025 00:32:18 -0300 Subject: [PATCH] feat(api): add endpoints for importing files and getting account balance --- app/apps/accounts/services.py | 33 +++ app/apps/accounts/tests.py | 134 +++++++++ app/apps/accounts/views/balance.py | 12 +- app/apps/api/serializers/__init__.py | 2 + app/apps/api/serializers/accounts.py | 9 + app/apps/api/serializers/imports.py | 41 +++ app/apps/api/tests/__init__.py | 4 + app/apps/api/tests/test_accounts.py | 99 +++++++ app/apps/api/tests/test_imports.py | 404 +++++++++++++++++++++++++++ app/apps/api/urls.py | 4 + app/apps/api/views/__init__.py | 2 + app/apps/api/views/accounts.py | 37 ++- app/apps/api/views/imports.py | 123 ++++++++ 13 files changed, 891 insertions(+), 13 deletions(-) create mode 100644 app/apps/accounts/services.py create mode 100644 app/apps/api/serializers/imports.py create mode 100644 app/apps/api/tests/__init__.py create mode 100644 app/apps/api/tests/test_accounts.py create mode 100644 app/apps/api/tests/test_imports.py create mode 100644 app/apps/api/views/imports.py diff --git a/app/apps/accounts/services.py b/app/apps/accounts/services.py new file mode 100644 index 0000000..8966fb6 --- /dev/null +++ b/app/apps/accounts/services.py @@ -0,0 +1,33 @@ +from decimal import Decimal + +from django.db import models + +from apps.accounts.models import Account +from apps.transactions.models import Transaction + + +def get_account_balance(account: Account, paid_only: bool = True) -> Decimal: + """ + Calculate account balance (income - expense). + + Args: + account: Account instance to calculate balance for. + paid_only: If True, only count paid transactions (current balance). + If False, count all transactions (projected balance). + + Returns: + Decimal: The calculated balance (income - expense). + """ + filters = {"account": account} + if paid_only: + filters["is_paid"] = True + + income = Transaction.objects.filter( + type=Transaction.Type.INCOME, **filters + ).aggregate(total=models.Sum("amount"))["total"] or Decimal("0") + + expense = Transaction.objects.filter( + type=Transaction.Type.EXPENSE, **filters + ).aggregate(total=models.Sum("amount"))["total"] or Decimal("0") + + return income - expense diff --git a/app/apps/accounts/tests.py b/app/apps/accounts/tests.py index 3d81d0d..34aab3f 100644 --- a/app/apps/accounts/tests.py +++ b/app/apps/accounts/tests.py @@ -1,3 +1,5 @@ +from datetime import date + from django.test import TestCase from apps.accounts.models import Account, AccountGroup @@ -39,3 +41,135 @@ class AccountTests(TestCase): exchange_currency=self.exchange_currency, ) self.assertEqual(account.exchange_currency, self.exchange_currency) + + +class GetAccountBalanceServiceTests(TestCase): + """Tests for the get_account_balance service function""" + + def setUp(self): + """Set up test data""" + from apps.transactions.models import Transaction + self.Transaction = Transaction + + self.currency = Currency.objects.create( + code="BRL", name="Brazilian Real", decimal_places=2, prefix="R$ " + ) + self.account_group = AccountGroup.objects.create(name="Service Test Group") + self.account = Account.objects.create( + name="Service Test Account", group=self.account_group, currency=self.currency + ) + + def test_balance_with_no_transactions(self): + """Test balance is 0 when no transactions exist""" + from apps.accounts.services import get_account_balance + from decimal import Decimal + + balance = get_account_balance(self.account, paid_only=True) + self.assertEqual(balance, Decimal("0")) + + def test_current_balance_only_counts_paid(self): + """Test current balance only counts paid transactions""" + from apps.accounts.services import get_account_balance + from decimal import Decimal + + # Paid income + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.INCOME, + amount=Decimal("100.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Paid income", + ) + # Unpaid income (should not count) + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.INCOME, + amount=Decimal("50.00"), + is_paid=False, + date=date(2025, 1, 1), + description="Unpaid income", + ) + # Paid expense + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.EXPENSE, + amount=Decimal("30.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Paid expense", + ) + + balance = get_account_balance(self.account, paid_only=True) + self.assertEqual(balance, Decimal("70.00")) # 100 - 30 + + def test_projected_balance_counts_all(self): + """Test projected balance counts all transactions""" + from apps.accounts.services import get_account_balance + from decimal import Decimal + + # Paid income + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.INCOME, + amount=Decimal("100.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Paid income", + ) + # Unpaid income + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.INCOME, + amount=Decimal("50.00"), + is_paid=False, + date=date(2025, 1, 1), + description="Unpaid income", + ) + # Paid expense + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.EXPENSE, + amount=Decimal("30.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Paid expense", + ) + # Unpaid expense + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.EXPENSE, + amount=Decimal("20.00"), + is_paid=False, + date=date(2025, 1, 1), + description="Unpaid expense", + ) + + balance = get_account_balance(self.account, paid_only=False) + self.assertEqual(balance, Decimal("100.00")) # (100 + 50) - (30 + 20) + + def test_balance_defaults_to_paid_only(self): + """Test that paid_only defaults to True""" + from apps.accounts.services import get_account_balance + from decimal import Decimal + + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.INCOME, + amount=Decimal("100.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Paid", + ) + self.Transaction.objects.create( + account=self.account, + type=self.Transaction.Type.INCOME, + amount=Decimal("50.00"), + is_paid=False, + date=date(2025, 1, 1), + description="Unpaid", + ) + + balance = get_account_balance(self.account) # defaults to paid_only=True + self.assertEqual(balance, Decimal("100.00")) + diff --git a/app/apps/accounts/views/balance.py b/app/apps/accounts/views/balance.py index 5292f95..2749593 100644 --- a/app/apps/accounts/views/balance.py +++ b/app/apps/accounts/views/balance.py @@ -11,23 +11,13 @@ from django.utils.translation import gettext_lazy as _ from apps.accounts.forms import AccountBalanceFormSet from apps.accounts.models import Account, Transaction +from apps.accounts.services import get_account_balance from apps.common.decorators.htmx import only_htmx @only_htmx @login_required def account_reconciliation(request): - def get_account_balance(account): - income = Transaction.objects.filter( - account=account, type=Transaction.Type.INCOME, is_paid=True - ).aggregate(total=models.Sum("amount"))["total"] or Decimal("0") - - expense = Transaction.objects.filter( - account=account, type=Transaction.Type.EXPENSE, is_paid=True - ).aggregate(total=models.Sum("amount"))["total"] or Decimal("0") - - return income - expense - initial_data = [ { "account_id": account.id, diff --git a/app/apps/api/serializers/__init__.py b/app/apps/api/serializers/__init__.py index 0fd06bc..903fe3b 100644 --- a/app/apps/api/serializers/__init__.py +++ b/app/apps/api/serializers/__init__.py @@ -2,3 +2,5 @@ from .transactions import * from .accounts import * from .currencies import * from .dca import * +from .imports import * + diff --git a/app/apps/api/serializers/accounts.py b/app/apps/api/serializers/accounts.py index c7f5e45..db95f25 100644 --- a/app/apps/api/serializers/accounts.py +++ b/app/apps/api/serializers/accounts.py @@ -67,3 +67,12 @@ class AccountSerializer(serializers.ModelSerializer): setattr(instance, attr, value) instance.save() return instance + + +class AccountBalanceSerializer(serializers.Serializer): + """Serializer for account balance response.""" + + current_balance = serializers.DecimalField(max_digits=20, decimal_places=10) + projected_balance = serializers.DecimalField(max_digits=20, decimal_places=10) + currency = CurrencySerializer() + diff --git a/app/apps/api/serializers/imports.py b/app/apps/api/serializers/imports.py new file mode 100644 index 0000000..5a00cc2 --- /dev/null +++ b/app/apps/api/serializers/imports.py @@ -0,0 +1,41 @@ +from rest_framework import serializers + +from apps.import_app.models import ImportProfile, ImportRun + + +class ImportProfileSerializer(serializers.ModelSerializer): + """Serializer for listing import profiles.""" + + class Meta: + model = ImportProfile + fields = ["id", "name", "version", "yaml_config"] + + +class ImportRunSerializer(serializers.ModelSerializer): + """Serializer for listing import runs.""" + + class Meta: + model = ImportRun + fields = [ + "id", + "status", + "profile", + "file_name", + "logs", + "processed_rows", + "total_rows", + "successful_rows", + "skipped_rows", + "failed_rows", + "started_at", + "finished_at", + ] + + +class ImportFileSerializer(serializers.Serializer): + """Serializer for uploading a file to import using an existing profile.""" + + profile_id = serializers.PrimaryKeyRelatedField( + queryset=ImportProfile.objects.all(), source="profile" + ) + file = serializers.FileField() diff --git a/app/apps/api/tests/__init__.py b/app/apps/api/tests/__init__.py new file mode 100644 index 0000000..3c860ef --- /dev/null +++ b/app/apps/api/tests/__init__.py @@ -0,0 +1,4 @@ +# Import all test classes for Django test discovery +from .test_imports import * +from .test_accounts import * + diff --git a/app/apps/api/tests/test_accounts.py b/app/apps/api/tests/test_accounts.py new file mode 100644 index 0000000..50e7a94 --- /dev/null +++ b/app/apps/api/tests/test_accounts.py @@ -0,0 +1,99 @@ +from datetime import date +from decimal import Decimal + +from django.contrib.auth import get_user_model +from django.test import TestCase, override_settings +from rest_framework import status +from rest_framework.test import APIClient + +from apps.accounts.models import Account, AccountGroup +from apps.currencies.models import Currency +from apps.transactions.models import Transaction + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class AccountBalanceAPITests(TestCase): + """Tests for the Account Balance API endpoint""" + + def setUp(self): + """Set up test data""" + User = get_user_model() + self.user = User.objects.create_user( + email="testuser@test.com", password="testpass123" + ) + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + self.currency = Currency.objects.create( + code="USD", name="US Dollar", decimal_places=2, prefix="$ " + ) + self.account_group = AccountGroup.objects.create(name="Test Group") + self.account = Account.objects.create( + name="Test Account", group=self.account_group, currency=self.currency + ) + + # Create some transactions + Transaction.objects.create( + account=self.account, + type=Transaction.Type.INCOME, + amount=Decimal("500.00"), + is_paid=True, + date=date(2025, 1, 1), + description="Paid income", + ) + Transaction.objects.create( + account=self.account, + type=Transaction.Type.INCOME, + amount=Decimal("200.00"), + is_paid=False, + date=date(2025, 1, 15), + description="Unpaid income", + ) + Transaction.objects.create( + account=self.account, + type=Transaction.Type.EXPENSE, + amount=Decimal("100.00"), + is_paid=True, + date=date(2025, 1, 10), + description="Paid expense", + ) + + def test_get_balance_success(self): + """Test successful balance retrieval""" + response = self.client.get(f"/api/accounts/{self.account.id}/balance/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("current_balance", response.data) + self.assertIn("projected_balance", response.data) + self.assertIn("currency", response.data) + + # Current: 500 - 100 = 400 + self.assertEqual(Decimal(response.data["current_balance"]), Decimal("400.00")) + # Projected: (500 + 200) - 100 = 600 + self.assertEqual(Decimal(response.data["projected_balance"]), Decimal("600.00")) + + # Check currency data + self.assertEqual(response.data["currency"]["code"], "USD") + + def test_get_balance_nonexistent_account(self): + """Test balance for non-existent account returns 404""" + response = self.client.get("/api/accounts/99999/balance/") + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_get_balance_unauthenticated(self): + """Test unauthenticated request returns 403""" + unauthenticated_client = APIClient() + response = unauthenticated_client.get( + f"/api/accounts/{self.account.id}/balance/" + ) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/app/apps/api/tests/test_imports.py b/app/apps/api/tests/test_imports.py new file mode 100644 index 0000000..9509e4d --- /dev/null +++ b/app/apps/api/tests/test_imports.py @@ -0,0 +1,404 @@ +from io import BytesIO +from unittest.mock import patch + +from django.contrib.auth import get_user_model +from django.core.files.uploadedfile import SimpleUploadedFile +from django.test import TestCase, override_settings +from rest_framework import status +from rest_framework.test import APIClient + +from apps.import_app.models import ImportProfile, ImportRun + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class ImportAPITests(TestCase): + """Tests for the Import API endpoint""" + + def setUp(self): + """Set up test data""" + User = get_user_model() + self.user = User.objects.create_user( + email="testuser@test.com", password="testpass123" + ) + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + # Create a basic import profile with minimal valid YAML config + self.profile = ImportProfile.objects.create( + name="Test Profile", + version=ImportProfile.Versions.VERSION_1, + yaml_config=""" +file_type: csv +date_format: "%Y-%m-%d" +column_mapping: + date: + source: date + description: + source: description + amount: + source: amount + transaction_type: + detection_method: always_expense + is_paid: + detection_method: always_paid + account: + source: account + match_field: name +""", + ) + + @patch("apps.import_app.tasks.process_import.defer") + @patch("django.core.files.storage.FileSystemStorage.save") + @patch("django.core.files.storage.FileSystemStorage.path") + def test_create_import_success(self, mock_path, mock_save, mock_defer): + """Test successful file upload creates ImportRun and queues task""" + mock_save.return_value = "test_file.csv" + mock_path.return_value = "/usr/src/app/temp/test_file.csv" + + csv_content = b"date,description,amount,account\n2025-01-01,Test,100,Main" + file = SimpleUploadedFile( + "test_file.csv", csv_content, content_type="text/csv" + ) + + response = self.client.post( + "/api/import/import/", + {"profile_id": self.profile.id, "file": file}, + format="multipart", + ) + + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + self.assertIn("import_run_id", response.data) + self.assertEqual(response.data["status"], "queued") + + # Verify ImportRun was created + import_run = ImportRun.objects.get(id=response.data["import_run_id"]) + self.assertEqual(import_run.profile, self.profile) + self.assertEqual(import_run.file_name, "test_file.csv") + + # Verify task was deferred + mock_defer.assert_called_once_with( + import_run_id=import_run.id, + file_path="/usr/src/app/temp/test_file.csv", + user_id=self.user.id, + ) + + def test_create_import_missing_profile(self): + """Test request without profile_id returns 400""" + csv_content = b"date,description,amount\n2025-01-01,Test,100" + file = SimpleUploadedFile( + "test_file.csv", csv_content, content_type="text/csv" + ) + + response = self.client.post( + "/api/import/import/", + {"file": file}, + format="multipart", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn("profile_id", response.data) + + def test_create_import_missing_file(self): + """Test request without file returns 400""" + response = self.client.post( + "/api/import/import/", + {"profile_id": self.profile.id}, + format="multipart", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn("file", response.data) + + def test_create_import_invalid_profile(self): + """Test request with non-existent profile returns 400""" + csv_content = b"date,description,amount\n2025-01-01,Test,100" + file = SimpleUploadedFile( + "test_file.csv", csv_content, content_type="text/csv" + ) + + response = self.client.post( + "/api/import/import/", + {"profile_id": 99999, "file": file}, + format="multipart", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn("profile_id", response.data) + + @patch("apps.import_app.tasks.process_import.defer") + @patch("django.core.files.storage.FileSystemStorage.save") + @patch("django.core.files.storage.FileSystemStorage.path") + def test_create_import_xlsx(self, mock_path, mock_save, mock_defer): + """Test successful XLSX file upload""" + mock_save.return_value = "test_file.xlsx" + mock_path.return_value = "/usr/src/app/temp/test_file.xlsx" + + # Create a simple XLSX-like content (just for the upload test) + xlsx_content = BytesIO(b"PK\x03\x04") # XLSX files start with PK header + file = SimpleUploadedFile( + "test_file.xlsx", + xlsx_content.getvalue(), + content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ) + + response = self.client.post( + "/api/import/import/", + {"profile_id": self.profile.id, "file": file}, + format="multipart", + ) + + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + self.assertIn("import_run_id", response.data) + + def test_unauthenticated_request(self): + """Test unauthenticated request returns 403""" + unauthenticated_client = APIClient() + + csv_content = b"date,description,amount\n2025-01-01,Test,100" + file = SimpleUploadedFile( + "test_file.csv", csv_content, content_type="text/csv" + ) + + response = unauthenticated_client.post( + "/api/import/import/", + {"profile_id": self.profile.id, "file": file}, + format="multipart", + ) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class ImportProfileAPITests(TestCase): + """Tests for the Import Profile API endpoints""" + + def setUp(self): + """Set up test data""" + User = get_user_model() + self.user = User.objects.create_user( + email="testuser@test.com", password="testpass123" + ) + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + self.profile1 = ImportProfile.objects.create( + name="Profile 1", + version=ImportProfile.Versions.VERSION_1, + yaml_config=""" +file_type: csv +date_format: "%Y-%m-%d" +column_mapping: + date: + source: date + description: + source: description + amount: + source: amount + transaction_type: + detection_method: always_expense + is_paid: + detection_method: always_paid + account: + source: account + match_field: name +""", + ) + self.profile2 = ImportProfile.objects.create( + name="Profile 2", + version=ImportProfile.Versions.VERSION_1, + yaml_config=""" +file_type: csv +date_format: "%Y-%m-%d" +column_mapping: + date: + source: date + description: + source: description + amount: + source: amount + transaction_type: + detection_method: always_income + is_paid: + detection_method: always_unpaid + account: + source: account + match_field: name +""", + ) + + def test_list_profiles(self): + """Test listing all profiles""" + response = self.client.get("/api/import/profiles/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 2) + self.assertEqual(len(response.data["results"]), 2) + + def test_retrieve_profile(self): + """Test retrieving a specific profile""" + response = self.client.get(f"/api/import/profiles/{self.profile1.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["id"], self.profile1.id) + self.assertEqual(response.data["name"], "Profile 1") + self.assertIn("yaml_config", response.data) + + def test_retrieve_nonexistent_profile(self): + """Test retrieving a non-existent profile returns 404""" + response = self.client.get("/api/import/profiles/99999/") + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_profiles_unauthenticated(self): + """Test unauthenticated request returns 403""" + unauthenticated_client = APIClient() + response = unauthenticated_client.get("/api/import/profiles/") + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + +@override_settings( + STORAGES={ + "default": {"BACKEND": "django.core.files.storage.FileSystemStorage"}, + "staticfiles": { + "BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage" + }, + }, + WHITENOISE_AUTOREFRESH=True, +) +class ImportRunAPITests(TestCase): + """Tests for the Import Run API endpoints""" + + def setUp(self): + """Set up test data""" + User = get_user_model() + self.user = User.objects.create_user( + email="testuser@test.com", password="testpass123" + ) + self.client = APIClient() + self.client.force_authenticate(user=self.user) + + self.profile1 = ImportProfile.objects.create( + name="Profile 1", + version=ImportProfile.Versions.VERSION_1, + yaml_config=""" +file_type: csv +date_format: "%Y-%m-%d" +column_mapping: + date: + source: date + description: + source: description + amount: + source: amount + transaction_type: + detection_method: always_expense + is_paid: + detection_method: always_paid + account: + source: account + match_field: name +""", + ) + self.profile2 = ImportProfile.objects.create( + name="Profile 2", + version=ImportProfile.Versions.VERSION_1, + yaml_config=""" +file_type: csv +date_format: "%Y-%m-%d" +column_mapping: + date: + source: date + description: + source: description + amount: + source: amount + transaction_type: + detection_method: always_income + is_paid: + detection_method: always_unpaid + account: + source: account + match_field: name +""", + ) + + # Create import runs + self.run1 = ImportRun.objects.create( + profile=self.profile1, + file_name="file1.csv", + status=ImportRun.Status.FINISHED, + ) + self.run2 = ImportRun.objects.create( + profile=self.profile1, + file_name="file2.csv", + status=ImportRun.Status.QUEUED, + ) + self.run3 = ImportRun.objects.create( + profile=self.profile2, + file_name="file3.csv", + status=ImportRun.Status.FINISHED, + ) + + def test_list_all_runs(self): + """Test listing all runs""" + response = self.client.get("/api/import/runs/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 3) + self.assertEqual(len(response.data["results"]), 3) + + def test_list_runs_by_profile(self): + """Test filtering runs by profile_id""" + response = self.client.get(f"/api/import/runs/?profile_id={self.profile1.id}") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 2) + for run in response.data["results"]: + self.assertEqual(run["profile"], self.profile1.id) + + def test_list_runs_by_other_profile(self): + """Test filtering runs by another profile_id""" + response = self.client.get(f"/api/import/runs/?profile_id={self.profile2.id}") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["profile"], self.profile2.id) + + def test_retrieve_run(self): + """Test retrieving a specific run""" + response = self.client.get(f"/api/import/runs/{self.run1.id}/") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["id"], self.run1.id) + self.assertEqual(response.data["file_name"], "file1.csv") + self.assertEqual(response.data["status"], "FINISHED") + + def test_retrieve_nonexistent_run(self): + """Test retrieving a non-existent run returns 404""" + response = self.client.get("/api/import/runs/99999/") + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_runs_unauthenticated(self): + """Test unauthenticated request returns 403""" + unauthenticated_client = APIClient() + response = unauthenticated_client.get("/api/import/runs/") + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/app/apps/api/urls.py b/app/apps/api/urls.py index b900064..ee2d1fe 100644 --- a/app/apps/api/urls.py +++ b/app/apps/api/urls.py @@ -16,7 +16,11 @@ router.register(r"currencies", views.CurrencyViewSet) router.register(r"exchange-rates", views.ExchangeRateViewSet) router.register(r"dca/strategies", views.DCAStrategyViewSet) router.register(r"dca/entries", views.DCAEntryViewSet) +router.register(r"import/profiles", views.ImportProfileViewSet, basename="import-profiles") +router.register(r"import/runs", views.ImportRunViewSet, basename="import-runs") +router.register(r"import/import", views.ImportViewSet, basename="import-import") urlpatterns = [ path("", include(router.urls)), ] + diff --git a/app/apps/api/views/__init__.py b/app/apps/api/views/__init__.py index 0fd06bc..903fe3b 100644 --- a/app/apps/api/views/__init__.py +++ b/app/apps/api/views/__init__.py @@ -2,3 +2,5 @@ from .transactions import * from .accounts import * from .currencies import * from .dca import * +from .imports import * + diff --git a/app/apps/api/views/accounts.py b/app/apps/api/views/accounts.py index b5c0fba..3a6db3e 100644 --- a/app/apps/api/views/accounts.py +++ b/app/apps/api/views/accounts.py @@ -1,11 +1,18 @@ +from drf_spectacular.utils import extend_schema, extend_schema_view from rest_framework import viewsets +from rest_framework.decorators import action +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response -from apps.api.custom.pagination import CustomPageNumberPagination from apps.accounts.models import AccountGroup, Account -from apps.api.serializers import AccountGroupSerializer, AccountSerializer +from apps.accounts.services import get_account_balance +from apps.api.custom.pagination import CustomPageNumberPagination +from apps.api.serializers import AccountGroupSerializer, AccountSerializer, AccountBalanceSerializer class AccountGroupViewSet(viewsets.ModelViewSet): + """ViewSet for managing account groups.""" + queryset = AccountGroup.objects.all() serializer_class = AccountGroupSerializer pagination_class = CustomPageNumberPagination @@ -14,7 +21,16 @@ class AccountGroupViewSet(viewsets.ModelViewSet): return AccountGroup.objects.all().order_by("id") +@extend_schema_view( + balance=extend_schema( + summary="Get account balance", + description="Returns the current and projected balance for the account, along with currency data.", + responses={200: AccountBalanceSerializer}, + ), +) class AccountViewSet(viewsets.ModelViewSet): + """ViewSet for managing accounts.""" + queryset = Account.objects.all() serializer_class = AccountSerializer pagination_class = CustomPageNumberPagination @@ -25,3 +41,20 @@ class AccountViewSet(viewsets.ModelViewSet): .order_by("id") .select_related("group", "currency", "exchange_currency") ) + + @action(detail=True, methods=["get"], permission_classes=[IsAuthenticated]) + def balance(self, request, pk=None): + """Get current and projected balance for an account.""" + account = self.get_object() + + current_balance = get_account_balance(account, paid_only=True) + projected_balance = get_account_balance(account, paid_only=False) + + serializer = AccountBalanceSerializer({ + "current_balance": current_balance, + "projected_balance": projected_balance, + "currency": account.currency, + }) + + return Response(serializer.data) + diff --git a/app/apps/api/views/imports.py b/app/apps/api/views/imports.py new file mode 100644 index 0000000..80f999e --- /dev/null +++ b/app/apps/api/views/imports.py @@ -0,0 +1,123 @@ +from django.core.files.storage import FileSystemStorage +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view, inline_serializer +from rest_framework import serializers as drf_serializers +from rest_framework import status, viewsets +from rest_framework.parsers import MultiPartParser +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response + +from apps.api.serializers import ImportFileSerializer, ImportProfileSerializer, ImportRunSerializer +from apps.import_app.models import ImportProfile, ImportRun +from apps.import_app.tasks import process_import + + +@extend_schema_view( + list=extend_schema( + summary="List import profiles", + description="Returns a paginated list of all available import profiles.", + ), + retrieve=extend_schema( + summary="Get import profile", + description="Returns the details of a specific import profile by ID.", + ), +) +class ImportProfileViewSet(viewsets.ReadOnlyModelViewSet): + """ViewSet for listing and retrieving import profiles.""" + + queryset = ImportProfile.objects.all() + serializer_class = ImportProfileSerializer + permission_classes = [IsAuthenticated] + + +@extend_schema_view( + list=extend_schema( + summary="List import runs", + description="Returns a paginated list of import runs. Optionally filter by profile_id.", + parameters=[ + OpenApiParameter( + name="profile_id", + type=int, + location=OpenApiParameter.QUERY, + description="Filter runs by profile ID", + required=False, + ), + ], + ), + retrieve=extend_schema( + summary="Get import run", + description="Returns the details of a specific import run by ID, including status and logs.", + ), +) +class ImportRunViewSet(viewsets.ReadOnlyModelViewSet): + """ViewSet for listing and retrieving import runs.""" + + queryset = ImportRun.objects.all().order_by("-id") + serializer_class = ImportRunSerializer + permission_classes = [IsAuthenticated] + + def get_queryset(self): + queryset = super().get_queryset() + profile_id = self.request.query_params.get("profile_id") + if profile_id: + queryset = queryset.filter(profile_id=profile_id) + return queryset + + +@extend_schema_view( + create=extend_schema( + summary="Import file", + description="Upload a CSV or XLSX file to import using an existing import profile. The import is queued and processed asynchronously.", + request={ + "multipart/form-data": { + "type": "object", + "properties": { + "profile_id": {"type": "integer", "description": "ID of the ImportProfile to use"}, + "file": {"type": "string", "format": "binary", "description": "CSV or XLSX file to import"}, + }, + "required": ["profile_id", "file"], + }, + }, + responses={ + 202: inline_serializer( + name="ImportResponse", + fields={ + "import_run_id": drf_serializers.IntegerField(), + "status": drf_serializers.CharField(), + }, + ), + }, + ), +) +class ImportViewSet(viewsets.ViewSet): + """ViewSet for importing data via file upload.""" + + permission_classes = [IsAuthenticated] + parser_classes = [MultiPartParser] + + def create(self, request): + serializer = ImportFileSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + + profile = serializer.validated_data["profile"] + uploaded_file = serializer.validated_data["file"] + + # Save file to temp location + fs = FileSystemStorage(location="/usr/src/app/temp") + filename = fs.save(uploaded_file.name, uploaded_file) + file_path = fs.path(filename) + + # Create ImportRun record + import_run = ImportRun.objects.create(profile=profile, file_name=filename) + + # Queue import task + process_import.defer( + import_run_id=import_run.id, + file_path=file_path, + user_id=request.user.id, + ) + + return Response( + {"import_run_id": import_run.id, "status": "queued"}, + status=status.HTTP_202_ACCEPTED, + )