feat(import): improve schema definition

This commit is contained in:
Herculino Trotta
2025-01-19 11:27:14 -03:00
parent fbb26b8442
commit 86dac632c4
3 changed files with 717 additions and 128 deletions

View File

@@ -1,8 +1 @@
from apps.import_app.schemas.v1 import ( import apps.import_app.schemas.v1 as version_1
ImportProfileSchema as SchemaV1,
ColumnMapping as ColumnMappingV1,
# TransformationRule as TransformationRuleV1,
ImportSettings as SettingsV1,
HashTransformationRule as HashTransformationRuleV1,
CompareDeduplicationRule as CompareDeduplicationRuleV1,
)

View File

@@ -1,5 +1,5 @@
from typing import Dict, List, Optional, Literal from typing import Dict, List, Optional, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator, field_validator
class CompareDeduplicationRule(BaseModel): class CompareDeduplicationRule(BaseModel):
@@ -9,6 +9,12 @@ class CompareDeduplicationRule(BaseModel):
) )
match_type: Literal["lax", "strict"] match_type: Literal["lax", "strict"]
@field_validator("fields", mode="before")
def coerce_fields_to_dict(cls, v):
if isinstance(v, list):
return {k: v for d in v for k, v in d.items()}
return v
class ReplaceTransformationRule(BaseModel): class ReplaceTransformationRule(BaseModel):
field: str field: str
@@ -17,6 +23,10 @@ class ReplaceTransformationRule(BaseModel):
) )
pattern: str = Field(..., description="Pattern to match") pattern: str = Field(..., description="Pattern to match")
replacement: str = Field(..., description="Value to replace with") replacement: str = Field(..., description="Value to replace with")
exclusive: bool = Field(
default=False,
description="If it should match against the last transformation or the original value",
)
class DateFormatTransformationRule(BaseModel): class DateFormatTransformationRule(BaseModel):
@@ -48,7 +58,7 @@ class SplitTransformationRule(BaseModel):
) )
class ImportSettings(BaseModel): class CSVImportSettings(BaseModel):
skip_errors: bool = Field( skip_errors: bool = Field(
default=False, default=False,
description="If True, errors during import will be logged and skipped", description="If True, errors during import will be logged and skipped",
@@ -56,7 +66,7 @@ class ImportSettings(BaseModel):
file_type: Literal["csv"] = "csv" file_type: Literal["csv"] = "csv"
delimiter: str = Field(default=",", description="CSV delimiter character") delimiter: str = Field(default=",", description="CSV delimiter character")
encoding: str = Field(default="utf-8", description="File encoding") encoding: str = Field(default="utf-8", description="File encoding")
skip_rows: int = Field( skip_lines: int = Field(
default=0, description="Number of rows to skip at the beginning of the file" default=0, description="Number of rows to skip at the beginning of the file"
) )
importing: Literal[ importing: Literal[
@@ -69,20 +79,7 @@ class ColumnMapping(BaseModel):
default=None, default=None,
description="CSV column header. If None, the field will be generated from transformations", description="CSV column header. If None, the field will be generated from transformations",
) )
target: Literal[ default: Optional[str] = None
"account",
"type",
"is_paid",
"date",
"reference_date",
"amount",
"notes",
"category",
"tags",
"entities",
"internal_note",
] = Field(..., description="Transaction field to map to")
default_value: Optional[str] = None
required: bool = False required: bool = False
transformations: Optional[ transformations: Optional[
List[ List[
@@ -95,10 +92,305 @@ class ColumnMapping(BaseModel):
] = Field(default_factory=list) ] = Field(default_factory=list)
class TransactionAccountMapping(ColumnMapping):
target: Literal["account"] = Field(..., description="Transaction field to map to")
type: Literal["id", "name"] = "name"
coerce_to: Literal["str|int"] = Field("str|int", frozen=True)
class TransactionTypeMapping(ColumnMapping):
target: Literal["type"] = Field(..., description="Transaction field to map to")
detection_method: Literal["sign", "always_income", "always_expense"] = "sign"
coerce_to: Literal["transaction_type"] = Field("transaction_type", frozen=True)
class TransactionIsPaidMapping(ColumnMapping):
target: Literal["is_paid"] = Field(..., description="Transaction field to map to")
detection_method: Literal["sign", "boolean", "always_paid", "always_unpaid"]
coerce_to: Literal["is_paid"] = Field("is_paid", frozen=True)
class TransactionDateMapping(ColumnMapping):
target: Literal["date"] = Field(..., description="Transaction field to map to")
format: List[str] | str
coerce_to: Literal["date"] = Field("date", frozen=True)
class TransactionReferenceDateMapping(ColumnMapping):
target: Literal["reference_date"] = Field(
..., description="Transaction field to map to"
)
format: List[str] | str
coerce_to: Literal["date"] = Field("date", frozen=True)
class TransactionAmountMapping(ColumnMapping):
target: Literal["amount"] = Field(..., description="Transaction field to map to")
coerce_to: Literal["positive_decimal"] = Field("positive_decimal", frozen=True)
class TransactionDescriptionMapping(ColumnMapping):
target: Literal["description"] = Field(
..., description="Transaction field to map to"
)
coerce_to: Literal["str"] = Field("str", frozen=True)
class TransactionNotesMapping(ColumnMapping):
target: Literal["notes"] = Field(..., description="Transaction field to map to")
coerce_to: Literal["str"] = Field("str", frozen=True)
class TransactionTagsMapping(ColumnMapping):
target: Literal["tags"] = Field(..., description="Transaction field to map to")
create: bool = Field(
default=True, description="Create new tags if they doesn't exist"
)
coerce_to: Literal["list"] = Field("list", frozen=True)
class TransactionEntitiesMapping(ColumnMapping):
target: Literal["entities"] = Field(..., description="Transaction field to map to")
create: bool = Field(
default=True, description="Create new entities if they doesn't exist"
)
coerce_to: Literal["list"] = Field("list", frozen=True)
class TransactionCategoryMapping(ColumnMapping):
target: Literal["category"] = Field(..., description="Transaction field to map to")
create: bool = Field(
default=True, description="Create category if it doesn't exist"
)
type: Literal["id", "name"] = "name"
coerce_to: Literal["str|int"] = Field("str|int", frozen=True)
class TransactionInternalMapping(ColumnMapping):
target: Literal["internal_note"] = Field(
..., description="Transaction field to map to"
)
coerce_to: Literal["str"] = Field("str", frozen=True)
class CategoryNameMapping(ColumnMapping):
target: Literal["category_name"] = Field(
..., description="Category field to map to"
)
coerce_to: Literal["str"] = Field("str", frozen=True)
class CategoryMuteMapping(ColumnMapping):
target: Literal["category_mute"] = Field(
..., description="Category field to map to"
)
coerce_to: Literal["bool"] = Field("bool", frozen=True)
class CategoryActiveMapping(ColumnMapping):
target: Literal["category_active"] = Field(
..., description="Category field to map to"
)
coerce_to: Literal["bool"] = Field("bool", frozen=True)
class TagNameMapping(ColumnMapping):
target: Literal["tag_name"] = Field(..., description="Tag field to map to")
coerce_to: Literal["str"] = Field("str", frozen=True)
class TagActiveMapping(ColumnMapping):
target: Literal["tag_active"] = Field(..., description="Tag field to map to")
coerce_to: Literal["bool"] = Field("bool", frozen=True)
class EntityNameMapping(ColumnMapping):
target: Literal["entity_name"] = Field(..., description="Entity field to map to")
coerce_to: Literal["str"] = Field("str", frozen=True)
class EntityActiveMapping(ColumnMapping):
target: Literal["entitiy_active"] = Field(..., description="Entity field to map to")
coerce_to: Literal["bool"] = Field("bool", frozen=True)
class AccountNameMapping(ColumnMapping):
target: Literal["account_name"] = Field(..., description="Account field to map to")
coerce_to: Literal["str"] = Field("str", frozen=True)
class AccountGroupMapping(ColumnMapping):
target: Literal["account_group"] = Field(..., description="Account field to map to")
type: Literal["id", "name"]
coerce_to: Literal["str|int"] = Field("str|int", frozen=True)
class AccountCurrencyMapping(ColumnMapping):
target: Literal["account_currency"] = Field(
..., description="Account field to map to"
)
type: Literal["id", "name", "code"]
coerce_to: Literal["str|int"] = Field("str|int", frozen=True)
class AccountExchangeCurrencyMapping(ColumnMapping):
target: Literal["account_exchange_currency"] = Field(
..., description="Account field to map to"
)
type: Literal["id", "name", "code"]
coerce_to: Literal["str|int"] = Field("str|int", frozen=True)
class AccountIsAssetMapping(ColumnMapping):
target: Literal["account_is_asset"] = Field(
..., description="Account field to map to"
)
coerce_to: Literal["bool"] = Field("bool", frozen=True)
class AccountIsArchivedMapping(ColumnMapping):
target: Literal["account_is_archived"] = Field(
..., description="Account field to map to"
)
coerce_to: Literal["bool"] = Field("bool", frozen=True)
class CurrencyCodeMapping(ColumnMapping):
target: Literal["currency_code"] = Field(
..., description="Currency field to map to"
)
coerce_to: Literal["str"] = Field("str", frozen=True)
class CurrencyNameMapping(ColumnMapping):
target: Literal["currency_name"] = Field(
..., description="Currency field to map to"
)
coerce_to: Literal["str"] = Field("str", frozen=True)
class CurrencyDecimalPlacesMapping(ColumnMapping):
target: Literal["currency_decimal_places"] = Field(
..., description="Currency field to map to"
)
coerce_to: Literal["int"] = Field("int", frozen=True)
class CurrencyPrefixMapping(ColumnMapping):
target: Literal["currency_prefix"] = Field(
..., description="Currency field to map to"
)
coerce_to: Literal["str"] = Field("str", frozen=True)
class CurrencySuffixMapping(ColumnMapping):
target: Literal["currency_suffix"] = Field(
..., description="Currency field to map to"
)
coerce_to: Literal["str"] = Field("str", frozen=True)
class CurrencyExchangeMapping(ColumnMapping):
target: Literal["currency_exchange"] = Field(
..., description="Currency field to map to"
)
type: Literal["id", "name", "code"]
coerce_to: Literal["str|int"] = Field("str|int", frozen=True)
class ImportProfileSchema(BaseModel): class ImportProfileSchema(BaseModel):
settings: ImportSettings settings: CSVImportSettings
column_mapping: Dict[str, ColumnMapping] mapping: Dict[
str,
TransactionAccountMapping
| TransactionTypeMapping
| TransactionIsPaidMapping
| TransactionDateMapping
| TransactionReferenceDateMapping
| TransactionAmountMapping
| TransactionDescriptionMapping
| TransactionNotesMapping
| TransactionTagsMapping
| TransactionEntitiesMapping
| TransactionCategoryMapping
| TransactionInternalMapping
| CategoryNameMapping
| CategoryMuteMapping
| CategoryActiveMapping
| TagNameMapping
| TagActiveMapping
| EntityNameMapping
| EntityActiveMapping
| AccountNameMapping
| AccountGroupMapping
| AccountCurrencyMapping
| AccountExchangeCurrencyMapping
| AccountIsAssetMapping
| AccountIsArchivedMapping
| CurrencyCodeMapping
| CurrencyNameMapping
| CurrencyDecimalPlacesMapping
| CurrencyPrefixMapping
| CurrencySuffixMapping
| CurrencyExchangeMapping,
]
deduplication: List[CompareDeduplicationRule] = Field( deduplication: List[CompareDeduplicationRule] = Field(
default_factory=list, default_factory=list,
description="Rules for deduplicating records during import", description="Rules for deduplicating records during import",
) )
@model_validator(mode="after")
def validate_mappings(self) -> "ImportProfileSchema":
import_type = self.settings.importing
# Define allowed mapping types for each import type
allowed_mappings = {
"transactions": (
TransactionAccountMapping,
TransactionTypeMapping,
TransactionIsPaidMapping,
TransactionDateMapping,
TransactionReferenceDateMapping,
TransactionAmountMapping,
TransactionDescriptionMapping,
TransactionNotesMapping,
TransactionTagsMapping,
TransactionEntitiesMapping,
TransactionCategoryMapping,
TransactionInternalMapping,
),
"accounts": (
AccountNameMapping,
AccountGroupMapping,
AccountCurrencyMapping,
AccountExchangeCurrencyMapping,
AccountIsAssetMapping,
AccountIsArchivedMapping,
),
"currencies": (
CurrencyCodeMapping,
CurrencyNameMapping,
CurrencyDecimalPlacesMapping,
CurrencyPrefixMapping,
CurrencySuffixMapping,
CurrencyExchangeMapping,
),
"categories": (
CategoryNameMapping,
CategoryMuteMapping,
CategoryActiveMapping,
),
"tags": (TagNameMapping, TagActiveMapping),
"entities": (EntityNameMapping, EntityActiveMapping),
}
allowed_types = allowed_mappings[import_type]
for field_name, mapping in self.mapping.items():
if not isinstance(mapping, allowed_types):
raise ValueError(
f"Mapping type '{type(mapping).__name__}' is not allowed when importing {import_type}. "
f"Allowed types are: {', '.join(t.__name__ for t in allowed_types)}"
)
return self

View File

@@ -1,42 +1,56 @@
import csv import csv
import hashlib import hashlib
import logging
import os
import re import re
from datetime import datetime from datetime import datetime
from typing import Dict, Any, Literal from decimal import Decimal
from typing import Dict, Any, Literal, Union
import yaml import yaml
from django.db import transaction from django.db import transaction
from django.core.files.storage import default_storage
from django.utils import timezone from django.utils import timezone
from apps.accounts.models import Account, AccountGroup
from apps.currencies.models import Currency
from apps.import_app.models import ImportRun, ImportProfile from apps.import_app.models import ImportRun, ImportProfile
from apps.import_app.schemas import ( from apps.import_app.schemas import version_1
SchemaV1, from apps.transactions.models import (
ColumnMappingV1, Transaction,
SettingsV1, TransactionCategory,
HashTransformationRuleV1, TransactionTag,
CompareDeduplicationRuleV1, TransactionEntity,
) )
from apps.transactions.models import Transaction
logger = logging.getLogger(__name__)
class ImportService: class ImportService:
TEMP_DIR = "/usr/src/app/temp"
def __init__(self, import_run: ImportRun): def __init__(self, import_run: ImportRun):
self.import_run: ImportRun = import_run self.import_run: ImportRun = import_run
self.profile: ImportProfile = import_run.profile self.profile: ImportProfile = import_run.profile
self.config: SchemaV1 = self._load_config() self.config: version_1.ImportProfileSchema = self._load_config()
self.settings: SettingsV1 = self.config.settings self.settings: version_1.CSVImportSettings = self.config.settings
self.deduplication: list[CompareDeduplicationRuleV1] = self.config.deduplication self.deduplication: list[version_1.CompareDeduplicationRule] = (
self.mapping: Dict[str, ColumnMappingV1] = self.config.column_mapping self.config.deduplication
)
self.mapping: Dict[str, version_1.ColumnMapping] = self.config.mapping
def _load_config(self) -> SchemaV1: # Ensure temp directory exists
os.makedirs(self.TEMP_DIR, exist_ok=True)
def _load_config(self) -> version_1.ImportProfileSchema:
yaml_data = yaml.safe_load(self.profile.yaml_config) yaml_data = yaml.safe_load(self.profile.yaml_config)
try:
if self.profile.version == ImportProfile.Versions.VERSION_1: config = version_1.ImportProfileSchema(**yaml_data)
return SchemaV1(**yaml_data) except Exception as e:
self._log("error", f"Fatal error processing YAML config: {str(e)}")
raise ValueError(f"Unsupported version: {self.profile.version}") self._update_status("FAILED")
raise e
else:
return config
def _log(self, level: str, message: str, **kwargs) -> None: def _log(self, level: str, message: str, **kwargs) -> None:
"""Add a log entry to the import run logs""" """Add a log entry to the import run logs"""
@@ -53,6 +67,48 @@ class ImportService:
self.import_run.logs += log_line self.import_run.logs += log_line
self.import_run.save(update_fields=["logs"]) self.import_run.save(update_fields=["logs"])
def _update_totals(
self,
field: Literal["total", "processed", "successful", "skipped", "failed"],
value: int,
) -> None:
if field == "total":
self.import_run.total_rows = value
self.import_run.save(update_fields=["total_rows"])
elif field == "processed":
self.import_run.processed_rows = value
self.import_run.save(update_fields=["processed_rows"])
elif field == "successful":
self.import_run.successful_rows = value
self.import_run.save(update_fields=["successful_rows"])
elif field == "skipped":
self.import_run.skipped_rows = value
self.import_run.save(update_fields=["skipped_rows"])
elif field == "failed":
self.import_run.failed_rows = value
self.import_run.save(update_fields=["failed_rows"])
def _increment_totals(
self,
field: Literal["total", "processed", "successful", "skipped", "failed"],
value: int,
) -> None:
if field == "total":
self.import_run.total_rows = self.import_run.total_rows + value
self.import_run.save(update_fields=["total_rows"])
elif field == "processed":
self.import_run.processed_rows = self.import_run.processed_rows + value
self.import_run.save(update_fields=["processed_rows"])
elif field == "successful":
self.import_run.successful_rows = self.import_run.successful_rows + value
self.import_run.save(update_fields=["successful_rows"])
elif field == "skipped":
self.import_run.skipped_rows = self.import_run.skipped_rows + value
self.import_run.save(update_fields=["skipped_rows"])
elif field == "failed":
self.import_run.failed_rows = self.import_run.failed_rows + value
self.import_run.save(update_fields=["failed_rows"])
def _update_status( def _update_status(
self, new_status: Literal["PROCESSING", "FAILED", "FINISHED"] self, new_status: Literal["PROCESSING", "FAILED", "FINISHED"]
) -> None: ) -> None:
@@ -67,15 +123,12 @@ class ImportService:
@staticmethod @staticmethod
def _transform_value( def _transform_value(
value: str, mapping: ColumnMappingV1, row: Dict[str, str] = None value: str, mapping: version_1.ColumnMapping, row: Dict[str, str] = None
) -> Any: ) -> Any:
transformed = value transformed = value
for transform in mapping.transformations: for transform in mapping.transformations:
if transform.type == "hash": if transform.type == "hash":
if not isinstance(transform, HashTransformationRuleV1):
continue
# Collect all values to be hashed # Collect all values to be hashed
values_to_hash = [] values_to_hash = []
for field in transform.fields: for field in transform.fields:
@@ -88,47 +141,143 @@ class ImportService:
transformed = hashlib.sha256(concatenated.encode()).hexdigest() transformed = hashlib.sha256(concatenated.encode()).hexdigest()
elif transform.type == "replace": elif transform.type == "replace":
transformed = transformed.replace( if transform.exclusive:
transform.pattern, transform.replacement transformed = value.replace(
) transform.pattern, transform.replacement
)
else:
transformed = transformed.replace(
transform.pattern, transform.replacement
)
elif transform.type == "regex": elif transform.type == "regex":
transformed = re.sub( if transform.exclusive:
transform.pattern, transform.replacement, transformed transformed = re.sub(
) transform.pattern, transform.replacement, value
)
else:
transformed = re.sub(
transform.pattern, transform.replacement, transformed
)
elif transform.type == "date_format": elif transform.type == "date_format":
transformed = datetime.strptime( transformed = datetime.strptime(
transformed, transform.pattern transformed, transform.original_format
).strftime(transform.replacement) ).strftime(transform.new_format)
elif transform.type == "merge":
values_to_merge = []
for field in transform.fields:
if field in row:
values_to_merge.append(str(row[field]))
transformed = transform.separator.join(values_to_merge)
elif transform.type == "split":
parts = transformed.split(transform.separator)
if transform.index is not None:
transformed = parts[transform.index] if parts else ""
else:
transformed = parts
return transformed return transformed
def _map_row_to_transaction(self, row: Dict[str, str]) -> Dict[str, Any]: def _create_transaction(self, data: Dict[str, Any]) -> Transaction:
transaction_data = {} tags = []
entities = []
# Handle related objects first
if "category" in data:
category_name = data.pop("category")
category, _ = TransactionCategory.objects.get_or_create(name=category_name)
data["category"] = category
self.import_run.categories.add(category)
for field, mapping in self.mapping.items(): if "account" in data:
# If source is None, use None as the initial value account_id = data.pop("account")
value = row.get(mapping.source) if mapping.source else None account = None
if isinstance(account_id, str):
account = Account.objects.get(name=account_id)
elif isinstance(account_id, int):
account = Account.objects.get(id=account_id)
data["account"] = account
# self.import_run.acc.add(category)
# Use default_value if value is None if "tags" in data:
if value is None: tag_names = data.pop("tags").split(",")
value = mapping.default_value for tag_name in tag_names:
tag, _ = TransactionTag.objects.get_or_create(name=tag_name.strip())
tags.append(tag)
self.import_run.tags.add(tag)
if mapping.required and value is None and not mapping.transformations: if "entities" in data:
raise ValueError(f"Required field {field} is missing") entity_names = data.pop("entities").split(",")
for entity_name in entity_names:
entity, _ = TransactionEntity.objects.get_or_create(
name=entity_name.strip()
)
entities.append(entity)
self.import_run.entities.add(entity)
# Apply transformations even if initial value is None if "amount" in data:
if mapping.transformations: amount = data.pop("amount")
value = self._transform_value(value, mapping, row) data["amount"] = abs(Decimal(amount))
if value is not None: # Create the transaction
transaction_data[field] = value new_transaction = Transaction.objects.create(**data)
self.import_run.transactions.add(new_transaction)
return transaction_data # Add many-to-many relationships
if tags:
new_transaction.tags.set(tags)
if entities:
new_transaction.entities.set(entities)
return new_transaction
def _create_account(self, data: Dict[str, Any]) -> Account:
if "group" in data:
group_name = data.pop("group")
group, _ = AccountGroup.objects.get_or_create(name=group_name)
data["group"] = group
# Handle currency references
if "currency" in data:
currency = Currency.objects.get(code=data["currency"])
data["currency"] = currency
self.import_run.currencies.add(currency)
if "exchange_currency" in data:
exchange_currency = Currency.objects.get(code=data["exchange_currency"])
data["exchange_currency"] = exchange_currency
self.import_run.currencies.add(exchange_currency)
return Account.objects.create(**data)
def _create_currency(self, data: Dict[str, Any]) -> Currency:
# Handle exchange currency reference
if "exchange_currency" in data:
exchange_currency = Currency.objects.get(code=data["exchange_currency"])
data["exchange_currency"] = exchange_currency
self.import_run.currencies.add(exchange_currency)
currency = Currency.objects.create(**data)
self.import_run.currencies.add(currency)
return currency
def _create_category(self, data: Dict[str, Any]) -> TransactionCategory:
category = TransactionCategory.objects.create(**data)
self.import_run.categories.add(category)
return category
def _create_tag(self, data: Dict[str, Any]) -> TransactionTag:
tag = TransactionTag.objects.create(**data)
self.import_run.tags.add(tag)
return tag
def _create_entity(self, data: Dict[str, Any]) -> TransactionEntity:
entity = TransactionEntity.objects.create(**data)
self.import_run.entities.add(entity)
return entity
def _check_duplicate_transaction(self, transaction_data: Dict[str, Any]) -> bool: def _check_duplicate_transaction(self, transaction_data: Dict[str, Any]) -> bool:
for rule in self.deduplication: for rule in self.deduplication:
if rule.type == "compare": if rule.type == "compare":
query = Transaction.objects.all() query = Transaction.objects.all().values("id")
# Build query conditions for each field in the rule # Build query conditions for each field in the rule
for field, header in rule.fields.items(): for field, header in rule.fields.items():
@@ -146,65 +295,214 @@ class ImportService:
return False return False
def _process_csv(self, file_path): def _coerce_type(
with open(file_path, "r", encoding=self.settings.encoding) as csv_file: self, value: str, mapping: version_1.ColumnMapping
reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter) ) -> Union[str, int, bool, Decimal, datetime, list]:
if not value:
return None
coerce_to = mapping.coerce_to
if "|" in coerce_to:
types = coerce_to.split("|")
for t in types:
try:
return self._coerce_single_type(value, t, mapping)
except ValueError:
continue
raise ValueError(
f"Could not coerce '{value}' to any of the types: {coerce_to}"
)
else:
return self._coerce_single_type(value, coerce_to, mapping)
def _coerce_single_type(
self, value: str, coerce_to: str, mapping: version_1.ColumnMapping
) -> Union[str, int, bool, Decimal, datetime.date, list]:
if coerce_to == "str":
return str(value)
elif coerce_to == "int":
if hasattr(mapping, "type") and mapping.type == "id":
return int(value)
elif hasattr(mapping, "type") and mapping.type in ["name", "code"]:
return str(value)
else:
return int(value)
elif coerce_to == "bool":
return value.lower() in ["true", "1", "yes", "y", "on"]
elif coerce_to == "positive_decimal":
return abs(Decimal(value))
elif coerce_to == "date":
if isinstance(
mapping,
(
version_1.TransactionDateMapping,
version_1.TransactionReferenceDateMapping,
),
):
formats = (
mapping.format
if isinstance(mapping.format, list)
else [mapping.format]
)
for fmt in formats:
try:
return datetime.strptime(value, fmt).date()
except ValueError:
continue
raise ValueError(
f"Could not parse date '{value}' with any of the provided formats"
)
else:
raise ValueError(
"Date coercion is only supported for TransactionDateMapping and TransactionReferenceDateMapping"
)
elif coerce_to == "list":
return (
value
if isinstance(value, list)
else [item.strip() for item in value.split(",") if item.strip()]
)
elif coerce_to == "transaction_type":
if isinstance(mapping, version_1.TransactionTypeMapping):
if mapping.detection_method == "sign":
return (
Transaction.Type.EXPENSE
if value.startswith("-")
else Transaction.Type.INCOME
)
elif mapping.detection_method == "always_income":
return Transaction.Type.INCOME
elif mapping.detection_method == "always_expense":
return Transaction.Type.EXPENSE
raise ValueError("Invalid transaction type detection method")
elif coerce_to == "is_paid":
if isinstance(mapping, version_1.TransactionIsPaidMapping):
if mapping.detection_method == "sign":
return not value.startswith("-")
elif mapping.detection_method == "boolean":
return value.lower() in ["true", "1", "yes", "y", "on"]
elif mapping.detection_method == "always_paid":
return True
elif mapping.detection_method == "always_unpaid":
return False
raise ValueError("Invalid is_paid detection method")
else:
raise ValueError(f"Unsupported coercion type: {coerce_to}")
def _map_row(self, row: Dict[str, str]) -> Dict[str, Any]:
mapped_data = {}
for field, mapping in self.mapping.items():
# If source is None, use None as the initial value
value = row.get(mapping.source) if mapping.source else None
# Use default_value if value is None
if value is None:
value = mapping.default
if mapping.required and value is None and not mapping.transformations:
raise ValueError(f"Required field {field} is missing")
# Apply transformations
if mapping.transformations:
value = self._transform_value(value, mapping, row)
value = self._coerce_type(value, mapping)
if value is not None:
# Remove the prefix from the target field
target = mapping.target
if self.settings.importing == "transactions":
mapped_data[target] = value
else:
# Remove the model prefix (e.g., "account_" from "account_name")
field_name = target.split("_", 1)[1]
mapped_data[field_name] = value
return mapped_data
def _process_row(self, row: Dict[str, str], row_number: int) -> None:
try:
mapped_data = self._map_row(row)
if mapped_data:
# Handle different import types
if self.settings.importing == "transactions":
if self.deduplication and self._check_duplicate_transaction(
mapped_data
):
self._increment_totals("skipped", 1)
self._log("info", f"Skipped duplicate row {row_number}")
return
self._create_transaction(mapped_data)
elif self.settings.importing == "accounts":
self._create_account(mapped_data)
elif self.settings.importing == "currencies":
self._create_currency(mapped_data)
elif self.settings.importing == "categories":
self._create_category(mapped_data)
elif self.settings.importing == "tags":
self._create_tag(mapped_data)
elif self.settings.importing == "entities":
self._create_entity(mapped_data)
self._increment_totals("successful", value=1)
self._log("info", f"Successfully processed row {row_number}")
self._increment_totals("processed", value=1)
except Exception as e:
if not self.settings.skip_errors:
self._log("error", f"Fatal error processing row {row_number}: {str(e)}")
self._update_status("FAILED")
raise
else:
self._log("warning", f"Error processing row {row_number}: {str(e)}")
self._increment_totals("failed", value=1)
logger.error(f"Fatal error processing row {row_number}", exc_info=e)
def _process_csv(self, file_path):
# First pass: count rows
with open(file_path, "r", encoding=self.settings.encoding) as csv_file:
# Skip specified number of rows
for _ in range(self.settings.skip_lines):
next(csv_file)
reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter)
self._update_totals("total", value=sum(1 for _ in reader))
with open(file_path, "r", encoding=self.settings.encoding) as csv_file:
# Skip specified number of rows
for _ in range(self.settings.skip_lines):
next(csv_file)
if self.settings.skip_lines:
self._log("info", f"Skipped {self.settings.skip_lines} initial lines")
# Count total rows
self.import_run.total_rows = sum(1 for _ in reader)
csv_file.seek(0)
reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter) reader = csv.DictReader(csv_file, delimiter=self.settings.delimiter)
self._log("info", f"Starting import with {self.import_run.total_rows} rows") self._log("info", f"Starting import with {self.import_run.total_rows} rows")
# Skip specified number of rows with transaction.atomic():
for _ in range(self.settings.skip_rows): for row_number, row in enumerate(reader, start=1):
next(reader) self._process_row(row, row_number)
self._increment_totals("processed", value=1)
if self.settings.skip_rows: def _validate_file_path(self, file_path: str) -> str:
self._log("info", f"Skipped {self.settings.skip_rows} initial rows") """
Validates that the file path is within the allowed temporary directory.
for row_number, row in enumerate(reader, start=1): Returns the absolute path.
try: """
transaction_data = self._map_row_to_transaction(row) abs_path = os.path.abspath(file_path)
if not abs_path.startswith(self.TEMP_DIR):
if transaction_data: raise ValueError(f"Invalid file path. File must be in {self.TEMP_DIR}")
if self.deduplication and self._check_duplicate_transaction( return abs_path
transaction_data
):
self.import_run.skipped_rows += 1
self._log("info", f"Skipped duplicate row {row_number}")
continue
self.import_run.transactions.add(transaction_data)
self.import_run.successful_rows += 1
self._log("debug", f"Successfully processed row {row_number}")
self.import_run.processed_rows += 1
self.import_run.save(
update_fields=[
"processed_rows",
"successful_rows",
"skipped_rows",
]
)
except Exception as e:
if not self.settings.skip_errors:
self._log(
"error",
f"Fatal error processing row {row_number}: {str(e)}",
)
self._update_status("FAILED")
raise
else:
self._log(
"warning", f"Error processing row {row_number}: {str(e)}"
)
self.import_run.failed_rows += 1
self.import_run.save(update_fields=["failed_rows"])
def process_file(self, file_path: str): def process_file(self, file_path: str):
# Validate and get absolute path
file_path = self._validate_file_path(file_path)
self._update_status("PROCESSING") self._update_status("PROCESSING")
self.import_run.started_at = timezone.now() self.import_run.started_at = timezone.now()
self.import_run.save(update_fields=["started_at"]) self.import_run.save(update_fields=["started_at"])
@@ -232,6 +530,12 @@ class ImportService:
finally: finally:
self._log("info", "Cleaning up temporary files") self._log("info", "Cleaning up temporary files")
default_storage.delete(file_path) try:
if os.path.exists(file_path):
os.remove(file_path)
self._log("info", f"Deleted temporary file: {file_path}")
except OSError as e:
self._log("warning", f"Failed to delete temporary file: {str(e)}")
self.import_run.finished_at = timezone.now() self.import_run.finished_at = timezone.now()
self.import_run.save(update_fields=["finished_at"]) self.import_run.save(update_fields=["finished_at"])