import csv import hashlib import logging import os import re from datetime import datetime, date from decimal import Decimal, InvalidOperation from typing import Dict, Any, Literal, Union import openpyxl import xlrd import yaml from cachalot.api import cachalot_disabled from django.utils import timezone from openpyxl.utils.exceptions import InvalidFileException from apps.accounts.models import Account, AccountGroup from apps.currencies.models import Currency from apps.import_app.models import ImportRun, ImportProfile from apps.import_app.schemas import version_1 from apps.transactions.models import ( Transaction, TransactionCategory, TransactionTag, TransactionEntity, ) from apps.rules.signals import transaction_created from apps.import_app.schemas.v1 import ( TransactionCategoryMapping, TransactionAccountMapping, TransactionTagsMapping, TransactionEntitiesMapping, ) logger = logging.getLogger(__name__) class ImportService: TEMP_DIR = "/usr/src/app/temp" def __init__(self, import_run: ImportRun): self.import_run: ImportRun = import_run self.profile: ImportProfile = import_run.profile self.config: version_1.ImportProfileSchema = self._load_config() self.settings: version_1.CSVImportSettings | version_1.ExcelImportSettings = ( self.config.settings ) self.deduplication: list[version_1.CompareDeduplicationRule] = ( self.config.deduplication ) self.mapping: Dict[str, version_1.ColumnMapping] = self.config.mapping # Ensure temp directory exists os.makedirs(self.TEMP_DIR, exist_ok=True) def _load_config(self) -> version_1.ImportProfileSchema: yaml_data = yaml.safe_load(self.profile.yaml_config) try: config = version_1.ImportProfileSchema(**yaml_data) except Exception as e: self._log("error", f"Fatal error processing YAML config: {str(e)}") self._update_status("FAILED") raise e else: return config def _log(self, level: str, message: str, **kwargs) -> None: """Add a log entry to the import run logs""" timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # Format additional context if present context = "" if kwargs: context = " - " + ", ".join(f"{k}={v}" for k, v in kwargs.items()) log_line = f"[{timestamp}] {level.upper()}: {message}{context}\n" # Append to existing logs self.import_run.logs += log_line self.import_run.save(update_fields=["logs"]) if level == "info": logger.info(log_line) elif level == "warning": logger.warning(log_line) elif level == "error": logger.error(log_line, exc_info=True) def _update_totals( self, field: Literal["total", "processed", "successful", "skipped", "failed"], value: int, ) -> None: if field == "total": self.import_run.total_rows = value self.import_run.save(update_fields=["total_rows"]) elif field == "processed": self.import_run.processed_rows = value self.import_run.save(update_fields=["processed_rows"]) elif field == "successful": self.import_run.successful_rows = value self.import_run.save(update_fields=["successful_rows"]) elif field == "skipped": self.import_run.skipped_rows = value self.import_run.save(update_fields=["skipped_rows"]) elif field == "failed": self.import_run.failed_rows = value self.import_run.save(update_fields=["failed_rows"]) def _increment_totals( self, field: Literal["total", "processed", "successful", "skipped", "failed"], value: int, ) -> None: if field == "total": self.import_run.total_rows = self.import_run.total_rows + value self.import_run.save(update_fields=["total_rows"]) elif field == "processed": self.import_run.processed_rows = self.import_run.processed_rows + value self.import_run.save(update_fields=["processed_rows"]) elif field == "successful": self.import_run.successful_rows = self.import_run.successful_rows + value self.import_run.save(update_fields=["successful_rows"]) elif field == "skipped": self.import_run.skipped_rows = self.import_run.skipped_rows + value self.import_run.save(update_fields=["skipped_rows"]) elif field == "failed": self.import_run.failed_rows = self.import_run.failed_rows + value self.import_run.save(update_fields=["failed_rows"]) def _update_status( self, new_status: Literal["PROCESSING", "FAILED", "FINISHED"] ) -> None: if new_status == "PROCESSING": self.import_run.status = ImportRun.Status.PROCESSING elif new_status == "FAILED": self.import_run.status = ImportRun.Status.FAILED elif new_status == "FINISHED": self.import_run.status = ImportRun.Status.FINISHED self.import_run.save(update_fields=["status"]) def _transform_value( self, value: str, mapping: version_1.ColumnMapping, row: Dict[str, str] = None, mapped_data: Dict[str, Any] = None, ) -> Any: transformed = value for transform in mapping.transformations: if transform.type == "hash": # Collect all values to be hashed values_to_hash = [] for field in transform.fields: if field in row: values_to_hash.append(str(row[field])) elif ( field.startswith("__") and mapped_data and field[2:] in mapped_data ): values_to_hash.append(str(mapped_data[field[2:]])) if values_to_hash: concatenated = "|".join(values_to_hash) transformed = hashlib.sha256(concatenated.encode()).hexdigest() elif transform.type == "replace": if transform.exclusive: transformed = value.replace( transform.pattern, transform.replacement ) else: transformed = transformed.replace( transform.pattern, transform.replacement ) elif transform.type == "regex": if transform.exclusive: transformed = re.sub( transform.pattern, transform.replacement, value ) else: transformed = re.sub( transform.pattern, transform.replacement, transformed ) elif transform.type == "date_format": transformed = datetime.strptime( transformed, transform.original_format ).strftime(transform.new_format) elif transform.type == "merge": values_to_merge = [] for field in transform.fields: if field in row: values_to_merge.append(str(row[field])) elif ( field.startswith("__") and mapped_data and field[2:] in mapped_data ): values_to_merge.append(str(mapped_data[field[2:]])) transformed = transform.separator.join(values_to_merge) elif transform.type == "split": parts = transformed.split(transform.separator) if transform.index is not None: transformed = parts[transform.index] if parts else "" else: transformed = parts elif transform.type in ["add", "subtract"]: try: source_value = Decimal(transformed) # First check row data, then mapped data if not found field_value = row.get(transform.field) if field_value is None and transform.field.startswith("__"): field_value = mapped_data.get(transform.field[2:]) if field_value is None: raise KeyError( f"Field '{transform.field}' not found in row or mapped data" ) field_value = self._prepare_numeric_value( str(field_value), transform.thousand_separator, transform.decimal_separator, ) if transform.absolute_values: source_value = abs(source_value) field_value = abs(field_value) if transform.type == "add": transformed = str(source_value + field_value) else: # subtract transformed = str(source_value - field_value) except (InvalidOperation, KeyError, AttributeError) as e: logger.warning( f"Error in {transform.type} transformation: {e}. Values: {transformed}, {transform.field}" ) return transformed def _create_transaction(self, data: Dict[str, Any]) -> Transaction: tags = [] entities = [] # Handle related objects first if "category" in data: if "category" in data: category_name = data.pop("category") category_mapping = next( ( m for m in self.mapping.values() if isinstance(m, TransactionCategoryMapping) and m.target == "category" ), None, ) try: if category_mapping: if category_mapping.type == "id": category = TransactionCategory.objects.get(id=category_name) else: # name if getattr(category_mapping, "create", False): try: category = TransactionCategory.objects.get( name=category_name ) except TransactionCategory.DoesNotExist: category = TransactionCategory(name=category_name) category.save() else: category = TransactionCategory.objects.filter( name=category_name ).first() if category: data["category"] = category self.import_run.categories.add(category) except (TransactionCategory.DoesNotExist, ValueError): # Ignore if category doesn't exist and create is False or not set data["category"] = None if "account" in data: account_id = data.pop("account") account_mapping = next( ( m for m in self.mapping.values() if isinstance(m, TransactionAccountMapping) and m.target == "account" ), None, ) try: if account_mapping and account_mapping.type == "id": account = Account.objects.filter(id=account_id).first() else: # name account = Account.objects.filter(name=account_id).first() if account: data["account"] = account except ValueError: # Ignore if account doesn't exist pass if "tags" in data: tag_names = data.pop("tags") tags_mapping = next( ( m for m in self.mapping.values() if isinstance(m, TransactionTagsMapping) and m.target == "tags" ), None, ) for tag_name in tag_names: try: if tags_mapping: if tags_mapping.type == "id": tag = TransactionTag.objects.filter(id=tag_name).first() else: # name if getattr(tags_mapping, "create", False): try: tag = TransactionTag.objects.get( name=tag_name.strip() ) except TransactionTag.DoesNotExist: tag = TransactionTag(name=tag_name.strip()) tag.save() else: tag = TransactionTag.objects.filter( name=tag_name.strip() ).first() if tag: tags.append(tag) self.import_run.tags.add(tag) except ValueError: # Ignore if tag doesn't exist and create is False or not set continue if "entities" in data: entity_names = data.pop("entities") entities_mapping = next( ( m for m in self.mapping.values() if isinstance(m, TransactionEntitiesMapping) and m.target == "entities" ), None, ) for entity_name in entity_names: try: if entities_mapping: if entities_mapping.type == "id": entity = TransactionTag.objects.filter( id=entity_name ).first() else: # name if getattr(entities_mapping, "create", False): try: entity = TransactionEntity.objects.get( name=entity_name.strip() ) except TransactionEntity.DoesNotExist: entity = TransactionEntity(name=entity_name.strip()) entity.save() else: entity = TransactionEntity.objects.filter( name=entity_name.strip() ).first() if entity: entities.append(entity) self.import_run.entities.add(entity) except ValueError: # Ignore if entity doesn't exist and create is False or not set continue # Create the transaction new_transaction = Transaction.objects.create(**data) self.import_run.transactions.add(new_transaction) # Add many-to-many relationships if tags: new_transaction.tags.set(tags) if entities: new_transaction.entities.set(entities) if self.settings.trigger_transaction_rules: transaction_created.send(sender=new_transaction) return new_transaction def _create_account(self, data: Dict[str, Any]) -> Account: if "group" in data: group_name = data.pop("group") try: group = AccountGroup.objects.get(name=group_name) except AccountGroup.DoesNotExist: group = AccountGroup(name=group_name) group.save() data["group"] = group # Handle currency references if "currency" in data: currency = Currency.objects.get(code=data["currency"]) data["currency"] = currency self.import_run.currencies.add(currency) if "exchange_currency" in data: exchange_currency = Currency.objects.get(code=data["exchange_currency"]) data["exchange_currency"] = exchange_currency self.import_run.currencies.add(exchange_currency) return Account.objects.create(**data) def _create_currency(self, data: Dict[str, Any]) -> Currency: # Handle exchange currency reference if "exchange_currency" in data: exchange_currency = Currency.objects.get(code=data["exchange_currency"]) data["exchange_currency"] = exchange_currency self.import_run.currencies.add(exchange_currency) currency = Currency.objects.create(**data) self.import_run.currencies.add(currency) return currency def _create_category(self, data: Dict[str, Any]) -> TransactionCategory: category = TransactionCategory.objects.create(**data) self.import_run.categories.add(category) return category def _create_tag(self, data: Dict[str, Any]) -> TransactionTag: tag = TransactionTag.objects.create(**data) self.import_run.tags.add(tag) return tag def _create_entity(self, data: Dict[str, Any]) -> TransactionEntity: entity = TransactionEntity.objects.create(**data) self.import_run.entities.add(entity) return entity def _check_duplicate_transaction(self, transaction_data: Dict[str, Any]) -> bool: for rule in self.deduplication: if rule.type == "compare": query = Transaction.all_objects.all().values("id") # Build query conditions for each field in the rule for field in rule.fields: if field in transaction_data: if rule.match_type == "strict": query = query.filter(**{field: transaction_data[field]}) else: # lax matching query = query.filter( **{f"{field}__iexact": transaction_data[field]} ) # If we found any matching transaction, it's a duplicate if query.exists(): return True return False def _coerce_type( self, value: str, mapping: version_1.ColumnMapping ) -> Union[str, int, bool, Decimal, datetime, list, None]: if not value: return None coerce_to = mapping.coerce_to return self._coerce_single_type(value, coerce_to, mapping) @staticmethod def _coerce_single_type( value: str, coerce_to: str, mapping: version_1.ColumnMapping ) -> Union[str, int, bool, Decimal, datetime.date, list]: if coerce_to == "str": return str(value) elif coerce_to == "int": return int(value) elif coerce_to == "str|int": if hasattr(mapping, "type") and mapping.type == "id": return int(value) elif hasattr(mapping, "type") and mapping.type in ["name", "code"]: return str(value) else: return str(value) elif coerce_to == "bool": return value.lower() in ["true", "1", "yes", "y", "on"] elif coerce_to == "positive_decimal": return abs(Decimal(value)) elif coerce_to == "date": if isinstance( mapping, ( version_1.TransactionDateMapping, version_1.TransactionReferenceDateMapping, ), ): if isinstance(value, datetime): return value.date() elif isinstance(value, date): return value formats = ( mapping.format if isinstance(mapping.format, list) else [mapping.format] ) for fmt in formats: try: return datetime.strptime(value, fmt).date() except ValueError: continue raise ValueError( f"Could not parse date '{value}' with any of the provided formats" ) else: raise ValueError( "Date coercion is only supported for TransactionDateMapping and TransactionReferenceDateMapping" ) elif coerce_to == "list": return ( value if isinstance(value, list) else [item.strip() for item in value.split(",") if item.strip()] ) elif coerce_to == "transaction_type": if isinstance(mapping, version_1.TransactionTypeMapping): if mapping.detection_method == "sign": return ( Transaction.Type.EXPENSE if value.startswith("-") else Transaction.Type.INCOME ) elif mapping.detection_method == "always_income": return Transaction.Type.INCOME elif mapping.detection_method == "always_expense": return Transaction.Type.EXPENSE raise ValueError("Invalid transaction type detection method") elif coerce_to == "is_paid": if isinstance(mapping, version_1.TransactionIsPaidMapping): if mapping.detection_method == "boolean": return value.lower() in ["true", "1", "yes", "y", "on"] elif mapping.detection_method == "always_paid": return True elif mapping.detection_method == "always_unpaid": return False raise ValueError("Invalid is_paid detection method") else: raise ValueError(f"Unsupported coercion type: {coerce_to}") def _map_row(self, row: Dict[str, str]) -> Dict[str, Any]: mapped_data = {} for field, mapping in self.mapping.items(): value = None if isinstance(mapping.source, str): if mapping.source in row: value = row[mapping.source] elif ( mapping.source.startswith("__") and mapping.source[2:] in mapped_data ): value = mapped_data[mapping.source[2:]] elif isinstance(mapping.source, list): for source in mapping.source: if source in row: value = row[source] break elif source.startswith("__") and source[2:] in mapped_data: value = mapped_data[source[2:]] break if value is None: value = mapping.default if mapping.transformations: value = self._transform_value(value, mapping, row, mapped_data) value = self._coerce_type(value, mapping) if mapping.required and value is None: raise ValueError(f"Required field {field} is missing") if value is not None: target = mapping.target if self.settings.importing == "transactions": mapped_data[target] = value else: field_name = target.split("_", 1)[1] mapped_data[field_name] = value return mapped_data @staticmethod def _prepare_numeric_value( value: str, thousand_separator: str, decimal_separator: str ) -> Decimal: # Remove thousand separators if thousand_separator: value = value.replace(thousand_separator, "") # Replace decimal separator with dot if decimal_separator != ".": value = value.replace(decimal_separator, ".") return Decimal(value) def _process_row(self, row: Dict[str, str], row_number: int) -> None: try: mapped_data = self._map_row(row) if mapped_data: # Handle different import types if self.settings.importing == "transactions": if self.deduplication and self._check_duplicate_transaction( mapped_data ): self._increment_totals("skipped", 1) self._log("info", f"Skipped duplicate row {row_number}") return self._create_transaction(mapped_data) elif self.settings.importing == "accounts": self._create_account(mapped_data) elif self.settings.importing == "currencies": self._create_currency(mapped_data) elif self.settings.importing == "categories": self._create_category(mapped_data) elif self.settings.importing == "tags": self._create_tag(mapped_data) elif self.settings.importing == "entities": self._create_entity(mapped_data) self._increment_totals("successful", value=1) self._log("info", f"Successfully processed row {row_number}") self._increment_totals("processed", value=1) except Exception as e: if not self.settings.skip_errors: self._log("error", f"Fatal error processing row {row_number}: {str(e)}") self._update_status("FAILED") raise else: self._log("warning", f"Error processing row {row_number}: {str(e)}") self._increment_totals("failed", value=1) logger.error(f"Fatal error processing row {row_number}", exc_info=e) def _process_csv(self, file_path): # First pass: count rows with open(file_path, "r", encoding=self.settings.encoding) as csv_file: # Skip specified number of rows for _ in range(self.settings.skip_lines): next(csv_file) reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter) self._update_totals("total", value=sum(1 for _ in reader)) with open(file_path, "r", encoding=self.settings.encoding) as csv_file: # Skip specified number of rows for _ in range(self.settings.skip_lines): next(csv_file) if self.settings.skip_lines: self._log("info", f"Skipped {self.settings.skip_lines} initial lines") reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter) self._log("info", f"Starting import with {self.import_run.total_rows} rows") for row_number, row in enumerate(reader, start=1): self._process_row(row, row_number) def _process_excel(self, file_path): try: if self.settings.file_type == "xlsx": workbook = openpyxl.load_workbook( file_path, read_only=True, data_only=True ) sheets_to_process = ( workbook.sheetnames if self.settings.sheets == "*" else ( self.settings.sheets if isinstance(self.settings.sheets, list) else [self.settings.sheets] ) ) # Calculate total rows total_rows = sum( max(0, workbook[sheet_name].max_row - self.settings.start_row) for sheet_name in sheets_to_process if sheet_name in workbook.sheetnames ) self._update_totals("total", value=total_rows) # Process sheets for sheet_name in sheets_to_process: if sheet_name not in workbook.sheetnames: self._log( "warning", f"Sheet '{sheet_name}' not found in the Excel file. Skipping.", ) continue sheet = workbook[sheet_name] self._log("info", f"Processing sheet: {sheet_name}") headers = [ str(cell.value or "") for cell in sheet[self.settings.start_row] ] for row_number, row in enumerate( sheet.iter_rows( min_row=self.settings.start_row + 1, values_only=True ), start=1, ): try: row_data = { key: str(value) if value is not None else None for key, value in zip(headers, row) } self._process_row(row_data, row_number) except Exception as e: if self.settings.skip_errors: self._log( "warning", f"Error processing row {row_number} in sheet '{sheet_name}': {str(e)}", ) self._increment_totals("failed", value=1) else: raise workbook.close() else: # xls workbook = xlrd.open_workbook(file_path) sheets_to_process = ( workbook.sheet_names() if self.settings.sheets == "*" else ( self.settings.sheets if isinstance(self.settings.sheets, list) else [self.settings.sheets] ) ) # Calculate total rows total_rows = sum( max( 0, workbook.sheet_by_name(sheet_name).nrows - self.settings.start_row, ) for sheet_name in sheets_to_process if sheet_name in workbook.sheet_names() ) self._update_totals("total", value=total_rows) # Process sheets for sheet_name in sheets_to_process: if sheet_name not in workbook.sheet_names(): self._log( "warning", f"Sheet '{sheet_name}' not found in the Excel file. Skipping.", ) continue sheet = workbook.sheet_by_name(sheet_name) self._log("info", f"Processing sheet: {sheet_name}") headers = [ str(sheet.cell_value(self.settings.start_row - 1, col) or "") for col in range(sheet.ncols) ] for row_number in range(self.settings.start_row, sheet.nrows): try: row_data = {} for col, key in enumerate(headers): cell_type = sheet.cell_type(row_number, col) cell_value = sheet.cell_value(row_number, col) if cell_type == xlrd.XL_CELL_DATE: # Convert Excel date to Python datetime try: python_date = datetime( *xlrd.xldate_as_tuple( cell_value, workbook.datemode ) ) row_data[key] = python_date except Exception: # If date conversion fails, use the original value row_data[key] = ( str(cell_value) if cell_value is not None else None ) elif cell_value is None: row_data[key] = None else: row_data[key] = str(cell_value) self._process_row( row_data, row_number - self.settings.start_row + 1 ) except Exception as e: if self.settings.skip_errors: self._log( "warning", f"Error processing row {row_number} in sheet '{sheet_name}': {str(e)}", ) self._increment_totals("failed", value=1) else: raise except (InvalidFileException, xlrd.XLRDError) as e: raise ValueError( f"Invalid {self.settings.file_type.upper()} file format: {str(e)}" ) def _validate_file_path(self, file_path: str) -> str: """ Validates that the file path is within the allowed temporary directory. Returns the absolute path. """ abs_path = os.path.abspath(file_path) if not abs_path.startswith(self.TEMP_DIR): raise ValueError(f"Invalid file path. File must be in {self.TEMP_DIR}") return abs_path def process_file(self, file_path: str): with cachalot_disabled(): # Validate and get absolute path file_path = self._validate_file_path(file_path) self._update_status("PROCESSING") self.import_run.started_at = timezone.now() self.import_run.save(update_fields=["started_at"]) self._log("info", "Starting import process") try: if isinstance(self.settings, version_1.CSVImportSettings): self._process_csv(file_path) elif isinstance(self.settings, version_1.ExcelImportSettings): self._process_excel(file_path) self._update_status("FINISHED") self._log( "info", f"Import completed successfully. " f"Successful: {self.import_run.successful_rows}, " f"Failed: {self.import_run.failed_rows}, " f"Skipped: {self.import_run.skipped_rows}", ) except Exception as e: self._update_status("FAILED") self._log("error", f"Import failed: {str(e)}") raise Exception("Import failed") finally: self._log("info", "Cleaning up temporary files") try: if os.path.exists(file_path): os.remove(file_path) self._log("info", f"Deleted temporary file: {file_path}") except OSError as e: self._log("warning", f"Failed to delete temporary file: {str(e)}") self.import_run.finished_at = timezone.now() self.import_run.save(update_fields=["finished_at"])