mirror of
https://github.com/eitchtee/WYGIWYH.git
synced 2026-01-18 14:57:00 +01:00
405 lines
13 KiB
Python
405 lines
13 KiB
Python
from io import BytesIO
|
|
from unittest.mock import patch
|
|
|
|
from django.contrib.auth import get_user_model
|
|
from django.core.files.uploadedfile import SimpleUploadedFile
|
|
from django.test import TestCase, override_settings
|
|
from rest_framework import status
|
|
from rest_framework.test import APIClient
|
|
|
|
from apps.import_app.models import ImportProfile, ImportRun
|
|
|
|
|
|
@override_settings(
|
|
STORAGES={
|
|
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
|
"staticfiles": {
|
|
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
|
},
|
|
},
|
|
WHITENOISE_AUTOREFRESH=True,
|
|
)
|
|
class ImportAPITests(TestCase):
|
|
"""Tests for the Import API endpoint"""
|
|
|
|
def setUp(self):
|
|
"""Set up test data"""
|
|
User = get_user_model()
|
|
self.user = User.objects.create_user(
|
|
email="testuser@test.com", password="testpass123"
|
|
)
|
|
self.client = APIClient()
|
|
self.client.force_authenticate(user=self.user)
|
|
|
|
# Create a basic import profile with minimal valid YAML config
|
|
self.profile = ImportProfile.objects.create(
|
|
name="Test Profile",
|
|
version=ImportProfile.Versions.VERSION_1,
|
|
yaml_config="""
|
|
file_type: csv
|
|
date_format: "%Y-%m-%d"
|
|
column_mapping:
|
|
date:
|
|
source: date
|
|
description:
|
|
source: description
|
|
amount:
|
|
source: amount
|
|
transaction_type:
|
|
detection_method: always_expense
|
|
is_paid:
|
|
detection_method: always_paid
|
|
account:
|
|
source: account
|
|
match_field: name
|
|
""",
|
|
)
|
|
|
|
@patch("apps.import_app.tasks.process_import.defer")
|
|
@patch("django.core.files.storage.FileSystemStorage.save")
|
|
@patch("django.core.files.storage.FileSystemStorage.path")
|
|
def test_create_import_success(self, mock_path, mock_save, mock_defer):
|
|
"""Test successful file upload creates ImportRun and queues task"""
|
|
mock_save.return_value = "test_file.csv"
|
|
mock_path.return_value = "/usr/src/app/temp/test_file.csv"
|
|
|
|
csv_content = b"date,description,amount,account\n2025-01-01,Test,100,Main"
|
|
file = SimpleUploadedFile(
|
|
"test_file.csv", csv_content, content_type="text/csv"
|
|
)
|
|
|
|
response = self.client.post(
|
|
"/api/import/import/",
|
|
{"profile_id": self.profile.id, "file": file},
|
|
format="multipart",
|
|
)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
|
self.assertIn("import_run_id", response.data)
|
|
self.assertEqual(response.data["status"], "queued")
|
|
|
|
# Verify ImportRun was created
|
|
import_run = ImportRun.objects.get(id=response.data["import_run_id"])
|
|
self.assertEqual(import_run.profile, self.profile)
|
|
self.assertEqual(import_run.file_name, "test_file.csv")
|
|
|
|
# Verify task was deferred
|
|
mock_defer.assert_called_once_with(
|
|
import_run_id=import_run.id,
|
|
file_path="/usr/src/app/temp/test_file.csv",
|
|
user_id=self.user.id,
|
|
)
|
|
|
|
def test_create_import_missing_profile(self):
|
|
"""Test request without profile_id returns 400"""
|
|
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
|
file = SimpleUploadedFile(
|
|
"test_file.csv", csv_content, content_type="text/csv"
|
|
)
|
|
|
|
response = self.client.post(
|
|
"/api/import/import/",
|
|
{"file": file},
|
|
format="multipart",
|
|
)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
self.assertIn("profile_id", response.data)
|
|
|
|
def test_create_import_missing_file(self):
|
|
"""Test request without file returns 400"""
|
|
response = self.client.post(
|
|
"/api/import/import/",
|
|
{"profile_id": self.profile.id},
|
|
format="multipart",
|
|
)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
self.assertIn("file", response.data)
|
|
|
|
def test_create_import_invalid_profile(self):
|
|
"""Test request with non-existent profile returns 400"""
|
|
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
|
file = SimpleUploadedFile(
|
|
"test_file.csv", csv_content, content_type="text/csv"
|
|
)
|
|
|
|
response = self.client.post(
|
|
"/api/import/import/",
|
|
{"profile_id": 99999, "file": file},
|
|
format="multipart",
|
|
)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
self.assertIn("profile_id", response.data)
|
|
|
|
@patch("apps.import_app.tasks.process_import.defer")
|
|
@patch("django.core.files.storage.FileSystemStorage.save")
|
|
@patch("django.core.files.storage.FileSystemStorage.path")
|
|
def test_create_import_xlsx(self, mock_path, mock_save, mock_defer):
|
|
"""Test successful XLSX file upload"""
|
|
mock_save.return_value = "test_file.xlsx"
|
|
mock_path.return_value = "/usr/src/app/temp/test_file.xlsx"
|
|
|
|
# Create a simple XLSX-like content (just for the upload test)
|
|
xlsx_content = BytesIO(b"PK\x03\x04") # XLSX files start with PK header
|
|
file = SimpleUploadedFile(
|
|
"test_file.xlsx",
|
|
xlsx_content.getvalue(),
|
|
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
)
|
|
|
|
response = self.client.post(
|
|
"/api/import/import/",
|
|
{"profile_id": self.profile.id, "file": file},
|
|
format="multipart",
|
|
)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
|
self.assertIn("import_run_id", response.data)
|
|
|
|
def test_unauthenticated_request(self):
|
|
"""Test unauthenticated request returns 403"""
|
|
unauthenticated_client = APIClient()
|
|
|
|
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
|
file = SimpleUploadedFile(
|
|
"test_file.csv", csv_content, content_type="text/csv"
|
|
)
|
|
|
|
response = unauthenticated_client.post(
|
|
"/api/import/import/",
|
|
{"profile_id": self.profile.id, "file": file},
|
|
format="multipart",
|
|
)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
|
|
|
|
|
@override_settings(
|
|
STORAGES={
|
|
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
|
"staticfiles": {
|
|
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
|
},
|
|
},
|
|
WHITENOISE_AUTOREFRESH=True,
|
|
)
|
|
class ImportProfileAPITests(TestCase):
|
|
"""Tests for the Import Profile API endpoints"""
|
|
|
|
def setUp(self):
|
|
"""Set up test data"""
|
|
User = get_user_model()
|
|
self.user = User.objects.create_user(
|
|
email="testuser@test.com", password="testpass123"
|
|
)
|
|
self.client = APIClient()
|
|
self.client.force_authenticate(user=self.user)
|
|
|
|
self.profile1 = ImportProfile.objects.create(
|
|
name="Profile 1",
|
|
version=ImportProfile.Versions.VERSION_1,
|
|
yaml_config="""
|
|
file_type: csv
|
|
date_format: "%Y-%m-%d"
|
|
column_mapping:
|
|
date:
|
|
source: date
|
|
description:
|
|
source: description
|
|
amount:
|
|
source: amount
|
|
transaction_type:
|
|
detection_method: always_expense
|
|
is_paid:
|
|
detection_method: always_paid
|
|
account:
|
|
source: account
|
|
match_field: name
|
|
""",
|
|
)
|
|
self.profile2 = ImportProfile.objects.create(
|
|
name="Profile 2",
|
|
version=ImportProfile.Versions.VERSION_1,
|
|
yaml_config="""
|
|
file_type: csv
|
|
date_format: "%Y-%m-%d"
|
|
column_mapping:
|
|
date:
|
|
source: date
|
|
description:
|
|
source: description
|
|
amount:
|
|
source: amount
|
|
transaction_type:
|
|
detection_method: always_income
|
|
is_paid:
|
|
detection_method: always_unpaid
|
|
account:
|
|
source: account
|
|
match_field: name
|
|
""",
|
|
)
|
|
|
|
def test_list_profiles(self):
|
|
"""Test listing all profiles"""
|
|
response = self.client.get("/api/import/profiles/")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["count"], 2)
|
|
self.assertEqual(len(response.data["results"]), 2)
|
|
|
|
def test_retrieve_profile(self):
|
|
"""Test retrieving a specific profile"""
|
|
response = self.client.get(f"/api/import/profiles/{self.profile1.id}/")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["id"], self.profile1.id)
|
|
self.assertEqual(response.data["name"], "Profile 1")
|
|
self.assertIn("yaml_config", response.data)
|
|
|
|
def test_retrieve_nonexistent_profile(self):
|
|
"""Test retrieving a non-existent profile returns 404"""
|
|
response = self.client.get("/api/import/profiles/99999/")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
|
|
|
def test_profiles_unauthenticated(self):
|
|
"""Test unauthenticated request returns 403"""
|
|
unauthenticated_client = APIClient()
|
|
response = unauthenticated_client.get("/api/import/profiles/")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
|
|
|
|
|
@override_settings(
|
|
STORAGES={
|
|
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
|
"staticfiles": {
|
|
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
|
},
|
|
},
|
|
WHITENOISE_AUTOREFRESH=True,
|
|
)
|
|
class ImportRunAPITests(TestCase):
|
|
"""Tests for the Import Run API endpoints"""
|
|
|
|
def setUp(self):
|
|
"""Set up test data"""
|
|
User = get_user_model()
|
|
self.user = User.objects.create_user(
|
|
email="testuser@test.com", password="testpass123"
|
|
)
|
|
self.client = APIClient()
|
|
self.client.force_authenticate(user=self.user)
|
|
|
|
self.profile1 = ImportProfile.objects.create(
|
|
name="Profile 1",
|
|
version=ImportProfile.Versions.VERSION_1,
|
|
yaml_config="""
|
|
file_type: csv
|
|
date_format: "%Y-%m-%d"
|
|
column_mapping:
|
|
date:
|
|
source: date
|
|
description:
|
|
source: description
|
|
amount:
|
|
source: amount
|
|
transaction_type:
|
|
detection_method: always_expense
|
|
is_paid:
|
|
detection_method: always_paid
|
|
account:
|
|
source: account
|
|
match_field: name
|
|
""",
|
|
)
|
|
self.profile2 = ImportProfile.objects.create(
|
|
name="Profile 2",
|
|
version=ImportProfile.Versions.VERSION_1,
|
|
yaml_config="""
|
|
file_type: csv
|
|
date_format: "%Y-%m-%d"
|
|
column_mapping:
|
|
date:
|
|
source: date
|
|
description:
|
|
source: description
|
|
amount:
|
|
source: amount
|
|
transaction_type:
|
|
detection_method: always_income
|
|
is_paid:
|
|
detection_method: always_unpaid
|
|
account:
|
|
source: account
|
|
match_field: name
|
|
""",
|
|
)
|
|
|
|
# Create import runs
|
|
self.run1 = ImportRun.objects.create(
|
|
profile=self.profile1,
|
|
file_name="file1.csv",
|
|
status=ImportRun.Status.FINISHED,
|
|
)
|
|
self.run2 = ImportRun.objects.create(
|
|
profile=self.profile1,
|
|
file_name="file2.csv",
|
|
status=ImportRun.Status.QUEUED,
|
|
)
|
|
self.run3 = ImportRun.objects.create(
|
|
profile=self.profile2,
|
|
file_name="file3.csv",
|
|
status=ImportRun.Status.FINISHED,
|
|
)
|
|
|
|
def test_list_all_runs(self):
|
|
"""Test listing all runs"""
|
|
response = self.client.get("/api/import/runs/")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["count"], 3)
|
|
self.assertEqual(len(response.data["results"]), 3)
|
|
|
|
def test_list_runs_by_profile(self):
|
|
"""Test filtering runs by profile_id"""
|
|
response = self.client.get(f"/api/import/runs/?profile_id={self.profile1.id}")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["count"], 2)
|
|
for run in response.data["results"]:
|
|
self.assertEqual(run["profile"], self.profile1.id)
|
|
|
|
def test_list_runs_by_other_profile(self):
|
|
"""Test filtering runs by another profile_id"""
|
|
response = self.client.get(f"/api/import/runs/?profile_id={self.profile2.id}")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["count"], 1)
|
|
self.assertEqual(response.data["results"][0]["profile"], self.profile2.id)
|
|
|
|
def test_retrieve_run(self):
|
|
"""Test retrieving a specific run"""
|
|
response = self.client.get(f"/api/import/runs/{self.run1.id}/")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["id"], self.run1.id)
|
|
self.assertEqual(response.data["file_name"], "file1.csv")
|
|
self.assertEqual(response.data["status"], "FINISHED")
|
|
|
|
def test_retrieve_nonexistent_run(self):
|
|
"""Test retrieving a non-existent run returns 404"""
|
|
response = self.client.get("/api/import/runs/99999/")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
|
|
|
def test_runs_unauthenticated(self):
|
|
"""Test unauthenticated request returns 403"""
|
|
unauthenticated_client = APIClient()
|
|
response = unauthenticated_client.get("/api/import/runs/")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|