diff --git a/app/apps/import_app/schemas/__init__.py b/app/apps/import_app/schemas/__init__.py index f68ce79..530268d 100644 --- a/app/apps/import_app/schemas/__init__.py +++ b/app/apps/import_app/schemas/__init__.py @@ -1,8 +1 @@ -from apps.import_app.schemas.v1 import ( - ImportProfileSchema as SchemaV1, - ColumnMapping as ColumnMappingV1, - # TransformationRule as TransformationRuleV1, - ImportSettings as SettingsV1, - HashTransformationRule as HashTransformationRuleV1, - CompareDeduplicationRule as CompareDeduplicationRuleV1, -) +import apps.import_app.schemas.v1 as version_1 diff --git a/app/apps/import_app/schemas/v1.py b/app/apps/import_app/schemas/v1.py index 1cc7dc5..043f2a9 100644 --- a/app/apps/import_app/schemas/v1.py +++ b/app/apps/import_app/schemas/v1.py @@ -1,5 +1,5 @@ from typing import Dict, List, Optional, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator, field_validator class CompareDeduplicationRule(BaseModel): @@ -9,6 +9,12 @@ class CompareDeduplicationRule(BaseModel): ) match_type: Literal["lax", "strict"] + @field_validator("fields", mode="before") + def coerce_fields_to_dict(cls, v): + if isinstance(v, list): + return {k: v for d in v for k, v in d.items()} + return v + class ReplaceTransformationRule(BaseModel): field: str @@ -17,6 +23,10 @@ class ReplaceTransformationRule(BaseModel): ) pattern: str = Field(..., description="Pattern to match") replacement: str = Field(..., description="Value to replace with") + exclusive: bool = Field( + default=False, + description="If it should match against the last transformation or the original value", + ) class DateFormatTransformationRule(BaseModel): @@ -48,7 +58,7 @@ class SplitTransformationRule(BaseModel): ) -class ImportSettings(BaseModel): +class CSVImportSettings(BaseModel): skip_errors: bool = Field( default=False, description="If True, errors during import will be logged and skipped", @@ -56,7 +66,7 @@ class ImportSettings(BaseModel): file_type: Literal["csv"] = "csv" delimiter: str = Field(default=",", description="CSV delimiter character") encoding: str = Field(default="utf-8", description="File encoding") - skip_rows: int = Field( + skip_lines: int = Field( default=0, description="Number of rows to skip at the beginning of the file" ) importing: Literal[ @@ -69,20 +79,7 @@ class ColumnMapping(BaseModel): default=None, description="CSV column header. If None, the field will be generated from transformations", ) - target: Literal[ - "account", - "type", - "is_paid", - "date", - "reference_date", - "amount", - "notes", - "category", - "tags", - "entities", - "internal_note", - ] = Field(..., description="Transaction field to map to") - default_value: Optional[str] = None + default: Optional[str] = None required: bool = False transformations: Optional[ List[ @@ -95,10 +92,305 @@ class ColumnMapping(BaseModel): ] = Field(default_factory=list) +class TransactionAccountMapping(ColumnMapping): + target: Literal["account"] = Field(..., description="Transaction field to map to") + type: Literal["id", "name"] = "name" + coerce_to: Literal["str|int"] = Field("str|int", frozen=True) + + +class TransactionTypeMapping(ColumnMapping): + target: Literal["type"] = Field(..., description="Transaction field to map to") + detection_method: Literal["sign", "always_income", "always_expense"] = "sign" + coerce_to: Literal["transaction_type"] = Field("transaction_type", frozen=True) + + +class TransactionIsPaidMapping(ColumnMapping): + target: Literal["is_paid"] = Field(..., description="Transaction field to map to") + detection_method: Literal["sign", "boolean", "always_paid", "always_unpaid"] + coerce_to: Literal["is_paid"] = Field("is_paid", frozen=True) + + +class TransactionDateMapping(ColumnMapping): + target: Literal["date"] = Field(..., description="Transaction field to map to") + format: List[str] | str + coerce_to: Literal["date"] = Field("date", frozen=True) + + +class TransactionReferenceDateMapping(ColumnMapping): + target: Literal["reference_date"] = Field( + ..., description="Transaction field to map to" + ) + format: List[str] | str + coerce_to: Literal["date"] = Field("date", frozen=True) + + +class TransactionAmountMapping(ColumnMapping): + target: Literal["amount"] = Field(..., description="Transaction field to map to") + coerce_to: Literal["positive_decimal"] = Field("positive_decimal", frozen=True) + + +class TransactionDescriptionMapping(ColumnMapping): + target: Literal["description"] = Field( + ..., description="Transaction field to map to" + ) + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class TransactionNotesMapping(ColumnMapping): + target: Literal["notes"] = Field(..., description="Transaction field to map to") + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class TransactionTagsMapping(ColumnMapping): + target: Literal["tags"] = Field(..., description="Transaction field to map to") + create: bool = Field( + default=True, description="Create new tags if they doesn't exist" + ) + coerce_to: Literal["list"] = Field("list", frozen=True) + + +class TransactionEntitiesMapping(ColumnMapping): + target: Literal["entities"] = Field(..., description="Transaction field to map to") + create: bool = Field( + default=True, description="Create new entities if they doesn't exist" + ) + coerce_to: Literal["list"] = Field("list", frozen=True) + + +class TransactionCategoryMapping(ColumnMapping): + target: Literal["category"] = Field(..., description="Transaction field to map to") + create: bool = Field( + default=True, description="Create category if it doesn't exist" + ) + type: Literal["id", "name"] = "name" + coerce_to: Literal["str|int"] = Field("str|int", frozen=True) + + +class TransactionInternalMapping(ColumnMapping): + target: Literal["internal_note"] = Field( + ..., description="Transaction field to map to" + ) + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class CategoryNameMapping(ColumnMapping): + target: Literal["category_name"] = Field( + ..., description="Category field to map to" + ) + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class CategoryMuteMapping(ColumnMapping): + target: Literal["category_mute"] = Field( + ..., description="Category field to map to" + ) + coerce_to: Literal["bool"] = Field("bool", frozen=True) + + +class CategoryActiveMapping(ColumnMapping): + target: Literal["category_active"] = Field( + ..., description="Category field to map to" + ) + coerce_to: Literal["bool"] = Field("bool", frozen=True) + + +class TagNameMapping(ColumnMapping): + target: Literal["tag_name"] = Field(..., description="Tag field to map to") + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class TagActiveMapping(ColumnMapping): + target: Literal["tag_active"] = Field(..., description="Tag field to map to") + coerce_to: Literal["bool"] = Field("bool", frozen=True) + + +class EntityNameMapping(ColumnMapping): + target: Literal["entity_name"] = Field(..., description="Entity field to map to") + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class EntityActiveMapping(ColumnMapping): + target: Literal["entitiy_active"] = Field(..., description="Entity field to map to") + coerce_to: Literal["bool"] = Field("bool", frozen=True) + + +class AccountNameMapping(ColumnMapping): + target: Literal["account_name"] = Field(..., description="Account field to map to") + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class AccountGroupMapping(ColumnMapping): + target: Literal["account_group"] = Field(..., description="Account field to map to") + type: Literal["id", "name"] + coerce_to: Literal["str|int"] = Field("str|int", frozen=True) + + +class AccountCurrencyMapping(ColumnMapping): + target: Literal["account_currency"] = Field( + ..., description="Account field to map to" + ) + type: Literal["id", "name", "code"] + coerce_to: Literal["str|int"] = Field("str|int", frozen=True) + + +class AccountExchangeCurrencyMapping(ColumnMapping): + target: Literal["account_exchange_currency"] = Field( + ..., description="Account field to map to" + ) + type: Literal["id", "name", "code"] + coerce_to: Literal["str|int"] = Field("str|int", frozen=True) + + +class AccountIsAssetMapping(ColumnMapping): + target: Literal["account_is_asset"] = Field( + ..., description="Account field to map to" + ) + coerce_to: Literal["bool"] = Field("bool", frozen=True) + + +class AccountIsArchivedMapping(ColumnMapping): + target: Literal["account_is_archived"] = Field( + ..., description="Account field to map to" + ) + coerce_to: Literal["bool"] = Field("bool", frozen=True) + + +class CurrencyCodeMapping(ColumnMapping): + target: Literal["currency_code"] = Field( + ..., description="Currency field to map to" + ) + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class CurrencyNameMapping(ColumnMapping): + target: Literal["currency_name"] = Field( + ..., description="Currency field to map to" + ) + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class CurrencyDecimalPlacesMapping(ColumnMapping): + target: Literal["currency_decimal_places"] = Field( + ..., description="Currency field to map to" + ) + coerce_to: Literal["int"] = Field("int", frozen=True) + + +class CurrencyPrefixMapping(ColumnMapping): + target: Literal["currency_prefix"] = Field( + ..., description="Currency field to map to" + ) + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class CurrencySuffixMapping(ColumnMapping): + target: Literal["currency_suffix"] = Field( + ..., description="Currency field to map to" + ) + coerce_to: Literal["str"] = Field("str", frozen=True) + + +class CurrencyExchangeMapping(ColumnMapping): + target: Literal["currency_exchange"] = Field( + ..., description="Currency field to map to" + ) + type: Literal["id", "name", "code"] + coerce_to: Literal["str|int"] = Field("str|int", frozen=True) + + class ImportProfileSchema(BaseModel): - settings: ImportSettings - column_mapping: Dict[str, ColumnMapping] + settings: CSVImportSettings + mapping: Dict[ + str, + TransactionAccountMapping + | TransactionTypeMapping + | TransactionIsPaidMapping + | TransactionDateMapping + | TransactionReferenceDateMapping + | TransactionAmountMapping + | TransactionDescriptionMapping + | TransactionNotesMapping + | TransactionTagsMapping + | TransactionEntitiesMapping + | TransactionCategoryMapping + | TransactionInternalMapping + | CategoryNameMapping + | CategoryMuteMapping + | CategoryActiveMapping + | TagNameMapping + | TagActiveMapping + | EntityNameMapping + | EntityActiveMapping + | AccountNameMapping + | AccountGroupMapping + | AccountCurrencyMapping + | AccountExchangeCurrencyMapping + | AccountIsAssetMapping + | AccountIsArchivedMapping + | CurrencyCodeMapping + | CurrencyNameMapping + | CurrencyDecimalPlacesMapping + | CurrencyPrefixMapping + | CurrencySuffixMapping + | CurrencyExchangeMapping, + ] deduplication: List[CompareDeduplicationRule] = Field( default_factory=list, description="Rules for deduplicating records during import", ) + + @model_validator(mode="after") + def validate_mappings(self) -> "ImportProfileSchema": + import_type = self.settings.importing + + # Define allowed mapping types for each import type + allowed_mappings = { + "transactions": ( + TransactionAccountMapping, + TransactionTypeMapping, + TransactionIsPaidMapping, + TransactionDateMapping, + TransactionReferenceDateMapping, + TransactionAmountMapping, + TransactionDescriptionMapping, + TransactionNotesMapping, + TransactionTagsMapping, + TransactionEntitiesMapping, + TransactionCategoryMapping, + TransactionInternalMapping, + ), + "accounts": ( + AccountNameMapping, + AccountGroupMapping, + AccountCurrencyMapping, + AccountExchangeCurrencyMapping, + AccountIsAssetMapping, + AccountIsArchivedMapping, + ), + "currencies": ( + CurrencyCodeMapping, + CurrencyNameMapping, + CurrencyDecimalPlacesMapping, + CurrencyPrefixMapping, + CurrencySuffixMapping, + CurrencyExchangeMapping, + ), + "categories": ( + CategoryNameMapping, + CategoryMuteMapping, + CategoryActiveMapping, + ), + "tags": (TagNameMapping, TagActiveMapping), + "entities": (EntityNameMapping, EntityActiveMapping), + } + + allowed_types = allowed_mappings[import_type] + + for field_name, mapping in self.mapping.items(): + if not isinstance(mapping, allowed_types): + raise ValueError( + f"Mapping type '{type(mapping).__name__}' is not allowed when importing {import_type}. " + f"Allowed types are: {', '.join(t.__name__ for t in allowed_types)}" + ) + + return self diff --git a/app/apps/import_app/services/v1.py b/app/apps/import_app/services/v1.py index 333eb6e..069115b 100644 --- a/app/apps/import_app/services/v1.py +++ b/app/apps/import_app/services/v1.py @@ -1,42 +1,56 @@ import csv import hashlib +import logging +import os import re from datetime import datetime -from typing import Dict, Any, Literal +from decimal import Decimal +from typing import Dict, Any, Literal, Union import yaml - from django.db import transaction -from django.core.files.storage import default_storage from django.utils import timezone +from apps.accounts.models import Account, AccountGroup +from apps.currencies.models import Currency from apps.import_app.models import ImportRun, ImportProfile -from apps.import_app.schemas import ( - SchemaV1, - ColumnMappingV1, - SettingsV1, - HashTransformationRuleV1, - CompareDeduplicationRuleV1, +from apps.import_app.schemas import version_1 +from apps.transactions.models import ( + Transaction, + TransactionCategory, + TransactionTag, + TransactionEntity, ) -from apps.transactions.models import Transaction + +logger = logging.getLogger(__name__) class ImportService: + TEMP_DIR = "/usr/src/app/temp" + def __init__(self, import_run: ImportRun): self.import_run: ImportRun = import_run self.profile: ImportProfile = import_run.profile - self.config: SchemaV1 = self._load_config() - self.settings: SettingsV1 = self.config.settings - self.deduplication: list[CompareDeduplicationRuleV1] = self.config.deduplication - self.mapping: Dict[str, ColumnMappingV1] = self.config.column_mapping + self.config: version_1.ImportProfileSchema = self._load_config() + self.settings: version_1.CSVImportSettings = self.config.settings + self.deduplication: list[version_1.CompareDeduplicationRule] = ( + self.config.deduplication + ) + self.mapping: Dict[str, version_1.ColumnMapping] = self.config.mapping - def _load_config(self) -> SchemaV1: + # Ensure temp directory exists + os.makedirs(self.TEMP_DIR, exist_ok=True) + + def _load_config(self) -> version_1.ImportProfileSchema: yaml_data = yaml.safe_load(self.profile.yaml_config) - - if self.profile.version == ImportProfile.Versions.VERSION_1: - return SchemaV1(**yaml_data) - - raise ValueError(f"Unsupported version: {self.profile.version}") + try: + config = version_1.ImportProfileSchema(**yaml_data) + except Exception as e: + self._log("error", f"Fatal error processing YAML config: {str(e)}") + self._update_status("FAILED") + raise e + else: + return config def _log(self, level: str, message: str, **kwargs) -> None: """Add a log entry to the import run logs""" @@ -53,6 +67,48 @@ class ImportService: self.import_run.logs += log_line self.import_run.save(update_fields=["logs"]) + def _update_totals( + self, + field: Literal["total", "processed", "successful", "skipped", "failed"], + value: int, + ) -> None: + if field == "total": + self.import_run.total_rows = value + self.import_run.save(update_fields=["total_rows"]) + elif field == "processed": + self.import_run.processed_rows = value + self.import_run.save(update_fields=["processed_rows"]) + elif field == "successful": + self.import_run.successful_rows = value + self.import_run.save(update_fields=["successful_rows"]) + elif field == "skipped": + self.import_run.skipped_rows = value + self.import_run.save(update_fields=["skipped_rows"]) + elif field == "failed": + self.import_run.failed_rows = value + self.import_run.save(update_fields=["failed_rows"]) + + def _increment_totals( + self, + field: Literal["total", "processed", "successful", "skipped", "failed"], + value: int, + ) -> None: + if field == "total": + self.import_run.total_rows = self.import_run.total_rows + value + self.import_run.save(update_fields=["total_rows"]) + elif field == "processed": + self.import_run.processed_rows = self.import_run.processed_rows + value + self.import_run.save(update_fields=["processed_rows"]) + elif field == "successful": + self.import_run.successful_rows = self.import_run.successful_rows + value + self.import_run.save(update_fields=["successful_rows"]) + elif field == "skipped": + self.import_run.skipped_rows = self.import_run.skipped_rows + value + self.import_run.save(update_fields=["skipped_rows"]) + elif field == "failed": + self.import_run.failed_rows = self.import_run.failed_rows + value + self.import_run.save(update_fields=["failed_rows"]) + def _update_status( self, new_status: Literal["PROCESSING", "FAILED", "FINISHED"] ) -> None: @@ -67,15 +123,12 @@ class ImportService: @staticmethod def _transform_value( - value: str, mapping: ColumnMappingV1, row: Dict[str, str] = None + value: str, mapping: version_1.ColumnMapping, row: Dict[str, str] = None ) -> Any: transformed = value for transform in mapping.transformations: if transform.type == "hash": - if not isinstance(transform, HashTransformationRuleV1): - continue - # Collect all values to be hashed values_to_hash = [] for field in transform.fields: @@ -88,47 +141,143 @@ class ImportService: transformed = hashlib.sha256(concatenated.encode()).hexdigest() elif transform.type == "replace": - transformed = transformed.replace( - transform.pattern, transform.replacement - ) + if transform.exclusive: + transformed = value.replace( + transform.pattern, transform.replacement + ) + else: + transformed = transformed.replace( + transform.pattern, transform.replacement + ) elif transform.type == "regex": - transformed = re.sub( - transform.pattern, transform.replacement, transformed - ) + if transform.exclusive: + transformed = re.sub( + transform.pattern, transform.replacement, value + ) + else: + transformed = re.sub( + transform.pattern, transform.replacement, transformed + ) elif transform.type == "date_format": transformed = datetime.strptime( - transformed, transform.pattern - ).strftime(transform.replacement) + transformed, transform.original_format + ).strftime(transform.new_format) + elif transform.type == "merge": + values_to_merge = [] + for field in transform.fields: + if field in row: + values_to_merge.append(str(row[field])) + transformed = transform.separator.join(values_to_merge) + elif transform.type == "split": + parts = transformed.split(transform.separator) + if transform.index is not None: + transformed = parts[transform.index] if parts else "" + else: + transformed = parts return transformed - def _map_row_to_transaction(self, row: Dict[str, str]) -> Dict[str, Any]: - transaction_data = {} + def _create_transaction(self, data: Dict[str, Any]) -> Transaction: + tags = [] + entities = [] + # Handle related objects first + if "category" in data: + category_name = data.pop("category") + category, _ = TransactionCategory.objects.get_or_create(name=category_name) + data["category"] = category + self.import_run.categories.add(category) - for field, mapping in self.mapping.items(): - # If source is None, use None as the initial value - value = row.get(mapping.source) if mapping.source else None + if "account" in data: + account_id = data.pop("account") + account = None + if isinstance(account_id, str): + account = Account.objects.get(name=account_id) + elif isinstance(account_id, int): + account = Account.objects.get(id=account_id) + data["account"] = account + # self.import_run.acc.add(category) - # Use default_value if value is None - if value is None: - value = mapping.default_value + if "tags" in data: + tag_names = data.pop("tags").split(",") + for tag_name in tag_names: + tag, _ = TransactionTag.objects.get_or_create(name=tag_name.strip()) + tags.append(tag) + self.import_run.tags.add(tag) - if mapping.required and value is None and not mapping.transformations: - raise ValueError(f"Required field {field} is missing") + if "entities" in data: + entity_names = data.pop("entities").split(",") + for entity_name in entity_names: + entity, _ = TransactionEntity.objects.get_or_create( + name=entity_name.strip() + ) + entities.append(entity) + self.import_run.entities.add(entity) - # Apply transformations even if initial value is None - if mapping.transformations: - value = self._transform_value(value, mapping, row) + if "amount" in data: + amount = data.pop("amount") + data["amount"] = abs(Decimal(amount)) - if value is not None: - transaction_data[field] = value + # Create the transaction + new_transaction = Transaction.objects.create(**data) + self.import_run.transactions.add(new_transaction) - return transaction_data + # Add many-to-many relationships + if tags: + new_transaction.tags.set(tags) + if entities: + new_transaction.entities.set(entities) + + return new_transaction + + def _create_account(self, data: Dict[str, Any]) -> Account: + if "group" in data: + group_name = data.pop("group") + group, _ = AccountGroup.objects.get_or_create(name=group_name) + data["group"] = group + + # Handle currency references + if "currency" in data: + currency = Currency.objects.get(code=data["currency"]) + data["currency"] = currency + self.import_run.currencies.add(currency) + + if "exchange_currency" in data: + exchange_currency = Currency.objects.get(code=data["exchange_currency"]) + data["exchange_currency"] = exchange_currency + self.import_run.currencies.add(exchange_currency) + + return Account.objects.create(**data) + + def _create_currency(self, data: Dict[str, Any]) -> Currency: + # Handle exchange currency reference + if "exchange_currency" in data: + exchange_currency = Currency.objects.get(code=data["exchange_currency"]) + data["exchange_currency"] = exchange_currency + self.import_run.currencies.add(exchange_currency) + + currency = Currency.objects.create(**data) + self.import_run.currencies.add(currency) + return currency + + def _create_category(self, data: Dict[str, Any]) -> TransactionCategory: + category = TransactionCategory.objects.create(**data) + self.import_run.categories.add(category) + return category + + def _create_tag(self, data: Dict[str, Any]) -> TransactionTag: + tag = TransactionTag.objects.create(**data) + self.import_run.tags.add(tag) + return tag + + def _create_entity(self, data: Dict[str, Any]) -> TransactionEntity: + entity = TransactionEntity.objects.create(**data) + self.import_run.entities.add(entity) + return entity def _check_duplicate_transaction(self, transaction_data: Dict[str, Any]) -> bool: for rule in self.deduplication: if rule.type == "compare": - query = Transaction.objects.all() + query = Transaction.objects.all().values("id") # Build query conditions for each field in the rule for field, header in rule.fields.items(): @@ -146,65 +295,214 @@ class ImportService: return False - def _process_csv(self, file_path): - with open(file_path, "r", encoding=self.settings.encoding) as csv_file: - reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter) + def _coerce_type( + self, value: str, mapping: version_1.ColumnMapping + ) -> Union[str, int, bool, Decimal, datetime, list]: + if not value: + return None + + coerce_to = mapping.coerce_to + + if "|" in coerce_to: + types = coerce_to.split("|") + for t in types: + try: + return self._coerce_single_type(value, t, mapping) + except ValueError: + continue + raise ValueError( + f"Could not coerce '{value}' to any of the types: {coerce_to}" + ) + else: + return self._coerce_single_type(value, coerce_to, mapping) + + def _coerce_single_type( + self, value: str, coerce_to: str, mapping: version_1.ColumnMapping + ) -> Union[str, int, bool, Decimal, datetime.date, list]: + if coerce_to == "str": + return str(value) + elif coerce_to == "int": + if hasattr(mapping, "type") and mapping.type == "id": + return int(value) + elif hasattr(mapping, "type") and mapping.type in ["name", "code"]: + return str(value) + else: + return int(value) + elif coerce_to == "bool": + return value.lower() in ["true", "1", "yes", "y", "on"] + elif coerce_to == "positive_decimal": + return abs(Decimal(value)) + elif coerce_to == "date": + if isinstance( + mapping, + ( + version_1.TransactionDateMapping, + version_1.TransactionReferenceDateMapping, + ), + ): + formats = ( + mapping.format + if isinstance(mapping.format, list) + else [mapping.format] + ) + for fmt in formats: + try: + return datetime.strptime(value, fmt).date() + except ValueError: + continue + raise ValueError( + f"Could not parse date '{value}' with any of the provided formats" + ) + else: + raise ValueError( + "Date coercion is only supported for TransactionDateMapping and TransactionReferenceDateMapping" + ) + elif coerce_to == "list": + return ( + value + if isinstance(value, list) + else [item.strip() for item in value.split(",") if item.strip()] + ) + elif coerce_to == "transaction_type": + if isinstance(mapping, version_1.TransactionTypeMapping): + if mapping.detection_method == "sign": + return ( + Transaction.Type.EXPENSE + if value.startswith("-") + else Transaction.Type.INCOME + ) + elif mapping.detection_method == "always_income": + return Transaction.Type.INCOME + elif mapping.detection_method == "always_expense": + return Transaction.Type.EXPENSE + raise ValueError("Invalid transaction type detection method") + elif coerce_to == "is_paid": + if isinstance(mapping, version_1.TransactionIsPaidMapping): + if mapping.detection_method == "sign": + return not value.startswith("-") + elif mapping.detection_method == "boolean": + return value.lower() in ["true", "1", "yes", "y", "on"] + elif mapping.detection_method == "always_paid": + return True + elif mapping.detection_method == "always_unpaid": + return False + raise ValueError("Invalid is_paid detection method") + else: + raise ValueError(f"Unsupported coercion type: {coerce_to}") + + def _map_row(self, row: Dict[str, str]) -> Dict[str, Any]: + mapped_data = {} + + for field, mapping in self.mapping.items(): + # If source is None, use None as the initial value + value = row.get(mapping.source) if mapping.source else None + + # Use default_value if value is None + if value is None: + value = mapping.default + + if mapping.required and value is None and not mapping.transformations: + raise ValueError(f"Required field {field} is missing") + + # Apply transformations + if mapping.transformations: + value = self._transform_value(value, mapping, row) + + value = self._coerce_type(value, mapping) + + if value is not None: + # Remove the prefix from the target field + target = mapping.target + if self.settings.importing == "transactions": + mapped_data[target] = value + else: + # Remove the model prefix (e.g., "account_" from "account_name") + field_name = target.split("_", 1)[1] + mapped_data[field_name] = value + + return mapped_data + + def _process_row(self, row: Dict[str, str], row_number: int) -> None: + try: + mapped_data = self._map_row(row) + + if mapped_data: + # Handle different import types + if self.settings.importing == "transactions": + if self.deduplication and self._check_duplicate_transaction( + mapped_data + ): + self._increment_totals("skipped", 1) + self._log("info", f"Skipped duplicate row {row_number}") + return + self._create_transaction(mapped_data) + elif self.settings.importing == "accounts": + self._create_account(mapped_data) + elif self.settings.importing == "currencies": + self._create_currency(mapped_data) + elif self.settings.importing == "categories": + self._create_category(mapped_data) + elif self.settings.importing == "tags": + self._create_tag(mapped_data) + elif self.settings.importing == "entities": + self._create_entity(mapped_data) + + self._increment_totals("successful", value=1) + self._log("info", f"Successfully processed row {row_number}") + + self._increment_totals("processed", value=1) + + except Exception as e: + if not self.settings.skip_errors: + self._log("error", f"Fatal error processing row {row_number}: {str(e)}") + self._update_status("FAILED") + raise + else: + self._log("warning", f"Error processing row {row_number}: {str(e)}") + self._increment_totals("failed", value=1) + + logger.error(f"Fatal error processing row {row_number}", exc_info=e) + + def _process_csv(self, file_path): + # First pass: count rows + with open(file_path, "r", encoding=self.settings.encoding) as csv_file: + # Skip specified number of rows + for _ in range(self.settings.skip_lines): + next(csv_file) + + reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter) + self._update_totals("total", value=sum(1 for _ in reader)) + + with open(file_path, "r", encoding=self.settings.encoding) as csv_file: + # Skip specified number of rows + for _ in range(self.settings.skip_lines): + next(csv_file) + if self.settings.skip_lines: + self._log("info", f"Skipped {self.settings.skip_lines} initial lines") - # Count total rows - self.import_run.total_rows = sum(1 for _ in reader) - csv_file.seek(0) reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter) self._log("info", f"Starting import with {self.import_run.total_rows} rows") - # Skip specified number of rows - for _ in range(self.settings.skip_rows): - next(reader) + with transaction.atomic(): + for row_number, row in enumerate(reader, start=1): + self._process_row(row, row_number) + self._increment_totals("processed", value=1) - if self.settings.skip_rows: - self._log("info", f"Skipped {self.settings.skip_rows} initial rows") - - for row_number, row in enumerate(reader, start=1): - try: - transaction_data = self._map_row_to_transaction(row) - - if transaction_data: - if self.deduplication and self._check_duplicate_transaction( - transaction_data - ): - self.import_run.skipped_rows += 1 - self._log("info", f"Skipped duplicate row {row_number}") - continue - - self.import_run.transactions.add(transaction_data) - self.import_run.successful_rows += 1 - self._log("debug", f"Successfully processed row {row_number}") - - self.import_run.processed_rows += 1 - self.import_run.save( - update_fields=[ - "processed_rows", - "successful_rows", - "skipped_rows", - ] - ) - - except Exception as e: - if not self.settings.skip_errors: - self._log( - "error", - f"Fatal error processing row {row_number}: {str(e)}", - ) - self._update_status("FAILED") - raise - else: - self._log( - "warning", f"Error processing row {row_number}: {str(e)}" - ) - self.import_run.failed_rows += 1 - self.import_run.save(update_fields=["failed_rows"]) + def _validate_file_path(self, file_path: str) -> str: + """ + Validates that the file path is within the allowed temporary directory. + Returns the absolute path. + """ + abs_path = os.path.abspath(file_path) + if not abs_path.startswith(self.TEMP_DIR): + raise ValueError(f"Invalid file path. File must be in {self.TEMP_DIR}") + return abs_path def process_file(self, file_path: str): + # Validate and get absolute path + file_path = self._validate_file_path(file_path) + self._update_status("PROCESSING") self.import_run.started_at = timezone.now() self.import_run.save(update_fields=["started_at"]) @@ -232,6 +530,12 @@ class ImportService: finally: self._log("info", "Cleaning up temporary files") - default_storage.delete(file_path) + try: + if os.path.exists(file_path): + os.remove(file_path) + self._log("info", f"Deleted temporary file: {file_path}") + except OSError as e: + self._log("warning", f"Failed to delete temporary file: {str(e)}") + self.import_run.finished_at = timezone.now() self.import_run.save(update_fields=["finished_at"])