mirror of
https://github.com/eitchtee/WYGIWYH.git
synced 2026-02-25 00:44:52 +01:00
Compare commits
318 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf9f8bbf3a | ||
|
|
a461a33dc2 | ||
|
|
1213ffebeb | ||
|
|
c5a352cf4d | ||
|
|
cfcca54aa6 | ||
|
|
234f8cd669 | ||
|
|
43184140f0 | ||
|
|
acc325c150 | ||
|
|
46eb471a34 | ||
|
|
6dc14c73d6 | ||
|
|
f942924e7c | ||
|
|
aa6019e0a9 | ||
|
|
9dfbd346bc | ||
|
|
73b1d36dfd | ||
|
|
3662fb030a | ||
|
|
a423ee1032 | ||
|
|
72eb59d24f | ||
|
|
1a0247e028 | ||
|
|
281a0fccda | ||
|
|
59ce50299a | ||
|
|
be89509beb | ||
|
|
80cded234d | ||
|
|
030bb63586 | ||
|
|
66e8fc5884 | ||
|
|
363047337d | ||
|
|
c7e32d1576 | ||
|
|
157e59a1d1 | ||
|
|
d9c505ac79 | ||
|
|
7274a13f3c | ||
|
|
5d64665ddd | ||
|
|
e0d92d15c8 | ||
|
|
48dd658627 | ||
|
|
80dbbd02f0 | ||
|
|
4b7ca61c29 | ||
|
|
b2f04ae1f9 | ||
|
|
f34d4b5e28 | ||
|
|
d2ebfbd615 | ||
|
|
812abbe488 | ||
|
|
9602a4affc | ||
|
|
bf548c0747 | ||
|
|
55ad2be08b | ||
|
|
2cd58c2464 | ||
|
|
4675ba9d56 | ||
|
|
a25c992d5c | ||
|
|
2eadfe99a5 | ||
|
|
11086a726f | ||
|
|
cd99b40b0a | ||
|
|
63aa51dc0d | ||
|
|
4708c5bc7e | ||
|
|
5a8462c050 | ||
|
|
6cac02e01f | ||
|
|
8d12ceeebb | ||
|
|
4681d3ca1d | ||
|
|
60ded03ea9 | ||
|
|
b20d137dc3 | ||
|
|
29ca6eed6c | ||
|
|
fa85303f36 | ||
|
|
a5f4f43678 | ||
|
|
d807bd5da3 | ||
|
|
85314fb749 | ||
|
|
c4d5e93a41 | ||
|
|
86f0c4365e | ||
|
|
202592b940 | ||
|
|
aea149bd13 | ||
|
|
411365f101 | ||
|
|
2008476021 | ||
|
|
53afe5b8eb | ||
|
|
6193c7a048 | ||
|
|
41f81d90d7 | ||
|
|
bf623cf16b | ||
|
|
ec213330cd | ||
|
|
7aedf524c6 | ||
|
|
04602b1964 | ||
|
|
15cfc4f300 | ||
|
|
3463c7c62c | ||
|
|
7b76c10093 | ||
|
|
7ad26a2e7b | ||
|
|
7706ca2d5f | ||
|
|
56198e93ce | ||
|
|
a74323f739 | ||
|
|
e4efde177b | ||
|
|
5871a03ee2 | ||
|
|
67af4430e1 | ||
|
|
696dcdf951 | ||
|
|
e35bad0e08 | ||
|
|
904f7cac22 | ||
|
|
ccd73963ca | ||
|
|
b5469b0413 | ||
|
|
dae848d951 | ||
|
|
45a33ad0c0 | ||
|
|
89e50b17bd | ||
|
|
ac54ba3da1 | ||
|
|
2da610f15e | ||
|
|
4ab6c4c6b6 | ||
|
|
68dbedd938 | ||
|
|
2800c53346 | ||
|
|
132547a074 | ||
|
|
61ed87dc45 | ||
|
|
96c1227c4f | ||
|
|
33f1ac1785 | ||
|
|
e9e94a8343 | ||
|
|
ba24a53853 | ||
|
|
4955fbde33 | ||
|
|
d04067a91d | ||
|
|
01333a439b | ||
|
|
d26907ea94 | ||
|
|
c98d9d3ce9 | ||
|
|
bfa4d3dea3 | ||
|
|
90323049eb | ||
|
|
b62122ed23 | ||
|
|
f74946cba7 | ||
|
|
585652064a | ||
|
|
ea6f61d5e4 | ||
|
|
e986f7d802 | ||
|
|
26b218ae51 | ||
|
|
19f0bc1034 | ||
|
|
47d34f3c27 | ||
|
|
046e02d506 | ||
|
|
92c7a29b6a | ||
|
|
d95e5f71cc | ||
|
|
992c518dab | ||
|
|
29aa1c9d2b | ||
|
|
1b3b7a583d | ||
|
|
2d22f961ad | ||
|
|
71551d7651 | ||
|
|
62d58d1be3 | ||
|
|
21917437f2 | ||
|
|
59acb14d05 | ||
|
|
050f794f2b | ||
|
|
a5958c0937 | ||
|
|
ee73ada5ae | ||
|
|
736a116685 | ||
|
|
6c03c7b4eb | ||
|
|
960e537709 | ||
|
|
e32285ce75 | ||
|
|
73e8fdbf04 | ||
|
|
d4c15da051 | ||
|
|
187b3174d2 | ||
|
|
c90ea7ef16 | ||
|
|
54713ecfe2 | ||
|
|
cf693aa0c3 | ||
|
|
3580f1b132 | ||
|
|
febd9a8ae7 | ||
|
|
3809f82b60 | ||
|
|
3c6b52462a | ||
|
|
cc8a4c97a9 | ||
|
|
99fbb5f7db | ||
|
|
3d61068ecf | ||
|
|
f6f06f4d65 | ||
|
|
56346c26ee | ||
|
|
23b74d73e5 | ||
|
|
17697dc565 | ||
|
|
e9bc35d9b2 | ||
|
|
d6fbb71f41 | ||
|
|
9a9cf75bcd | ||
|
|
d6a8658fe1 | ||
|
|
211963ea7d | ||
|
|
776068a438 | ||
|
|
621799f445 | ||
|
|
124d29e965 | ||
|
|
bf4d23f15e | ||
|
|
020dd74f80 | ||
|
|
c7d70a1748 | ||
|
|
1025b80dda | ||
|
|
1ae245fe01 | ||
|
|
46c5efb8a9 | ||
|
|
abb0993435 | ||
|
|
a9e7692f99 | ||
|
|
531571798a | ||
|
|
7282aa20ee | ||
|
|
13f9950afa | ||
|
|
672cc5ebc7 | ||
|
|
8045e2c73a | ||
|
|
7c042d9299 | ||
|
|
aba47f0eed | ||
|
|
2010ccc92d | ||
|
|
d73d6cbf22 | ||
|
|
e5a9b6e921 | ||
|
|
dbd9774681 | ||
|
|
5a93a907e1 | ||
|
|
e0e159166b | ||
|
|
6c7594ad14 | ||
|
|
d3ea0e43da | ||
|
|
dde75416ca | ||
|
|
c9b346b791 | ||
|
|
9896044a15 | ||
|
|
eb65eb4590 | ||
|
|
017c70e8b2 | ||
|
|
64b0830909 | ||
|
|
25d99cbece | ||
|
|
033f0e1b0d | ||
|
|
35027ee0ae | ||
|
|
91904e959b | ||
|
|
a6a85ae3a2 | ||
|
|
b0f53f45f9 | ||
|
|
0f60f8d486 | ||
|
|
efb207a109 | ||
|
|
95b1481dd5 | ||
|
|
8de340b68b | ||
|
|
ef15b85386 | ||
|
|
45d939237d | ||
|
|
6bf262e514 | ||
|
|
f9d9137336 | ||
|
|
b532521f27 | ||
|
|
1e06e2d34d | ||
|
|
a33fa5e184 | ||
|
|
a2453695d8 | ||
|
|
3e929d0433 | ||
|
|
185fc464a5 | ||
|
|
647c009525 | ||
|
|
ba75492dcc | ||
|
|
8312baaf45 | ||
|
|
4d346dc278 | ||
|
|
70ff7fab38 | ||
|
|
6947c6affd | ||
|
|
dcab83f936 | ||
|
|
b228e4ec26 | ||
|
|
4071a1301f | ||
|
|
5c9db10710 | ||
|
|
19c92e0014 | ||
|
|
6459f2eb46 | ||
|
|
7926e081ef | ||
|
|
ceefe7075f | ||
|
|
ad3230fd83 | ||
|
|
c89b07ed93 | ||
|
|
201ccea842 | ||
|
|
32ada488b4 | ||
|
|
794d11a355 | ||
|
|
67f8f5fe89 | ||
|
|
9ac69fd92a | ||
|
|
069f1b450c | ||
|
|
2f388af928 | ||
|
|
beeb0579ce | ||
|
|
a8666da57b | ||
|
|
835316d0f3 | ||
|
|
f5feeb9617 | ||
|
|
09e380a480 | ||
|
|
3080df9b66 | ||
|
|
ebc41a8049 | ||
|
|
635628e30e | ||
|
|
819a58ac06 | ||
|
|
d433375522 | ||
|
|
c0150f71a8 | ||
|
|
6119698d38 | ||
|
|
f5ae231601 | ||
|
|
972d23abbd | ||
|
|
9a514a8a69 | ||
|
|
7325231548 | ||
|
|
570657371a | ||
|
|
67da60b5b0 | ||
|
|
84c047c5ab | ||
|
|
23f5d09bec | ||
|
|
2a19075e23 | ||
|
|
7f231175b2 | ||
|
|
062e84f864 | ||
|
|
5521eb20bf | ||
|
|
627b5d250b | ||
|
|
195a8a68d6 | ||
|
|
daf1f68b82 | ||
|
|
dd24fd56d3 | ||
|
|
7a2acb6497 | ||
|
|
9c339faa72 | ||
|
|
02376ad02b | ||
|
|
b53a4a0286 | ||
|
|
a1f618434b | ||
|
|
7b5be29f0d | ||
|
|
56a73b181a | ||
|
|
865618e054 | ||
|
|
9e912b2736 | ||
|
|
da7680e70f | ||
|
|
ab594eb511 | ||
|
|
cffaaa369a | ||
|
|
5f414e82ee | ||
|
|
f3bcef534e | ||
|
|
d140ff5b70 | ||
|
|
7eceacfe68 | ||
|
|
038438fba7 | ||
|
|
ee98a5ef12 | ||
|
|
28b12faaf0 | ||
|
|
d0f2742637 | ||
|
|
9c55dac866 | ||
|
|
e6d8b548b7 | ||
|
|
4f8c2215c1 | ||
|
|
851b34f07a | ||
|
|
546ed5c6af | ||
|
|
04ae7337f5 | ||
|
|
a3a8791e96 | ||
|
|
63069f0ec9 | ||
|
|
32b522dad2 | ||
|
|
0c20a079e3 | ||
|
|
7c9697f683 | ||
|
|
15d04230ae | ||
|
|
ecc09ca6a6 | ||
|
|
cd753c5dd5 | ||
|
|
a3b9952f80 | ||
|
|
e93969c035 | ||
|
|
6ec5b5df1e | ||
|
|
93e7adeea8 | ||
|
|
37b5a43c1f | ||
|
|
87a07c25d1 | ||
|
|
9e27fef5e5 | ||
|
|
2cbba53e06 | ||
|
|
d9e8be7efb | ||
|
|
7dc9ef9950 | ||
|
|
00e83cf6a2 | ||
|
|
039242b48a | ||
|
|
94e2bdf93d | ||
|
|
79b387ce60 | ||
|
|
43eb87d3ba | ||
|
|
0110220b72 | ||
|
|
f5c86f3d97 | ||
|
|
7b7f58d34d | ||
|
|
86112931d9 | ||
|
|
e6e0e4caea | ||
|
|
942154480e | ||
|
|
467131d9f1 | ||
|
|
fee1db8660 | ||
|
|
4f7fc1c9c8 |
@@ -10,6 +10,11 @@ SECRET_KEY=<GENERATE A SAFE SECRET KEY AND PLACE IT HERE>
|
||||
DJANGO_ALLOWED_HOSTS=localhost 127.0.0.1 [::1]
|
||||
OUTBOUND_PORT=9005
|
||||
|
||||
# Uncomment these variables to automatically create an admin account using these credentials on startup.
|
||||
# After your first successfull login you can remove these variables from your file for safety reasons.
|
||||
#ADMIN_EMAIL=<ENTER YOUR EMAIL>
|
||||
#ADMIN_PASSWORD=<YOUR SAFE PASSWORD>
|
||||
|
||||
SQL_DATABASE=wygiwyh
|
||||
SQL_USER=wygiwyh
|
||||
SQL_PASSWORD=<INSERT A SAFE PASSWORD HERE>
|
||||
|
||||
4
.github/FUNDING.yml
vendored
Normal file
4
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: eitchtee
|
||||
custom: ["https://www.paypal.com/donate/?hosted_button_id=FFWM4W9NQDMM6"]
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -20,6 +20,10 @@ on:
|
||||
env:
|
||||
IMAGE_NAME: wygiwyh
|
||||
|
||||
concurrency:
|
||||
group: release
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
72
.github/workflows/translations.yml
vendored
Normal file
72
.github/workflows/translations.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
name: Django Translation Update
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
# Add manual trigger
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
reason:
|
||||
description: 'Reason for running'
|
||||
required: false
|
||||
default: 'Manual update of translation files'
|
||||
|
||||
# Ensure only one translation job runs at a time
|
||||
concurrency:
|
||||
group: django-translations
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
update-translations:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
# Skip on PRs from forks (which don't have write permissions)
|
||||
# Allow manual runs and pushes to main
|
||||
if: github.event_name == 'workflow_dispatch' || github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository)
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
token: ${{ secrets.PAT }}
|
||||
ref: ${{ github.head_ref }}
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Install gettext
|
||||
run: sudo apt-get install -y gettext
|
||||
|
||||
- name: Run makemessages
|
||||
run: |
|
||||
cd app
|
||||
python manage.py makemessages -a
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
run: |
|
||||
if git diff --exit-code --quiet app/locale/; then
|
||||
echo "No translation changes detected"
|
||||
else
|
||||
echo "changes_detected=true" >> $GITHUB_OUTPUT
|
||||
echo "Translation changes detected"
|
||||
fi
|
||||
|
||||
- name: Commit translation files
|
||||
if: steps.check_changes.outputs.changes_detected == 'true'
|
||||
uses: stefanzweifel/git-auto-commit-action@v5
|
||||
with:
|
||||
push_options: --force
|
||||
commit_message: |
|
||||
chore(locale): update translation files
|
||||
|
||||
[skip ci] Automatically generated by Django makemessages workflow
|
||||
file_pattern: "app/locale/**/*.po"
|
||||
33
README.md
33
README.md
@@ -13,6 +13,7 @@
|
||||
<a href="#key-features">Features</a> •
|
||||
<a href="#how-to-use">Usage</a> •
|
||||
<a href="#how-it-works">How</a> •
|
||||
<a href="#help-us-translate-wygiwyh">Translate</a> •
|
||||
<a href="#caveats-and-warnings">Caveats and Warnings</a> •
|
||||
<a href="#built-with">Built with</a>
|
||||
</p>
|
||||
@@ -50,6 +51,17 @@ Frustrated by the lack of comprehensive options, I set out to build **WYGIWYH**
|
||||
* **Built-in Dollar-Cost Average (DCA) tracker**: Essential for tracking recurring investments, especially for crypto and stocks.
|
||||
* **API support for automation**: Seamlessly integrate with existing services to synchronize transactions.
|
||||
|
||||
# Demo
|
||||
|
||||
You can try WYGIWYH on [wygiwyh-demo.herculino.com](https://wygiwyh-demo.herculino.com/) with the credentials below:
|
||||
|
||||
> [!NOTE]
|
||||
> E-mail: `demo@demo.com`
|
||||
>
|
||||
> Password: `wygiwyhdemo`
|
||||
|
||||
Keep in mind that **any data you add will be wiped in 24 hours or less**. And that **most automation features like the API, Rules, Automatic Exchange Rates and Import/Export are disabled**.
|
||||
|
||||
# How To Use
|
||||
|
||||
To run this application, you'll need [Docker](https://docs.docker.com/engine/install/) with [docker-compose](https://docs.docker.com/compose/install/).
|
||||
@@ -75,12 +87,12 @@ $ nano .env # or any other editor you want to use
|
||||
# Run the app
|
||||
$ docker compose up -d
|
||||
|
||||
# Create the first admin account
|
||||
# Create the first admin account. This isn't required if you set the enviroment variables: ADMIN_EMAIL and ADMIN_PASSWORD.
|
||||
$ docker compose exec -it web python manage.py createsuperuser
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> If you're using Unraid, you don't need to follow these steps, use the app on the store. Make sure to read the [Unraid section](#unraid) and [Enviroment Variables](#enviroment-variables) for an explanation of all available variables
|
||||
> If you're using Unraid, you don't need to follow these steps, use the app on the store. Make sure to read the [Unraid section](#unraid) and [Environment Variables](#environment-variables) for an explanation of all available variables
|
||||
|
||||
## Running locally
|
||||
|
||||
@@ -110,29 +122,40 @@ WYGIWYH is available on the Unraid Store. You'll need to provision your own post
|
||||
|
||||
To create the first user, open the container's console using Unraid's UI, by clicking on WYGIWYH icon on the Docker page and selecting `Console`, then type `python manage.py createsuperuser`, you'll them be prompted to input your e-mail and password.
|
||||
|
||||
## Enviroment Variables
|
||||
## Environment Variables
|
||||
|
||||
| variable | type | default | explanation |
|
||||
|-------------------------------|-------------|-----------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| DJANGO_ALLOWED_HOSTS | string | localhost 127.0.0.1 | A list of space separated domains and IPs representing the host/domain names that WYGIWYH site can serve. [Click here](https://docs.djangoproject.com/en/5.1/ref/settings/#allowed-hosts) for more details |
|
||||
| HTTPS_ENABLED | true\|false | false | Whether to use secure cookies. If this is set to true, the cookie will be marked as “secure”, which means browsers may ensure that the cookie is only sent under an HTTPS connection |
|
||||
| URL | string | http://localhost http://127.0.0.1 | A list of space separated domains and IPs (with the protocol) representing the trusted origins for unsafe requests (e.g. POST). [Click here](https://docs.djangoproject.com/en/5.1/ref/settings/#csrf-trusted-origins ) for more details |
|
||||
| URL | string | http://localhost http://127.0.0.1 | A list of space separated domains and IPs (with the protocol) representing the trusted origins for unsafe requests (e.g. POST). [Click here](https://docs.djangoproject.com/en/5.1/ref/settings/#csrf-trusted-origins ) for more details |
|
||||
| SECRET_KEY | string | "" | This is used to provide cryptographic signing, and should be set to a unique, unpredictable value. |
|
||||
| DEBUG | true\|false | false | Turns DEBUG mode on or off, this is useful to gather more data about possible errors you're having. Don't use in production. |
|
||||
| SQL_DATABASE | string | None *required | The name of your postgres database |
|
||||
| SQL_USER | string | user | The username used to connect to your postgres database |
|
||||
| SQL_PASSWORD | string | password | The password used to connect to your postgres database |
|
||||
| SQL_HOST | string | localhost | The adress used to connect to your postgres database |
|
||||
| SQL_HOST | string | localhost | The address used to connect to your postgres database |
|
||||
| SQL_PORT | string | 5432 | The port used to connect to your postgres database |
|
||||
| SESSION_EXPIRY_TIME | int | 2678400 (31 days) | The age of session cookies, in seconds. E.g. how long you will stay logged in |
|
||||
| ENABLE_SOFT_DELETE | true\|false | false | Whether to enable transactions soft delete, if enabled, deleted transactions will remain in the database. Useful for imports and avoiding duplicate entries. |
|
||||
| KEEP_DELETED_TRANSACTIONS_FOR | int | 365 | Time in days to keep soft deleted transactions for. If 0, will keep all transactions indefinitely. Only works if ENABLE_SOFT_DELETE is true. |
|
||||
| TASK_WORKERS | int | 1 | How many workers to have for async tasks. One should be enough for most use cases |
|
||||
| DEMO | true\|false | false | If demo mode is enabled. |
|
||||
| ADMIN_EMAIL | string | None | Automatically creates an admin account with this email. Must have `ADMIN_PASSWORD` also set. |
|
||||
| ADMIN_PASSWORD | string | None | Automatically creates an admin account with this password. Must have `ADMIN_EMAIL` also set. |
|
||||
|
||||
# How it works
|
||||
|
||||
Check out our [Wiki](https://github.com/eitchtee/WYGIWYH/wiki) for more information.
|
||||
|
||||
# Help us translate WYGIWYH!
|
||||
<a href="https://translations.herculino.com/engage/wygiwyh/">
|
||||
<img src="https://translations.herculino.com/widget/wygiwyh/open-graph.png" alt="Translation status" />
|
||||
</a>
|
||||
|
||||
> [!NOTE]
|
||||
> Login with your github account
|
||||
|
||||
# Caveats and Warnings
|
||||
|
||||
- I'm not an accountant, some terms and even calculations might be wrong. Make sure to open an issue if you see anything that could be improved.
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
import logging
|
||||
|
||||
|
||||
class ProcrastinateFilter(logging.Filter):
|
||||
# from https://github.com/madzak/python-json-logger/blob/master/src/pythonjsonlogger/jsonlogger.py#L19
|
||||
_reserved_log_keys = frozenset(
|
||||
"""args asctime created exc_info exc_text filename
|
||||
funcName levelname levelno lineno module msecs message msg name pathname
|
||||
process processName relativeCreated stack_info thread threadName""".split()
|
||||
)
|
||||
|
||||
def filter(self, record: logging.LogRecord):
|
||||
record.procrastinate = {}
|
||||
for key, value in vars(record).items():
|
||||
if not key.startswith("_") and key not in self._reserved_log_keys | {
|
||||
"procrastinate"
|
||||
}:
|
||||
record.procrastinate[key] = value # type: ignore
|
||||
return True
|
||||
@@ -55,6 +55,7 @@ INSTALLED_APPS = [
|
||||
"hijack",
|
||||
"hijack.contrib.admin",
|
||||
"django_filters",
|
||||
"import_export",
|
||||
"apps.users.apps.UsersConfig",
|
||||
"procrastinate.contrib.django",
|
||||
"apps.transactions.apps.TransactionsConfig",
|
||||
@@ -63,6 +64,7 @@ INSTALLED_APPS = [
|
||||
"apps.common.apps.CommonConfig",
|
||||
"apps.net_worth.apps.NetWorthConfig",
|
||||
"apps.import_app.apps.ImportConfig",
|
||||
"apps.export_app.apps.ExportConfig",
|
||||
"apps.api.apps.ApiConfig",
|
||||
"cachalot",
|
||||
"rest_framework",
|
||||
@@ -75,6 +77,7 @@ INSTALLED_APPS = [
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
"django_browser_reload.middleware.BrowserReloadMiddleware",
|
||||
"apps.common.middleware.thread_local.ThreadLocalMiddleware",
|
||||
"debug_toolbar.middleware.DebugToolbarMiddleware",
|
||||
"django.middleware.security.SecurityMiddleware",
|
||||
@@ -87,7 +90,6 @@ MIDDLEWARE = [
|
||||
"apps.common.middleware.localization.LocalizationMiddleware",
|
||||
"django.contrib.messages.middleware.MessageMiddleware",
|
||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||
"django_browser_reload.middleware.BrowserReloadMiddleware",
|
||||
"hijack.middleware.HijackUserMiddleware",
|
||||
]
|
||||
|
||||
@@ -161,9 +163,105 @@ AUTH_USER_MODEL = "users.User"
|
||||
|
||||
LANGUAGE_CODE = "en"
|
||||
LANGUAGES = (
|
||||
("af", "Afrikaans"),
|
||||
("ar", "العربية"),
|
||||
("ar-dz", "العربية (الجزائر)"), # Algerian Arabic often uses the base name + region
|
||||
("ast", "Asturianu"),
|
||||
("az", "Azərbaycan"),
|
||||
("bg", "Български"),
|
||||
("be", "Беларуская"),
|
||||
("bn", "বাংলা"),
|
||||
("br", "Brezhoneg"),
|
||||
("bs", "Bosanski"),
|
||||
("ca", "Català"),
|
||||
("ckb", "کوردیی ناوەندی"), # Central Kurdish (Sorani)
|
||||
("cs", "Čeština"),
|
||||
("cy", "Cymraeg"),
|
||||
("da", "Dansk"),
|
||||
("de", "Deutsch"),
|
||||
("dsb", "Dolnoserbšćina"),
|
||||
("el", "Ελληνικά"),
|
||||
("en", "English"),
|
||||
("en-au", "English (Australia)"),
|
||||
("en-gb", "English (UK)"),
|
||||
("eo", "Esperanto"),
|
||||
("es", "Español"),
|
||||
("es-ar", "Español (Argentina)"),
|
||||
("es-co", "Español (Colombia)"),
|
||||
("es-mx", "Español (México)"),
|
||||
("es-ni", "Español (Nicaragua)"),
|
||||
("es-ve", "Español (Venezuela)"),
|
||||
("et", "Eesti"),
|
||||
("eu", "Euskara"),
|
||||
("fa", "فارسی"),
|
||||
("fi", "Suomi"),
|
||||
("fr", "Français"),
|
||||
("fy", "Frysk"),
|
||||
("ga", "Gaeilge"),
|
||||
("gd", "Gàidhlig"),
|
||||
("gl", "Galego"),
|
||||
("he", "עברית"),
|
||||
("hi", "हिन्दी"),
|
||||
("hr", "Hrvatski"),
|
||||
("hsb", "Hornjoserbšćina"),
|
||||
("hu", "Magyar"),
|
||||
("hy", "Հայերեն"),
|
||||
("ia", "Interlingua"),
|
||||
("id", "Bahasa Indonesia"),
|
||||
("ig", "Igbo"),
|
||||
("io", "Ido"),
|
||||
("is", "Íslenska"),
|
||||
("it", "Italiano"),
|
||||
("ja", "日本語"),
|
||||
("ka", "ქართული"),
|
||||
("kab", "Taqbaylit"),
|
||||
("kk", "Қазақша"),
|
||||
("km", "ខ្មែរ"),
|
||||
("kn", "ಕನ್ನಡ"),
|
||||
("ko", "한국어"),
|
||||
("ky", "Кыргызча"),
|
||||
("lb", "Lëtzebuergesch"),
|
||||
("lt", "Lietuvių"),
|
||||
("lv", "Latviešu"),
|
||||
("mk", "Македонски"),
|
||||
("ml", "മലയാളം"),
|
||||
("mn", "Монгол"),
|
||||
("mr", "मराठी"),
|
||||
("ms", "Bahasa Melayu"),
|
||||
("my", "မြန်မာဘာသာ"),
|
||||
("nb", "Norsk (Bokmål)"),
|
||||
("ne", "नेपाली"),
|
||||
("nl", "Nederlands"),
|
||||
("nn", "Norsk (Nynorsk)"),
|
||||
("os", "Ирон"), # Ossetic
|
||||
("pa", "ਪੰਜਾਬੀ"),
|
||||
("pl", "Polski"),
|
||||
("pt", "Português"),
|
||||
("pt-br", "Português (Brasil)"),
|
||||
("ro", "Română"),
|
||||
("ru", "Русский"),
|
||||
("sk", "Slovenčina"),
|
||||
("sl", "Slovenščina"),
|
||||
("sq", "Shqip"),
|
||||
("sr", "Српски"),
|
||||
("sr-latn", "Srpski (Latinica)"),
|
||||
("sv", "Svenska"),
|
||||
("sw", "Kiswahili"),
|
||||
("ta", "தமிழ்"),
|
||||
("te", "తెలుగు"),
|
||||
("tg", "Тоҷикӣ"),
|
||||
("th", "ไทย"),
|
||||
("tk", "Türkmençe"),
|
||||
("tr", "Türkçe"),
|
||||
("tt", "Татарча"),
|
||||
("udm", "Удмурт"),
|
||||
("ug", "ئۇيغۇرچە"),
|
||||
("uk", "Українська"),
|
||||
("ur", "اردو"),
|
||||
("uz", "Oʻzbekcha"),
|
||||
("vi", "Tiếng Việt"),
|
||||
("zh-hans", "简体中文"),
|
||||
("zh-hant", "繁體中文"),
|
||||
)
|
||||
|
||||
TIME_ZONE = os.getenv("TZ", "UTC")
|
||||
@@ -258,7 +356,10 @@ if DEBUG:
|
||||
REST_FRAMEWORK = {
|
||||
# Use Django's standard `django.contrib.auth` permissions,
|
||||
# or allow read-only access for unauthenticated users.
|
||||
"DEFAULT_PERMISSION_CLASSES": ["rest_framework.permissions.DjangoModelPermissions"],
|
||||
"DEFAULT_PERMISSION_CLASSES": [
|
||||
"apps.api.permissions.NotInDemoMode",
|
||||
"rest_framework.permissions.DjangoModelPermissions",
|
||||
],
|
||||
"DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
|
||||
"PAGE_SIZE": 10,
|
||||
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
|
||||
@@ -277,27 +378,16 @@ if "procrastinate" in sys.argv:
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"procrastinate": {
|
||||
"format": "[%(asctime)s] - %(levelname)s - %(name)s - %(message)s -> %(procrastinate)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
"standard": {
|
||||
"format": "[%(asctime)s] - %(levelname)s - %(name)s - %(message)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
},
|
||||
"filters": {
|
||||
"procrastinate": {
|
||||
"()": "WYGIWYH.logs.ProcrastinateFilter.ProcrastinateFilter",
|
||||
"name": "procrastinate",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"procrastinate": {
|
||||
"level": "INFO",
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "procrastinate",
|
||||
"filters": ["procrastinate"],
|
||||
"formatter": "standard",
|
||||
},
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
@@ -308,10 +398,10 @@ if "procrastinate" in sys.argv:
|
||||
"loggers": {
|
||||
"procrastinate": {
|
||||
"handlers": ["procrastinate"],
|
||||
"propagate": True,
|
||||
"propagate": False,
|
||||
},
|
||||
"root": {
|
||||
"handlers": None,
|
||||
"handlers": ["console"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
@@ -402,3 +492,4 @@ PWA_SERVICE_WORKER_PATH = BASE_DIR / "templates" / "pwa" / "serviceworker.js"
|
||||
ENABLE_SOFT_DELETE = os.getenv("ENABLE_SOFT_DELETE", "false").lower() == "true"
|
||||
KEEP_DELETED_TRANSACTIONS_FOR = int(os.getenv("KEEP_DELETED_ENTRIES_FOR", "365"))
|
||||
APP_VERSION = os.getenv("APP_VERSION", "unknown")
|
||||
DEMO = os.getenv("DEMO", "false").lower() == "true"
|
||||
|
||||
@@ -49,4 +49,6 @@ urlpatterns = [
|
||||
path("", include("apps.dca.urls")),
|
||||
path("", include("apps.mini_tools.urls")),
|
||||
path("", include("apps.import_app.urls")),
|
||||
path("", include("apps.export_app.urls")),
|
||||
path("", include("apps.insights.urls")),
|
||||
]
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
from django.contrib import admin
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.common.admin import SharedObjectModelAdmin
|
||||
|
||||
|
||||
admin.site.register(Account)
|
||||
@admin.register(Account)
|
||||
class AccountModelAdmin(SharedObjectModelAdmin):
|
||||
pass
|
||||
|
||||
|
||||
@admin.register(AccountGroup)
|
||||
class AccountGroupModelAdmin(SharedObjectModelAdmin):
|
||||
pass
|
||||
|
||||
@@ -77,6 +77,8 @@ class AccountForm(forms.ModelForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fields["group"].queryset = AccountGroup.objects.all()
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.form_method = "post"
|
||||
@@ -151,5 +153,11 @@ class AccountBalanceForm(forms.Form):
|
||||
decimal_places=self.currency_decimal_places
|
||||
)
|
||||
|
||||
self.fields["category"].queryset = TransactionCategory.objects.filter(
|
||||
active=True
|
||||
)
|
||||
|
||||
self.fields["tags"].queryset = TransactionTag.objects.filter(active=True)
|
||||
|
||||
|
||||
AccountBalanceFormSet = forms.formset_factory(AccountBalanceForm, extra=0)
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-04 15:07
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0008_alter_account_name'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='account',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='owned_accounts', to=settings.AUTH_USER_MODEL, verbose_name='Owner'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='account',
|
||||
name='shared_with',
|
||||
field=models.ManyToManyField(blank=True, related_name='shared_accounts', to=settings.AUTH_USER_MODEL, verbose_name='Shared With'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='accountgroup',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='owned_account_groups', to=settings.AUTH_USER_MODEL, verbose_name='Owner'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,46 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-05 02:42
|
||||
|
||||
import django.db.models.manager
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0009_account_owner_account_shared_with_accountgroup_owner'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelManagers(
|
||||
name='account',
|
||||
managers=[
|
||||
('all_objects', django.db.models.manager.Manager()),
|
||||
],
|
||||
),
|
||||
migrations.AlterModelManagers(
|
||||
name='accountgroup',
|
||||
managers=[
|
||||
('all_objects', django.db.models.manager.Manager()),
|
||||
],
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='account',
|
||||
name='name',
|
||||
field=models.CharField(max_length=255, verbose_name='Name'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='accountgroup',
|
||||
name='name',
|
||||
field=models.CharField(max_length=255, verbose_name='Name'),
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name='account',
|
||||
unique_together={('owner', 'name')},
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name='accountgroup',
|
||||
unique_together={('owner', 'name')},
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,26 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-05 04:19
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0010_alter_account_managers_alter_accountgroup_managers_and_more'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='account',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='owned_accounts', to=settings.AUTH_USER_MODEL, verbose_name='Owner'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='accountgroup',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='owned_account_groups', to=settings.AUTH_USER_MODEL, verbose_name='Owner'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,56 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-05 23:46
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0011_alter_account_owner_alter_accountgroup_owner'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelManagers(
|
||||
name='account',
|
||||
managers=[
|
||||
],
|
||||
),
|
||||
migrations.AlterModelManagers(
|
||||
name='accountgroup',
|
||||
managers=[
|
||||
],
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='account',
|
||||
name='visibility',
|
||||
field=models.CharField(choices=[('private', 'Private'), ('shared', 'Shared'), ('public', 'Public')], default='private', max_length=10),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='accountgroup',
|
||||
name='shared_with',
|
||||
field=models.ManyToManyField(blank=True, related_name='%(class)s_shared', to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='accountgroup',
|
||||
name='visibility',
|
||||
field=models.CharField(choices=[('private', 'Private'), ('shared', 'Shared'), ('public', 'Public')], default='private', max_length=10),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='account',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(class)s_owned', to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='account',
|
||||
name='shared_with',
|
||||
field=models.ManyToManyField(blank=True, related_name='%(class)s_shared', to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='accountgroup',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(class)s_owned', to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-06 01:47
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0012_alter_account_managers_alter_accountgroup_managers_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='account',
|
||||
name='visibility',
|
||||
field=models.CharField(choices=[('private', 'Private'), ('public', 'Public')], default='private', max_length=10),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='accountgroup',
|
||||
name='visibility',
|
||||
field=models.CharField(choices=[('private', 'Private'), ('public', 'Public')], default='private', max_length=10),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,21 @@
|
||||
# Generated by Django 5.1.7 on 2025-03-09 21:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0013_alter_account_visibility_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name='account',
|
||||
options={'ordering': ['name', 'id'], 'verbose_name': 'Account', 'verbose_name_plural': 'Accounts'},
|
||||
),
|
||||
migrations.AlterModelOptions(
|
||||
name='accountgroup',
|
||||
options={'ordering': ['name', 'id'], 'verbose_name': 'Account Group', 'verbose_name_plural': 'Account Groups'},
|
||||
),
|
||||
]
|
||||
@@ -1,24 +1,32 @@
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.transactions.models import Transaction
|
||||
from apps.common.models import SharedObject, SharedObjectManager
|
||||
|
||||
|
||||
class AccountGroup(models.Model):
|
||||
name = models.CharField(max_length=255, verbose_name=_("Name"), unique=True)
|
||||
class AccountGroup(SharedObject):
|
||||
name = models.CharField(max_length=255, verbose_name=_("Name"))
|
||||
|
||||
objects = SharedObjectManager()
|
||||
all_objects = models.Manager() # Unfiltered manager
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Account Group")
|
||||
verbose_name_plural = _("Account Groups")
|
||||
db_table = "account_groups"
|
||||
unique_together = (("owner", "name"),)
|
||||
ordering = ["name", "id"]
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Account(models.Model):
|
||||
name = models.CharField(max_length=255, verbose_name=_("Name"), unique=True)
|
||||
class Account(SharedObject):
|
||||
name = models.CharField(max_length=255, verbose_name=_("Name"))
|
||||
group = models.ForeignKey(
|
||||
AccountGroup,
|
||||
on_delete=models.SET_NULL,
|
||||
@@ -55,9 +63,14 @@ class Account(models.Model):
|
||||
help_text=_("Archived accounts don't show up nor count towards your net worth"),
|
||||
)
|
||||
|
||||
objects = SharedObjectManager()
|
||||
all_objects = models.Manager() # Unfiltered manager
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Account")
|
||||
verbose_name_plural = _("Accounts")
|
||||
unique_together = (("owner", "name"),)
|
||||
ordering = ["name", "id"]
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -1,19 +1,33 @@
|
||||
from django.test import TestCase
|
||||
from django.test import TestCase, Client
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.contrib.auth.models import User
|
||||
from django.db import IntegrityError, models
|
||||
from django.utils import timezone
|
||||
from django.urls import reverse
|
||||
from decimal import Decimal
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.accounts.forms import AccountForm
|
||||
from apps.transactions.models import Transaction, TransactionCategory
|
||||
|
||||
|
||||
class AccountTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
self.owner1 = User.objects.create_user(username='testowner', password='password123')
|
||||
self.client = Client()
|
||||
self.client.login(username='testowner', password='password123')
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.exchange_currency = Currency.objects.create(
|
||||
self.eur = Currency.objects.create(
|
||||
code="EUR", name="Euro", decimal_places=2, prefix="€ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group", owner=self.owner1)
|
||||
self.reconciliation_category = TransactionCategory.objects.create(name='Reconciliation', owner=self.owner1, type='INFO')
|
||||
|
||||
|
||||
def test_account_creation(self):
|
||||
"""Test basic account creation"""
|
||||
@@ -35,7 +49,262 @@ class AccountTests(TestCase):
|
||||
"""Test account creation with exchange currency"""
|
||||
account = Account.objects.create(
|
||||
name="Exchange Account",
|
||||
owner=self.owner1, # Added owner
|
||||
group=self.account_group, # Added group
|
||||
currency=self.currency,
|
||||
exchange_currency=self.exchange_currency,
|
||||
exchange_currency=self.eur, # Changed to self.eur
|
||||
)
|
||||
self.assertEqual(account.exchange_currency, self.exchange_currency)
|
||||
self.assertEqual(account.exchange_currency, self.eur) # Changed to self.eur
|
||||
|
||||
def test_account_archiving(self):
|
||||
"""Test archiving and unarchiving an account"""
|
||||
account = Account.objects.create(
|
||||
name="Archivable Account",
|
||||
owner=self.owner1, # Added owner
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
is_asset=True, # Assuming default, can be anything for this test
|
||||
is_archived=False,
|
||||
)
|
||||
self.assertFalse(account.is_archived, "Account should initially be unarchived")
|
||||
|
||||
# Archive the account
|
||||
account.is_archived = True
|
||||
account.save()
|
||||
|
||||
archived_account = Account.objects.get(pk=account.pk)
|
||||
self.assertTrue(archived_account.is_archived, "Account should be archived")
|
||||
|
||||
# Unarchive the account
|
||||
archived_account.is_archived = False
|
||||
archived_account.save()
|
||||
|
||||
unarchived_account = Account.objects.get(pk=account.pk)
|
||||
self.assertFalse(unarchived_account.is_archived, "Account should be unarchived")
|
||||
|
||||
def test_account_exchange_currency_cannot_be_same_as_currency(self):
|
||||
"""Test that exchange_currency cannot be the same as currency."""
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
account = Account(
|
||||
name="Same Currency Account",
|
||||
owner=self.owner1, # Added owner
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
exchange_currency=self.currency, # Same as currency
|
||||
)
|
||||
account.full_clean()
|
||||
self.assertIn('exchange_currency', cm.exception.error_dict)
|
||||
# To check for a specific message (optional, might make test brittle):
|
||||
# self.assertTrue(any("cannot be the same as the main currency" in e.message
|
||||
# for e in cm.exception.error_dict['exchange_currency']))
|
||||
|
||||
def test_account_name_unique_per_owner(self):
|
||||
"""Test that account name is unique per owner."""
|
||||
owner1 = User.objects.create_user(username='owner1', password='password123')
|
||||
owner2 = User.objects.create_user(username='owner2', password='password123')
|
||||
|
||||
# Initial account for self.owner1 (owner1 from setUp)
|
||||
Account.objects.create(
|
||||
name="Unique Name Test",
|
||||
owner=self.owner1, # Changed to self.owner1
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
)
|
||||
|
||||
# Attempt to create another account with the same name and self.owner1 - should fail
|
||||
with self.assertRaises(IntegrityError):
|
||||
Account.objects.create(
|
||||
name="Unique Name Test",
|
||||
owner=self.owner1, # Changed to self.owner1
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
)
|
||||
|
||||
# Create account with the same name but for owner2 - should succeed
|
||||
try:
|
||||
Account.objects.create(
|
||||
name="Unique Name Test",
|
||||
owner=owner2, # owner2 is locally defined here, that's fine for this test
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
)
|
||||
except IntegrityError:
|
||||
self.fail("Creating account with same name but different owner failed unexpectedly.")
|
||||
|
||||
# Create account with a different name for self.owner1 - should succeed
|
||||
try:
|
||||
Account.objects.create(
|
||||
name="Another Name Test",
|
||||
owner=self.owner1, # Changed to self.owner1
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
)
|
||||
except IntegrityError:
|
||||
self.fail("Creating account with different name for the same owner failed unexpectedly.")
|
||||
|
||||
def test_account_form_valid_data(self):
|
||||
"""Test AccountForm with valid data."""
|
||||
form_data = {
|
||||
'name': 'Form Test Account',
|
||||
'group': self.account_group.pk,
|
||||
'currency': self.currency.pk,
|
||||
'exchange_currency': self.eur.pk,
|
||||
'is_asset': True,
|
||||
'is_archived': False,
|
||||
'description': 'A valid test account from form.'
|
||||
}
|
||||
form = AccountForm(data=form_data)
|
||||
self.assertTrue(form.is_valid(), form.errors.as_text())
|
||||
|
||||
account = form.save(commit=False)
|
||||
account.owner = self.owner1
|
||||
account.save()
|
||||
|
||||
self.assertEqual(account.name, 'Form Test Account')
|
||||
self.assertEqual(account.owner, self.owner1)
|
||||
self.assertEqual(account.group, self.account_group)
|
||||
self.assertEqual(account.currency, self.currency)
|
||||
self.assertEqual(account.exchange_currency, self.eur)
|
||||
self.assertTrue(account.is_asset)
|
||||
self.assertFalse(account.is_archived)
|
||||
|
||||
def test_account_form_missing_name(self):
|
||||
"""Test AccountForm with missing name."""
|
||||
form_data = {
|
||||
'group': self.account_group.pk,
|
||||
'currency': self.currency.pk,
|
||||
}
|
||||
form = AccountForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('name', form.errors)
|
||||
|
||||
def test_account_form_exchange_currency_same_as_currency(self):
|
||||
"""Test AccountForm where exchange_currency is the same as currency."""
|
||||
form_data = {
|
||||
'name': 'Same Currency Form Account',
|
||||
'group': self.account_group.pk,
|
||||
'currency': self.currency.pk,
|
||||
'exchange_currency': self.currency.pk, # Same as currency
|
||||
}
|
||||
form = AccountForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('exchange_currency', form.errors)
|
||||
|
||||
|
||||
class AccountGroupTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test data for AccountGroup tests."""
|
||||
self.owner1 = User.objects.create_user(username='groupowner1', password='password123')
|
||||
self.owner2 = User.objects.create_user(username='groupowner2', password='password123')
|
||||
|
||||
def test_account_group_creation(self):
|
||||
"""Test basic AccountGroup creation."""
|
||||
group = AccountGroup.objects.create(name="Test Group", owner=self.owner1)
|
||||
self.assertEqual(group.name, "Test Group")
|
||||
self.assertEqual(group.owner, self.owner1)
|
||||
self.assertEqual(str(group), "Test Group") # Assuming __str__ returns the name
|
||||
|
||||
def test_account_group_name_unique_per_owner(self):
|
||||
"""Test that AccountGroup name is unique per owner."""
|
||||
# Initial group for owner1
|
||||
AccountGroup.objects.create(name="Unique Group Name", owner=self.owner1)
|
||||
|
||||
# Attempt to create another group with the same name and owner1 - should fail
|
||||
with self.assertRaises(IntegrityError):
|
||||
AccountGroup.objects.create(name="Unique Group Name", owner=self.owner1)
|
||||
|
||||
# Create group with the same name but for owner2 - should succeed
|
||||
try:
|
||||
AccountGroup.objects.create(name="Unique Group Name", owner=self.owner2)
|
||||
except IntegrityError:
|
||||
self.fail("Creating group with same name but different owner failed unexpectedly.")
|
||||
|
||||
# Create group with a different name for owner1 - should succeed
|
||||
try:
|
||||
AccountGroup.objects.create(name="Another Group Name", owner=self.owner1)
|
||||
except IntegrityError:
|
||||
self.fail("Creating group with different name for the same owner failed unexpectedly.")
|
||||
|
||||
def test_account_reconciliation_creates_transaction(self):
|
||||
"""Test that account_reconciliation view creates a transaction for the difference."""
|
||||
|
||||
# Helper function to get balance
|
||||
def get_balance(account):
|
||||
balance = account.transactions.filter(is_paid=True).aggregate(
|
||||
total_income=models.Sum('amount', filter=models.Q(type=Transaction.Type.INCOME)),
|
||||
total_expense=models.Sum('amount', filter=models.Q(type=Transaction.Type.EXPENSE)),
|
||||
total_transfer_in=models.Sum('amount', filter=models.Q(type=Transaction.Type.TRANSFER, transfer_to_account=account)),
|
||||
total_transfer_out=models.Sum('amount', filter=models.Q(type=Transaction.Type.TRANSFER, account=account))
|
||||
)['total_income'] or Decimal('0.00')
|
||||
balance -= account.transactions.filter(is_paid=True).aggregate(
|
||||
total_expense=models.Sum('amount', filter=models.Q(type=Transaction.Type.EXPENSE))
|
||||
)['total_expense'] or Decimal('0.00')
|
||||
# For transfers, a more complete logic might be needed if transfers are involved in reconciliation scope
|
||||
return balance
|
||||
|
||||
account_usd = Account.objects.create(
|
||||
name="USD Account for Recon",
|
||||
owner=self.owner1,
|
||||
currency=self.currency,
|
||||
group=self.account_group
|
||||
)
|
||||
account_eur = Account.objects.create(
|
||||
name="EUR Account for Recon",
|
||||
owner=self.owner1,
|
||||
currency=self.eur,
|
||||
group=self.account_group
|
||||
)
|
||||
|
||||
# Initial transactions
|
||||
Transaction.objects.create(account=account_usd, type=Transaction.Type.INCOME, amount=Decimal('100.00'), date=timezone.localdate(timezone.now()), description='Initial USD', category=self.reconciliation_category, owner=self.owner1, is_paid=True)
|
||||
Transaction.objects.create(account=account_eur, type=Transaction.Type.INCOME, amount=Decimal('200.00'), date=timezone.localdate(timezone.now()), description='Initial EUR', category=self.reconciliation_category, owner=self.owner1, is_paid=True)
|
||||
Transaction.objects.create(account=account_eur, type=Transaction.Type.EXPENSE, amount=Decimal('50.00'), date=timezone.localdate(timezone.now()), description='EUR Expense', category=self.reconciliation_category, owner=self.owner1, is_paid=True)
|
||||
|
||||
initial_usd_balance = get_balance(account_usd) # Should be 100.00
|
||||
initial_eur_balance = get_balance(account_eur) # Should be 150.00
|
||||
self.assertEqual(initial_usd_balance, Decimal('100.00'))
|
||||
self.assertEqual(initial_eur_balance, Decimal('150.00'))
|
||||
|
||||
initial_transaction_count = Transaction.objects.filter(owner=self.owner1).count() # Should be 3
|
||||
|
||||
formset_data = {
|
||||
'form-TOTAL_FORMS': '2',
|
||||
'form-INITIAL_FORMS': '2', # Based on view logic, it builds initial data for all accounts
|
||||
'form-MAX_NUM_FORMS': '', # Can be empty or a number >= TOTAL_FORMS
|
||||
'form-0-account_id': account_usd.id,
|
||||
'form-0-new_balance': '120.00', # New balance for USD account (implies +20 adjustment)
|
||||
'form-0-category': self.reconciliation_category.id,
|
||||
'form-1-account_id': account_eur.id,
|
||||
'form-1-new_balance': '150.00', # Same as current balance for EUR account (no adjustment)
|
||||
'form-1-category': self.reconciliation_category.id,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
reverse('accounts:account_reconciliation'),
|
||||
data=formset_data,
|
||||
HTTP_HX_REQUEST='true' # Required if view uses @only_htmx
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 204, response.content.decode()) # 204 No Content for successful HTMX POST
|
||||
|
||||
# Check that only one new transaction was created
|
||||
self.assertEqual(Transaction.objects.filter(owner=self.owner1).count(), initial_transaction_count + 1)
|
||||
|
||||
# Get the newly created transaction
|
||||
new_transaction = Transaction.objects.filter(
|
||||
account=account_usd,
|
||||
description="Balance reconciliation"
|
||||
).first()
|
||||
|
||||
self.assertIsNotNone(new_transaction)
|
||||
self.assertEqual(new_transaction.type, Transaction.Type.INCOME)
|
||||
self.assertEqual(new_transaction.amount, Decimal('20.00'))
|
||||
self.assertEqual(new_transaction.category, self.reconciliation_category)
|
||||
self.assertEqual(new_transaction.owner, self.owner1)
|
||||
self.assertTrue(new_transaction.is_paid)
|
||||
self.assertEqual(new_transaction.date, timezone.localdate(timezone.now()))
|
||||
|
||||
|
||||
# Verify final balances
|
||||
self.assertEqual(get_balance(account_usd), Decimal('120.00'))
|
||||
self.assertEqual(get_balance(account_eur), Decimal('150.00'))
|
||||
|
||||
@@ -16,11 +16,21 @@ urlpatterns = [
|
||||
views.account_edit,
|
||||
name="account_edit",
|
||||
),
|
||||
path(
|
||||
"account/<int:pk>/share/",
|
||||
views.account_share,
|
||||
name="account_share_settings",
|
||||
),
|
||||
path(
|
||||
"account/<int:pk>/delete/",
|
||||
views.account_delete,
|
||||
name="account_delete",
|
||||
),
|
||||
path(
|
||||
"account/<int:pk>/take-ownership/",
|
||||
views.account_take_ownership,
|
||||
name="account_take_ownership",
|
||||
),
|
||||
path("account-groups/", views.account_groups_index, name="account_groups_index"),
|
||||
path("account-groups/list/", views.account_groups_list, name="account_groups_list"),
|
||||
path("account-groups/add/", views.account_group_add, name="account_group_add"),
|
||||
@@ -34,4 +44,14 @@ urlpatterns = [
|
||||
views.account_group_delete,
|
||||
name="account_group_delete",
|
||||
),
|
||||
path(
|
||||
"account-groups/<int:pk>/take-ownership/",
|
||||
views.account_group_take_ownership,
|
||||
name="account_group_take_ownership",
|
||||
),
|
||||
path(
|
||||
"account-groups/<int:pk>/share/",
|
||||
views.account_group_share,
|
||||
name="account_group_share_settings",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -8,6 +8,8 @@ from django.views.decorators.http import require_http_methods
|
||||
from apps.accounts.forms import AccountGroupForm
|
||||
from apps.accounts.models import AccountGroup
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
from apps.common.models import SharedObject
|
||||
from apps.common.forms import SharedObjectForm
|
||||
|
||||
|
||||
@login_required
|
||||
@@ -63,6 +65,16 @@ def account_group_add(request, **kwargs):
|
||||
def account_group_edit(request, pk):
|
||||
account_group = get_object_or_404(AccountGroup, id=pk)
|
||||
|
||||
if account_group.owner and account_group.owner != request.user:
|
||||
messages.error(request, _("Only the owner can edit this"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
if request.method == "POST":
|
||||
form = AccountGroupForm(request.POST, instance=account_group)
|
||||
if form.is_valid():
|
||||
@@ -91,9 +103,15 @@ def account_group_edit(request, pk):
|
||||
def account_group_delete(request, pk):
|
||||
account_group = get_object_or_404(AccountGroup, id=pk)
|
||||
|
||||
account_group.delete()
|
||||
|
||||
messages.success(request, _("Account Group deleted successfully"))
|
||||
if (
|
||||
account_group.owner != request.user
|
||||
and request.user in account_group.shared_with.all()
|
||||
):
|
||||
account_group.shared_with.remove(request.user)
|
||||
messages.success(request, _("Item no longer shared with you"))
|
||||
else:
|
||||
account_group.delete()
|
||||
messages.success(request, _("Account Group deleted successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
@@ -101,3 +119,62 @@ def account_group_delete(request, pk):
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def account_group_take_ownership(request, pk):
|
||||
account_group = get_object_or_404(AccountGroup, id=pk)
|
||||
|
||||
if not account_group.owner:
|
||||
account_group.owner = request.user
|
||||
account_group.visibility = SharedObject.Visibility.private
|
||||
account_group.save()
|
||||
|
||||
messages.success(request, _("Ownership taken successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def account_group_share(request, pk):
|
||||
obj = get_object_or_404(AccountGroup, id=pk)
|
||||
|
||||
if obj.owner and obj.owner != request.user:
|
||||
messages.error(request, _("Only the owner can edit this"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
if request.method == "POST":
|
||||
form = SharedObjectForm(request.POST, instance=obj, user=request.user)
|
||||
if form.is_valid():
|
||||
form.save()
|
||||
messages.success(request, _("Configuration saved successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
else:
|
||||
form = SharedObjectForm(instance=obj, user=request.user)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"accounts/fragments/share.html",
|
||||
{"form": form, "object": obj},
|
||||
)
|
||||
|
||||
@@ -8,6 +8,8 @@ from django.views.decorators.http import require_http_methods
|
||||
from apps.accounts.forms import AccountForm
|
||||
from apps.accounts.models import Account
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
from apps.common.models import SharedObject
|
||||
from apps.common.forms import SharedObjectForm
|
||||
|
||||
|
||||
@login_required
|
||||
@@ -62,6 +64,15 @@ def account_add(request, **kwargs):
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def account_edit(request, pk):
|
||||
account = get_object_or_404(Account, id=pk)
|
||||
if account.owner and account.owner != request.user:
|
||||
messages.error(request, _("Only the owner can edit this"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
if request.method == "POST":
|
||||
form = AccountForm(request.POST, instance=account)
|
||||
@@ -85,15 +96,77 @@ def account_edit(request, pk):
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def account_share(request, pk):
|
||||
obj = get_object_or_404(Account, id=pk)
|
||||
|
||||
if obj.owner and obj.owner != request.user:
|
||||
messages.error(request, _("Only the owner can edit this"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
if request.method == "POST":
|
||||
form = SharedObjectForm(request.POST, instance=obj, user=request.user)
|
||||
if form.is_valid():
|
||||
form.save()
|
||||
messages.success(request, _("Configuration saved successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
else:
|
||||
form = SharedObjectForm(instance=obj, user=request.user)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"accounts/fragments/share.html",
|
||||
{"form": form, "object": obj},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["DELETE"])
|
||||
def account_delete(request, pk):
|
||||
account = get_object_or_404(Account, id=pk)
|
||||
|
||||
account.delete()
|
||||
|
||||
messages.success(request, _("Account deleted successfully"))
|
||||
if account.owner != request.user and request.user in account.shared_with.all():
|
||||
account.shared_with.remove(request.user)
|
||||
messages.success(request, _("Item no longer shared with you"))
|
||||
else:
|
||||
account.delete()
|
||||
messages.success(request, _("Account deleted successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def account_take_ownership(request, pk):
|
||||
account = get_object_or_404(Account, id=pk)
|
||||
|
||||
if not account.owner:
|
||||
account.owner = request.user
|
||||
account.visibility = SharedObject.Visibility.private
|
||||
account.save()
|
||||
|
||||
messages.success(request, _("Ownership taken successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
|
||||
@@ -38,9 +38,9 @@ def account_reconciliation(request):
|
||||
"prefix": account.currency.prefix,
|
||||
"current_balance": get_account_balance(account),
|
||||
}
|
||||
for account in Account.objects.filter(is_archived=False).select_related(
|
||||
"currency", "group"
|
||||
)
|
||||
for account in Account.objects.filter(is_archived=False)
|
||||
.select_related("currency", "group")
|
||||
.order_by("group", "name")
|
||||
]
|
||||
|
||||
if request.method == "POST":
|
||||
|
||||
6
app/apps/api/custom/pagination.py
Normal file
6
app/apps/api/custom/pagination.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
|
||||
|
||||
class CustomPageNumberPagination(PageNumberPagination):
|
||||
page_size = 100
|
||||
page_size_query_param = "page_size"
|
||||
@@ -1,8 +1,6 @@
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework import serializers
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.transactions.models import (
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
@@ -29,7 +27,11 @@ class TransactionCategoryField(serializers.Field):
|
||||
_("Category with this ID does not exist.")
|
||||
)
|
||||
elif isinstance(data, str):
|
||||
category, created = TransactionCategory.objects.get_or_create(name=data)
|
||||
try:
|
||||
category = TransactionCategory.objects.get(name=data)
|
||||
except TransactionCategory.DoesNotExist:
|
||||
category = TransactionCategory(name=data)
|
||||
category.save()
|
||||
return category
|
||||
raise serializers.ValidationError(
|
||||
_("Invalid category data. Provide an ID or name.")
|
||||
@@ -39,7 +41,10 @@ class TransactionCategoryField(serializers.Field):
|
||||
def get_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "description": "TransactionTag ID or name"},
|
||||
"items": {
|
||||
"type": "string",
|
||||
"description": "TransactionCategory ID or name",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -65,7 +70,11 @@ class TransactionTagField(serializers.Field):
|
||||
_("Tag with this ID does not exist.")
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
tag, created = TransactionTag.objects.get_or_create(name=item)
|
||||
try:
|
||||
tag = TransactionTag.objects.get(name=item)
|
||||
except TransactionTag.DoesNotExist:
|
||||
tag = TransactionTag(name=item)
|
||||
tag.save()
|
||||
else:
|
||||
raise serializers.ValidationError(
|
||||
_("Invalid tag data. Provide an ID or name.")
|
||||
@@ -74,6 +83,13 @@ class TransactionTagField(serializers.Field):
|
||||
return tags
|
||||
|
||||
|
||||
@extend_schema_field(
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"oneOf": [{"type": "string"}, {"type": "integer"}]},
|
||||
"description": "TransactionEntity ID or name. If the name doesn't exist, a new one will be created",
|
||||
}
|
||||
)
|
||||
class TransactionEntityField(serializers.Field):
|
||||
def to_representation(self, value):
|
||||
return [{"id": entity.id, "name": entity.name} for entity in value.all()]
|
||||
@@ -84,12 +100,16 @@ class TransactionEntityField(serializers.Field):
|
||||
if isinstance(item, int):
|
||||
try:
|
||||
entity = TransactionEntity.objects.get(pk=item)
|
||||
except TransactionTag.DoesNotExist:
|
||||
except TransactionEntity.DoesNotExist:
|
||||
raise serializers.ValidationError(
|
||||
_("Entity with this ID does not exist.")
|
||||
)
|
||||
elif isinstance(item, str):
|
||||
entity, created = TransactionEntity.objects.get_or_create(name=item)
|
||||
try:
|
||||
entity = TransactionEntity.objects.get(name=item)
|
||||
except TransactionEntity.DoesNotExist:
|
||||
entity = TransactionEntity(name=item)
|
||||
entity.save()
|
||||
else:
|
||||
raise serializers.ValidationError(
|
||||
_("Invalid entity data. Provide an ID or name.")
|
||||
|
||||
10
app/apps/api/permissions.py
Normal file
10
app/apps/api/permissions.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from rest_framework.permissions import BasePermission
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class NotInDemoMode(BasePermission):
|
||||
def has_permission(self, request, view):
|
||||
if settings.DEMO and not request.user.is_superuser:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -1,3 +1,4 @@
|
||||
from django.db.models import Q
|
||||
from rest_framework import serializers
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
|
||||
@@ -22,6 +23,7 @@ class AccountSerializer(serializers.ModelSerializer):
|
||||
write_only=True,
|
||||
allow_null=True,
|
||||
)
|
||||
|
||||
currency = CurrencySerializer(read_only=True)
|
||||
currency_id = serializers.PrimaryKeyRelatedField(
|
||||
queryset=Currency.objects.all(), source="currency", write_only=True
|
||||
@@ -50,6 +52,13 @@ class AccountSerializer(serializers.ModelSerializer):
|
||||
"is_asset",
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
request = self.context.get("request")
|
||||
if request and request.user.is_authenticated:
|
||||
# Reload the queryset to get an updated version with the requesting user
|
||||
self.fields["group_id"].queryset = AccountGroup.objects.all()
|
||||
|
||||
def create(self, validated_data):
|
||||
return Account.objects.create(**validated_data)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular import openapi
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
@@ -21,6 +23,7 @@ from apps.transactions.models import (
|
||||
TransactionEntity,
|
||||
RecurringTransaction,
|
||||
)
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
|
||||
|
||||
class TransactionCategorySerializer(serializers.ModelSerializer):
|
||||
@@ -29,6 +32,10 @@ class TransactionCategorySerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = TransactionCategory
|
||||
fields = "__all__"
|
||||
read_only_fields = [
|
||||
"id",
|
||||
"owner",
|
||||
]
|
||||
|
||||
|
||||
class TransactionTagSerializer(serializers.ModelSerializer):
|
||||
@@ -37,6 +44,10 @@ class TransactionTagSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = TransactionTag
|
||||
fields = "__all__"
|
||||
read_only_fields = [
|
||||
"id",
|
||||
"owner",
|
||||
]
|
||||
|
||||
|
||||
class TransactionEntitySerializer(serializers.ModelSerializer):
|
||||
@@ -45,12 +56,16 @@ class TransactionEntitySerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = TransactionEntity
|
||||
fields = "__all__"
|
||||
read_only_fields = [
|
||||
"id",
|
||||
"owner",
|
||||
]
|
||||
|
||||
|
||||
class InstallmentPlanSerializer(serializers.ModelSerializer):
|
||||
category = TransactionCategoryField(required=False)
|
||||
tags = TransactionTagField(required=False)
|
||||
entities = TransactionEntityField(required=False)
|
||||
category: str | int = TransactionCategoryField(required=False)
|
||||
tags: str | int = TransactionTagField(required=False)
|
||||
entities: str | int = TransactionEntityField(required=False)
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
@@ -88,9 +103,9 @@ class InstallmentPlanSerializer(serializers.ModelSerializer):
|
||||
|
||||
|
||||
class RecurringTransactionSerializer(serializers.ModelSerializer):
|
||||
category = TransactionCategoryField(required=False)
|
||||
tags = TransactionTagField(required=False)
|
||||
entities = TransactionEntityField(required=False)
|
||||
category: str | int = TransactionCategoryField(required=False)
|
||||
tags: str | int = TransactionTagField(required=False)
|
||||
entities: str | int = TransactionEntityField(required=False)
|
||||
|
||||
class Meta:
|
||||
model = RecurringTransaction
|
||||
@@ -127,9 +142,9 @@ class RecurringTransactionSerializer(serializers.ModelSerializer):
|
||||
|
||||
|
||||
class TransactionSerializer(serializers.ModelSerializer):
|
||||
category = TransactionCategoryField(required=False)
|
||||
tags = TransactionTagField(required=False)
|
||||
entities = TransactionEntityField(required=False)
|
||||
category: str | int = TransactionCategoryField(required=False)
|
||||
tags: str | int = TransactionTagField(required=False)
|
||||
entities: str | int = TransactionEntityField(required=False)
|
||||
|
||||
exchanged_amount = serializers.SerializerMethodField()
|
||||
|
||||
@@ -155,8 +170,16 @@ class TransactionSerializer(serializers.ModelSerializer):
|
||||
"installment_plan",
|
||||
"recurring_transaction",
|
||||
"installment_id",
|
||||
"owner",
|
||||
"deleted_at",
|
||||
"deleted",
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fields["account_id"].queryset = Account.objects.all()
|
||||
|
||||
def validate(self, data):
|
||||
if not self.partial:
|
||||
if "date" in data and "reference_date" not in data:
|
||||
@@ -192,5 +215,5 @@ class TransactionSerializer(serializers.ModelSerializer):
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def get_exchanged_amount(obj):
|
||||
def get_exchanged_amount(obj) -> Decimal:
|
||||
return obj.exchanged_amount()
|
||||
|
||||
124
app/apps/api/tests.py
Normal file
124
app/apps/api/tests.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework.test import APIClient
|
||||
from django.urls import reverse
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup # Added AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, Transaction
|
||||
from apps.rules.signals import transaction_created # Assuming this is the correct path
|
||||
|
||||
# Default page size for pagination, adjust if your project's default is different
|
||||
DEFAULT_PAGE_SIZE = 10
|
||||
|
||||
class APITestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username='testuser', email='test@example.com', password='testpassword')
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.currency = Currency.objects.create(code="USD", name="US Dollar Test API", decimal_places=2)
|
||||
# Account model requires an AccountGroup
|
||||
self.account_group = AccountGroup.objects.create(name="API Test Group", owner=self.user)
|
||||
self.account = Account.objects.create(
|
||||
name="Test API Account",
|
||||
currency=self.currency,
|
||||
owner=self.user,
|
||||
group=self.account_group
|
||||
)
|
||||
self.category = TransactionCategory.objects.create(
|
||||
name="Test API Category",
|
||||
owner=self.user,
|
||||
type=TransactionCategory.TransactionType.EXPENSE # Default type, can be adjusted
|
||||
)
|
||||
# Remove the example test if it's no longer needed or update it
|
||||
# self.assertEqual(1 + 1, 2) # from test_example
|
||||
|
||||
def test_transactions_endpoint_authenticated_user(self):
|
||||
# User and client are now set up in self.setUp
|
||||
url = reverse('api:transaction-list') # Using 'api:' namespace
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@patch('apps.rules.signals.transaction_created.send')
|
||||
def test_create_transaction_api_success(self, mock_signal_send):
|
||||
url = reverse('api:transaction-list')
|
||||
data = {
|
||||
'account': self.account.pk, # Changed from account_id to account to match typical DRF serializer field names
|
||||
'type': Transaction.Type.EXPENSE.value, # Use enum value
|
||||
'date': date(2023, 1, 15).isoformat(),
|
||||
'amount': '123.45',
|
||||
'description': 'API Test Expense',
|
||||
'category': self.category.pk,
|
||||
'tags': [],
|
||||
'entities': []
|
||||
}
|
||||
|
||||
initial_transaction_count = Transaction.objects.count()
|
||||
response = self.client.post(url, data, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, 201, response.data) # Print response.data on failure
|
||||
self.assertEqual(Transaction.objects.count(), initial_transaction_count + 1)
|
||||
|
||||
created_transaction = Transaction.objects.latest('id') # Get the latest transaction
|
||||
|
||||
self.assertEqual(created_transaction.description, 'API Test Expense')
|
||||
self.assertEqual(created_transaction.amount, Decimal('123.45'))
|
||||
self.assertEqual(created_transaction.owner, self.user)
|
||||
self.assertEqual(created_transaction.account, self.account)
|
||||
self.assertEqual(created_transaction.category, self.category)
|
||||
|
||||
mock_signal_send.assert_called_once()
|
||||
# Check sender argument of the signal call
|
||||
self.assertEqual(mock_signal_send.call_args.kwargs['sender'], Transaction)
|
||||
self.assertEqual(mock_signal_send.call_args.kwargs['instance'], created_transaction)
|
||||
|
||||
|
||||
def test_create_transaction_api_invalid_data(self):
|
||||
url = reverse('api:transaction-list')
|
||||
data = {
|
||||
'account': self.account.pk,
|
||||
'type': 'INVALID_TYPE', # Invalid type
|
||||
'date': date(2023, 1, 15).isoformat(),
|
||||
'amount': 'not_a_number', # Invalid amount
|
||||
'description': 'API Test Invalid Data',
|
||||
'category': self.category.pk
|
||||
}
|
||||
response = self.client.post(url, data, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn('type', response.data)
|
||||
self.assertIn('amount', response.data)
|
||||
|
||||
def test_transaction_list_pagination(self):
|
||||
# Create more transactions than page size (e.g., DEFAULT_PAGE_SIZE + 5)
|
||||
num_to_create = DEFAULT_PAGE_SIZE + 5
|
||||
for i in range(num_to_create):
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=date(2023, 1, 1) + timedelta(days=i),
|
||||
amount=Decimal(f"{10 + i}.00"),
|
||||
description=f"Pag Test Transaction {i+1}",
|
||||
owner=self.user,
|
||||
category=self.category
|
||||
)
|
||||
|
||||
url = reverse('api:transaction-list')
|
||||
response = self.client.get(url)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn('count', response.data)
|
||||
self.assertEqual(response.data['count'], num_to_create)
|
||||
|
||||
self.assertIn('next', response.data)
|
||||
self.assertIsNotNone(response.data['next']) # Assuming count > page size
|
||||
|
||||
self.assertIn('previous', response.data) # Will be None for the first page
|
||||
# self.assertIsNone(response.data['previous']) # For the first page
|
||||
|
||||
self.assertIn('results', response.data)
|
||||
self.assertEqual(len(response.data['results']), DEFAULT_PAGE_SIZE)
|
||||
@@ -1,4 +1,6 @@
|
||||
from rest_framework import viewsets
|
||||
|
||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||
from apps.accounts.models import AccountGroup, Account
|
||||
from apps.api.serializers import AccountGroupSerializer, AccountSerializer
|
||||
|
||||
@@ -6,12 +8,20 @@ from apps.api.serializers import AccountGroupSerializer, AccountSerializer
|
||||
class AccountGroupViewSet(viewsets.ModelViewSet):
|
||||
queryset = AccountGroup.objects.all()
|
||||
serializer_class = AccountGroupSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
|
||||
def get_queryset(self):
|
||||
return AccountGroup.objects.all().order_by("id")
|
||||
|
||||
|
||||
class AccountViewSet(viewsets.ModelViewSet):
|
||||
queryset = Account.objects.all()
|
||||
serializer_class = AccountSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
return queryset.select_related("group", "currency", "exchange_currency")
|
||||
return (
|
||||
Account.objects.all()
|
||||
.order_by("id")
|
||||
.select_related("group", "currency", "exchange_currency")
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from rest_framework import viewsets
|
||||
|
||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||
from apps.api.serializers import (
|
||||
TransactionSerializer,
|
||||
TransactionCategorySerializer,
|
||||
@@ -22,6 +23,7 @@ from apps.rules.signals import transaction_updated, transaction_created
|
||||
class TransactionViewSet(viewsets.ModelViewSet):
|
||||
queryset = Transaction.objects.all()
|
||||
serializer_class = TransactionSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
|
||||
def perform_create(self, serializer):
|
||||
instance = serializer.save()
|
||||
@@ -35,27 +37,50 @@ class TransactionViewSet(viewsets.ModelViewSet):
|
||||
kwargs["partial"] = True
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
return Transaction.objects.all().order_by("-id")
|
||||
|
||||
|
||||
class TransactionCategoryViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionCategory.objects.all()
|
||||
serializer_class = TransactionCategorySerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionCategory.objects.all().order_by("id")
|
||||
|
||||
|
||||
class TransactionTagViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionTag.objects.all()
|
||||
serializer_class = TransactionTagSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionTag.objects.all().order_by("id")
|
||||
|
||||
|
||||
class TransactionEntityViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionEntity.objects.all()
|
||||
serializer_class = TransactionEntitySerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionEntity.objects.all().order_by("id")
|
||||
|
||||
|
||||
class InstallmentPlanViewSet(viewsets.ModelViewSet):
|
||||
queryset = InstallmentPlan.objects.all()
|
||||
serializer_class = InstallmentPlanSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
|
||||
def get_queryset(self):
|
||||
return InstallmentPlan.objects.all().order_by("-id")
|
||||
|
||||
|
||||
class RecurringTransactionViewSet(viewsets.ModelViewSet):
|
||||
queryset = RecurringTransaction.objects.all()
|
||||
serializer_class = RecurringTransactionSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
|
||||
def get_queryset(self):
|
||||
return RecurringTransaction.objects.all().order_by("-id")
|
||||
|
||||
100
app/apps/calendar_view/tests.py
Normal file
100
app/apps/calendar_view/tests.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from django.test import TestCase, Client
|
||||
from django.contrib.auth.models import User
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone # Though specific dates are used, good for general test setup
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, Transaction
|
||||
# from apps.calendar_view.utils.calendar import get_transactions_by_day # Not directly testing this util here
|
||||
|
||||
class CalendarViewTests(TestCase): # Renamed from CalendarViewTestCase to CalendarViewTests
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username='testcalendaruser', password='password')
|
||||
self.client = Client()
|
||||
self.client.login(username='testcalendaruser', password='password')
|
||||
|
||||
self.currency_usd = Currency.objects.create(name="CV USD", code="CVUSD", decimal_places=2, prefix="$CV ")
|
||||
self.account_group = AccountGroup.objects.create(name="CV Group", owner=self.user)
|
||||
self.account_usd1 = Account.objects.create(
|
||||
name="CV Account USD 1",
|
||||
currency=self.currency_usd,
|
||||
owner=self.user,
|
||||
group=self.account_group
|
||||
)
|
||||
self.category_cv = TransactionCategory.objects.create(
|
||||
name="CV Cat",
|
||||
owner=self.user,
|
||||
type=TransactionCategory.TransactionType.INFO # Using INFO as a generic type
|
||||
)
|
||||
|
||||
# Transactions for specific dates
|
||||
self.t1 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_cv,
|
||||
date=date(2023, 3, 5), amount=Decimal("10.00"),
|
||||
type=Transaction.Type.EXPENSE, is_paid=True, description="March 5th Tx"
|
||||
)
|
||||
self.t2 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_cv,
|
||||
date=date(2023, 3, 10), amount=Decimal("20.00"),
|
||||
type=Transaction.Type.EXPENSE, is_paid=True, description="March 10th Tx"
|
||||
)
|
||||
self.t3 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_cv,
|
||||
date=date(2023, 4, 5), amount=Decimal("30.00"),
|
||||
type=Transaction.Type.EXPENSE, is_paid=True, description="April 5th Tx"
|
||||
)
|
||||
|
||||
def test_calendar_list_view_context_data(self):
|
||||
# Assumes 'calendar_view:calendar_list' is the correct URL name for the main calendar view
|
||||
# The previous test used 'calendar_view:calendar'. I'll assume 'calendar_list' is the new/correct one.
|
||||
# If the view that shows the grid is named 'calendar', this should be adjusted.
|
||||
# Based on subtask, this is for calendar_list view.
|
||||
url = reverse('calendar_view:calendar_list', kwargs={'month': 3, 'year': 2023})
|
||||
response = self.client.get(url)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn('dates', response.context)
|
||||
|
||||
dates_context = response.context['dates']
|
||||
|
||||
entry_mar5 = next((d for d in dates_context if d['date'] == date(2023, 3, 5)), None)
|
||||
self.assertIsNotNone(entry_mar5, "Date March 5th not found in context.")
|
||||
self.assertIn(self.t1, entry_mar5['transactions'], "Transaction t1 not in March 5th transactions.")
|
||||
|
||||
entry_mar10 = next((d for d in dates_context if d['date'] == date(2023, 3, 10)), None)
|
||||
self.assertIsNotNone(entry_mar10, "Date March 10th not found in context.")
|
||||
self.assertIn(self.t2, entry_mar10['transactions'], "Transaction t2 not in March 10th transactions.")
|
||||
|
||||
for day_data in dates_context:
|
||||
self.assertNotIn(self.t3, day_data['transactions'], f"Transaction t3 (April 5th) found in March {day_data['date']} transactions.")
|
||||
|
||||
def test_calendar_transactions_list_view_specific_day(self):
|
||||
url = reverse('calendar_view:calendar_transactions_list', kwargs={'day': 5, 'month': 3, 'year': 2023})
|
||||
response = self.client.get(url)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn('transactions', response.context)
|
||||
|
||||
transactions_context = response.context['transactions']
|
||||
|
||||
self.assertIn(self.t1, transactions_context, "Transaction t1 (March 5th) not found in context for specific day view.")
|
||||
self.assertNotIn(self.t2, transactions_context, "Transaction t2 (March 10th) found in context for March 5th.")
|
||||
self.assertNotIn(self.t3, transactions_context, "Transaction t3 (April 5th) found in context for March 5th.")
|
||||
self.assertEqual(len(transactions_context), 1)
|
||||
|
||||
def test_calendar_view_authenticated_user_generic_month(self):
|
||||
# This is similar to the old test_calendar_view_authenticated_user.
|
||||
# It tests general access to the main calendar view (which might be 'calendar_list' or 'calendar')
|
||||
# Let's use the 'calendar' name as it was in the old test, assuming it's the main monthly view.
|
||||
# If 'calendar_list' is the actual main monthly view, this might be slightly redundant
|
||||
# with the setup of test_calendar_list_view_context_data but still good for general access check.
|
||||
url = reverse('calendar_view:calendar', args=[2023, 1]) # e.g. Jan 2023
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
# Further context checks could be added here if this view has a different structure than 'calendar_list'
|
||||
self.assertIn('dates', response.context) # Assuming it also provides 'dates'
|
||||
self.assertIn('current_month_date', response.context)
|
||||
self.assertEqual(response.context['current_month_date'], date(2023,1,1))
|
||||
7
app/apps/common/admin.py
Normal file
7
app/apps/common/admin.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from django.contrib import admin
|
||||
|
||||
|
||||
class SharedObjectModelAdmin(admin.ModelAdmin):
|
||||
def get_queryset(self, request):
|
||||
# Use the all_objects manager to show all transactions, including deleted ones
|
||||
return self.model.all_objects.all()
|
||||
15
app/apps/common/decorators/demo.py
Normal file
15
app/apps/common/decorators/demo.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from functools import wraps
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import PermissionDenied
|
||||
|
||||
|
||||
def disabled_on_demo(view):
|
||||
@wraps(view)
|
||||
def _view(request, *args, **kwargs):
|
||||
if settings.DEMO and not request.user.is_superuser:
|
||||
raise PermissionDenied
|
||||
|
||||
return view(request, *args, **kwargs)
|
||||
|
||||
return _view
|
||||
78
app/apps/common/decorators/user.py
Normal file
78
app/apps/common/decorators/user.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from functools import wraps
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import HttpResponse
|
||||
from django.shortcuts import redirect
|
||||
from django.urls import reverse, NoReverseMatch
|
||||
|
||||
|
||||
def is_superuser(view):
|
||||
@wraps(view)
|
||||
def _view(request, *args, **kwargs):
|
||||
if not request.user.is_superuser:
|
||||
raise PermissionDenied
|
||||
|
||||
return view(request, *args, **kwargs)
|
||||
|
||||
return _view
|
||||
|
||||
|
||||
def htmx_login_required(function=None, login_url=None):
|
||||
"""
|
||||
Decorator that checks if the user is logged in.
|
||||
|
||||
Allows overriding the default login URL.
|
||||
|
||||
If the user is not logged in:
|
||||
- If "hx-request" is present in the request header, it returns a 200 response
|
||||
with a "HX-Redirect" header containing the determined login URL (including the "next" parameter).
|
||||
- If "hx-request" is not present, it redirects to the determined login page normally.
|
||||
|
||||
Args:
|
||||
function: The view function to decorate.
|
||||
login_url: Optional. The URL or URL name to redirect to for login.
|
||||
Defaults to settings.LOGIN_URL.
|
||||
"""
|
||||
|
||||
def decorator(view_func):
|
||||
# Simplified @wraps usage - it handles necessary attribute assignments by default
|
||||
@wraps(view_func)
|
||||
def wrapped_view(request, *args, **kwargs):
|
||||
if request.user.is_authenticated:
|
||||
return view_func(request, *args, **kwargs)
|
||||
else:
|
||||
# Determine the login URL
|
||||
resolved_login_url = login_url
|
||||
if not resolved_login_url:
|
||||
resolved_login_url = settings.LOGIN_URL
|
||||
|
||||
# Try to reverse the URL name if it's not a path
|
||||
try:
|
||||
# Check if it looks like a URL path already
|
||||
if "/" not in resolved_login_url and "." not in resolved_login_url:
|
||||
login_url_path = reverse(resolved_login_url)
|
||||
else:
|
||||
login_url_path = resolved_login_url
|
||||
except NoReverseMatch:
|
||||
# If reverse fails, assume it's already a URL path
|
||||
login_url_path = resolved_login_url
|
||||
|
||||
# Construct the full redirect path with 'next' parameter
|
||||
# Ensure request.path is URL-encoded if needed, though Django usually handles this
|
||||
redirect_path = f"{login_url_path}?next={request.get_full_path()}" # Use get_full_path() to include query params
|
||||
|
||||
if request.headers.get("hx-request"):
|
||||
# For HTMX requests, return a 200 with the HX-Redirect header.
|
||||
response = HttpResponse()
|
||||
response["HX-Redirect"] = login_url_path
|
||||
return response
|
||||
else:
|
||||
# For regular requests, redirect to the login page.
|
||||
return redirect(redirect_path)
|
||||
|
||||
return wrapped_view
|
||||
|
||||
if function:
|
||||
return decorator(function)
|
||||
return decorator
|
||||
@@ -4,6 +4,7 @@ from django.db import transaction
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.tom_select import TomSelect, TomSelectMultiple
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
|
||||
|
||||
class DynamicModelChoiceField(forms.ModelChoiceField):
|
||||
@@ -12,15 +13,14 @@ class DynamicModelChoiceField(forms.ModelChoiceField):
|
||||
self.to_field_name = kwargs.pop("to_field_name", "pk")
|
||||
|
||||
self.create_field = kwargs.pop("create_field", None)
|
||||
if not self.create_field:
|
||||
raise ValueError("The 'create_field' parameter is required.")
|
||||
|
||||
self.queryset = kwargs.pop("queryset", model.objects.all())
|
||||
super().__init__(queryset=self.queryset, *args, **kwargs)
|
||||
self._created_instance = None
|
||||
|
||||
self.widget = TomSelect(clear_button=True, create=True)
|
||||
|
||||
super().__init__(queryset=self.queryset, *args, **kwargs)
|
||||
self._created_instance = None
|
||||
|
||||
def to_python(self, value):
|
||||
if value in self.empty_values:
|
||||
return None
|
||||
@@ -53,17 +53,27 @@ class DynamicModelChoiceField(forms.ModelChoiceField):
|
||||
else:
|
||||
raise self.model.DoesNotExist
|
||||
except self.model.DoesNotExist:
|
||||
try:
|
||||
with transaction.atomic():
|
||||
instance, _ = self.model.objects.update_or_create(
|
||||
**{self.create_field: value}
|
||||
)
|
||||
self._created_instance = instance
|
||||
return instance
|
||||
except Exception as e:
|
||||
if self.create_field:
|
||||
try:
|
||||
with transaction.atomic():
|
||||
# First try to get the object
|
||||
lookup = {self.create_field: value}
|
||||
try:
|
||||
instance = self.model.objects.get(**lookup)
|
||||
except self.model.DoesNotExist:
|
||||
# Create a new instance directly
|
||||
instance = self.model(**lookup)
|
||||
instance.save()
|
||||
|
||||
self._created_instance = instance
|
||||
return instance
|
||||
except Exception as e:
|
||||
raise ValidationError(_("Error creating new instance"))
|
||||
else:
|
||||
raise ValidationError(
|
||||
self.error_messages["invalid_choice"], code="invalid_choice"
|
||||
)
|
||||
|
||||
return super().clean(value)
|
||||
|
||||
def bound_data(self, data, initial):
|
||||
@@ -86,8 +96,6 @@ class DynamicModelMultipleChoiceField(forms.ModelMultipleChoiceField):
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
"""
|
||||
Initialize the CreateIfNotExistsModelMultipleChoiceField.
|
||||
|
||||
Args:
|
||||
create_field (str): The name of the field to use when creating new instances.
|
||||
*args: Variable length argument list.
|
||||
@@ -119,33 +127,28 @@ class DynamicModelMultipleChoiceField(forms.ModelMultipleChoiceField):
|
||||
"""
|
||||
try:
|
||||
with transaction.atomic():
|
||||
instance, _ = self.model.objects.update_or_create(
|
||||
**{self.create_field: value}
|
||||
)
|
||||
return instance
|
||||
# Check if exists first without using update_or_create
|
||||
lookup = {self.create_field: value}
|
||||
try:
|
||||
# Use base manager to bypass distinct filters
|
||||
instance = self.model.objects.get(**lookup)
|
||||
return instance
|
||||
except self.model.DoesNotExist:
|
||||
# Create a new instance directly
|
||||
instance = self.model(**lookup)
|
||||
instance.save()
|
||||
return instance
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise ValidationError(_("Error creating new instance"))
|
||||
|
||||
def clean(self, value):
|
||||
"""
|
||||
Clean and validate the field value.
|
||||
|
||||
This method checks if each selected choice exists in the database.
|
||||
If a choice doesn't exist, it creates a new instance of the model.
|
||||
|
||||
Args:
|
||||
value (list): List of selected values.
|
||||
|
||||
Returns:
|
||||
list: A list containing all selected and newly created model instances.
|
||||
|
||||
Raises:
|
||||
ValidationError: If there's an error during the cleaning process.
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
string_values = set(str(v) for v in value)
|
||||
|
||||
# Get existing objects first
|
||||
existing_objects = list(
|
||||
self.queryset.filter(**{f"{self.create_field}__in": string_values})
|
||||
)
|
||||
@@ -153,13 +156,11 @@ class DynamicModelMultipleChoiceField(forms.ModelMultipleChoiceField):
|
||||
str(getattr(obj, self.create_field)) for obj in existing_objects
|
||||
)
|
||||
|
||||
# Create new objects for missing values
|
||||
new_values = string_values - existing_values
|
||||
new_objects = []
|
||||
|
||||
for new_value in new_values:
|
||||
try:
|
||||
new_objects.append(self._create_new_instance(new_value))
|
||||
except ValidationError as e:
|
||||
raise ValidationError(_("Error creating new instance"))
|
||||
new_objects.append(self._create_new_instance(new_value))
|
||||
|
||||
return existing_objects + new_objects
|
||||
|
||||
@@ -20,7 +20,15 @@ class MonthYearModelField(models.DateField):
|
||||
# Set the day to 1
|
||||
return date.replace(day=1).date()
|
||||
except ValueError:
|
||||
raise ValidationError(_("Invalid date format. Use YYYY-MM."))
|
||||
try:
|
||||
# Also accept YYYY-MM-DD format (for loaddata)
|
||||
return (
|
||||
datetime.datetime.strptime(value, "%Y-%m-%d").replace(day=1).date()
|
||||
)
|
||||
except ValueError:
|
||||
raise ValidationError(
|
||||
_("Invalid date format. Use YYYY-MM or YYYY-MM-DD.")
|
||||
)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
kwargs["widget"] = MonthYearWidget
|
||||
|
||||
113
app/apps/common/forms.py
Normal file
113
app/apps/common/forms.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from django import forms
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.core.exceptions import ValidationError
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Submit, Div, HTML
|
||||
|
||||
from apps.common.widgets.tom_select import TomSelect, TomSelectMultiple
|
||||
from apps.common.models import SharedObject
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class SharedObjectForm(forms.Form):
|
||||
"""
|
||||
Generic form for editing visibility and sharing settings
|
||||
for models inheriting from SharedObject.
|
||||
"""
|
||||
|
||||
owner = forms.ModelChoiceField(
|
||||
queryset=User.objects.all(),
|
||||
required=False,
|
||||
label=_("Owner"),
|
||||
widget=TomSelect(clear_button=False),
|
||||
help_text=_(
|
||||
"The owner of this object, if empty all users can see, edit and take ownership."
|
||||
),
|
||||
)
|
||||
shared_with_users = forms.ModelMultipleChoiceField(
|
||||
queryset=User.objects.all(),
|
||||
required=False,
|
||||
widget=TomSelectMultiple(clear_button=True),
|
||||
label=_("Shared with users"),
|
||||
help_text=_("Select users to share this object with"),
|
||||
)
|
||||
visibility = forms.ChoiceField(
|
||||
choices=SharedObject.Visibility.choices,
|
||||
required=True,
|
||||
label=_("Visibility"),
|
||||
help_text=_(
|
||||
"Private: Only shown for the owner and shared users. Only editable by the owner."
|
||||
"<br/>"
|
||||
"Public: Shown for all users. Only editable by the owner."
|
||||
),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
fields = ["visibility", "shared_with_users"]
|
||||
widgets = {
|
||||
"visibility": TomSelect(clear_button=False),
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get the current user to filter available sharing options
|
||||
self.user = kwargs.pop("user", None)
|
||||
self.instance = kwargs.pop("instance", None)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Pre-populate shared users if instance exists
|
||||
if self.instance:
|
||||
self.fields["shared_with_users"].initial = self.instance.shared_with.all()
|
||||
self.fields["visibility"].initial = self.instance.visibility
|
||||
self.fields["owner"].initial = self.instance.owner
|
||||
|
||||
# Set up crispy form helper
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_method = "post"
|
||||
self.helper.form_tag = False
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Field("owner"),
|
||||
Field("visibility"),
|
||||
HTML("<hr>"),
|
||||
Field("shared_with_users"),
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Save"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
owner = cleaned_data.get("owner")
|
||||
shared_with_users = cleaned_data.get("shared_with_users", [])
|
||||
|
||||
# Raise validation error if owner is in shared_with_users
|
||||
if owner and owner in shared_with_users:
|
||||
self.add_error(
|
||||
"shared_with_users",
|
||||
ValidationError(
|
||||
_("You cannot share this item with its owner."),
|
||||
code="invalid_share",
|
||||
),
|
||||
)
|
||||
|
||||
return cleaned_data
|
||||
|
||||
def save(self):
|
||||
instance = self.instance
|
||||
|
||||
instance.visibility = self.cleaned_data["visibility"]
|
||||
instance.owner = self.cleaned_data["owner"]
|
||||
|
||||
instance.save()
|
||||
|
||||
# Clear and set shared_with users
|
||||
instance.shared_with.set(self.cleaned_data.get("shared_with_users", []))
|
||||
|
||||
return instance
|
||||
0
app/apps/common/management/__init__.py
Normal file
0
app/apps/common/management/__init__.py
Normal file
0
app/apps/common/management/commands/__init__.py
Normal file
0
app/apps/common/management/commands/__init__.py
Normal file
137
app/apps/common/management/commands/setup_users.py
Normal file
137
app/apps/common/management/commands/setup_users.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.conf import settings
|
||||
from django.db import IntegrityError
|
||||
|
||||
# Get the custom User model if defined, otherwise the default User model
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = (
|
||||
"Creates a superuser from environment variables (ADMIN_EMAIL, ADMIN_PASSWORD) "
|
||||
"and optionally creates a demo user (demo@demo.com) if settings.DEMO is True."
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.stdout.write("Starting user setup...")
|
||||
|
||||
# --- Create Superuser ---
|
||||
admin_email = os.environ.get("ADMIN_EMAIL")
|
||||
admin_password = os.environ.get("ADMIN_PASSWORD")
|
||||
|
||||
if admin_email and admin_password:
|
||||
self.stdout.write(f"Attempting to create superuser: {admin_email}")
|
||||
# Use email as username for simplicity, requires USERNAME_FIELD='email'
|
||||
# or adapt if your USERNAME_FIELD is different.
|
||||
# If USERNAME_FIELD is 'username', you might need ADMIN_USERNAME env var.
|
||||
username_field = User.USERNAME_FIELD # Get the actual username field name
|
||||
|
||||
# Check if the user already exists by email or username
|
||||
user_exists_kwargs = {"email": admin_email}
|
||||
if username_field != "email":
|
||||
# Assume username should also be the email if not explicitly provided
|
||||
user_exists_kwargs[username_field] = admin_email
|
||||
|
||||
if User.objects.filter(**user_exists_kwargs).exists():
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
f"Superuser with email '{admin_email}' (or corresponding username) already exists. Skipping creation."
|
||||
)
|
||||
)
|
||||
else:
|
||||
try:
|
||||
create_kwargs = {
|
||||
username_field: admin_email, # Use email as username by default
|
||||
"email": admin_email,
|
||||
"password": admin_password,
|
||||
}
|
||||
User.objects.create_superuser(**create_kwargs)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"Superuser '{admin_email}' created successfully."
|
||||
)
|
||||
)
|
||||
except IntegrityError as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(
|
||||
f"Failed to create superuser '{admin_email}'. IntegrityError: {e}"
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(
|
||||
f"An unexpected error occurred creating superuser '{admin_email}': {e}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.NOTICE(
|
||||
"ADMIN_EMAIL or ADMIN_PASSWORD environment variables not set. Skipping superuser creation."
|
||||
)
|
||||
)
|
||||
|
||||
self.stdout.write("---") # Separator
|
||||
|
||||
# --- Create Demo User ---
|
||||
# Use getattr to safely check for the DEMO setting, default to False if not present
|
||||
create_demo_user = getattr(settings, "DEMO", False)
|
||||
|
||||
if create_demo_user:
|
||||
demo_email = "demo@demo.com"
|
||||
demo_password = (
|
||||
"wygiwyhdemo" # Consider making this an env var too for security
|
||||
)
|
||||
demo_username = demo_email # Using email as username for consistency
|
||||
|
||||
self.stdout.write(
|
||||
f"DEMO setting is True. Attempting to create demo user: {demo_email}"
|
||||
)
|
||||
|
||||
username_field = User.USERNAME_FIELD # Get the actual username field name
|
||||
|
||||
# Check if the user already exists by email or username
|
||||
user_exists_kwargs = {"email": demo_email}
|
||||
if username_field != "email":
|
||||
user_exists_kwargs[username_field] = demo_username
|
||||
|
||||
if User.objects.filter(**user_exists_kwargs).exists():
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
f"Demo user with email '{demo_email}' (or corresponding username) already exists. Skipping creation."
|
||||
)
|
||||
)
|
||||
else:
|
||||
try:
|
||||
create_kwargs = {
|
||||
username_field: demo_username,
|
||||
"email": demo_email,
|
||||
"password": demo_password,
|
||||
}
|
||||
User.objects.create_user(**create_kwargs)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"Demo user '{demo_email}' created successfully."
|
||||
)
|
||||
)
|
||||
except IntegrityError as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(
|
||||
f"Failed to create demo user '{demo_email}'. IntegrityError: {e}"
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(
|
||||
f"An unexpected error occurred creating demo user '{demo_email}': {e}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.NOTICE(
|
||||
"DEMO setting is not True (or not set). Skipping demo user creation."
|
||||
)
|
||||
)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS("User setup command finished."))
|
||||
@@ -56,6 +56,16 @@ def get_current_user():
|
||||
if request:
|
||||
return getattr(request, "user", None)
|
||||
|
||||
return getattr(_thread_locals, "user", None)
|
||||
|
||||
|
||||
def write_current_user(user):
|
||||
_thread_locals.user = user
|
||||
|
||||
|
||||
def delete_current_user():
|
||||
del _thread_locals.user
|
||||
|
||||
|
||||
class ThreadLocalMiddleware(MiddlewareMixin):
|
||||
"""Simple middleware that adds the request object in thread local storage."""
|
||||
|
||||
84
app/apps/common/models.py
Normal file
84
app/apps/common/models.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from django.db import models
|
||||
from django.conf import settings
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
|
||||
|
||||
class SharedObjectManager(models.Manager):
|
||||
def get_queryset(self):
|
||||
"""Return only objects the user can access"""
|
||||
user = get_current_user()
|
||||
base_qs = super().get_queryset()
|
||||
|
||||
if user and user.is_authenticated:
|
||||
return base_qs.filter(
|
||||
Q(visibility="public")
|
||||
| Q(owner=user)
|
||||
| Q(shared_with=user)
|
||||
| Q(visibility="private", owner=None)
|
||||
).distinct()
|
||||
|
||||
return base_qs.filter(visibility="public")
|
||||
|
||||
|
||||
class SharedObject(models.Model):
|
||||
# Access control enum
|
||||
class Visibility(models.TextChoices):
|
||||
private = "private", _("Private")
|
||||
is_paid = "public", _("Public")
|
||||
|
||||
# Core sharing fields
|
||||
owner = models.ForeignKey(
|
||||
settings.AUTH_USER_MODEL,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="%(class)s_owned",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
visibility = models.CharField(
|
||||
max_length=10, choices=Visibility.choices, default=Visibility.private
|
||||
)
|
||||
shared_with = models.ManyToManyField(
|
||||
settings.AUTH_USER_MODEL, related_name="%(class)s_shared", blank=True
|
||||
)
|
||||
|
||||
# Use as abstract base class
|
||||
class Meta:
|
||||
abstract = True
|
||||
indexes = [
|
||||
models.Index(fields=["visibility"]),
|
||||
]
|
||||
|
||||
def is_accessible_by(self, user):
|
||||
"""Check if a user can access this object"""
|
||||
return (
|
||||
self.visibility == "public"
|
||||
or self.owner == user
|
||||
or (self.visibility == "shared" and user in self.shared_with.all())
|
||||
)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.pk and not self.owner:
|
||||
self.owner = get_current_user()
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class OwnedObject(models.Model):
|
||||
owner = models.ForeignKey(
|
||||
settings.AUTH_USER_MODEL,
|
||||
on_delete=models.CASCADE,
|
||||
related_name="%(class)s_owned",
|
||||
null=True,
|
||||
blank=True,
|
||||
)
|
||||
|
||||
# Use as abstract base class
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.pk and not self.owner:
|
||||
self.owner = get_current_user()
|
||||
super().save(*args, **kwargs)
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
from django.core import management
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
|
||||
from procrastinate import builtin_tasks
|
||||
from procrastinate.contrib.django import app
|
||||
@@ -40,3 +42,40 @@ async def remove_expired_sessions(timestamp=None):
|
||||
"Error while executing 'remove_expired_sessions' task",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
@app.periodic(cron="0 8 * * *")
|
||||
@app.task(name="reset_demo_data")
|
||||
def reset_demo_data(timestamp=None):
|
||||
"""
|
||||
Wipes the database and loads fresh demo data if DEMO mode is active.
|
||||
Runs daily at 8:00 AM.
|
||||
"""
|
||||
if not settings.DEMO:
|
||||
return # Exit if not in demo mode
|
||||
|
||||
logger.info("Demo mode active. Starting daily data reset...")
|
||||
|
||||
try:
|
||||
# 1. Flush the database (wipe all data)
|
||||
logger.info("Flushing the database...")
|
||||
|
||||
management.call_command(
|
||||
"flush", "--noinput", database=DEFAULT_DB_ALIAS, verbosity=1
|
||||
)
|
||||
logger.info("Database flushed successfully.")
|
||||
|
||||
# 2. Load data from the fixture
|
||||
# TO-DO: Roll dates over based on today's date
|
||||
fixture_name = "fixtures/demo_data.json"
|
||||
logger.info(f"Loading data from fixture: {fixture_name}...")
|
||||
management.call_command(
|
||||
"loaddata", fixture_name, database=DEFAULT_DB_ALIAS, verbosity=1
|
||||
)
|
||||
logger.info(f"Data loaded successfully from {fixture_name}.")
|
||||
|
||||
logger.info("Daily demo data reset completed.")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during daily demo data reset: {e}")
|
||||
raise
|
||||
|
||||
183
app/apps/common/tests.py
Normal file
183
app/apps/common/tests.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from django.test import TestCase, RequestFactory
|
||||
from django.template import Template, Context
|
||||
from django.urls import reverse, resolve, NoReverseMatch
|
||||
from django.contrib.auth.models import User
|
||||
from decimal import Decimal # Keep existing imports if they are from other tests
|
||||
from app.apps.common.functions.decimals import truncate_decimal # Keep existing imports
|
||||
|
||||
# Helper to create a dummy request with resolver_match
|
||||
def setup_request_for_view(factory, view_name_or_url, user=None, namespace=None, view_name_for_resolver=None):
|
||||
try:
|
||||
url = reverse(view_name_or_url)
|
||||
except NoReverseMatch:
|
||||
url = view_name_or_url # Assume it's already a URL path
|
||||
|
||||
request = factory.get(url)
|
||||
if user:
|
||||
request.user = user
|
||||
|
||||
try:
|
||||
# For resolver_match, we need to simulate how Django does it.
|
||||
# It needs specific view_name and namespace if applicable.
|
||||
# If view_name_for_resolver is provided, use that for resolving,
|
||||
# otherwise, assume view_name_or_url is the view name for resolver_match.
|
||||
resolver_match_source = view_name_for_resolver if view_name_for_resolver else view_name_or_url
|
||||
|
||||
# If it's a namespaced view name like 'app:view', resolve might handle it directly.
|
||||
# If namespace is separately provided, it means the view_name itself is not namespaced.
|
||||
resolved_match = resolve(url) # Resolve the URL to get func, args, kwargs, etc.
|
||||
|
||||
# Ensure resolver_match has the correct attributes, especially 'view_name' and 'namespace'
|
||||
if hasattr(resolved_match, 'view_name'):
|
||||
if ':' in resolved_match.view_name and not namespace: # e.g. 'app_name:view_name'
|
||||
request.resolver_match = resolved_match
|
||||
elif namespace and resolved_match.namespace == namespace and resolved_match.url_name == resolver_match_source.split(':')[-1]:
|
||||
request.resolver_match = resolved_match
|
||||
elif not namespace and resolved_match.url_name == resolver_match_source:
|
||||
request.resolver_match = resolved_match
|
||||
else: # Fallback or if specific view_name/namespace parts are needed for resolver_match
|
||||
# This part is tricky without knowing the exact structure of resolver_match expected by the tag
|
||||
# Forcing the view_name and namespace if they are explicitly passed.
|
||||
if namespace:
|
||||
resolved_match.namespace = namespace
|
||||
if view_name_for_resolver: # This should be the non-namespaced view name part
|
||||
resolved_match.view_name = f"{namespace}:{view_name_for_resolver.split(':')[-1]}" if namespace else view_name_for_resolver.split(':')[-1]
|
||||
resolved_match.url_name = view_name_for_resolver.split(':')[-1]
|
||||
|
||||
request.resolver_match = resolved_match
|
||||
|
||||
else: # Fallback if resolve() doesn't directly give a full resolver_match object as expected
|
||||
request.resolver_match = None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not resolve URL or set resolver_match for '{view_name_or_url}' (or '{view_name_for_resolver}') for test setup: {e}")
|
||||
request.resolver_match = None
|
||||
return request
|
||||
|
||||
class CommonTestCase(TestCase): # Keep existing test class if other tests depend on it
|
||||
def test_example(self): # Example of an old test
|
||||
self.assertEqual(1 + 1, 2)
|
||||
|
||||
def test_truncate_decimal_function(self): # Example of an old test from problem description
|
||||
test_cases = [
|
||||
(Decimal('123.456'), 0, Decimal('123')),
|
||||
(Decimal('123.456'), 1, Decimal('123.4')),
|
||||
(Decimal('123.456'), 2, Decimal('123.45')),
|
||||
]
|
||||
for value, places, expected in test_cases:
|
||||
with self.subTest(value=value, places=places, expected=expected):
|
||||
self.assertEqual(truncate_decimal(value, places), expected)
|
||||
|
||||
|
||||
class CommonTemplateTagsTests(TestCase):
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user('testuser', 'password123')
|
||||
|
||||
# Using view names that should exist in a typical Django project with auth
|
||||
# Ensure these URLs are part of your project's urlpatterns for tests to pass.
|
||||
self.view_name_login = 'login' # Typically 'login' or 'account_login'
|
||||
self.namespace_login = None # Often no namespace for basic auth views, or 'account'
|
||||
|
||||
self.view_name_admin = 'admin:index' # Admin index
|
||||
self.namespace_admin = 'admin'
|
||||
|
||||
# Check if these can be reversed, skip tests if not.
|
||||
try:
|
||||
reverse(self.view_name_login)
|
||||
except NoReverseMatch:
|
||||
self.view_name_login = None # Mark as unusable
|
||||
print(f"Warning: Could not reverse '{self.view_name_login}'. Some active_link tests might be skipped.")
|
||||
try:
|
||||
reverse(self.view_name_admin)
|
||||
except NoReverseMatch:
|
||||
self.view_name_admin = None # Mark as unusable
|
||||
print(f"Warning: Could not reverse '{self.view_name_admin}'. Some active_link tests might be skipped.")
|
||||
|
||||
def test_active_link_view_match(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='" + self.view_name_login + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "active")
|
||||
|
||||
def test_active_link_view_no_match(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='non_existent_view_name' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "")
|
||||
|
||||
def test_active_link_view_match_custom_class(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='" + self.view_name_login + "' css_class='custom-active' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "custom-active")
|
||||
|
||||
def test_active_link_view_no_match_inactive_class(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='non_existent_view_name' inactive_class='custom-inactive' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "custom-inactive")
|
||||
|
||||
def test_active_link_namespace_match(self):
|
||||
if not self.view_name_admin: self.skipTest("Admin URL not reversible.")
|
||||
# The view_name_admin is already namespaced 'admin:index'
|
||||
request = setup_request_for_view(self.factory, self.view_name_admin, self.user,
|
||||
namespace=self.namespace_admin, view_name_for_resolver=self.view_name_admin)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_admin}.")
|
||||
# Ensure the resolver_match has the namespace set correctly by setup_request_for_view
|
||||
self.assertEqual(request.resolver_match.namespace, self.namespace_admin, "Namespace not correctly set in resolver_match for test.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link namespaces='" + self.namespace_admin + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "active")
|
||||
|
||||
def test_active_link_multiple_views_one_match(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='other_app:other_view||" + self.view_name_login + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "active")
|
||||
|
||||
def test_active_link_no_request_in_context(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible for placeholder view name.")
|
||||
template_str = "{% load active_link %} {% active_link views='" + self.view_name_login + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({})) # Empty context, no 'request'
|
||||
self.assertEqual(rendered.strip(), "")
|
||||
|
||||
def test_active_link_request_without_resolver_match(self):
|
||||
request = self.factory.get('/some_unresolved_url/') # This URL won't resolve
|
||||
request.user = self.user
|
||||
request.resolver_match = None # Explicitly set to None, as resolve() would fail
|
||||
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible for placeholder view name.")
|
||||
template_str = "{% load active_link %} {% active_link views='" + self.view_name_login + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "")
|
||||
@@ -15,10 +15,11 @@ from cachalot.api import invalidate
|
||||
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
from apps.transactions.models import Transaction
|
||||
from apps.common.decorators.user import htmx_login_required
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@htmx_login_required
|
||||
@require_http_methods(["GET"])
|
||||
def toasts(request):
|
||||
return render(request, "common/fragments/toasts.html")
|
||||
|
||||
@@ -19,6 +19,8 @@ class AirDatePickerInput(widgets.DateInput):
|
||||
format=None,
|
||||
clear_button=True,
|
||||
auto_close=True,
|
||||
read_only=True,
|
||||
toggle_selected=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -26,6 +28,10 @@ class AirDatePickerInput(widgets.DateInput):
|
||||
super().__init__(attrs=attrs, format=format, *args, **kwargs)
|
||||
self.clear_button = clear_button
|
||||
self.auto_close = auto_close
|
||||
self.read_only = read_only
|
||||
self.toggle_selected = (
|
||||
toggle_selected if isinstance(toggle_selected, bool) else self.clear_button
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_current_language():
|
||||
@@ -47,9 +53,13 @@ class AirDatePickerInput(widgets.DateInput):
|
||||
attrs["data-now-button-txt"] = _("Today")
|
||||
attrs["data-auto-close"] = str(self.auto_close).lower()
|
||||
attrs["data-clear-button"] = str(self.clear_button).lower()
|
||||
attrs["data-toggle-selected"] = str(self.toggle_selected).lower()
|
||||
attrs["data-language"] = self._get_current_language()
|
||||
attrs["data-date-format"] = django_to_airdatepicker_datetime(self._get_format())
|
||||
|
||||
if self.read_only:
|
||||
attrs["readonly"] = True
|
||||
|
||||
return attrs
|
||||
|
||||
def format_value(self, value):
|
||||
@@ -89,6 +99,8 @@ class AirDateTimePickerInput(widgets.DateTimeInput):
|
||||
timepicker=True,
|
||||
clear_button=True,
|
||||
auto_close=True,
|
||||
read_only=True,
|
||||
toggle_selected=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -97,6 +109,10 @@ class AirDateTimePickerInput(widgets.DateTimeInput):
|
||||
self.timepicker = timepicker
|
||||
self.clear_button = clear_button
|
||||
self.auto_close = auto_close
|
||||
self.read_only = read_only
|
||||
self.toggle_selected = (
|
||||
toggle_selected if isinstance(toggle_selected, bool) else self.clear_button
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_current_language():
|
||||
@@ -123,11 +139,15 @@ class AirDateTimePickerInput(widgets.DateTimeInput):
|
||||
attrs["data-now-button-txt"] = _("Now")
|
||||
attrs["data-timepicker"] = str(self.timepicker).lower()
|
||||
attrs["data-auto-close"] = str(self.auto_close).lower()
|
||||
attrs["data-toggle-selected"] = str(self.toggle_selected).lower()
|
||||
attrs["data-clear-button"] = str(self.clear_button).lower()
|
||||
attrs["data-language"] = self._get_current_language()
|
||||
attrs["data-date-format"] = date_format
|
||||
attrs["data-time-format"] = time_format
|
||||
|
||||
if self.read_only:
|
||||
attrs["readonly"] = True
|
||||
|
||||
return attrs
|
||||
|
||||
def format_value(self, value):
|
||||
@@ -227,3 +247,56 @@ class AirMonthYearPickerInput(AirDatePickerInput):
|
||||
except (ValueError, KeyError):
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class AirYearPickerInput(AirDatePickerInput):
|
||||
def __init__(self, attrs=None, format=None, *args, **kwargs):
|
||||
super().__init__(attrs=attrs, format=format, *args, **kwargs)
|
||||
# Store the display format for AirDatepicker
|
||||
self.display_format = "yyyy"
|
||||
# Store the Python format for internal use
|
||||
self.python_format = "%Y"
|
||||
|
||||
def build_attrs(self, base_attrs, extra_attrs=None):
|
||||
attrs = super().build_attrs(base_attrs, extra_attrs)
|
||||
|
||||
# Add data attributes for AirDatepicker configuration
|
||||
attrs["data-now-button-txt"] = _("Today")
|
||||
attrs["data-date-format"] = "yyyy"
|
||||
|
||||
return attrs
|
||||
|
||||
def format_value(self, value):
|
||||
"""Format the value for display in the widget."""
|
||||
if value:
|
||||
self.attrs["data-value"] = (
|
||||
value # We use this to dynamically select the initial date on AirDatePicker
|
||||
)
|
||||
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
value = datetime.datetime.strptime(value, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
return value
|
||||
if isinstance(value, (datetime.datetime, datetime.date)):
|
||||
# Use Django's date translation
|
||||
return f"{value.year}"
|
||||
return value
|
||||
|
||||
def value_from_datadict(self, data, files, name):
|
||||
"""Convert the value from the widget format back to a format Django can handle."""
|
||||
value = super().value_from_datadict(data, files, name)
|
||||
if value:
|
||||
try:
|
||||
# Split the value into month name and year
|
||||
year_str = value
|
||||
year = int(year_str)
|
||||
|
||||
if year:
|
||||
# Return the first day of the month in Django's expected format
|
||||
return datetime.date(year, 1, 1).strftime("%Y-%m-%d")
|
||||
except (ValueError, KeyError):
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from django.forms import widgets, SelectMultiple
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
@@ -129,3 +130,28 @@ class TomSelect(widgets.Select):
|
||||
|
||||
class TomSelectMultiple(SelectMultiple, TomSelect):
|
||||
pass
|
||||
|
||||
|
||||
class TransactionSelect(TomSelect):
|
||||
def __init__(self, income: bool = True, expense: bool = True, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.load_income = income
|
||||
self.load_expense = expense
|
||||
self.create = False
|
||||
|
||||
def build_attrs(self, base_attrs, extra_attrs=None):
|
||||
attrs = super().build_attrs(base_attrs, extra_attrs)
|
||||
|
||||
if self.load_income and self.load_expense:
|
||||
attrs["data-load"] = reverse("transactions_search")
|
||||
elif self.load_income and not self.load_expense:
|
||||
attrs["data-load"] = reverse(
|
||||
"transactions_search", kwargs={"filter_type": "income"}
|
||||
)
|
||||
elif self.load_expense and not self.load_income:
|
||||
attrs["data-load"] = reverse(
|
||||
"transactions_search", kwargs={"filter_type": "expenses"}
|
||||
)
|
||||
|
||||
return attrs
|
||||
|
||||
@@ -6,8 +6,10 @@ from django.utils import timezone
|
||||
|
||||
from apps.currencies.exchange_rates.providers import (
|
||||
SynthFinanceProvider,
|
||||
SynthFinanceStockProvider,
|
||||
CoinGeckoFreeProvider,
|
||||
CoinGeckoProProvider,
|
||||
TransitiveRateProvider,
|
||||
)
|
||||
from apps.currencies.models import ExchangeRateService, ExchangeRate, Currency
|
||||
|
||||
@@ -17,8 +19,10 @@ logger = logging.getLogger(__name__)
|
||||
# Map service types to provider classes
|
||||
PROVIDER_MAPPING = {
|
||||
"synth_finance": SynthFinanceProvider,
|
||||
"synth_finance_stock": SynthFinanceStockProvider,
|
||||
"coingecko_free": CoinGeckoFreeProvider,
|
||||
"coingecko_pro": CoinGeckoProProvider,
|
||||
"transitive": TransitiveRateProvider,
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +35,7 @@ class ExchangeRateFetcher:
|
||||
service.fetch_interval
|
||||
)
|
||||
should_fetch = current_hour not in blocked_hours
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"NOT_ON check for {service.name}: "
|
||||
f"current_hour={current_hour}, "
|
||||
f"blocked_hours={blocked_hours}, "
|
||||
@@ -43,18 +47,35 @@ class ExchangeRateFetcher:
|
||||
allowed_hours = ExchangeRateService._parse_hour_ranges(
|
||||
service.fetch_interval
|
||||
)
|
||||
return current_hour in allowed_hours
|
||||
|
||||
should_fetch = current_hour in allowed_hours
|
||||
|
||||
logger.info(
|
||||
f"ON check for {service.name}: "
|
||||
f"current_hour={current_hour}, "
|
||||
f"allowed_hours={allowed_hours}, "
|
||||
f"should_fetch={should_fetch}"
|
||||
)
|
||||
|
||||
return should_fetch
|
||||
|
||||
if service.interval_type == ExchangeRateService.IntervalType.EVERY:
|
||||
try:
|
||||
interval_hours = int(service.fetch_interval)
|
||||
|
||||
if service.last_fetch is None:
|
||||
return True
|
||||
hours_since_last = (
|
||||
timezone.now() - service.last_fetch
|
||||
).total_seconds() / 3600
|
||||
|
||||
# Round down to nearest hour
|
||||
now = timezone.now().replace(minute=0, second=0, microsecond=0)
|
||||
last_fetch = service.last_fetch.replace(
|
||||
minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
hours_since_last = (now - last_fetch).total_seconds() / 3600
|
||||
should_fetch = hours_since_last >= interval_hours
|
||||
logger.debug(
|
||||
|
||||
logger.info(
|
||||
f"EVERY check for {service.name}: "
|
||||
f"hours_since_last={hours_since_last:.1f}, "
|
||||
f"interval={interval_hours}, "
|
||||
|
||||
@@ -3,11 +3,11 @@ import time
|
||||
|
||||
import requests
|
||||
from decimal import Decimal
|
||||
from typing import Tuple, List
|
||||
from typing import Tuple, List, Optional, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from apps.currencies.models import Currency
|
||||
from apps.currencies.models import Currency, ExchangeRate
|
||||
from apps.currencies.exchange_rates.base import ExchangeRateProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -150,3 +150,159 @@ class CoinGeckoProProvider(CoinGeckoFreeProvider):
|
||||
super().__init__(api_key)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"x-cg-pro-api-key": api_key})
|
||||
|
||||
|
||||
class SynthFinanceStockProvider(ExchangeRateProvider):
|
||||
"""Implementation for Synth Finance API Real-Time Prices endpoint (synthfinance.com)"""
|
||||
|
||||
BASE_URL = "https://api.synthfinance.com/tickers"
|
||||
rates_inverted = True
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
super().__init__(api_key)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(
|
||||
{"Authorization": f"Bearer {self.api_key}", "accept": "application/json"}
|
||||
)
|
||||
|
||||
def get_rates(
|
||||
self, target_currencies: QuerySet, exchange_currencies: set
|
||||
) -> List[Tuple[Currency, Currency, Decimal]]:
|
||||
results = []
|
||||
|
||||
for currency in target_currencies:
|
||||
if currency.exchange_currency not in exchange_currencies:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Same currency has rate of 1
|
||||
if currency.code == currency.exchange_currency.code:
|
||||
rate = Decimal("1")
|
||||
results.append((currency.exchange_currency, currency, rate))
|
||||
continue
|
||||
|
||||
# Fetch real-time price for this ticker
|
||||
response = self.session.get(
|
||||
f"{self.BASE_URL}/{currency.code}/real-time"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Use fair market value as the rate
|
||||
rate = Decimal(data["data"]["fair_market_value"])
|
||||
results.append((currency.exchange_currency, currency, rate))
|
||||
|
||||
# Log API usage
|
||||
credits_used = data["meta"]["credits_used"]
|
||||
credits_remaining = data["meta"]["credits_remaining"]
|
||||
logger.info(
|
||||
f"Synth Finance API call for {currency.code}: {credits_used} credits used, {credits_remaining} remaining"
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Error fetching rate from Synth Finance API for ticker {currency.code}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"Unexpected response structure from Synth Finance API for ticker {currency.code}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error processing Synth Finance data for ticker {currency.code}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TransitiveRateProvider(ExchangeRateProvider):
|
||||
"""Calculates exchange rates through paths of existing rates"""
|
||||
|
||||
rates_inverted = True
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
super().__init__(api_key) # API key not needed but maintaining interface
|
||||
|
||||
@classmethod
|
||||
def requires_api_key(cls) -> bool:
|
||||
return False
|
||||
|
||||
def get_rates(
|
||||
self, target_currencies: QuerySet, exchange_currencies: set
|
||||
) -> List[Tuple[Currency, Currency, Decimal]]:
|
||||
results = []
|
||||
|
||||
# Get recent rates for building the graph
|
||||
recent_rates = ExchangeRate.objects.all()
|
||||
|
||||
# Build currency graph
|
||||
currency_graph = self._build_currency_graph(recent_rates)
|
||||
|
||||
for target in target_currencies:
|
||||
if (
|
||||
not target.exchange_currency
|
||||
or target.exchange_currency not in exchange_currencies
|
||||
):
|
||||
continue
|
||||
|
||||
# Find path and calculate rate
|
||||
from_id = target.exchange_currency.id
|
||||
to_id = target.id
|
||||
|
||||
path, rate = self._find_conversion_path(currency_graph, from_id, to_id)
|
||||
|
||||
if path and rate:
|
||||
path_codes = [Currency.objects.get(id=cid).code for cid in path]
|
||||
logger.info(
|
||||
f"Found conversion path: {' -> '.join(path_codes)}, rate: {rate}"
|
||||
)
|
||||
results.append((target.exchange_currency, target, rate))
|
||||
else:
|
||||
logger.debug(
|
||||
f"No conversion path found for {target.exchange_currency.code}->{target.code}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_currency_graph(rates) -> Dict[int, Dict[int, Decimal]]:
|
||||
"""Build a graph representation of currency relationships"""
|
||||
graph = {}
|
||||
|
||||
for rate in rates:
|
||||
# Add both directions to make the graph bidirectional
|
||||
if rate.from_currency_id not in graph:
|
||||
graph[rate.from_currency_id] = {}
|
||||
graph[rate.from_currency_id][rate.to_currency_id] = rate.rate
|
||||
|
||||
if rate.to_currency_id not in graph:
|
||||
graph[rate.to_currency_id] = {}
|
||||
graph[rate.to_currency_id][rate.from_currency_id] = Decimal("1") / rate.rate
|
||||
|
||||
return graph
|
||||
|
||||
@staticmethod
|
||||
def _find_conversion_path(
|
||||
graph, from_id, to_id
|
||||
) -> Tuple[Optional[list], Optional[Decimal]]:
|
||||
"""Find the shortest path between currencies using breadth-first search"""
|
||||
if from_id not in graph or to_id not in graph:
|
||||
return None, None
|
||||
|
||||
queue = [(from_id, [from_id], Decimal("1"))]
|
||||
visited = {from_id}
|
||||
|
||||
while queue:
|
||||
current, path, current_rate = queue.pop(0)
|
||||
|
||||
if current == to_id:
|
||||
return path, current_rate
|
||||
|
||||
for neighbor, rate in graph.get(current, {}).items():
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append((neighbor, path + [neighbor], current_rate * rate))
|
||||
|
||||
return None, None
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-02 01:28
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0011_remove_exchangerateservice_fetch_interval_hours_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='exchangerateservice',
|
||||
name='service_type',
|
||||
field=models.CharField(choices=[('synth_finance', 'Synth Finance'), ('synth_finance_stock', 'Synth Finance Stock'), ('coingecko_free', 'CoinGecko (Demo/Free)'), ('coingecko_pro', 'CoinGecko (Pro)')], max_length=255, verbose_name='Service Type'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-02 01:58
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0012_alter_exchangerateservice_service_type'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='exchangerateservice',
|
||||
name='service_type',
|
||||
field=models.CharField(choices=[('synth_finance', 'Synth Finance'), ('synth_finance_stock', 'Synth Finance Stock'), ('coingecko_free', 'CoinGecko (Demo/Free)'), ('coingecko_pro', 'CoinGecko (Pro)'), ('transitive', 'Transitive (Calculated from Existing Rates)')], max_length=255, verbose_name='Service Type'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
# Generated by Django 5.1.7 on 2025-03-09 21:54
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0013_alter_exchangerateservice_service_type'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelOptions(
|
||||
name='currency',
|
||||
options={'ordering': ['name', 'id'], 'verbose_name': 'Currency', 'verbose_name_plural': 'Currencies'},
|
||||
),
|
||||
]
|
||||
@@ -38,6 +38,7 @@ class Currency(models.Model):
|
||||
class Meta:
|
||||
verbose_name = _("Currency")
|
||||
verbose_name_plural = _("Currencies")
|
||||
ordering = ["name", "id"]
|
||||
|
||||
def clean(self):
|
||||
super().clean()
|
||||
@@ -92,8 +93,10 @@ class ExchangeRateService(models.Model):
|
||||
|
||||
class ServiceType(models.TextChoices):
|
||||
SYNTH_FINANCE = "synth_finance", "Synth Finance"
|
||||
SYNTH_FINANCE_STOCK = "synth_finance_stock", "Synth Finance Stock"
|
||||
COINGECKO_FREE = "coingecko_free", "CoinGecko (Demo/Free)"
|
||||
COINGECKO_PRO = "coingecko_pro", "CoinGecko (Pro)"
|
||||
TRANSITIVE = "transitive", "Transitive (Calculated from Existing Rates)"
|
||||
|
||||
class IntervalType(models.TextChoices):
|
||||
ON = "on", _("On")
|
||||
@@ -204,11 +207,11 @@ class ExchangeRateService(models.Model):
|
||||
}
|
||||
)
|
||||
hours = int(self.fetch_interval)
|
||||
if hours < 0 or hours > 23:
|
||||
if hours < 1 or hours > 24:
|
||||
raise ValidationError(
|
||||
{
|
||||
"fetch_interval": _(
|
||||
"'Every X hours' interval must be between 0 and 23."
|
||||
"'Every X hours' interval must be between 1 and 24."
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -4,8 +4,12 @@ from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
from django.contrib.auth.models import User # Added for ERS owner
|
||||
from datetime import date # Added for CurrencyConversionUtilsTests
|
||||
from apps.currencies.utils.convert import get_exchange_rate, convert # Added convert
|
||||
from unittest.mock import patch # Added patch
|
||||
|
||||
from apps.currencies.models import Currency, ExchangeRate
|
||||
from apps.currencies.models import Currency, ExchangeRate, ExchangeRateService
|
||||
|
||||
|
||||
class CurrencyTests(TestCase):
|
||||
@@ -52,6 +56,163 @@ class CurrencyTests(TestCase):
|
||||
with self.assertRaises(IntegrityError):
|
||||
Currency.objects.create(code="USD2", name="US Dollar", decimal_places=2)
|
||||
|
||||
def test_currency_exchange_currency_cannot_be_self(self):
|
||||
"""Test that a currency's exchange_currency cannot be itself."""
|
||||
currency = Currency.objects.create(
|
||||
code="XYZ", name="Test XYZ", decimal_places=2
|
||||
)
|
||||
currency.exchange_currency = currency # Set exchange_currency to self
|
||||
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
currency.full_clean()
|
||||
|
||||
self.assertIn('exchange_currency', cm.exception.error_dict)
|
||||
# Optionally, check for a specific error message if known:
|
||||
# self.assertTrue(any("cannot be the same as the currency itself" in e.message
|
||||
# for e in cm.exception.error_dict['exchange_currency']))
|
||||
|
||||
|
||||
class ExchangeRateServiceTests(TestCase):
|
||||
def setUp(self):
|
||||
self.owner = User.objects.create_user(username='ers_owner', password='password123')
|
||||
self.base_currency = Currency.objects.create(code="BSC", name="Base Service Coin", decimal_places=2)
|
||||
self.default_ers_params = {
|
||||
'name': "Test ERS",
|
||||
'owner': self.owner,
|
||||
'base_currency': self.base_currency,
|
||||
'provider_class': "dummy.provider.ClassName", # Placeholder
|
||||
}
|
||||
|
||||
def _create_ers_instance(self, interval_type, fetch_interval, **kwargs):
|
||||
params = {**self.default_ers_params, 'interval_type': interval_type, 'fetch_interval': fetch_interval, **kwargs}
|
||||
return ExchangeRateService(**params)
|
||||
|
||||
# Tests for IntervalType.EVERY
|
||||
def test_ers_interval_every_valid_integer(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.EVERY, "12")
|
||||
try:
|
||||
ers.full_clean()
|
||||
except ValidationError:
|
||||
self.fail("ValidationError raised unexpectedly for valid 'EVERY' interval '12'.")
|
||||
|
||||
def test_ers_interval_every_invalid_not_integer(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.EVERY, "abc")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_every_invalid_too_low(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.EVERY, "0")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_every_invalid_too_high(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.EVERY, "25") # Max is 24 for 'EVERY'
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
# Tests for IntervalType.ON (and by extension NOT_ON, as validation logic is shared)
|
||||
def test_ers_interval_on_not_on_valid_single_hour(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "5")
|
||||
try:
|
||||
ers.full_clean() # Should normalize to "5" if not already
|
||||
except ValidationError:
|
||||
self.fail("ValidationError raised unexpectedly for valid 'ON' interval '5'.")
|
||||
self.assertEqual(ers.fetch_interval, "5")
|
||||
|
||||
|
||||
def test_ers_interval_on_not_on_valid_multiple_hours(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "1,8,22")
|
||||
try:
|
||||
ers.full_clean()
|
||||
except ValidationError:
|
||||
self.fail("ValidationError raised unexpectedly for valid 'ON' interval '1,8,22'.")
|
||||
self.assertEqual(ers.fetch_interval, "1,8,22")
|
||||
|
||||
|
||||
def test_ers_interval_on_not_on_valid_range(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "0-4")
|
||||
ers.full_clean() # Should not raise ValidationError
|
||||
self.assertEqual(ers.fetch_interval, "0,1,2,3,4")
|
||||
|
||||
def test_ers_interval_on_not_on_valid_mixed(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "1-3,8,10-12")
|
||||
ers.full_clean() # Should not raise ValidationError
|
||||
self.assertEqual(ers.fetch_interval, "1,2,3,8,10,11,12")
|
||||
|
||||
def test_ers_interval_on_not_on_invalid_char(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "1-3,a")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_on_not_on_invalid_hour_too_high(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "24") # Max is 23 for 'ON' type hours
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_on_not_on_invalid_range_format(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "5-1")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_on_not_on_invalid_range_value_too_high(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "20-24") # 24 is invalid hour
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_on_not_on_empty_interval(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
@patch('apps.currencies.exchange_rates.fetcher.PROVIDER_MAPPING')
|
||||
def test_get_provider_valid_service_type(self, mock_provider_mapping):
|
||||
"""Test get_provider returns a configured provider instance for a valid service_type."""
|
||||
|
||||
class MockSynthFinanceProvider:
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
|
||||
# Configure the mock PROVIDER_MAPPING
|
||||
mock_provider_mapping.get.return_value = MockSynthFinanceProvider
|
||||
|
||||
service_instance = self._create_ers_instance(
|
||||
interval_type=ExchangeRateService.IntervalType.EVERY, # Needs some valid interval type
|
||||
fetch_interval="1", # Needs some valid fetch interval
|
||||
service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE,
|
||||
api_key="test_key"
|
||||
)
|
||||
# Ensure the service_type is correctly passed to the mock
|
||||
# The actual get_provider method uses PROVIDER_MAPPING[self.service_type]
|
||||
# So, we should make the mock_provider_mapping behave like a dict for the specific key
|
||||
mock_provider_mapping = {ExchangeRateService.ServiceType.SYNTH_FINANCE: MockSynthFinanceProvider}
|
||||
|
||||
with patch('apps.currencies.exchange_rates.fetcher.PROVIDER_MAPPING', mock_provider_mapping):
|
||||
provider = service_instance.get_provider()
|
||||
|
||||
self.assertIsInstance(provider, MockSynthFinanceProvider)
|
||||
self.assertEqual(provider.key, "test_key")
|
||||
|
||||
@patch('apps.currencies.exchange_rates.fetcher.PROVIDER_MAPPING', {}) # Empty mapping
|
||||
def test_get_provider_invalid_service_type(self, mock_provider_mapping_empty):
|
||||
"""Test get_provider raises KeyError for an invalid or unmapped service_type."""
|
||||
service_instance = self._create_ers_instance(
|
||||
interval_type=ExchangeRateService.IntervalType.EVERY,
|
||||
fetch_interval="1",
|
||||
service_type="UNMAPPED_SERVICE_TYPE", # A type not in the (mocked) mapping
|
||||
api_key="any_key"
|
||||
)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
service_instance.get_provider()
|
||||
|
||||
|
||||
class ExchangeRateTests(TestCase):
|
||||
def setUp(self):
|
||||
@@ -83,10 +244,169 @@ class ExchangeRateTests(TestCase):
|
||||
rate=Decimal("0.85"),
|
||||
date=date,
|
||||
)
|
||||
with self.assertRaises(Exception): # Could be IntegrityError
|
||||
with self.assertRaises(IntegrityError):
|
||||
ExchangeRate.objects.create(
|
||||
from_currency=self.usd,
|
||||
to_currency=self.eur,
|
||||
rate=Decimal("0.86"),
|
||||
date=date,
|
||||
)
|
||||
|
||||
def test_from_and_to_currency_cannot_be_same(self):
|
||||
"""Test that from_currency and to_currency cannot be the same."""
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
rate = ExchangeRate(
|
||||
from_currency=self.usd,
|
||||
to_currency=self.usd, # Same as from_currency
|
||||
rate=Decimal("1.00"),
|
||||
date=timezone.now().date(),
|
||||
)
|
||||
rate.full_clean()
|
||||
|
||||
# Check if the error message is as expected or if the error is associated with a specific field.
|
||||
# The exact key ('to_currency' or '__all__') depends on how the model's clean() method is implemented.
|
||||
# Assuming the validation error is raised with a message like "From and to currency cannot be the same."
|
||||
# and is a non-field error or specifically tied to 'to_currency'.
|
||||
self.assertTrue(
|
||||
'__all__' in cm.exception.error_dict or 'to_currency' in cm.exception.error_dict,
|
||||
"ValidationError should be for '__all__' or 'to_currency'"
|
||||
)
|
||||
# Optionally, check for a specific message if it's consistent:
|
||||
# found_message = False
|
||||
# if '__all__' in cm.exception.error_dict:
|
||||
# found_message = any("cannot be the same" in e.message for e in cm.exception.error_dict['__all__'])
|
||||
# if not found_message and 'to_currency' in cm.exception.error_dict:
|
||||
# found_message = any("cannot be the same" in e.message for e in cm.exception.error_dict['to_currency'])
|
||||
# self.assertTrue(found_message, "Error message about currencies being the same not found.")
|
||||
|
||||
|
||||
class CurrencyConversionUtilsTests(TestCase):
|
||||
def setUp(self):
|
||||
self.usd = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2, prefix="$", suffix="")
|
||||
self.eur = Currency.objects.create(code="EUR", name="Euro", decimal_places=2, prefix="€", suffix="")
|
||||
self.gbp = Currency.objects.create(code="GBP", name="British Pound", decimal_places=2, prefix="£", suffix="")
|
||||
|
||||
# Rates for USD <-> EUR
|
||||
self.usd_eur_rate_10th = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.90"), date=date(2023, 1, 10))
|
||||
self.usd_eur_rate_15th = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.92"), date=date(2023, 1, 15))
|
||||
ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.88"), date=date(2023, 1, 5))
|
||||
|
||||
# Rate for GBP <-> USD (for inverse lookup)
|
||||
self.gbp_usd_rate_10th = ExchangeRate.objects.create(from_currency=self.gbp, to_currency=self.usd, rate=Decimal("1.25"), date=date(2023, 1, 10))
|
||||
|
||||
def test_get_direct_rate_closest_date(self):
|
||||
"""Test fetching a direct rate, ensuring the closest date is chosen."""
|
||||
result = get_exchange_rate(self.usd, self.eur, date(2023, 1, 16))
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.effective_rate, Decimal("0.92"))
|
||||
self.assertEqual(result.original_from_currency, self.usd)
|
||||
self.assertEqual(result.original_to_currency, self.eur)
|
||||
|
||||
def test_get_inverse_rate_closest_date(self):
|
||||
"""Test fetching an inverse rate, ensuring the closest date and correct calculation."""
|
||||
# We are looking for USD to GBP. We have GBP to USD on 2023-01-10.
|
||||
# Target date is 2023-01-12.
|
||||
result = get_exchange_rate(self.usd, self.gbp, date(2023, 1, 12))
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.effective_rate, Decimal("1") / self.gbp_usd_rate_10th.rate)
|
||||
self.assertEqual(result.original_from_currency, self.gbp) # original_from_currency should be GBP
|
||||
self.assertEqual(result.original_to_currency, self.usd) # original_to_currency should be USD
|
||||
|
||||
def test_get_rate_exact_date_preference(self):
|
||||
"""Test that an exact date match is preferred over a closer one."""
|
||||
# Existing rate is on 2023-01-15 (0.92)
|
||||
# Add an exact match for the query date
|
||||
exact_date_rate = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.91"), date=date(2023, 1, 16))
|
||||
|
||||
result = get_exchange_rate(self.usd, self.eur, date(2023, 1, 16))
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.effective_rate, Decimal("0.91"))
|
||||
self.assertEqual(result.original_from_currency, self.usd)
|
||||
self.assertEqual(result.original_to_currency, self.eur)
|
||||
|
||||
def test_get_rate_no_matching_pair(self):
|
||||
"""Test that None is returned if no direct or inverse rate exists between the pair."""
|
||||
# No rates exist for EUR <-> GBP in the setUp
|
||||
result = get_exchange_rate(self.eur, self.gbp, date(2023, 1, 10))
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_rate_prefer_direct_over_inverse_same_diff(self):
|
||||
"""Test that a direct rate is preferred over an inverse if date differences are equal."""
|
||||
# We have GBP-USD on 2023-01-10 (self.gbp_usd_rate_10th)
|
||||
# This means an inverse USD-GBP rate is available for 2023-01-10.
|
||||
# Add a direct USD-GBP rate for the same date.
|
||||
direct_usd_gbp_rate = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.gbp, rate=Decimal("0.80"), date=date(2023, 1, 10))
|
||||
|
||||
result = get_exchange_rate(self.usd, self.gbp, date(2023, 1, 10))
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.effective_rate, Decimal("0.80"))
|
||||
self.assertEqual(result.original_from_currency, self.usd)
|
||||
self.assertEqual(result.original_to_currency, self.gbp)
|
||||
|
||||
# Now test the EUR to USD case from the problem description
|
||||
# Add EUR to USD, rate 1.1, date 2023-01-10
|
||||
eur_usd_direct_rate = ExchangeRate.objects.create(from_currency=self.eur, to_currency=self.usd, rate=Decimal("1.1"), date=date(2023, 1, 10))
|
||||
# We also have USD to EUR on 2023-01-10 (rate 0.90), which would be an inverse match for EUR to USD.
|
||||
|
||||
result_eur_usd = get_exchange_rate(self.eur, self.usd, date(2023, 1, 10))
|
||||
self.assertIsNotNone(result_eur_usd)
|
||||
self.assertEqual(result_eur_usd.effective_rate, Decimal("1.1"))
|
||||
self.assertEqual(result_eur_usd.original_from_currency, self.eur)
|
||||
self.assertEqual(result_eur_usd.original_to_currency, self.usd)
|
||||
|
||||
def test_convert_successful_direct(self):
|
||||
"""Test successful conversion using a direct rate."""
|
||||
# Uses self.usd_eur_rate_15th (0.92) as it's closest to 2023-01-16
|
||||
converted_amount, prefix, suffix, dp = convert(Decimal('100'), self.usd, self.eur, date(2023, 1, 16))
|
||||
self.assertEqual(converted_amount, Decimal('92.00'))
|
||||
self.assertEqual(prefix, self.eur.prefix)
|
||||
self.assertEqual(suffix, self.eur.suffix)
|
||||
self.assertEqual(dp, self.eur.decimal_places)
|
||||
|
||||
def test_convert_successful_inverse(self):
|
||||
"""Test successful conversion using an inverse rate."""
|
||||
# Uses self.gbp_usd_rate_10th (GBP to USD @ 1.25), so USD to GBP is 1/1.25 = 0.8
|
||||
# Target date 2023-01-12, closest is 2023-01-10
|
||||
converted_amount, prefix, suffix, dp = convert(Decimal('100'), self.usd, self.gbp, date(2023, 1, 12))
|
||||
expected_amount = Decimal('100') * (Decimal('1') / self.gbp_usd_rate_10th.rate)
|
||||
self.assertEqual(converted_amount, expected_amount.quantize(Decimal('0.01')))
|
||||
self.assertEqual(prefix, self.gbp.prefix)
|
||||
self.assertEqual(suffix, self.gbp.suffix)
|
||||
self.assertEqual(dp, self.gbp.decimal_places)
|
||||
|
||||
def test_convert_no_rate_found(self):
|
||||
"""Test conversion when no exchange rate is found."""
|
||||
result_tuple = convert(Decimal('100'), self.eur, self.gbp, date(2023, 1, 10))
|
||||
self.assertEqual(result_tuple, (None, None, None, None))
|
||||
|
||||
def test_convert_same_currency(self):
|
||||
"""Test conversion when from_currency and to_currency are the same."""
|
||||
result_tuple = convert(Decimal('100'), self.usd, self.usd, date(2023, 1, 10))
|
||||
self.assertEqual(result_tuple, (None, None, None, None))
|
||||
|
||||
def test_convert_zero_amount(self):
|
||||
"""Test conversion when the amount is zero."""
|
||||
result_tuple = convert(Decimal('0'), self.usd, self.eur, date(2023, 1, 10))
|
||||
self.assertEqual(result_tuple, (None, None, None, None))
|
||||
|
||||
@patch('apps.currencies.utils.convert.timezone')
|
||||
def test_convert_no_date_uses_today(self, mock_timezone):
|
||||
"""Test conversion uses today's date when no date is provided."""
|
||||
# Mock timezone.now().date() to return a specific date
|
||||
mock_today = date(2023, 1, 16)
|
||||
mock_timezone.now.return_value.date.return_value = mock_today
|
||||
|
||||
# This should use self.usd_eur_rate_15th (0.92) as it's closest to mocked "today" (2023-01-16)
|
||||
converted_amount, prefix, suffix, dp = convert(Decimal('100'), self.usd, self.eur)
|
||||
|
||||
self.assertEqual(converted_amount, Decimal('92.00'))
|
||||
self.assertEqual(prefix, self.eur.prefix)
|
||||
self.assertEqual(suffix, self.eur.suffix)
|
||||
self.assertEqual(dp, self.eur.decimal_places)
|
||||
|
||||
# Verify that timezone.now().date() was called (indirectly, by get_exchange_rate)
|
||||
# This specific assertion for get_exchange_rate being called with a specific date
|
||||
# would require patching get_exchange_rate itself, which is more complex.
|
||||
# For now, we rely on the correct outcome given the mocked date.
|
||||
# A more direct way to test date passing is if convert took get_exchange_rate as a dependency.
|
||||
mock_timezone.now.return_value.date.assert_called_once()
|
||||
|
||||
@@ -11,9 +11,11 @@ from apps.common.decorators.htmx import only_htmx
|
||||
from apps.currencies.forms import ExchangeRateForm, ExchangeRateServiceForm
|
||||
from apps.currencies.models import ExchangeRate, ExchangeRateService
|
||||
from apps.currencies.tasks import manual_fetch_exchange_rates
|
||||
from apps.common.decorators.demo import disabled_on_demo
|
||||
|
||||
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET"])
|
||||
def exchange_rates_services_index(request):
|
||||
return render(
|
||||
@@ -24,6 +26,7 @@ def exchange_rates_services_index(request):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET"])
|
||||
def exchange_rates_services_list(request):
|
||||
services = ExchangeRateService.objects.all()
|
||||
@@ -37,6 +40,7 @@ def exchange_rates_services_list(request):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def exchange_rate_service_add(request):
|
||||
if request.method == "POST":
|
||||
@@ -63,6 +67,7 @@ def exchange_rate_service_add(request):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def exchange_rate_service_edit(request, pk):
|
||||
service = get_object_or_404(ExchangeRateService, id=pk)
|
||||
@@ -91,6 +96,7 @@ def exchange_rate_service_edit(request, pk):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["DELETE"])
|
||||
def exchange_rate_service_delete(request, pk):
|
||||
service = get_object_or_404(ExchangeRateService, id=pk)
|
||||
@@ -109,6 +115,7 @@ def exchange_rate_service_delete(request, pk):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET"])
|
||||
def exchange_rate_service_force_fetch(request):
|
||||
manual_fetch_exchange_rates.defer()
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
from django.contrib import admin
|
||||
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.common.admin import SharedObjectModelAdmin
|
||||
|
||||
|
||||
# Register your models here.
|
||||
admin.site.register(DCAStrategy)
|
||||
admin.site.register(DCAEntry)
|
||||
|
||||
|
||||
@admin.register(DCAStrategy)
|
||||
class TransactionEntityModelAdmin(SharedObjectModelAdmin):
|
||||
def get_queryset(self, request):
|
||||
return DCAStrategy.all_objects.all()
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_bootstrap5.bootstrap5 import Switch, BS5Accordion
|
||||
from crispy_forms.bootstrap import FormActions, AccordionGroup
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Row, Column
|
||||
from crispy_forms.layout import Layout, Row, Column, HTML
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.datepicker import AirDatePickerInput
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.common.widgets.tom_select import TransactionSelect
|
||||
from apps.transactions.models import Transaction, TransactionTag, TransactionCategory
|
||||
from apps.common.fields.forms.dynamic_select import (
|
||||
DynamicModelChoiceField,
|
||||
DynamicModelMultipleChoiceField,
|
||||
)
|
||||
|
||||
|
||||
class DCAStrategyForm(forms.ModelForm):
|
||||
@@ -53,6 +61,75 @@ class DCAStrategyForm(forms.ModelForm):
|
||||
|
||||
|
||||
class DCAEntryForm(forms.ModelForm):
|
||||
create_transaction = forms.BooleanField(
|
||||
label=_("Create transaction"), initial=False, required=False
|
||||
)
|
||||
|
||||
from_account = forms.ModelChoiceField(
|
||||
queryset=Account.objects.filter(is_archived=False),
|
||||
label=_("From Account"),
|
||||
widget=TomSelect(clear_button=False, group_by="group"),
|
||||
required=False,
|
||||
)
|
||||
to_account = forms.ModelChoiceField(
|
||||
queryset=Account.objects.filter(is_archived=False),
|
||||
label=_("To Account"),
|
||||
widget=TomSelect(clear_button=False, group_by="group"),
|
||||
required=False,
|
||||
)
|
||||
|
||||
from_category = DynamicModelChoiceField(
|
||||
create_field="name",
|
||||
model=TransactionCategory,
|
||||
required=False,
|
||||
label=_("Category"),
|
||||
queryset=TransactionCategory.objects.filter(active=True),
|
||||
)
|
||||
to_category = DynamicModelChoiceField(
|
||||
create_field="name",
|
||||
model=TransactionCategory,
|
||||
required=False,
|
||||
label=_("Category"),
|
||||
queryset=TransactionCategory.objects.filter(active=True),
|
||||
)
|
||||
|
||||
from_tags = DynamicModelMultipleChoiceField(
|
||||
model=TransactionTag,
|
||||
to_field_name="name",
|
||||
create_field="name",
|
||||
required=False,
|
||||
label=_("Tags"),
|
||||
queryset=TransactionTag.objects.filter(active=True),
|
||||
)
|
||||
to_tags = DynamicModelMultipleChoiceField(
|
||||
model=TransactionTag,
|
||||
to_field_name="name",
|
||||
create_field="name",
|
||||
required=False,
|
||||
label=_("Tags"),
|
||||
queryset=TransactionTag.objects.filter(active=True),
|
||||
)
|
||||
|
||||
expense_transaction = DynamicModelChoiceField(
|
||||
model=Transaction,
|
||||
to_field_name="id",
|
||||
label=_("Expense Transaction"),
|
||||
required=False,
|
||||
queryset=Transaction.objects.none(),
|
||||
widget=TransactionSelect(clear_button=True, income=False, expense=True),
|
||||
help_text=_("Type to search for a transaction to link to this entry"),
|
||||
)
|
||||
|
||||
income_transaction = DynamicModelChoiceField(
|
||||
model=Transaction,
|
||||
to_field_name="id",
|
||||
label=_("Income Transaction"),
|
||||
required=False,
|
||||
queryset=Transaction.objects.none(),
|
||||
widget=TransactionSelect(clear_button=True, income=True, expense=False),
|
||||
help_text=_("Type to search for a transaction to link to this entry"),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = DCAEntry
|
||||
fields = [
|
||||
@@ -60,13 +137,19 @@ class DCAEntryForm(forms.ModelForm):
|
||||
"amount_paid",
|
||||
"amount_received",
|
||||
"notes",
|
||||
"expense_transaction",
|
||||
"income_transaction",
|
||||
]
|
||||
widgets = {
|
||||
"notes": forms.Textarea(attrs={"rows": 3}),
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
strategy = kwargs.pop("strategy", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.strategy = strategy if strategy else self.instance.strategy
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.layout = Layout(
|
||||
@@ -75,18 +158,66 @@ class DCAEntryForm(forms.ModelForm):
|
||||
Column("amount_paid", css_class="form-group col-md-6"),
|
||||
Column("amount_received", css_class="form-group col-md-6"),
|
||||
),
|
||||
Row(
|
||||
Column("expense_transaction", css_class="form-group col-md-6"),
|
||||
Column("income_transaction", css_class="form-group col-md-6"),
|
||||
),
|
||||
"notes",
|
||||
BS5Accordion(
|
||||
AccordionGroup(
|
||||
_("Create transaction"),
|
||||
Switch("create_transaction"),
|
||||
Row(
|
||||
Column(
|
||||
Row(
|
||||
Column(
|
||||
"from_account",
|
||||
css_class="form-group",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
"from_category",
|
||||
css_class="form-group col-md-6 mb-0",
|
||||
),
|
||||
Column(
|
||||
"from_tags", css_class="form-group col-md-6 mb-0"
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
),
|
||||
css_class="p-1 mx-1 my-3 border rounded-3",
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Row(
|
||||
Column(
|
||||
"to_account",
|
||||
css_class="form-group",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
"to_category", css_class="form-group col-md-6 mb-0"
|
||||
),
|
||||
Column("to_tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
),
|
||||
),
|
||||
css_class="p-1 mx-1 my-3 border rounded-3",
|
||||
),
|
||||
active=False,
|
||||
),
|
||||
AccordionGroup(
|
||||
_("Link transaction"),
|
||||
"income_transaction",
|
||||
"expense_transaction",
|
||||
),
|
||||
flush=False,
|
||||
always_open=False,
|
||||
css_class="mb-3",
|
||||
),
|
||||
)
|
||||
|
||||
if self.instance and self.instance.pk:
|
||||
# decimal_places = self.instance.account.currency.decimal_places
|
||||
# self.fields["amount"].widget = ArbitraryDecimalDisplayNumberInput(
|
||||
# decimal_places=decimal_places
|
||||
# )
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
@@ -95,7 +226,6 @@ class DCAEntryForm(forms.ModelForm):
|
||||
),
|
||||
)
|
||||
else:
|
||||
# self.fields["amount"].widget = ArbitraryDecimalDisplayNumberInput()
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
@@ -107,3 +237,136 @@ class DCAEntryForm(forms.ModelForm):
|
||||
self.fields["amount_paid"].widget = ArbitraryDecimalDisplayNumberInput()
|
||||
self.fields["amount_received"].widget = ArbitraryDecimalDisplayNumberInput()
|
||||
self.fields["date"].widget = AirDatePickerInput(clear_button=False)
|
||||
|
||||
expense_transaction = None
|
||||
income_transaction = None
|
||||
if self.instance and self.instance.pk:
|
||||
# Edit mode - get from instance
|
||||
expense_transaction = self.instance.expense_transaction
|
||||
income_transaction = self.instance.income_transaction
|
||||
elif self.data.get("expense_transaction"):
|
||||
# Form validation - get from submitted data
|
||||
try:
|
||||
expense_transaction = Transaction.objects.get(
|
||||
id=self.data["expense_transaction"]
|
||||
)
|
||||
income_transaction = Transaction.objects.get(
|
||||
id=self.data["income_transaction"]
|
||||
)
|
||||
except Transaction.DoesNotExist:
|
||||
pass
|
||||
|
||||
# If we have a current transaction, ensure it's in the queryset
|
||||
if income_transaction:
|
||||
self.fields["income_transaction"].queryset = Transaction.objects.filter(
|
||||
id=income_transaction.id
|
||||
)
|
||||
if expense_transaction:
|
||||
self.fields["expense_transaction"].queryset = Transaction.objects.filter(
|
||||
id=expense_transaction.id
|
||||
)
|
||||
|
||||
self.fields["from_account"].queryset = Account.objects.filter(
|
||||
is_archived=False,
|
||||
)
|
||||
|
||||
self.fields["from_category"].queryset = TransactionCategory.objects.filter(
|
||||
active=True
|
||||
)
|
||||
self.fields["from_tags"].queryset = TransactionTag.objects.filter(active=True)
|
||||
|
||||
self.fields["to_account"].queryset = Account.objects.filter(
|
||||
is_archived=False,
|
||||
)
|
||||
|
||||
self.fields["to_category"].queryset = TransactionCategory.objects.filter(
|
||||
active=True
|
||||
)
|
||||
self.fields["to_tags"].queryset = TransactionTag.objects.filter(active=True)
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
|
||||
if cleaned_data.get("create_transaction"):
|
||||
from_account = cleaned_data.get("from_account")
|
||||
to_account = cleaned_data.get("to_account")
|
||||
|
||||
if not from_account and not to_account:
|
||||
raise forms.ValidationError(
|
||||
{
|
||||
"from_account": _("You must provide an account."),
|
||||
"to_account": _("You must provide an account."),
|
||||
}
|
||||
)
|
||||
elif not from_account and to_account:
|
||||
raise forms.ValidationError(
|
||||
{"from_account": _("You must provide an account.")}
|
||||
)
|
||||
elif not to_account and from_account:
|
||||
raise forms.ValidationError(
|
||||
{"to_account": _("You must provide an account.")}
|
||||
)
|
||||
|
||||
if from_account == to_account:
|
||||
raise forms.ValidationError(
|
||||
_("From and To accounts must be different.")
|
||||
)
|
||||
|
||||
return cleaned_data
|
||||
|
||||
def save(self, **kwargs):
|
||||
instance = super().save(commit=False)
|
||||
|
||||
if self.cleaned_data.get("create_transaction"):
|
||||
from_account = self.cleaned_data["from_account"]
|
||||
to_account = self.cleaned_data["to_account"]
|
||||
from_amount = instance.amount_paid
|
||||
to_amount = instance.amount_received
|
||||
date = instance.date
|
||||
description = _("DCA for %(strategy_name)s") % {
|
||||
"strategy_name": self.strategy.name
|
||||
}
|
||||
from_category = self.cleaned_data.get("from_category")
|
||||
to_category = self.cleaned_data.get("to_category")
|
||||
notes = self.cleaned_data.get("notes")
|
||||
|
||||
# Create "From" transaction
|
||||
from_transaction = Transaction.objects.create(
|
||||
account=from_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
is_paid=True,
|
||||
date=date,
|
||||
amount=from_amount,
|
||||
description=description,
|
||||
category=from_category,
|
||||
notes=notes,
|
||||
)
|
||||
from_transaction.tags.set(self.cleaned_data.get("from_tags", []))
|
||||
|
||||
# Create "To" transaction
|
||||
to_transaction = Transaction.objects.create(
|
||||
account=to_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
is_paid=True,
|
||||
date=date,
|
||||
amount=to_amount,
|
||||
description=description,
|
||||
category=to_category,
|
||||
notes=notes,
|
||||
)
|
||||
to_transaction.tags.set(self.cleaned_data.get("to_tags", []))
|
||||
|
||||
instance.expense_transaction = from_transaction
|
||||
instance.income_transaction = to_transaction
|
||||
else:
|
||||
if instance.expense_transaction:
|
||||
instance.expense_transaction.amount = instance.amount_paid
|
||||
instance.expense_transaction.save()
|
||||
if instance.income_transaction:
|
||||
instance.income_transaction.amount = instance.amount_received
|
||||
instance.income_transaction.save()
|
||||
|
||||
instance.strategy = self.strategy
|
||||
instance.save()
|
||||
|
||||
return instance
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
# Generated by Django 5.1.6 on 2025-03-07 18:20
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dca', '0002_alter_dcaentry_amount_paid_and_more'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='dcastrategy',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(class)s_owned', to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='dcastrategy',
|
||||
name='shared_with',
|
||||
field=models.ManyToManyField(blank=True, related_name='%(class)s_shared', to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='dcastrategy',
|
||||
name='visibility',
|
||||
field=models.CharField(choices=[('private', 'Private'), ('public', 'Public')], default='private', max_length=10),
|
||||
),
|
||||
]
|
||||
@@ -1,16 +1,15 @@
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from statistics import mean, stdev
|
||||
|
||||
from django.db import models
|
||||
from django.template.defaultfilters import date
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.models import SharedObject, SharedObjectManager
|
||||
from apps.currencies.utils.convert import convert, get_exchange_rate
|
||||
|
||||
|
||||
class DCAStrategy(models.Model):
|
||||
class DCAStrategy(SharedObject):
|
||||
name = models.CharField(max_length=255, verbose_name=_("Name"))
|
||||
target_currency = models.ForeignKey(
|
||||
"currencies.Currency",
|
||||
@@ -28,6 +27,9 @@ class DCAStrategy(models.Model):
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
objects = SharedObjectManager()
|
||||
all_objects = models.Manager() # Unfiltered manager
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("DCA Strategy")
|
||||
verbose_name_plural = _("DCA Strategies")
|
||||
|
||||
@@ -1,3 +1,344 @@
|
||||
from django.test import TestCase
|
||||
from django.test import TestCase, Client
|
||||
from django.contrib.auth.models import User
|
||||
from django.urls import reverse
|
||||
from django.forms import NON_FIELD_ERRORS
|
||||
from apps.currencies.models import Currency
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.dca.forms import DCAStrategyForm, DCAEntryForm # Added DCAEntryForm
|
||||
from apps.accounts.models import Account, AccountGroup # Added Account models
|
||||
from apps.transactions.models import TransactionCategory, Transaction # Added Transaction models
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
from unittest.mock import patch
|
||||
|
||||
# Create your tests here.
|
||||
class DCATests(TestCase):
|
||||
def setUp(self):
|
||||
self.owner = User.objects.create_user(username='testowner', password='password123')
|
||||
self.client = Client()
|
||||
self.client.login(username='testowner', password='password123')
|
||||
|
||||
self.payment_curr = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2)
|
||||
self.target_curr = Currency.objects.create(code="BTC", name="Bitcoin", decimal_places=8)
|
||||
|
||||
# AccountGroup for accounts
|
||||
self.account_group = AccountGroup.objects.create(name="DCA Test Group", owner=self.owner)
|
||||
|
||||
# Accounts for transactions
|
||||
self.account1 = Account.objects.create(
|
||||
name="Payment Account USD",
|
||||
owner=self.owner,
|
||||
currency=self.payment_curr,
|
||||
group=self.account_group
|
||||
)
|
||||
self.account2 = Account.objects.create(
|
||||
name="Target Account BTC",
|
||||
owner=self.owner,
|
||||
currency=self.target_curr,
|
||||
group=self.account_group
|
||||
)
|
||||
|
||||
# TransactionCategory for transactions
|
||||
# Using INFO type as it's generic. TRANSFER might imply specific paired transaction logic not relevant here.
|
||||
self.category1 = TransactionCategory.objects.create(
|
||||
name="DCA Category",
|
||||
owner=self.owner,
|
||||
type=TransactionCategory.TransactionType.INFO
|
||||
)
|
||||
|
||||
|
||||
self.strategy1 = DCAStrategy.objects.create(
|
||||
name="Test Strategy 1",
|
||||
owner=self.owner,
|
||||
payment_currency=self.payment_curr,
|
||||
target_currency=self.target_curr
|
||||
)
|
||||
|
||||
self.entries1 = [
|
||||
DCAEntry.objects.create(
|
||||
strategy=self.strategy1,
|
||||
date=date(2023, 1, 1),
|
||||
amount_paid=Decimal('100.00'),
|
||||
amount_received=Decimal('0.010')
|
||||
),
|
||||
DCAEntry.objects.create(
|
||||
strategy=self.strategy1,
|
||||
date=date(2023, 2, 1),
|
||||
amount_paid=Decimal('150.00'),
|
||||
amount_received=Decimal('0.012')
|
||||
),
|
||||
DCAEntry.objects.create(
|
||||
strategy=self.strategy1,
|
||||
date=date(2023, 3, 1),
|
||||
amount_paid=Decimal('120.00'),
|
||||
amount_received=Decimal('0.008')
|
||||
)
|
||||
]
|
||||
|
||||
def test_strategy_index_view_authenticated_user(self):
|
||||
# Uses self.client and self.owner from setUp
|
||||
response = self.client.get(reverse('dca:dca_strategy_index'))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_strategy_totals_and_average_price(self):
|
||||
self.assertEqual(self.strategy1.total_entries(), 3)
|
||||
self.assertEqual(self.strategy1.total_invested(), Decimal('370.00')) # 100 + 150 + 120
|
||||
self.assertEqual(self.strategy1.total_received(), Decimal('0.030')) # 0.01 + 0.012 + 0.008
|
||||
|
||||
expected_avg_price = Decimal('370.00') / Decimal('0.030')
|
||||
# Match precision of the model method if it's specific, e.g. quantize
|
||||
# For now, direct comparison. The model might return a Decimal that needs specific quantizing.
|
||||
self.assertEqual(self.strategy1.average_entry_price(), expected_avg_price)
|
||||
|
||||
def test_strategy_average_price_no_received(self):
|
||||
strategy2 = DCAStrategy.objects.create(
|
||||
name="Test Strategy 2",
|
||||
owner=self.owner,
|
||||
payment_currency=self.payment_curr,
|
||||
target_currency=self.target_curr
|
||||
)
|
||||
DCAEntry.objects.create(
|
||||
strategy=strategy2,
|
||||
date=date(2023, 4, 1),
|
||||
amount_paid=Decimal('100.00'),
|
||||
amount_received=Decimal('0') # Total received is zero
|
||||
)
|
||||
self.assertEqual(strategy2.total_received(), Decimal('0'))
|
||||
self.assertEqual(strategy2.average_entry_price(), Decimal('0'))
|
||||
|
||||
@patch('apps.dca.models.convert')
|
||||
def test_dca_entry_value_and_pl(self, mock_convert):
|
||||
entry = self.entries1[0] # amount_paid=100, amount_received=0.010
|
||||
|
||||
# Simulate current price: 1 target_curr = 20,000 payment_curr
|
||||
# So, 0.010 target_curr should be 0.010 * 20000 = 200 payment_curr
|
||||
simulated_converted_value = entry.amount_received * Decimal('20000')
|
||||
mock_convert.return_value = (
|
||||
simulated_converted_value,
|
||||
self.payment_curr.prefix,
|
||||
self.payment_curr.suffix,
|
||||
self.payment_curr.decimal_places
|
||||
)
|
||||
|
||||
current_val = entry.current_value()
|
||||
self.assertEqual(current_val, Decimal('200.00'))
|
||||
|
||||
# Profit/Loss = current_value - amount_paid = 200 - 100 = 100
|
||||
self.assertEqual(entry.profit_loss(), Decimal('100.00'))
|
||||
|
||||
# P/L % = (profit_loss / amount_paid) * 100 = (100 / 100) * 100 = 100
|
||||
self.assertEqual(entry.profit_loss_percentage(), Decimal('100.00'))
|
||||
|
||||
# Check that convert was called correctly by current_value()
|
||||
# current_value calls convert(self.amount_received, self.strategy.target_currency, self.strategy.payment_currency)
|
||||
# The date argument defaults to None if not passed, which is the case here.
|
||||
mock_convert.assert_called_once_with(
|
||||
entry.amount_received,
|
||||
self.strategy1.target_currency,
|
||||
self.strategy1.payment_currency,
|
||||
None # Date argument is optional and defaults to None
|
||||
)
|
||||
|
||||
@patch('apps.dca.models.convert')
|
||||
def test_dca_strategy_value_and_pl(self, mock_convert):
|
||||
|
||||
def side_effect_func(amount_to_convert, from_currency, to_currency, date=None):
|
||||
if from_currency == self.target_curr and to_currency == self.payment_curr:
|
||||
# Simulate current price: 1 target_curr = 20,000 payment_curr
|
||||
converted_value = amount_to_convert * Decimal('20000')
|
||||
return (converted_value, self.payment_curr.prefix, self.payment_curr.suffix, self.payment_curr.decimal_places)
|
||||
# Fallback for any other unexpected calls, though not expected in this test
|
||||
return (Decimal('0'), '', '', 2)
|
||||
|
||||
mock_convert.side_effect = side_effect_func
|
||||
|
||||
# strategy1 entries:
|
||||
# 1: paid 100, received 0.010. Current value = 0.010 * 20000 = 200
|
||||
# 2: paid 150, received 0.012. Current value = 0.012 * 20000 = 240
|
||||
# 3: paid 120, received 0.008. Current value = 0.008 * 20000 = 160
|
||||
# Total current value = 200 + 240 + 160 = 600
|
||||
self.assertEqual(self.strategy1.current_total_value(), Decimal('600.00'))
|
||||
|
||||
# Total invested = 100 + 150 + 120 = 370
|
||||
# Total profit/loss = current_total_value - total_invested = 600 - 370 = 230
|
||||
self.assertEqual(self.strategy1.total_profit_loss(), Decimal('230.00'))
|
||||
|
||||
# Total P/L % = (total_profit_loss / total_invested) * 100
|
||||
# (230 / 370) * 100 = 62.162162...
|
||||
expected_pl_percentage = (Decimal('230.00') / Decimal('370.00')) * Decimal('100')
|
||||
self.assertAlmostEqual(self.strategy1.total_profit_loss_percentage(), expected_pl_percentage, places=2)
|
||||
|
||||
def test_dca_strategy_form_valid_data(self):
|
||||
form_data = {
|
||||
'name': 'Form Test Strategy',
|
||||
'target_currency': self.target_curr.pk,
|
||||
'payment_currency': self.payment_curr.pk
|
||||
}
|
||||
form = DCAStrategyForm(data=form_data)
|
||||
self.assertTrue(form.is_valid(), form.errors.as_text())
|
||||
|
||||
strategy = form.save(commit=False)
|
||||
strategy.owner = self.owner
|
||||
strategy.save()
|
||||
|
||||
self.assertEqual(strategy.name, 'Form Test Strategy')
|
||||
self.assertEqual(strategy.owner, self.owner)
|
||||
self.assertEqual(strategy.target_currency, self.target_curr)
|
||||
self.assertEqual(strategy.payment_currency, self.payment_curr)
|
||||
|
||||
def test_dca_strategy_form_missing_name(self):
|
||||
form_data = {
|
||||
'target_currency': self.target_curr.pk,
|
||||
'payment_currency': self.payment_curr.pk
|
||||
}
|
||||
form = DCAStrategyForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('name', form.errors)
|
||||
|
||||
def test_dca_strategy_form_missing_target_currency(self):
|
||||
form_data = {
|
||||
'name': 'Form Test Missing Target',
|
||||
'payment_currency': self.payment_curr.pk
|
||||
}
|
||||
form = DCAStrategyForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('target_currency', form.errors)
|
||||
|
||||
# Tests for DCAEntryForm clean method
|
||||
def test_dca_entry_form_clean_create_transaction_missing_accounts(self):
|
||||
data = {
|
||||
'date': date(2023, 1, 1),
|
||||
'amount_paid': Decimal('100.00'),
|
||||
'amount_received': Decimal('0.01'),
|
||||
'create_transaction': True,
|
||||
# from_account and to_account are missing
|
||||
}
|
||||
form = DCAEntryForm(data=data, strategy=self.strategy1, owner=self.owner)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('from_account', form.errors)
|
||||
self.assertIn('to_account', form.errors)
|
||||
|
||||
def test_dca_entry_form_clean_create_transaction_same_accounts(self):
|
||||
data = {
|
||||
'date': date(2023, 1, 1),
|
||||
'amount_paid': Decimal('100.00'),
|
||||
'amount_received': Decimal('0.01'),
|
||||
'create_transaction': True,
|
||||
'from_account': self.account1.pk,
|
||||
'to_account': self.account1.pk, # Same as from_account
|
||||
'from_category': self.category1.pk,
|
||||
'to_category': self.category1.pk,
|
||||
}
|
||||
form = DCAEntryForm(data=data, strategy=self.strategy1, owner=self.owner)
|
||||
self.assertFalse(form.is_valid())
|
||||
# Check for non-field error or specific field error based on form implementation
|
||||
self.assertTrue(NON_FIELD_ERRORS in form.errors or 'to_account' in form.errors)
|
||||
if NON_FIELD_ERRORS in form.errors:
|
||||
self.assertTrue(any("From and To accounts must be different" in error for error in form.errors[NON_FIELD_ERRORS]))
|
||||
|
||||
|
||||
# Tests for DCAEntryForm save method
|
||||
def test_dca_entry_form_save_create_transactions(self):
|
||||
data = {
|
||||
'date': date(2023, 5, 1),
|
||||
'amount_paid': Decimal('200.00'),
|
||||
'amount_received': Decimal('0.025'),
|
||||
'create_transaction': True,
|
||||
'from_account': self.account1.pk,
|
||||
'to_account': self.account2.pk,
|
||||
'from_category': self.category1.pk,
|
||||
'to_category': self.category1.pk,
|
||||
'description': 'Test DCA entry transaction creation'
|
||||
}
|
||||
form = DCAEntryForm(data=data, strategy=self.strategy1, owner=self.owner)
|
||||
|
||||
if not form.is_valid():
|
||||
print(form.errors.as_json()) # Print errors if form is invalid
|
||||
self.assertTrue(form.is_valid())
|
||||
|
||||
entry = form.save()
|
||||
|
||||
self.assertIsNotNone(entry.pk)
|
||||
self.assertEqual(entry.strategy, self.strategy1)
|
||||
self.assertIsNotNone(entry.expense_transaction)
|
||||
self.assertIsNotNone(entry.income_transaction)
|
||||
|
||||
# Check expense transaction
|
||||
expense_tx = entry.expense_transaction
|
||||
self.assertEqual(expense_tx.account, self.account1)
|
||||
self.assertEqual(expense_tx.type, Transaction.Type.EXPENSE)
|
||||
self.assertEqual(expense_tx.amount, data['amount_paid'])
|
||||
self.assertEqual(expense_tx.category, self.category1)
|
||||
self.assertEqual(expense_tx.owner, self.owner)
|
||||
self.assertEqual(expense_tx.date, data['date'])
|
||||
self.assertIn(str(entry.id)[:8], expense_tx.description) # Check if part of entry ID is in description
|
||||
|
||||
# Check income transaction
|
||||
income_tx = entry.income_transaction
|
||||
self.assertEqual(income_tx.account, self.account2)
|
||||
self.assertEqual(income_tx.type, Transaction.Type.INCOME)
|
||||
self.assertEqual(income_tx.amount, data['amount_received'])
|
||||
self.assertEqual(income_tx.category, self.category1)
|
||||
self.assertEqual(income_tx.owner, self.owner)
|
||||
self.assertEqual(income_tx.date, data['date'])
|
||||
self.assertIn(str(entry.id)[:8], income_tx.description)
|
||||
|
||||
|
||||
def test_dca_entry_form_save_update_linked_transactions(self):
|
||||
# 1. Create an initial DCAEntry with linked transactions
|
||||
initial_data = {
|
||||
'date': date(2023, 6, 1),
|
||||
'amount_paid': Decimal('50.00'),
|
||||
'amount_received': Decimal('0.005'),
|
||||
'create_transaction': True,
|
||||
'from_account': self.account1.pk,
|
||||
'to_account': self.account2.pk,
|
||||
'from_category': self.category1.pk,
|
||||
'to_category': self.category1.pk,
|
||||
}
|
||||
initial_form = DCAEntryForm(data=initial_data, strategy=self.strategy1, owner=self.owner)
|
||||
self.assertTrue(initial_form.is_valid(), initial_form.errors.as_json())
|
||||
initial_entry = initial_form.save()
|
||||
|
||||
self.assertIsNotNone(initial_entry.expense_transaction)
|
||||
self.assertIsNotNone(initial_entry.income_transaction)
|
||||
|
||||
# 2. Data for updating the form
|
||||
update_data = {
|
||||
'date': initial_entry.date, # Keep date same or change, as needed
|
||||
'amount_paid': Decimal('55.00'), # New value
|
||||
'amount_received': Decimal('0.006'), # New value
|
||||
# 'create_transaction': False, # Or not present, form should not create new if instance has linked tx
|
||||
'from_account': initial_entry.expense_transaction.account.pk, # Keep same accounts
|
||||
'to_account': initial_entry.income_transaction.account.pk,
|
||||
'from_category': initial_entry.expense_transaction.category.pk,
|
||||
'to_category': initial_entry.income_transaction.category.pk,
|
||||
}
|
||||
|
||||
# When create_transaction is not checked (or False), it means we are not creating *new* transactions,
|
||||
# but if the instance already has linked transactions, they *should* be updated.
|
||||
# The form's save method should handle this.
|
||||
|
||||
update_form = DCAEntryForm(data=update_data, instance=initial_entry, strategy=initial_entry.strategy, owner=self.owner)
|
||||
|
||||
if not update_form.is_valid():
|
||||
print(update_form.errors.as_json()) # Print errors if form is invalid
|
||||
self.assertTrue(update_form.is_valid())
|
||||
|
||||
updated_entry = update_form.save()
|
||||
|
||||
# Refresh from DB to ensure changes are saved and reflected
|
||||
updated_entry.refresh_from_db()
|
||||
if updated_entry.expense_transaction: # Check if it exists before trying to refresh
|
||||
updated_entry.expense_transaction.refresh_from_db()
|
||||
if updated_entry.income_transaction: # Check if it exists before trying to refresh
|
||||
updated_entry.income_transaction.refresh_from_db()
|
||||
|
||||
|
||||
self.assertEqual(updated_entry.amount_paid, Decimal('55.00'))
|
||||
self.assertEqual(updated_entry.amount_received, Decimal('0.006'))
|
||||
|
||||
self.assertIsNotNone(updated_entry.expense_transaction, "Expense transaction should still be linked.")
|
||||
self.assertEqual(updated_entry.expense_transaction.amount, Decimal('55.00'))
|
||||
|
||||
self.assertIsNotNone(updated_entry.income_transaction, "Income transaction should still be linked.")
|
||||
self.assertEqual(updated_entry.income_transaction.amount, Decimal('0.006'))
|
||||
|
||||
@@ -12,6 +12,16 @@ urlpatterns = [
|
||||
views.strategy_delete,
|
||||
name="dca_strategy_delete",
|
||||
),
|
||||
path(
|
||||
"dca/<int:strategy_id>/take-ownership/",
|
||||
views.strategy_take_ownership,
|
||||
name="dca_strategy_take_ownership",
|
||||
),
|
||||
path(
|
||||
"dca/<int:pk>/share/",
|
||||
views.strategy_share,
|
||||
name="dca_strategy_share_settings",
|
||||
),
|
||||
path(
|
||||
"dca/<int:strategy_id>/",
|
||||
views.strategy_detail_index,
|
||||
|
||||
@@ -11,6 +11,8 @@ from django.views.decorators.http import require_http_methods
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
from apps.dca.forms import DCAEntryForm, DCAStrategyForm
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.common.models import SharedObject
|
||||
from apps.common.forms import SharedObjectForm
|
||||
|
||||
|
||||
@login_required
|
||||
@@ -57,6 +59,16 @@ def strategy_add(request):
|
||||
def strategy_edit(request, strategy_id):
|
||||
dca_strategy = get_object_or_404(DCAStrategy, id=strategy_id)
|
||||
|
||||
if dca_strategy.owner and dca_strategy.owner != request.user:
|
||||
messages.error(request, _("Only the owner can edit this"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
if request.method == "POST":
|
||||
form = DCAStrategyForm(request.POST, instance=dca_strategy)
|
||||
if form.is_valid():
|
||||
@@ -85,9 +97,15 @@ def strategy_edit(request, strategy_id):
|
||||
def strategy_delete(request, strategy_id):
|
||||
dca_strategy = get_object_or_404(DCAStrategy, id=strategy_id)
|
||||
|
||||
dca_strategy.delete()
|
||||
|
||||
messages.success(request, _("DCA strategy deleted successfully"))
|
||||
if (
|
||||
dca_strategy.owner != request.user
|
||||
and request.user in dca_strategy.shared_with.all()
|
||||
):
|
||||
dca_strategy.shared_with.remove(request.user)
|
||||
messages.success(request, _("Item no longer shared with you"))
|
||||
else:
|
||||
dca_strategy.delete()
|
||||
messages.success(request, _("DCA strategy deleted successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
@@ -97,6 +115,65 @@ def strategy_delete(request, strategy_id):
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def strategy_take_ownership(request, strategy_id):
|
||||
dca_strategy = get_object_or_404(DCAStrategy, id=strategy_id)
|
||||
|
||||
if not dca_strategy.owner:
|
||||
dca_strategy.owner = request.user
|
||||
dca_strategy.visibility = SharedObject.Visibility.private
|
||||
dca_strategy.save()
|
||||
|
||||
messages.success(request, _("Ownership taken successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def strategy_share(request, pk):
|
||||
obj = get_object_or_404(DCAStrategy, id=pk)
|
||||
|
||||
if obj.owner and obj.owner != request.user:
|
||||
messages.error(request, _("Only the owner can edit this"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
|
||||
if request.method == "POST":
|
||||
form = SharedObjectForm(request.POST, instance=obj, user=request.user)
|
||||
if form.is_valid():
|
||||
form.save()
|
||||
messages.success(request, _("Configuration saved successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated, hide_offcanvas",
|
||||
},
|
||||
)
|
||||
else:
|
||||
form = SharedObjectForm(instance=obj, user=request.user)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"dca/fragments/strategy/share.html",
|
||||
{"form": form, "object": obj},
|
||||
)
|
||||
|
||||
|
||||
@login_required
|
||||
def strategy_detail_index(request, strategy_id):
|
||||
strategy = get_object_or_404(DCAStrategy, id=strategy_id)
|
||||
@@ -155,11 +232,9 @@ def strategy_detail(request, strategy_id):
|
||||
def strategy_entry_add(request, strategy_id):
|
||||
strategy = get_object_or_404(DCAStrategy, id=strategy_id)
|
||||
if request.method == "POST":
|
||||
form = DCAEntryForm(request.POST)
|
||||
form = DCAEntryForm(request.POST, strategy=strategy)
|
||||
if form.is_valid():
|
||||
entry = form.save(commit=False)
|
||||
entry.strategy = strategy
|
||||
entry.save()
|
||||
entry = form.save()
|
||||
messages.success(request, _("Entry added successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
@@ -169,7 +244,7 @@ def strategy_entry_add(request, strategy_id):
|
||||
},
|
||||
)
|
||||
else:
|
||||
form = DCAEntryForm()
|
||||
form = DCAEntryForm(strategy=strategy)
|
||||
|
||||
return render(
|
||||
request,
|
||||
|
||||
0
app/apps/export_app/__init__.py
Normal file
0
app/apps/export_app/__init__.py
Normal file
3
app/apps/export_app/admin.py
Normal file
3
app/apps/export_app/admin.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
6
app/apps/export_app/apps.py
Normal file
6
app/apps/export_app/apps.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class ExportConfig(AppConfig):
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "apps.export_app"
|
||||
198
app/apps/export_app/forms.py
Normal file
198
app/apps/export_app/forms.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, HTML
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
|
||||
|
||||
class ExportForm(forms.Form):
|
||||
users = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Users"),
|
||||
initial=True,
|
||||
)
|
||||
accounts = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Accounts"),
|
||||
initial=True,
|
||||
)
|
||||
currencies = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Currencies"),
|
||||
initial=True,
|
||||
)
|
||||
transactions = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Transactions"),
|
||||
initial=True,
|
||||
)
|
||||
categories = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Categories"),
|
||||
initial=True,
|
||||
)
|
||||
tags = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Tags"),
|
||||
initial=False,
|
||||
)
|
||||
entities = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Entities"),
|
||||
initial=False,
|
||||
)
|
||||
recurring_transactions = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Recurring Transactions"),
|
||||
initial=True,
|
||||
)
|
||||
installment_plans = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Installment Plans"),
|
||||
initial=True,
|
||||
)
|
||||
exchange_rates = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Exchange Rates"),
|
||||
initial=False,
|
||||
)
|
||||
exchange_rates_services = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Automatic Exchange Rates"),
|
||||
initial=False,
|
||||
)
|
||||
rules = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Rules"),
|
||||
initial=True,
|
||||
)
|
||||
dca = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("DCA"),
|
||||
initial=False,
|
||||
)
|
||||
import_profiles = forms.BooleanField(
|
||||
required=False,
|
||||
widget=forms.CheckboxInput(),
|
||||
label=_("Import Profiles"),
|
||||
initial=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.form_method = "post"
|
||||
self.helper.layout = Layout(
|
||||
"users",
|
||||
"accounts",
|
||||
"currencies",
|
||||
"transactions",
|
||||
"categories",
|
||||
"entities",
|
||||
"tags",
|
||||
"installment_plans",
|
||||
"recurring_transactions",
|
||||
"exchange_rates_services",
|
||||
"exchange_rates",
|
||||
"rules",
|
||||
"dca",
|
||||
"import_profiles",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Export"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RestoreForm(forms.Form):
|
||||
zip_file = forms.FileField(
|
||||
required=False,
|
||||
help_text=_("Import a ZIP file exported from WYGIWYH"),
|
||||
label=_("ZIP File"),
|
||||
)
|
||||
users = forms.FileField(required=False, label=_("Users"))
|
||||
accounts = forms.FileField(required=False, label=_("Accounts"))
|
||||
currencies = forms.FileField(required=False, label=_("Currencies"))
|
||||
transactions_categories = forms.FileField(required=False, label=_("Categories"))
|
||||
transactions_tags = forms.FileField(required=False, label=_("Tags"))
|
||||
transactions_entities = forms.FileField(required=False, label=_("Entities"))
|
||||
transactions = forms.FileField(required=False, label=_("Transactions"))
|
||||
installment_plans = forms.FileField(required=False, label=_("Installment Plans"))
|
||||
recurring_transactions = forms.FileField(
|
||||
required=False, label=_("Recurring Transactions")
|
||||
)
|
||||
automatic_exchange_rates = forms.FileField(
|
||||
required=False, label=_("Automatic Exchange Rates")
|
||||
)
|
||||
exchange_rates = forms.FileField(required=False, label=_("Exchange Rates"))
|
||||
transaction_rules = forms.FileField(required=False, label=_("Transaction rules"))
|
||||
transaction_rules_actions = forms.FileField(
|
||||
required=False, label=_("Edit transaction action")
|
||||
)
|
||||
transaction_rules_update_or_create = forms.FileField(
|
||||
required=False, label=_("Update or create transaction actions")
|
||||
)
|
||||
dca_strategies = forms.FileField(required=False, label=_("DCA Strategies"))
|
||||
dca_entries = forms.FileField(required=False, label=_("DCA Entries"))
|
||||
import_profiles = forms.FileField(required=False, label=_("Import Profiles"))
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.form_method = "post"
|
||||
self.helper.layout = Layout(
|
||||
"zip_file",
|
||||
HTML("<hr />"),
|
||||
"users",
|
||||
"accounts",
|
||||
"currencies",
|
||||
"transactions",
|
||||
"transactions_categories",
|
||||
"transactions_entities",
|
||||
"transactions_tags",
|
||||
"installment_plans",
|
||||
"recurring_transactions",
|
||||
"automatic_exchange_rates",
|
||||
"exchange_rates",
|
||||
"transaction_rules",
|
||||
"transaction_rules_actions",
|
||||
"transaction_rules_update_or_create",
|
||||
"dca_strategies",
|
||||
"dca_entries",
|
||||
"import_profiles",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Restore"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super().clean()
|
||||
if not cleaned_data.get("zip_file") and not any(
|
||||
cleaned_data.get(field) for field in self.fields if field != "zip_file"
|
||||
):
|
||||
raise forms.ValidationError(
|
||||
_("Please upload either a ZIP file or at least one CSV file")
|
||||
)
|
||||
return cleaned_data
|
||||
0
app/apps/export_app/migrations/__init__.py
Normal file
0
app/apps/export_app/migrations/__init__.py
Normal file
3
app/apps/export_app/models.py
Normal file
3
app/apps/export_app/models.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from django.db import models
|
||||
|
||||
# Create your models here.
|
||||
0
app/apps/export_app/resources/__init__.py
Normal file
0
app/apps/export_app/resources/__init__.py
Normal file
29
app/apps/export_app/resources/accounts.py
Normal file
29
app/apps/export_app/resources/accounts.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from import_export import fields, resources, widgets
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.export_app.widgets.foreign_key import AutoCreateForeignKeyWidget
|
||||
from apps.currencies.models import Currency
|
||||
|
||||
|
||||
class AccountResource(resources.ModelResource):
|
||||
group = fields.Field(
|
||||
attribute="group",
|
||||
column_name="group",
|
||||
widget=AutoCreateForeignKeyWidget(AccountGroup, "name"),
|
||||
)
|
||||
currency = fields.Field(
|
||||
attribute="currency",
|
||||
column_name="currency",
|
||||
widget=widgets.ForeignKeyWidget(Currency, "name"),
|
||||
)
|
||||
exchange_currency = fields.Field(
|
||||
attribute="exchange_currency",
|
||||
column_name="exchange_currency",
|
||||
widget=widgets.ForeignKeyWidget(Currency, "name"),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = Account
|
||||
|
||||
def get_queryset(self):
|
||||
return Account.all_objects.all()
|
||||
52
app/apps/export_app/resources/currencies.py
Normal file
52
app/apps/export_app/resources/currencies.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from import_export import fields, resources, widgets
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.currencies.models import Currency, ExchangeRate, ExchangeRateService
|
||||
from apps.export_app.widgets.foreign_key import SkipMissingForeignKeyWidget
|
||||
from apps.export_app.widgets.numbers import UniversalDecimalWidget
|
||||
|
||||
|
||||
class CurrencyResource(resources.ModelResource):
|
||||
exchange_currency = fields.Field(
|
||||
attribute="exchange_currency",
|
||||
column_name="exchange_currency",
|
||||
widget=SkipMissingForeignKeyWidget(Currency, "name"),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = Currency
|
||||
|
||||
|
||||
class ExchangeRateResource(resources.ModelResource):
|
||||
from_currency = fields.Field(
|
||||
attribute="from_currency",
|
||||
column_name="from_currency",
|
||||
widget=widgets.ForeignKeyWidget(Currency, "name"),
|
||||
)
|
||||
to_currency = fields.Field(
|
||||
attribute="to_currency",
|
||||
column_name="to_currency",
|
||||
widget=widgets.ForeignKeyWidget(Currency, "name"),
|
||||
)
|
||||
rate = fields.Field(
|
||||
attribute="rate", column_name="rate", widget=UniversalDecimalWidget()
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = ExchangeRate
|
||||
|
||||
|
||||
class ExchangeRateServiceResource(resources.ModelResource):
|
||||
target_currencies = fields.Field(
|
||||
attribute="target_currencies",
|
||||
column_name="target_currencies",
|
||||
widget=widgets.ManyToManyWidget(Currency, field="name"),
|
||||
)
|
||||
target_accounts = fields.Field(
|
||||
attribute="target_accounts",
|
||||
column_name="target_accounts",
|
||||
widget=widgets.ManyToManyWidget(Account, field="name"),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = ExchangeRateService
|
||||
38
app/apps/export_app/resources/dca.py
Normal file
38
app/apps/export_app/resources/dca.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from import_export import fields, resources
|
||||
from import_export.widgets import ForeignKeyWidget
|
||||
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.currencies.models import Currency
|
||||
from apps.export_app.widgets.numbers import UniversalDecimalWidget
|
||||
|
||||
|
||||
class DCAStrategyResource(resources.ModelResource):
|
||||
target_currency = fields.Field(
|
||||
attribute="target_currency",
|
||||
column_name="target_currency",
|
||||
widget=ForeignKeyWidget(Currency, "name"),
|
||||
)
|
||||
payment_currency = fields.Field(
|
||||
attribute="payment_currency",
|
||||
column_name="payment_currency",
|
||||
widget=ForeignKeyWidget(Currency, "name"),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = DCAStrategy
|
||||
|
||||
|
||||
class DCAEntryResource(resources.ModelResource):
|
||||
amount_paid = fields.Field(
|
||||
attribute="amount_paid",
|
||||
column_name="amount_paid",
|
||||
widget=UniversalDecimalWidget(),
|
||||
)
|
||||
amount_received = fields.Field(
|
||||
attribute="amount_received",
|
||||
column_name="amount_received",
|
||||
widget=UniversalDecimalWidget(),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = DCAEntry
|
||||
8
app/apps/export_app/resources/import_app.py
Normal file
8
app/apps/export_app/resources/import_app.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from import_export import resources
|
||||
|
||||
from apps.import_app.models import ImportProfile
|
||||
|
||||
|
||||
class ImportProfileResource(resources.ModelResource):
|
||||
class Meta:
|
||||
model = ImportProfile
|
||||
25
app/apps/export_app/resources/rules.py
Normal file
25
app/apps/export_app/resources/rules.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from import_export import fields, resources
|
||||
from import_export.widgets import ForeignKeyWidget
|
||||
|
||||
from apps.export_app.widgets.foreign_key import AutoCreateForeignKeyWidget
|
||||
from apps.export_app.widgets.many_to_many import AutoCreateManyToManyWidget
|
||||
from apps.rules.models import (
|
||||
TransactionRule,
|
||||
TransactionRuleAction,
|
||||
UpdateOrCreateTransactionRuleAction,
|
||||
)
|
||||
|
||||
|
||||
class TransactionRuleResource(resources.ModelResource):
|
||||
class Meta:
|
||||
model = TransactionRule
|
||||
|
||||
|
||||
class TransactionRuleActionResource(resources.ModelResource):
|
||||
class Meta:
|
||||
model = TransactionRuleAction
|
||||
|
||||
|
||||
class UpdateOrCreateTransactionRuleResource(resources.ModelResource):
|
||||
class Meta:
|
||||
model = UpdateOrCreateTransactionRuleAction
|
||||
158
app/apps/export_app/resources/transactions.py
Normal file
158
app/apps/export_app/resources/transactions.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from import_export import fields, resources
|
||||
from import_export.widgets import ForeignKeyWidget
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.export_app.widgets.foreign_key import AutoCreateForeignKeyWidget
|
||||
from apps.export_app.widgets.many_to_many import AutoCreateManyToManyWidget
|
||||
from apps.export_app.widgets.string import EmptyStringToNoneField
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
TransactionEntity,
|
||||
RecurringTransaction,
|
||||
InstallmentPlan,
|
||||
)
|
||||
from apps.export_app.widgets.numbers import UniversalDecimalWidget
|
||||
|
||||
|
||||
class TransactionResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
attribute="category",
|
||||
column_name="category",
|
||||
widget=AutoCreateForeignKeyWidget(TransactionCategory, "name"),
|
||||
)
|
||||
|
||||
tags = fields.Field(
|
||||
attribute="tags",
|
||||
column_name="tags",
|
||||
widget=AutoCreateManyToManyWidget(TransactionTag, field="name"),
|
||||
)
|
||||
|
||||
entities = fields.Field(
|
||||
attribute="entities",
|
||||
column_name="entities",
|
||||
widget=AutoCreateManyToManyWidget(TransactionEntity, field="name"),
|
||||
)
|
||||
|
||||
internal_id = EmptyStringToNoneField(
|
||||
column_name="internal_id", attribute="internal_id"
|
||||
)
|
||||
|
||||
amount = fields.Field(
|
||||
attribute="amount",
|
||||
column_name="amount",
|
||||
widget=UniversalDecimalWidget(),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = Transaction
|
||||
|
||||
def get_queryset(self):
|
||||
return Transaction.userless_all_objects.all()
|
||||
|
||||
|
||||
class TransactionTagResource(resources.ModelResource):
|
||||
class Meta:
|
||||
model = TransactionTag
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionTag.all_objects.all()
|
||||
|
||||
|
||||
class TransactionEntityResource(resources.ModelResource):
|
||||
class Meta:
|
||||
model = TransactionEntity
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionEntity.all_objects.all()
|
||||
|
||||
|
||||
class TransactionCategoyResource(resources.ModelResource):
|
||||
class Meta:
|
||||
model = TransactionCategory
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionCategory.all_objects.all()
|
||||
|
||||
|
||||
class RecurringTransactionResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
attribute="category",
|
||||
column_name="category",
|
||||
widget=AutoCreateForeignKeyWidget(TransactionCategory, "name"),
|
||||
)
|
||||
|
||||
tags = fields.Field(
|
||||
attribute="tags",
|
||||
column_name="tags",
|
||||
widget=AutoCreateManyToManyWidget(TransactionTag, field="name"),
|
||||
)
|
||||
|
||||
entities = fields.Field(
|
||||
attribute="entities",
|
||||
column_name="entities",
|
||||
widget=AutoCreateManyToManyWidget(TransactionEntity, field="name"),
|
||||
)
|
||||
|
||||
amount = fields.Field(
|
||||
attribute="amount",
|
||||
column_name="amount",
|
||||
widget=UniversalDecimalWidget(),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = RecurringTransaction
|
||||
|
||||
def get_queryset(self):
|
||||
return RecurringTransaction.all_objects.all()
|
||||
|
||||
|
||||
class InstallmentPlanResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
attribute="category",
|
||||
column_name="category",
|
||||
widget=AutoCreateForeignKeyWidget(TransactionCategory, "name"),
|
||||
)
|
||||
|
||||
tags = fields.Field(
|
||||
attribute="tags",
|
||||
column_name="tags",
|
||||
widget=AutoCreateManyToManyWidget(TransactionTag, field="name"),
|
||||
)
|
||||
|
||||
entities = fields.Field(
|
||||
attribute="entities",
|
||||
column_name="entities",
|
||||
widget=AutoCreateManyToManyWidget(TransactionEntity, field="name"),
|
||||
)
|
||||
|
||||
installment_amount = fields.Field(
|
||||
attribute="installment_amount",
|
||||
column_name="installment_amount",
|
||||
widget=UniversalDecimalWidget(),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = InstallmentPlan
|
||||
|
||||
def get_queryset(self):
|
||||
return InstallmentPlan.all_objects.all()
|
||||
161
app/apps/export_app/resources/users.py
Normal file
161
app/apps/export_app/resources/users.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from import_export import resources, fields
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
from apps.users.models import UserSettings
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class UserResource(resources.ModelResource):
|
||||
# User fields
|
||||
email = fields.Field(attribute="email", column_name="Email")
|
||||
|
||||
# UserSettings fields - for export only
|
||||
hide_amounts = fields.Field(
|
||||
attribute="settings__hide_amounts", column_name="Hide Amounts", readonly=True
|
||||
)
|
||||
mute_sounds = fields.Field(
|
||||
attribute="settings__mute_sounds", column_name="Mute Sounds", readonly=True
|
||||
)
|
||||
date_format = fields.Field(
|
||||
attribute="settings__date_format", column_name="Date Format", readonly=True
|
||||
)
|
||||
datetime_format = fields.Field(
|
||||
attribute="settings__datetime_format",
|
||||
column_name="Datetime Format",
|
||||
readonly=True,
|
||||
)
|
||||
number_format = fields.Field(
|
||||
attribute="settings__number_format", column_name="Number Format", readonly=True
|
||||
)
|
||||
language = fields.Field(
|
||||
attribute="settings__language", column_name="Language", readonly=True
|
||||
)
|
||||
timezone = fields.Field(
|
||||
attribute="settings__timezone", column_name="Timezone", readonly=True
|
||||
)
|
||||
start_page = fields.Field(
|
||||
attribute="settings__start_page", column_name="Start Page", readonly=True
|
||||
)
|
||||
|
||||
# Human-readable fields for choice values
|
||||
start_page_display = fields.Field(column_name="Start Page Display", readonly=True)
|
||||
language_display = fields.Field(column_name="Language Display", readonly=True)
|
||||
timezone_display = fields.Field(column_name="Timezone Display", readonly=True)
|
||||
|
||||
@staticmethod
|
||||
def dehydrate_start_page_display(user):
|
||||
if hasattr(user, "settings"):
|
||||
return dict(UserSettings.StartPage.choices).get(
|
||||
user.settings.start_page, ""
|
||||
)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def dehydrate_language_display(user):
|
||||
if hasattr(user, "settings"):
|
||||
languages = dict([("auto", "Auto")] + list(settings.LANGUAGES))
|
||||
return languages.get(user.settings.language, user.settings.language)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def dehydrate_timezone_display(user):
|
||||
if hasattr(user, "settings"):
|
||||
if user.settings.timezone == "auto":
|
||||
return "Auto"
|
||||
return user.settings.timezone
|
||||
return ""
|
||||
|
||||
def after_init_instance(self, instance, new, row, **kwargs):
|
||||
"""
|
||||
Store settings data on the instance to be used after save
|
||||
"""
|
||||
# Process boolean fields properly
|
||||
hide_amounts = row.get("Hide Amounts", "").lower() == "true"
|
||||
mute_sounds = row.get("Mute Sounds", "").lower() == "true"
|
||||
|
||||
# Store settings data on the instance for later use
|
||||
instance._settings_data = {
|
||||
"hide_amounts": hide_amounts,
|
||||
"mute_sounds": mute_sounds,
|
||||
"date_format": row.get("Date Format", "SHORT_DATE_FORMAT"),
|
||||
"datetime_format": row.get("Datetime Format", "SHORT_DATETIME_FORMAT"),
|
||||
"number_format": row.get("Number Format", "AA"),
|
||||
"language": row.get("Language", "auto"),
|
||||
"timezone": row.get("Timezone", "auto"),
|
||||
"start_page": row.get("Start Page", UserSettings.StartPage.MONTHLY),
|
||||
}
|
||||
|
||||
return instance
|
||||
|
||||
def after_save_instance(self, instance, row, **kwargs):
|
||||
"""
|
||||
Create or update UserSettings after User is saved
|
||||
"""
|
||||
if not hasattr(instance, "_settings_data"):
|
||||
return
|
||||
|
||||
settings_data = instance._settings_data
|
||||
|
||||
# Create or update UserSettings
|
||||
try:
|
||||
user_settings = UserSettings.objects.get(user=instance)
|
||||
# Update existing settings
|
||||
for key, value in settings_data.items():
|
||||
setattr(user_settings, key, value)
|
||||
user_settings.save()
|
||||
except UserSettings.DoesNotExist:
|
||||
# Create new settings
|
||||
UserSettings.objects.create(user=instance, **settings_data)
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
Ensure settings are prefetched when exporting users
|
||||
"""
|
||||
return super().get_queryset().select_related("settings")
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
import_id_fields = ["id"]
|
||||
fields = (
|
||||
"id",
|
||||
"email",
|
||||
"first_name",
|
||||
"last_name",
|
||||
"is_active",
|
||||
"date_joined",
|
||||
"password",
|
||||
"hide_amounts",
|
||||
"mute_sounds",
|
||||
"date_format",
|
||||
"datetime_format",
|
||||
"number_format",
|
||||
"language",
|
||||
"language_display",
|
||||
"timezone",
|
||||
"timezone_display",
|
||||
"start_page",
|
||||
"start_page_display",
|
||||
)
|
||||
export_order = (
|
||||
"id",
|
||||
"email",
|
||||
"first_name",
|
||||
"last_name",
|
||||
"is_active",
|
||||
"date_joined",
|
||||
"password",
|
||||
"hide_amounts",
|
||||
"mute_sounds",
|
||||
"date_format",
|
||||
"datetime_format",
|
||||
"number_format",
|
||||
"language",
|
||||
"language_display",
|
||||
"timezone",
|
||||
"timezone_display",
|
||||
"start_page",
|
||||
"start_page_display",
|
||||
)
|
||||
164
app/apps/export_app/tests.py
Normal file
164
app/apps/export_app/tests.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from django.test import TestCase, Client
|
||||
from django.contrib.auth.models import User
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from unittest.mock import patch, MagicMock
|
||||
from io import BytesIO
|
||||
import zipfile # Added for zip file creation
|
||||
from django.core.files.uploadedfile import InMemoryUploadedFile # Added for file upload testing
|
||||
|
||||
# Dataset from tablib is not directly imported, its behavior will be mocked.
|
||||
# Resource classes are also mocked by path string.
|
||||
|
||||
from apps.export_app.forms import ExportForm, RestoreForm # Added RestoreForm
|
||||
|
||||
|
||||
class ExportAppTests(TestCase):
|
||||
def setUp(self):
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username='super',
|
||||
email='super@example.com',
|
||||
password='password'
|
||||
)
|
||||
self.client = Client()
|
||||
self.client.login(username='super', password='password')
|
||||
|
||||
@patch('apps.export_app.views.UserResource')
|
||||
def test_export_form_single_selection_csv_response(self, mock_UserResource):
|
||||
# Configure the mock UserResource
|
||||
mock_user_resource_instance = mock_UserResource.return_value
|
||||
|
||||
# Mock the export() method's return value (which is a Dataset object)
|
||||
# Then, mock the 'csv' attribute of this Dataset object
|
||||
mock_dataset = MagicMock() # Using MagicMock for the dataset
|
||||
mock_dataset.csv = "user_id,username\n1,testuser"
|
||||
mock_user_resource_instance.export.return_value = mock_dataset
|
||||
|
||||
post_data = {'users': True} # Other fields default to False or their initial values
|
||||
|
||||
response = self.client.post(reverse('export_app:export_form'), data=post_data)
|
||||
|
||||
mock_user_resource_instance.export.assert_called_once()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response['Content-Type'], 'text/csv')
|
||||
self.assertIn("attachment; filename=", response['Content-Disposition'])
|
||||
self.assertIn(".csv", response['Content-Disposition'])
|
||||
# Check if the filename contains 'users'
|
||||
self.assertIn("users_export_", response['Content-Disposition'].lower())
|
||||
self.assertEqual(response.content.decode(), "user_id,username\n1,testuser")
|
||||
|
||||
@patch('apps.export_app.views.AccountResource') # Mock AccountResource first
|
||||
@patch('apps.export_app.views.UserResource') # Then UserResource
|
||||
def test_export_form_multiple_selections_zip_response(self, mock_UserResource, mock_AccountResource):
|
||||
# Configure UserResource mock
|
||||
mock_user_instance = mock_UserResource.return_value
|
||||
mock_user_dataset = MagicMock()
|
||||
mock_user_dataset.csv = "user_data_here"
|
||||
mock_user_instance.export.return_value = mock_user_dataset
|
||||
|
||||
# Configure AccountResource mock
|
||||
mock_account_instance = mock_AccountResource.return_value
|
||||
mock_account_dataset = MagicMock()
|
||||
mock_account_dataset.csv = "account_data_here"
|
||||
mock_account_instance.export.return_value = mock_account_dataset
|
||||
|
||||
post_data = {
|
||||
'users': True,
|
||||
'accounts': True
|
||||
# other fields default to False or their initial values
|
||||
}
|
||||
|
||||
response = self.client.post(reverse('export_app:export_form'), data=post_data)
|
||||
|
||||
mock_user_instance.export.assert_called_once()
|
||||
mock_account_instance.export.assert_called_once()
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response['Content-Type'], 'application/zip')
|
||||
self.assertIn("attachment; filename=", response['Content-Disposition'])
|
||||
self.assertIn(".zip", response['Content-Disposition'])
|
||||
# Add zip file content check if possible and required later
|
||||
|
||||
def test_export_form_no_selection(self):
|
||||
# Get all field names from ExportForm and set them to False
|
||||
# This ensures that if new export options are added, this test still tries to unselect them.
|
||||
form_fields = ExportForm.base_fields.keys()
|
||||
post_data = {field: False for field in form_fields}
|
||||
|
||||
response = self.client.post(reverse('export_app:export_form'), data=post_data)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
# The expected message is "You have to select at least one export"
|
||||
# This message is translatable, so using _() for comparison if the view returns translated string.
|
||||
# The view returns HttpResponse(_("You have to select at least one export"))
|
||||
self.assertEqual(response.content.decode('utf-8'), _("You have to select at least one export"))
|
||||
|
||||
# Placeholder for zip content check, if to be implemented
|
||||
# import zipfile
|
||||
# def test_zip_contents(self):
|
||||
# # ... (setup response with zip data) ...
|
||||
# with zipfile.ZipFile(BytesIO(response.content), 'r') as zipf:
|
||||
# self.assertIn('users.csv', zipf.namelist())
|
||||
# self.assertIn('accounts.csv', zipf.namelist())
|
||||
# user_csv_content = zipf.read('users.csv').decode()
|
||||
# self.assertEqual(user_csv_content, "user_data_here")
|
||||
# account_csv_content = zipf.read('accounts.csv').decode()
|
||||
# self.assertEqual(account_csv_content, "account_data_here")
|
||||
|
||||
@patch('apps.export_app.views.process_imports')
|
||||
def test_import_form_valid_zip_calls_process_imports(self, mock_process_imports):
|
||||
# Create a mock ZIP file content
|
||||
zip_content_buffer = BytesIO()
|
||||
with zipfile.ZipFile(zip_content_buffer, 'w') as zf:
|
||||
zf.writestr('dummy.csv', 'some,data')
|
||||
zip_content_buffer.seek(0)
|
||||
|
||||
# Create an InMemoryUploadedFile instance
|
||||
mock_zip_file = InMemoryUploadedFile(
|
||||
zip_content_buffer,
|
||||
'zip_file', # field_name
|
||||
'test_export.zip', # file_name
|
||||
'application/zip', # content_type
|
||||
zip_content_buffer.getbuffer().nbytes, # size
|
||||
None # charset
|
||||
)
|
||||
|
||||
post_data = {'zip_file': mock_zip_file}
|
||||
url = reverse('export_app:restore_form')
|
||||
|
||||
response = self.client.post(url, data=post_data, format='multipart')
|
||||
|
||||
mock_process_imports.assert_called_once()
|
||||
# Check the second argument passed to process_imports (the form's cleaned_data['zip_file'])
|
||||
# The first argument (args[0]) is the request object.
|
||||
# The second argument (args[1]) is the form instance.
|
||||
# We need to check the 'zip_file' attribute of the cleaned_data of the form instance.
|
||||
# However, it's simpler to check the UploadedFile object directly if that's what process_imports receives.
|
||||
# Based on the task: "The second argument to process_imports is form.cleaned_data['zip_file']"
|
||||
# This means that process_imports is called as process_imports(request, form.cleaned_data['zip_file'], ...)
|
||||
# Let's assume process_imports signature is process_imports(request, file_obj, ...)
|
||||
# So, call_args[0][1] would be the file_obj.
|
||||
|
||||
# Actually, the view calls process_imports(request, form)
|
||||
# So, we check form.cleaned_data['zip_file'] on the passed form instance
|
||||
called_form_instance = mock_process_imports.call_args[0][1] # The form instance
|
||||
self.assertEqual(called_form_instance.cleaned_data['zip_file'], mock_zip_file)
|
||||
|
||||
self.assertEqual(response.status_code, 204)
|
||||
# The HX-Trigger header might have multiple values, ensure both are present
|
||||
self.assertIn("hide_offcanvas", response.headers['HX-Trigger'])
|
||||
self.assertIn("updated", response.headers['HX-Trigger'])
|
||||
|
||||
|
||||
def test_import_form_no_file_selected(self):
|
||||
post_data = {} # No file selected
|
||||
url = reverse('export_app:restore_form')
|
||||
|
||||
response = self.client.post(url, data=post_data)
|
||||
|
||||
self.assertEqual(response.status_code, 200) # Form re-rendered with errors
|
||||
# Check that the specific error message from RestoreForm.clean() is present
|
||||
expected_error_message = _("Please upload either a ZIP file or at least one CSV file")
|
||||
self.assertContains(response, expected_error_message)
|
||||
# Also check for the HX-Trigger which is always set
|
||||
self.assertIn("updated", response.headers['HX-Trigger'])
|
||||
8
app/apps/export_app/urls.py
Normal file
8
app/apps/export_app/urls.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from django.urls import path
|
||||
import apps.export_app.views as views
|
||||
|
||||
urlpatterns = [
|
||||
path("export/", views.export_index, name="export_index"),
|
||||
path("export/form/", views.export_form, name="export_form"),
|
||||
path("export/restore/", views.import_form, name="restore_form"),
|
||||
]
|
||||
296
app/apps/export_app/views.py
Normal file
296
app/apps/export_app/views.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import logging
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth.decorators import login_required, user_passes_test
|
||||
from django.db import transaction
|
||||
from django.http import HttpResponse
|
||||
from django.shortcuts import render
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.views.decorators.http import require_http_methods
|
||||
from tablib import Dataset
|
||||
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
from apps.export_app.forms import ExportForm, RestoreForm
|
||||
from apps.export_app.resources.accounts import AccountResource
|
||||
from apps.export_app.resources.currencies import (
|
||||
CurrencyResource,
|
||||
ExchangeRateResource,
|
||||
ExchangeRateServiceResource,
|
||||
)
|
||||
from apps.export_app.resources.dca import (
|
||||
DCAStrategyResource,
|
||||
DCAEntryResource,
|
||||
)
|
||||
from apps.export_app.resources.import_app import (
|
||||
ImportProfileResource,
|
||||
)
|
||||
from apps.export_app.resources.rules import (
|
||||
TransactionRuleResource,
|
||||
TransactionRuleActionResource,
|
||||
UpdateOrCreateTransactionRuleResource,
|
||||
)
|
||||
from apps.export_app.resources.transactions import (
|
||||
TransactionResource,
|
||||
TransactionTagResource,
|
||||
TransactionEntityResource,
|
||||
TransactionCategoyResource,
|
||||
InstallmentPlanResource,
|
||||
RecurringTransactionResource,
|
||||
)
|
||||
from apps.export_app.resources.users import UserResource
|
||||
from apps.common.decorators.demo import disabled_on_demo
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@user_passes_test(lambda u: u.is_superuser)
|
||||
@require_http_methods(["GET"])
|
||||
def export_index(request):
|
||||
return render(request, "export_app/pages/index.html")
|
||||
|
||||
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@user_passes_test(lambda u: u.is_superuser)
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def export_form(request):
|
||||
timestamp = timezone.localtime(timezone.now()).strftime("%Y-%m-%dT%H-%M-%S")
|
||||
|
||||
if request.method == "POST":
|
||||
form = ExportForm(request.POST)
|
||||
if form.is_valid():
|
||||
zip_buffer = BytesIO()
|
||||
|
||||
export_users = form.cleaned_data.get("users", False)
|
||||
export_accounts = form.cleaned_data.get("accounts", False)
|
||||
export_currencies = form.cleaned_data.get("currencies", False)
|
||||
export_transactions = form.cleaned_data.get("transactions", False)
|
||||
export_categories = form.cleaned_data.get("categories", False)
|
||||
export_tags = form.cleaned_data.get("tags", False)
|
||||
export_entities = form.cleaned_data.get("entities", False)
|
||||
export_installment_plans = form.cleaned_data.get("installment_plans", False)
|
||||
export_recurring_transactions = form.cleaned_data.get(
|
||||
"recurring_transactions", False
|
||||
)
|
||||
|
||||
export_exchange_rates_services = form.cleaned_data.get(
|
||||
"exchange_rates_services", False
|
||||
)
|
||||
export_exchange_rates = form.cleaned_data.get("exchange_rates", False)
|
||||
export_rules = form.cleaned_data.get("rules", False)
|
||||
export_dca = form.cleaned_data.get("dca", False)
|
||||
export_import_profiles = form.cleaned_data.get("import_profiles", False)
|
||||
|
||||
exports = []
|
||||
if export_users:
|
||||
exports.append((UserResource().export(), "users"))
|
||||
if export_accounts:
|
||||
exports.append((AccountResource().export(), "accounts"))
|
||||
if export_currencies:
|
||||
exports.append((CurrencyResource().export(), "currencies"))
|
||||
if export_transactions:
|
||||
exports.append((TransactionResource().export(), "transactions"))
|
||||
if export_categories:
|
||||
exports.append(
|
||||
(TransactionCategoyResource().export(), "transactions_categories")
|
||||
)
|
||||
if export_tags:
|
||||
exports.append((TransactionTagResource().export(), "transactions_tags"))
|
||||
if export_entities:
|
||||
exports.append(
|
||||
(TransactionEntityResource().export(), "transactions_entities")
|
||||
)
|
||||
if export_installment_plans:
|
||||
exports.append(
|
||||
(InstallmentPlanResource().export(), "installment_plans")
|
||||
)
|
||||
if export_recurring_transactions:
|
||||
exports.append(
|
||||
(RecurringTransactionResource().export(), "recurring_transactions")
|
||||
)
|
||||
if export_exchange_rates_services:
|
||||
exports.append(
|
||||
(ExchangeRateServiceResource().export(), "automatic_exchange_rates")
|
||||
)
|
||||
if export_exchange_rates:
|
||||
exports.append((ExchangeRateResource().export(), "exchange_rates"))
|
||||
if export_rules:
|
||||
exports.append(
|
||||
(TransactionRuleResource().export(), "transaction_rules")
|
||||
)
|
||||
exports.append(
|
||||
(
|
||||
TransactionRuleActionResource().export(),
|
||||
"transaction_rules_actions",
|
||||
)
|
||||
)
|
||||
exports.append(
|
||||
(
|
||||
UpdateOrCreateTransactionRuleResource().export(),
|
||||
"transaction_rules_update_or_create",
|
||||
)
|
||||
)
|
||||
if export_dca:
|
||||
exports.append((DCAStrategyResource().export(), "dca_strategies"))
|
||||
exports.append(
|
||||
(
|
||||
DCAEntryResource().export(),
|
||||
"dca_entries",
|
||||
)
|
||||
)
|
||||
if export_import_profiles:
|
||||
exports.append((ImportProfileResource().export(), "import_profiles"))
|
||||
|
||||
if len(exports) >= 2:
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
|
||||
for dataset, name in exports:
|
||||
zip_file.writestr(f"{name}.csv", dataset.csv)
|
||||
|
||||
response = HttpResponse(
|
||||
zip_buffer.getvalue(),
|
||||
content_type="application/zip",
|
||||
headers={
|
||||
"HX-Trigger": "hide_offcanvas, updated",
|
||||
"Content-Disposition": f'attachment; filename="{timestamp}_WYGIWYH_export.zip"',
|
||||
},
|
||||
)
|
||||
return response
|
||||
elif len(exports) == 1:
|
||||
dataset, name = exports[0]
|
||||
|
||||
response = HttpResponse(
|
||||
dataset.csv,
|
||||
content_type="text/csv",
|
||||
headers={
|
||||
"HX-Trigger": "hide_offcanvas, updated",
|
||||
"Content-Disposition": f'attachment; filename="{timestamp}_WYGIWYH_export_{name}.csv"',
|
||||
},
|
||||
)
|
||||
return response
|
||||
else:
|
||||
return HttpResponse(
|
||||
_("You have to select at least one export"),
|
||||
)
|
||||
|
||||
else:
|
||||
form = ExportForm()
|
||||
|
||||
return render(request, "export_app/fragments/export.html", context={"form": form})
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@user_passes_test(lambda u: u.is_superuser)
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def import_form(request):
|
||||
if request.method == "POST":
|
||||
form = RestoreForm(request.POST, request.FILES)
|
||||
if form.is_valid():
|
||||
try:
|
||||
process_imports(request, form.cleaned_data)
|
||||
messages.success(request, _("Data restored successfully"))
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "hide_offcanvas, updated",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error importing", exc_info=e)
|
||||
messages.error(
|
||||
request,
|
||||
_(
|
||||
"There was an error restoring your data. Check the logs for more details."
|
||||
),
|
||||
)
|
||||
else:
|
||||
form = RestoreForm()
|
||||
|
||||
response = render(request, "export_app/fragments/restore.html", {"form": form})
|
||||
response["HX-Trigger"] = "updated"
|
||||
return response
|
||||
|
||||
|
||||
def process_imports(request, cleaned_data):
|
||||
# Define import order to handle dependencies
|
||||
import_order = [
|
||||
("users", UserResource),
|
||||
("currencies", CurrencyResource),
|
||||
(
|
||||
"currencies",
|
||||
CurrencyResource,
|
||||
), # We do a double pass because exchange_currency may not exist when currency is initially created
|
||||
("accounts", AccountResource),
|
||||
("transactions_categories", TransactionCategoyResource),
|
||||
("transactions_tags", TransactionTagResource),
|
||||
("transactions_entities", TransactionEntityResource),
|
||||
("automatic_exchange_rates", ExchangeRateServiceResource),
|
||||
("exchange_rates", ExchangeRateResource),
|
||||
("installment_plans", InstallmentPlanResource),
|
||||
("recurring_transactions", RecurringTransactionResource),
|
||||
("transactions", TransactionResource),
|
||||
("dca_strategies", DCAStrategyResource),
|
||||
("dca_entries", DCAEntryResource),
|
||||
("import_profiles", ImportProfileResource),
|
||||
("transaction_rules", TransactionRuleResource),
|
||||
("transaction_rules_actions", TransactionRuleActionResource),
|
||||
("transaction_rules_update_or_create", UpdateOrCreateTransactionRuleResource),
|
||||
]
|
||||
|
||||
def import_dataset(content, resource_class, field_name):
|
||||
try:
|
||||
# Create a new resource instance
|
||||
resource = resource_class()
|
||||
|
||||
# Create dataset from CSV content
|
||||
dataset = Dataset()
|
||||
dataset.load(content, format="csv")
|
||||
|
||||
# Perform the import
|
||||
result = resource.import_data(
|
||||
dataset,
|
||||
dry_run=False,
|
||||
raise_errors=True,
|
||||
collect_failed_rows=True,
|
||||
use_transactions=False,
|
||||
skip_unchanged=True,
|
||||
)
|
||||
|
||||
if result.has_errors():
|
||||
raise ImportError(f"Failed rows: {result.failed_dataset}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing {field_name}: {str(e)}")
|
||||
raise ImportError(f"Error importing {field_name}: {str(e)}")
|
||||
|
||||
with transaction.atomic():
|
||||
files = {}
|
||||
|
||||
if zip_file := cleaned_data.get("zip_file"):
|
||||
# Process ZIP file
|
||||
with zipfile.ZipFile(zip_file) as z:
|
||||
for filename in z.namelist():
|
||||
name = filename.replace(".csv", "")
|
||||
with z.open(filename) as f:
|
||||
content = f.read().decode("utf-8")
|
||||
|
||||
files[name] = content
|
||||
|
||||
for field_name, resource_class in import_order:
|
||||
if field_name in files.keys():
|
||||
content = files[field_name]
|
||||
import_dataset(content, resource_class, field_name)
|
||||
else:
|
||||
# Process individual files
|
||||
for field_name, resource_class in import_order:
|
||||
if csv_file := cleaned_data.get(field_name):
|
||||
content = csv_file.read().decode("utf-8")
|
||||
import_dataset(content, resource_class, field_name)
|
||||
0
app/apps/export_app/widgets/__init__.py
Normal file
0
app/apps/export_app/widgets/__init__.py
Normal file
22
app/apps/export_app/widgets/foreign_key.py
Normal file
22
app/apps/export_app/widgets/foreign_key.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from import_export.widgets import ForeignKeyWidget
|
||||
|
||||
|
||||
class AutoCreateForeignKeyWidget(ForeignKeyWidget):
|
||||
def clean(self, value, row=None, *args, **kwargs):
|
||||
if value:
|
||||
try:
|
||||
return super().clean(value, row, **kwargs)
|
||||
except self.model.DoesNotExist:
|
||||
return self.model.objects.create(name=value)
|
||||
return None
|
||||
|
||||
|
||||
class SkipMissingForeignKeyWidget(ForeignKeyWidget):
|
||||
def clean(self, value, row=None, *args, **kwargs):
|
||||
if not value:
|
||||
return None
|
||||
|
||||
try:
|
||||
return super().clean(value, row, *args, **kwargs)
|
||||
except self.model.DoesNotExist:
|
||||
return None
|
||||
21
app/apps/export_app/widgets/many_to_many.py
Normal file
21
app/apps/export_app/widgets/many_to_many.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from import_export.widgets import ManyToManyWidget
|
||||
|
||||
|
||||
class AutoCreateManyToManyWidget(ManyToManyWidget):
|
||||
def clean(self, value, row=None, *args, **kwargs):
|
||||
if not value:
|
||||
return []
|
||||
|
||||
values = value.split(self.separator)
|
||||
cleaned_values = []
|
||||
|
||||
for val in values:
|
||||
val = val.strip()
|
||||
if val:
|
||||
try:
|
||||
obj = self.model.objects.get(**{self.field: val})
|
||||
except self.model.DoesNotExist:
|
||||
obj = self.model.objects.create(name=val)
|
||||
cleaned_values.append(obj)
|
||||
|
||||
return cleaned_values
|
||||
18
app/apps/export_app/widgets/numbers.py
Normal file
18
app/apps/export_app/widgets/numbers.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from import_export.widgets import NumberWidget
|
||||
|
||||
|
||||
class UniversalDecimalWidget(NumberWidget):
|
||||
def clean(self, value, row=None, *args, **kwargs):
|
||||
if self.is_empty(value):
|
||||
return None
|
||||
# Replace comma with dot if present
|
||||
if isinstance(value, str):
|
||||
value = value.replace(",", ".")
|
||||
return Decimal(str(value))
|
||||
|
||||
def render(self, value, obj=None, **kwargs):
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).replace(",", ".")
|
||||
7
app/apps/export_app/widgets/string.py
Normal file
7
app/apps/export_app/widgets/string.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from import_export import fields
|
||||
|
||||
|
||||
class EmptyStringToNoneField(fields.Field):
|
||||
def clean(self, data, **kwargs):
|
||||
value = super().clean(data)
|
||||
return None if value == "" else value
|
||||
@@ -47,6 +47,34 @@ class SplitTransformationRule(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class AddTransformationRule(BaseModel):
|
||||
type: Literal["add"]
|
||||
field: str = Field(..., description="Field to add to the source value")
|
||||
absolute_values: bool = Field(
|
||||
default=False, description="Use absolute values for addition"
|
||||
)
|
||||
thousand_separator: str = Field(
|
||||
default="", description="Thousand separator character"
|
||||
)
|
||||
decimal_separator: str = Field(
|
||||
default=".", description="Decimal separator character"
|
||||
)
|
||||
|
||||
|
||||
class SubtractTransformationRule(BaseModel):
|
||||
type: Literal["subtract"]
|
||||
field: str = Field(..., description="Field to subtract from the source value")
|
||||
absolute_values: bool = Field(
|
||||
default=False, description="Use absolute values for subtraction"
|
||||
)
|
||||
thousand_separator: str = Field(
|
||||
default="", description="Thousand separator character"
|
||||
)
|
||||
decimal_separator: str = Field(
|
||||
default=".", description="Decimal separator character"
|
||||
)
|
||||
|
||||
|
||||
class CSVImportSettings(BaseModel):
|
||||
skip_errors: bool = Field(
|
||||
default=False,
|
||||
@@ -64,6 +92,20 @@ class CSVImportSettings(BaseModel):
|
||||
]
|
||||
|
||||
|
||||
class ExcelImportSettings(BaseModel):
|
||||
skip_errors: bool = Field(
|
||||
default=False,
|
||||
description="If True, errors during import will be logged and skipped",
|
||||
)
|
||||
file_type: Literal["xls", "xlsx"]
|
||||
trigger_transaction_rules: bool = True
|
||||
importing: Literal[
|
||||
"transactions", "accounts", "currencies", "categories", "tags", "entities"
|
||||
]
|
||||
start_row: int = Field(default=1, description="Where your header is located")
|
||||
sheets: list[str] | str = "*"
|
||||
|
||||
|
||||
class ColumnMapping(BaseModel):
|
||||
source: Optional[str] | Optional[list[str]] = Field(
|
||||
default=None,
|
||||
@@ -78,6 +120,8 @@ class ColumnMapping(BaseModel):
|
||||
| HashTransformationRule
|
||||
| MergeTransformationRule
|
||||
| SplitTransformationRule
|
||||
| AddTransformationRule
|
||||
| SubtractTransformationRule
|
||||
]
|
||||
] = Field(default_factory=list)
|
||||
|
||||
@@ -86,7 +130,6 @@ class TransactionAccountMapping(ColumnMapping):
|
||||
target: Literal["account"] = Field(..., description="Transaction field to map to")
|
||||
type: Literal["id", "name"] = "name"
|
||||
coerce_to: Literal["str|int"] = Field("str|int", frozen=True)
|
||||
required: bool = Field(True, frozen=True)
|
||||
|
||||
|
||||
class TransactionTypeMapping(ColumnMapping):
|
||||
@@ -105,7 +148,6 @@ class TransactionDateMapping(ColumnMapping):
|
||||
target: Literal["date"] = Field(..., description="Transaction field to map to")
|
||||
format: List[str] | str
|
||||
coerce_to: Literal["date"] = Field("date", frozen=True)
|
||||
required: bool = Field(True, frozen=True)
|
||||
|
||||
|
||||
class TransactionReferenceDateMapping(ColumnMapping):
|
||||
@@ -119,7 +161,6 @@ class TransactionReferenceDateMapping(ColumnMapping):
|
||||
class TransactionAmountMapping(ColumnMapping):
|
||||
target: Literal["amount"] = Field(..., description="Transaction field to map to")
|
||||
coerce_to: Literal["positive_decimal"] = Field("positive_decimal", frozen=True)
|
||||
required: bool = Field(True, frozen=True)
|
||||
|
||||
|
||||
class TransactionDescriptionMapping(ColumnMapping):
|
||||
@@ -301,7 +342,7 @@ class CurrencyExchangeMapping(ColumnMapping):
|
||||
|
||||
|
||||
class ImportProfileSchema(BaseModel):
|
||||
settings: CSVImportSettings
|
||||
settings: CSVImportSettings | ExcelImportSettings
|
||||
mapping: Dict[
|
||||
str,
|
||||
TransactionAccountMapping
|
||||
|
||||
@@ -3,14 +3,16 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Dict, Any, Literal, Union
|
||||
|
||||
import cachalot.api
|
||||
import openpyxl
|
||||
import xlrd
|
||||
import yaml
|
||||
from cachalot.api import cachalot_disabled
|
||||
from django.utils import timezone
|
||||
from openpyxl.utils.exceptions import InvalidFileException
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
@@ -40,7 +42,9 @@ class ImportService:
|
||||
self.import_run: ImportRun = import_run
|
||||
self.profile: ImportProfile = import_run.profile
|
||||
self.config: version_1.ImportProfileSchema = self._load_config()
|
||||
self.settings: version_1.CSVImportSettings = self.config.settings
|
||||
self.settings: version_1.CSVImportSettings | version_1.ExcelImportSettings = (
|
||||
self.config.settings
|
||||
)
|
||||
self.deduplication: list[version_1.CompareDeduplicationRule] = (
|
||||
self.config.deduplication
|
||||
)
|
||||
@@ -75,6 +79,13 @@ class ImportService:
|
||||
self.import_run.logs += log_line
|
||||
self.import_run.save(update_fields=["logs"])
|
||||
|
||||
if level == "info":
|
||||
logger.info(log_line)
|
||||
elif level == "warning":
|
||||
logger.warning(log_line)
|
||||
elif level == "error":
|
||||
logger.error(log_line, exc_info=True)
|
||||
|
||||
def _update_totals(
|
||||
self,
|
||||
field: Literal["total", "processed", "successful", "skipped", "failed"],
|
||||
@@ -129,9 +140,12 @@ class ImportService:
|
||||
|
||||
self.import_run.save(update_fields=["status"])
|
||||
|
||||
@staticmethod
|
||||
def _transform_value(
|
||||
value: str, mapping: version_1.ColumnMapping, row: Dict[str, str] = None
|
||||
self,
|
||||
value: str,
|
||||
mapping: version_1.ColumnMapping,
|
||||
row: Dict[str, str] = None,
|
||||
mapped_data: Dict[str, Any] = None,
|
||||
) -> Any:
|
||||
transformed = value
|
||||
|
||||
@@ -142,8 +156,12 @@ class ImportService:
|
||||
for field in transform.fields:
|
||||
if field in row:
|
||||
values_to_hash.append(str(row[field]))
|
||||
|
||||
# Create hash from concatenated values
|
||||
elif (
|
||||
field.startswith("__")
|
||||
and mapped_data
|
||||
and field[2:] in mapped_data
|
||||
):
|
||||
values_to_hash.append(str(mapped_data[field[2:]]))
|
||||
if values_to_hash:
|
||||
concatenated = "|".join(values_to_hash)
|
||||
transformed = hashlib.sha256(concatenated.encode()).hexdigest()
|
||||
@@ -157,6 +175,7 @@ class ImportService:
|
||||
transformed = transformed.replace(
|
||||
transform.pattern, transform.replacement
|
||||
)
|
||||
|
||||
elif transform.type == "regex":
|
||||
if transform.exclusive:
|
||||
transformed = re.sub(
|
||||
@@ -166,16 +185,25 @@ class ImportService:
|
||||
transformed = re.sub(
|
||||
transform.pattern, transform.replacement, transformed
|
||||
)
|
||||
|
||||
elif transform.type == "date_format":
|
||||
transformed = datetime.strptime(
|
||||
transformed, transform.original_format
|
||||
).strftime(transform.new_format)
|
||||
|
||||
elif transform.type == "merge":
|
||||
values_to_merge = []
|
||||
for field in transform.fields:
|
||||
if field in row:
|
||||
values_to_merge.append(str(row[field]))
|
||||
elif (
|
||||
field.startswith("__")
|
||||
and mapped_data
|
||||
and field[2:] in mapped_data
|
||||
):
|
||||
values_to_merge.append(str(mapped_data[field[2:]]))
|
||||
transformed = transform.separator.join(values_to_merge)
|
||||
|
||||
elif transform.type == "split":
|
||||
parts = transformed.split(transform.separator)
|
||||
if transform.index is not None:
|
||||
@@ -183,6 +211,38 @@ class ImportService:
|
||||
else:
|
||||
transformed = parts
|
||||
|
||||
elif transform.type in ["add", "subtract"]:
|
||||
try:
|
||||
source_value = Decimal(transformed)
|
||||
|
||||
# First check row data, then mapped data if not found
|
||||
field_value = row.get(transform.field)
|
||||
if field_value is None and transform.field.startswith("__"):
|
||||
field_value = mapped_data.get(transform.field[2:])
|
||||
|
||||
if field_value is None:
|
||||
raise KeyError(
|
||||
f"Field '{transform.field}' not found in row or mapped data"
|
||||
)
|
||||
|
||||
field_value = self._prepare_numeric_value(
|
||||
str(field_value),
|
||||
transform.thousand_separator,
|
||||
transform.decimal_separator,
|
||||
)
|
||||
|
||||
if transform.absolute_values:
|
||||
source_value = abs(source_value)
|
||||
field_value = abs(field_value)
|
||||
|
||||
if transform.type == "add":
|
||||
transformed = str(source_value + field_value)
|
||||
else: # subtract
|
||||
transformed = str(source_value - field_value)
|
||||
except (InvalidOperation, KeyError, AttributeError) as e:
|
||||
logger.warning(
|
||||
f"Error in {transform.type} transformation: {e}. Values: {transformed}, {transform.field}"
|
||||
)
|
||||
return transformed
|
||||
|
||||
def _create_transaction(self, data: Dict[str, Any]) -> Transaction:
|
||||
@@ -208,14 +268,17 @@ class ImportService:
|
||||
category = TransactionCategory.objects.get(id=category_name)
|
||||
else: # name
|
||||
if getattr(category_mapping, "create", False):
|
||||
category, _ = TransactionCategory.objects.get_or_create(
|
||||
name=category_name
|
||||
)
|
||||
try:
|
||||
category = TransactionCategory.objects.get(
|
||||
name=category_name
|
||||
)
|
||||
except TransactionCategory.DoesNotExist:
|
||||
category = TransactionCategory(name=category_name)
|
||||
category.save()
|
||||
else:
|
||||
category = TransactionCategory.objects.filter(
|
||||
name=category_name
|
||||
).first()
|
||||
|
||||
if category:
|
||||
data["category"] = category
|
||||
self.import_run.categories.add(category)
|
||||
@@ -265,9 +328,13 @@ class ImportService:
|
||||
tag = TransactionTag.objects.filter(id=tag_name).first()
|
||||
else: # name
|
||||
if getattr(tags_mapping, "create", False):
|
||||
tag, _ = TransactionTag.objects.get_or_create(
|
||||
name=tag_name.strip()
|
||||
)
|
||||
try:
|
||||
tag = TransactionTag.objects.get(
|
||||
name=tag_name.strip()
|
||||
)
|
||||
except TransactionTag.DoesNotExist:
|
||||
tag = TransactionTag(name=tag_name.strip())
|
||||
tag.save()
|
||||
else:
|
||||
tag = TransactionTag.objects.filter(
|
||||
name=tag_name.strip()
|
||||
@@ -301,9 +368,13 @@ class ImportService:
|
||||
).first()
|
||||
else: # name
|
||||
if getattr(entities_mapping, "create", False):
|
||||
entity, _ = TransactionEntity.objects.get_or_create(
|
||||
name=entity_name.strip()
|
||||
)
|
||||
try:
|
||||
entity = TransactionEntity.objects.get(
|
||||
name=entity_name.strip()
|
||||
)
|
||||
except TransactionEntity.DoesNotExist:
|
||||
entity = TransactionEntity(name=entity_name.strip())
|
||||
entity.save()
|
||||
else:
|
||||
entity = TransactionEntity.objects.filter(
|
||||
name=entity_name.strip()
|
||||
@@ -334,7 +405,11 @@ class ImportService:
|
||||
def _create_account(self, data: Dict[str, Any]) -> Account:
|
||||
if "group" in data:
|
||||
group_name = data.pop("group")
|
||||
group, _ = AccountGroup.objects.get_or_create(name=group_name)
|
||||
try:
|
||||
group = AccountGroup.objects.get(name=group_name)
|
||||
except AccountGroup.DoesNotExist:
|
||||
group = AccountGroup(name=group_name)
|
||||
group.save()
|
||||
data["group"] = group
|
||||
|
||||
# Handle currency references
|
||||
@@ -399,7 +474,7 @@ class ImportService:
|
||||
|
||||
def _coerce_type(
|
||||
self, value: str, mapping: version_1.ColumnMapping
|
||||
) -> Union[str, int, bool, Decimal, datetime, list]:
|
||||
) -> Union[str, int, bool, Decimal, datetime, list, None]:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
@@ -434,6 +509,11 @@ class ImportService:
|
||||
version_1.TransactionReferenceDateMapping,
|
||||
),
|
||||
):
|
||||
if isinstance(value, datetime):
|
||||
return value.date()
|
||||
elif isinstance(value, date):
|
||||
return value
|
||||
|
||||
formats = (
|
||||
mapping.format
|
||||
if isinstance(mapping.format, list)
|
||||
@@ -484,28 +564,30 @@ class ImportService:
|
||||
|
||||
def _map_row(self, row: Dict[str, str]) -> Dict[str, Any]:
|
||||
mapped_data = {}
|
||||
|
||||
for field, mapping in self.mapping.items():
|
||||
value = None
|
||||
|
||||
if isinstance(mapping.source, str):
|
||||
value = row.get(mapping.source, None)
|
||||
if mapping.source in row:
|
||||
value = row[mapping.source]
|
||||
elif (
|
||||
mapping.source.startswith("__")
|
||||
and mapping.source[2:] in mapped_data
|
||||
):
|
||||
value = mapped_data[mapping.source[2:]]
|
||||
elif isinstance(mapping.source, list):
|
||||
for source in mapping.source:
|
||||
value = row.get(source, None)
|
||||
if value:
|
||||
if source in row:
|
||||
value = row[source]
|
||||
break
|
||||
elif source.startswith("__") and source[2:] in mapped_data:
|
||||
value = mapped_data[source[2:]]
|
||||
break
|
||||
else:
|
||||
# If source is None, use None as the initial value
|
||||
value = None
|
||||
|
||||
# Use default_value if value is None
|
||||
if not value:
|
||||
if value is None:
|
||||
value = mapping.default
|
||||
|
||||
# Apply transformations
|
||||
if mapping.transformations:
|
||||
value = self._transform_value(value, mapping, row)
|
||||
value = self._transform_value(value, mapping, row, mapped_data)
|
||||
|
||||
value = self._coerce_type(value, mapping)
|
||||
|
||||
@@ -513,17 +595,29 @@ class ImportService:
|
||||
raise ValueError(f"Required field {field} is missing")
|
||||
|
||||
if value is not None:
|
||||
# Remove the prefix from the target field
|
||||
target = mapping.target
|
||||
if self.settings.importing == "transactions":
|
||||
mapped_data[target] = value
|
||||
else:
|
||||
# Remove the model prefix (e.g., "account_" from "account_name")
|
||||
field_name = target.split("_", 1)[1]
|
||||
mapped_data[field_name] = value
|
||||
|
||||
return mapped_data
|
||||
|
||||
@staticmethod
|
||||
def _prepare_numeric_value(
|
||||
value: str, thousand_separator: str, decimal_separator: str
|
||||
) -> Decimal:
|
||||
# Remove thousand separators
|
||||
if thousand_separator:
|
||||
value = value.replace(thousand_separator, "")
|
||||
|
||||
# Replace decimal separator with dot
|
||||
if decimal_separator != ".":
|
||||
value = value.replace(decimal_separator, ".")
|
||||
|
||||
return Decimal(value)
|
||||
|
||||
def _process_row(self, row: Dict[str, str], row_number: int) -> None:
|
||||
try:
|
||||
mapped_data = self._map_row(row)
|
||||
@@ -589,6 +683,151 @@ class ImportService:
|
||||
for row_number, row in enumerate(reader, start=1):
|
||||
self._process_row(row, row_number)
|
||||
|
||||
def _process_excel(self, file_path):
|
||||
try:
|
||||
if self.settings.file_type == "xlsx":
|
||||
workbook = openpyxl.load_workbook(
|
||||
file_path, read_only=True, data_only=True
|
||||
)
|
||||
sheets_to_process = (
|
||||
workbook.sheetnames
|
||||
if self.settings.sheets == "*"
|
||||
else (
|
||||
self.settings.sheets
|
||||
if isinstance(self.settings.sheets, list)
|
||||
else [self.settings.sheets]
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate total rows
|
||||
total_rows = sum(
|
||||
max(0, workbook[sheet_name].max_row - self.settings.start_row)
|
||||
for sheet_name in sheets_to_process
|
||||
if sheet_name in workbook.sheetnames
|
||||
)
|
||||
self._update_totals("total", value=total_rows)
|
||||
|
||||
# Process sheets
|
||||
for sheet_name in sheets_to_process:
|
||||
if sheet_name not in workbook.sheetnames:
|
||||
self._log(
|
||||
"warning",
|
||||
f"Sheet '{sheet_name}' not found in the Excel file. Skipping.",
|
||||
)
|
||||
continue
|
||||
|
||||
sheet = workbook[sheet_name]
|
||||
self._log("info", f"Processing sheet: {sheet_name}")
|
||||
headers = [
|
||||
str(cell.value or "") for cell in sheet[self.settings.start_row]
|
||||
]
|
||||
|
||||
for row_number, row in enumerate(
|
||||
sheet.iter_rows(
|
||||
min_row=self.settings.start_row + 1, values_only=True
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
try:
|
||||
row_data = {
|
||||
key: str(value) if value is not None else None
|
||||
for key, value in zip(headers, row)
|
||||
}
|
||||
self._process_row(row_data, row_number)
|
||||
except Exception as e:
|
||||
if self.settings.skip_errors:
|
||||
self._log(
|
||||
"warning",
|
||||
f"Error processing row {row_number} in sheet '{sheet_name}': {str(e)}",
|
||||
)
|
||||
self._increment_totals("failed", value=1)
|
||||
else:
|
||||
raise
|
||||
|
||||
workbook.close()
|
||||
|
||||
else: # xls
|
||||
workbook = xlrd.open_workbook(file_path)
|
||||
sheets_to_process = (
|
||||
workbook.sheet_names()
|
||||
if self.settings.sheets == "*"
|
||||
else (
|
||||
self.settings.sheets
|
||||
if isinstance(self.settings.sheets, list)
|
||||
else [self.settings.sheets]
|
||||
)
|
||||
)
|
||||
# Calculate total rows
|
||||
total_rows = sum(
|
||||
max(
|
||||
0,
|
||||
workbook.sheet_by_name(sheet_name).nrows
|
||||
- self.settings.start_row,
|
||||
)
|
||||
for sheet_name in sheets_to_process
|
||||
if sheet_name in workbook.sheet_names()
|
||||
)
|
||||
self._update_totals("total", value=total_rows)
|
||||
# Process sheets
|
||||
for sheet_name in sheets_to_process:
|
||||
if sheet_name not in workbook.sheet_names():
|
||||
self._log(
|
||||
"warning",
|
||||
f"Sheet '{sheet_name}' not found in the Excel file. Skipping.",
|
||||
)
|
||||
continue
|
||||
sheet = workbook.sheet_by_name(sheet_name)
|
||||
self._log("info", f"Processing sheet: {sheet_name}")
|
||||
headers = [
|
||||
str(sheet.cell_value(self.settings.start_row - 1, col) or "")
|
||||
for col in range(sheet.ncols)
|
||||
]
|
||||
for row_number in range(self.settings.start_row, sheet.nrows):
|
||||
try:
|
||||
row_data = {}
|
||||
for col, key in enumerate(headers):
|
||||
cell_type = sheet.cell_type(row_number, col)
|
||||
cell_value = sheet.cell_value(row_number, col)
|
||||
|
||||
if cell_type == xlrd.XL_CELL_DATE:
|
||||
# Convert Excel date to Python datetime
|
||||
try:
|
||||
python_date = datetime(
|
||||
*xlrd.xldate_as_tuple(
|
||||
cell_value, workbook.datemode
|
||||
)
|
||||
)
|
||||
row_data[key] = python_date
|
||||
except Exception:
|
||||
# If date conversion fails, use the original value
|
||||
row_data[key] = (
|
||||
str(cell_value)
|
||||
if cell_value is not None
|
||||
else None
|
||||
)
|
||||
elif cell_value is None:
|
||||
row_data[key] = None
|
||||
else:
|
||||
row_data[key] = str(cell_value)
|
||||
|
||||
self._process_row(
|
||||
row_data, row_number - self.settings.start_row + 1
|
||||
)
|
||||
except Exception as e:
|
||||
if self.settings.skip_errors:
|
||||
self._log(
|
||||
"warning",
|
||||
f"Error processing row {row_number} in sheet '{sheet_name}': {str(e)}",
|
||||
)
|
||||
self._increment_totals("failed", value=1)
|
||||
else:
|
||||
raise
|
||||
|
||||
except (InvalidFileException, xlrd.XLRDError) as e:
|
||||
raise ValueError(
|
||||
f"Invalid {self.settings.file_type.upper()} file format: {str(e)}"
|
||||
)
|
||||
|
||||
def _validate_file_path(self, file_path: str) -> str:
|
||||
"""
|
||||
Validates that the file path is within the allowed temporary directory.
|
||||
@@ -611,8 +850,10 @@ class ImportService:
|
||||
self._log("info", "Starting import process")
|
||||
|
||||
try:
|
||||
if self.settings.file_type == "csv":
|
||||
if isinstance(self.settings, version_1.CSVImportSettings):
|
||||
self._process_csv(file_path)
|
||||
elif isinstance(self.settings, version_1.ExcelImportSettings):
|
||||
self._process_excel(file_path)
|
||||
|
||||
self._update_status("FINISHED")
|
||||
self._log(
|
||||
@@ -639,4 +880,3 @@ class ImportService:
|
||||
|
||||
self.import_run.finished_at = timezone.now()
|
||||
self.import_run.save(update_fields=["finished_at"])
|
||||
cachalot.api.invalidate()
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
|
||||
import cachalot.api
|
||||
from django.contrib.auth import get_user_model
|
||||
from procrastinate.contrib.django import app
|
||||
|
||||
from apps.common.middleware.thread_local import write_current_user, delete_current_user
|
||||
from apps.import_app.models import ImportRun
|
||||
from apps.import_app.services import ImportServiceV1
|
||||
|
||||
@@ -10,12 +11,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.task(name="process_import")
|
||||
def process_import(import_run_id: int, file_path: str):
|
||||
def process_import(import_run_id: int, file_path: str, user_id: int):
|
||||
user = get_user_model().objects.get(id=user_id)
|
||||
write_current_user(user)
|
||||
|
||||
try:
|
||||
import_run = ImportRun.objects.get(id=import_run_id)
|
||||
import_service = ImportServiceV1(import_run)
|
||||
import_service.process_file(file_path)
|
||||
cachalot.api.invalidate()
|
||||
delete_current_user()
|
||||
except ImportRun.DoesNotExist:
|
||||
cachalot.api.invalidate()
|
||||
delete_current_user()
|
||||
raise ValueError(f"ImportRun with id {import_run_id} not found")
|
||||
|
||||
@@ -1,3 +1,424 @@
|
||||
from django.test import TestCase
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
import yaml
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
# Create your tests here.
|
||||
from django.test import TestCase
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
from apps.import_app.forms import ImportProfileForm
|
||||
from apps.import_app.services.v1 import ImportService
|
||||
from apps.import_app.schemas import version_1
|
||||
from apps.transactions.models import Transaction # For Transaction.Type
|
||||
from unittest.mock import patch
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class ImportProfileTests(TestCase):
|
||||
|
||||
def test_import_profile_valid_yaml_v1(self):
|
||||
valid_yaml_config = """
|
||||
settings:
|
||||
file_type: csv
|
||||
delimiter: ','
|
||||
encoding: utf-8
|
||||
skip_lines: 0
|
||||
trigger_transaction_rules: true
|
||||
importing: transactions
|
||||
mapping:
|
||||
date:
|
||||
target: date
|
||||
source: Transaction Date
|
||||
format: '%Y-%m-%d'
|
||||
amount:
|
||||
target: amount
|
||||
source: Amount
|
||||
description:
|
||||
target: description
|
||||
source: Narrative
|
||||
account:
|
||||
target: account
|
||||
source: Account Name
|
||||
type: name
|
||||
type:
|
||||
target: type
|
||||
source: Credit Debit
|
||||
detection_method: sign # Assumes positive is income, negative is expense
|
||||
is_paid:
|
||||
target: is_paid
|
||||
detection_method: always_paid
|
||||
deduplication: []
|
||||
"""
|
||||
profile = ImportProfile(
|
||||
name="Test Valid Profile V1",
|
||||
yaml_config=valid_yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
try:
|
||||
profile.full_clean()
|
||||
except ValidationError as e:
|
||||
self.fail(f"Valid YAML config raised ValidationError: {e.error_dict}")
|
||||
|
||||
# Optional: Save and retrieve
|
||||
profile.save()
|
||||
retrieved_profile = ImportProfile.objects.get(pk=profile.pk)
|
||||
self.assertIsNotNone(retrieved_profile)
|
||||
self.assertEqual(retrieved_profile.name, "Test Valid Profile V1")
|
||||
|
||||
def test_import_profile_invalid_yaml_syntax_v1(self):
|
||||
invalid_yaml = "settings: { file_type: csv, delimiter: ','" # Malformed YAML
|
||||
profile = ImportProfile(
|
||||
name="Test Invalid Syntax V1",
|
||||
yaml_config=invalid_yaml,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
profile.full_clean()
|
||||
|
||||
self.assertIn('yaml_config', cm.exception.error_dict)
|
||||
self.assertTrue(any("YAML" in error.message.lower() or "syntax" in error.message.lower() for error in cm.exception.error_dict['yaml_config']))
|
||||
|
||||
def test_import_profile_schema_validation_error_v1(self):
|
||||
schema_error_yaml = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
mapping:
|
||||
date: # Missing 'format' which is required for TransactionDateMapping
|
||||
target: date
|
||||
source: Transaction Date
|
||||
"""
|
||||
profile = ImportProfile(
|
||||
name="Test Schema Error V1",
|
||||
yaml_config=schema_error_yaml,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
profile.full_clean()
|
||||
|
||||
self.assertIn('yaml_config', cm.exception.error_dict)
|
||||
# Pydantic errors usually mention the field and "field required" or similar
|
||||
self.assertTrue(any("format" in error.message.lower() and "field required" in error.message.lower()
|
||||
for error in cm.exception.error_dict['yaml_config']),
|
||||
f"Error messages: {[e.message for e in cm.exception.error_dict['yaml_config']]}")
|
||||
|
||||
|
||||
def test_import_profile_custom_validate_mappings_error_v1(self):
|
||||
custom_validate_yaml = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions # Importing transactions
|
||||
mapping:
|
||||
account_name: # This is an AccountNameMapping, not suitable for 'transactions' importing setting
|
||||
target: account_name
|
||||
source: AccName
|
||||
"""
|
||||
profile = ImportProfile(
|
||||
name="Test Custom Validate Error V1",
|
||||
yaml_config=custom_validate_yaml,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
profile.full_clean()
|
||||
|
||||
self.assertIn('yaml_config', cm.exception.error_dict)
|
||||
# Check for the specific message raised by custom_validate_mappings
|
||||
# The message is "Mapping type AccountNameMapping not allowed for importing 'transactions'."
|
||||
self.assertTrue(any("mapping type accountnamemapping not allowed for importing 'transactions'" in error.message.lower()
|
||||
for error in cm.exception.error_dict['yaml_config']),
|
||||
f"Error messages: {[e.message for e in cm.exception.error_dict['yaml_config']]}")
|
||||
|
||||
|
||||
def test_import_profile_name_unique(self):
|
||||
valid_yaml_config = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
mapping:
|
||||
date:
|
||||
target: date
|
||||
source: Date
|
||||
format: '%Y-%m-%d'
|
||||
""" # Minimal valid YAML for this test
|
||||
|
||||
ImportProfile.objects.create(
|
||||
name="Unique Name Test",
|
||||
yaml_config=valid_yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
|
||||
profile2 = ImportProfile(
|
||||
name="Unique Name Test", # Same name
|
||||
yaml_config=valid_yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
|
||||
# full_clean should catch this because of the unique constraint on the model field.
|
||||
# Django's Model.full_clean() calls Model.validate_unique().
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
profile2.full_clean()
|
||||
|
||||
self.assertIn('name', cm.exception.error_dict)
|
||||
self.assertTrue(any("already exists" in error.message.lower() for error in cm.exception.error_dict['name']))
|
||||
|
||||
# As a fallback, or for more direct DB constraint testing, also test IntegrityError on save if full_clean didn't catch it.
|
||||
# This will only be reached if the full_clean() above somehow passes.
|
||||
# try:
|
||||
# profile2.save()
|
||||
# except IntegrityError:
|
||||
# pass # Expected if full_clean didn't catch it
|
||||
# else:
|
||||
# if 'name' not in cm.exception.error_dict: # If full_clean passed and save also passed
|
||||
# self.fail("IntegrityError not raised for duplicate name on save(), and full_clean() didn't catch it.")
|
||||
|
||||
def test_import_profile_form_valid_data(self):
|
||||
valid_yaml_config = """
|
||||
settings:
|
||||
file_type: csv
|
||||
delimiter: ','
|
||||
encoding: utf-8
|
||||
skip_lines: 0
|
||||
trigger_transaction_rules: true
|
||||
importing: transactions
|
||||
mapping:
|
||||
date:
|
||||
target: date
|
||||
source: Transaction Date
|
||||
format: '%Y-%m-%d'
|
||||
amount:
|
||||
target: amount
|
||||
source: Amount
|
||||
description:
|
||||
target: description
|
||||
source: Narrative
|
||||
account:
|
||||
target: account
|
||||
source: Account Name
|
||||
type: name
|
||||
type:
|
||||
target: type
|
||||
source: Credit Debit
|
||||
detection_method: sign
|
||||
is_paid:
|
||||
target: is_paid
|
||||
detection_method: always_paid
|
||||
deduplication: []
|
||||
"""
|
||||
form_data = {
|
||||
'name': 'Form Test Valid',
|
||||
'yaml_config': valid_yaml_config,
|
||||
'version': ImportProfile.Versions.VERSION_1
|
||||
}
|
||||
form = ImportProfileForm(data=form_data)
|
||||
self.assertTrue(form.is_valid(), f"Form errors: {form.errors.as_json()}")
|
||||
|
||||
profile = form.save()
|
||||
self.assertIsNotNone(profile.pk)
|
||||
self.assertEqual(profile.name, 'Form Test Valid')
|
||||
# YAMLField might re-serialize the YAML, so direct string comparison might be brittle
|
||||
# if spacing/ordering changes. However, for now, let's assume it's stored as provided or close enough.
|
||||
# A more robust check would be to load both YAMLs and compare the resulting dicts.
|
||||
self.assertEqual(profile.yaml_config.strip(), valid_yaml_config.strip())
|
||||
self.assertEqual(profile.version, ImportProfile.Versions.VERSION_1)
|
||||
|
||||
def test_import_profile_form_invalid_yaml(self):
|
||||
# Using a YAML that causes a schema validation error (missing 'format' for date mapping)
|
||||
invalid_yaml_for_form = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
mapping:
|
||||
date:
|
||||
target: date
|
||||
source: Transaction Date
|
||||
"""
|
||||
form_data = {
|
||||
'name': 'Form Test Invalid',
|
||||
'yaml_config': invalid_yaml_for_form,
|
||||
'version': ImportProfile.Versions.VERSION_1
|
||||
}
|
||||
form = ImportProfileForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('yaml_config', form.errors)
|
||||
# Check for a message indicating schema validation failure
|
||||
self.assertTrue(any("field required" in error.lower() for error in form.errors['yaml_config']))
|
||||
|
||||
|
||||
class ImportServiceTests(TestCase):
|
||||
# ... (existing setUp and other test methods from previous task) ...
|
||||
def setUp(self):
|
||||
minimal_yaml_config = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
mapping:
|
||||
description:
|
||||
target: description
|
||||
source: Desc
|
||||
"""
|
||||
self.profile = ImportProfile.objects.create(
|
||||
name="Test Service Profile",
|
||||
yaml_config=minimal_yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
self.import_run = ImportRun.objects.create(
|
||||
profile=self.profile,
|
||||
status=ImportRun.Status.PENDING
|
||||
)
|
||||
# self.service is initialized in each test to allow specific mapping_config
|
||||
# or to re-initialize if service state changes (though it shouldn't for these private methods)
|
||||
|
||||
# Tests for _transform_value
|
||||
def test_transform_value_replace(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.ColumnMapping(target="description", source="Desc") # Basic mapping
|
||||
mapping_config.transformations = [
|
||||
version_1.ReplaceTransformationRule(type="replace", pattern="old", replacement="new")
|
||||
]
|
||||
transformed_value = service._transform_value("this is old text", mapping_config)
|
||||
self.assertEqual(transformed_value, "this is new text")
|
||||
|
||||
def test_transform_value_date_format(self):
|
||||
service = ImportService(self.import_run)
|
||||
# DateFormatTransformationRule is typically part of a DateMapping, but testing transform directly
|
||||
mapping_config = version_1.TransactionDateMapping(target="date", source="Date", format="%d/%m/%Y") # format is for final coercion
|
||||
mapping_config.transformations = [
|
||||
version_1.DateFormatTransformationRule(type="date_format", original_format="%Y-%m-%d", new_format="%d/%m/%Y")
|
||||
]
|
||||
transformed_value = service._transform_value("2023-01-15", mapping_config)
|
||||
self.assertEqual(transformed_value, "15/01/2023")
|
||||
|
||||
def test_transform_value_regex_replace(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.ColumnMapping(target="description", source="Desc")
|
||||
mapping_config.transformations = [
|
||||
version_1.ReplaceTransformationRule(type="regex", pattern=r"\\d+", replacement="NUM")
|
||||
]
|
||||
transformed_value = service._transform_value("abc123xyz456", mapping_config)
|
||||
self.assertEqual(transformed_value, "abcNUMxyzNUM")
|
||||
|
||||
# Tests for _coerce_type
|
||||
def test_coerce_type_string_to_decimal(self):
|
||||
service = ImportService(self.import_run)
|
||||
# TransactionAmountMapping has coerce_to="positive_decimal" by default
|
||||
mapping_config = version_1.TransactionAmountMapping(target="amount", source="Amt")
|
||||
|
||||
coerced = service._coerce_type("123.45", mapping_config)
|
||||
self.assertEqual(coerced, Decimal("123.45"))
|
||||
|
||||
coerced_neg = service._coerce_type("-123.45", mapping_config)
|
||||
self.assertEqual(coerced_neg, Decimal("123.45")) # positive_decimal behavior
|
||||
|
||||
# Test with coerce_to="decimal"
|
||||
mapping_config_decimal = version_1.TransactionAmountMapping(target="amount", source="Amt", coerce_to="decimal")
|
||||
coerced_neg_decimal = service._coerce_type("-123.45", mapping_config_decimal)
|
||||
self.assertEqual(coerced_neg_decimal, Decimal("-123.45"))
|
||||
|
||||
|
||||
def test_coerce_type_string_to_date(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.TransactionDateMapping(target="date", source="Dt", format="%Y-%m-%d")
|
||||
coerced = service._coerce_type("2023-01-15", mapping_config)
|
||||
self.assertEqual(coerced, date(2023, 1, 15))
|
||||
|
||||
def test_coerce_type_string_to_transaction_type_sign(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.TransactionTypeMapping(target="type", source="TType", detection_method="sign")
|
||||
|
||||
self.assertEqual(service._coerce_type("100.00", mapping_config), Transaction.Type.INCOME)
|
||||
self.assertEqual(service._coerce_type("-100.00", mapping_config), Transaction.Type.EXPENSE)
|
||||
self.assertEqual(service._coerce_type("0.00", mapping_config), Transaction.Type.EXPENSE) # Sign detection treats 0 as expense
|
||||
self.assertEqual(service._coerce_type("+200", mapping_config), Transaction.Type.INCOME)
|
||||
|
||||
def test_coerce_type_string_to_transaction_type_keywords(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.TransactionTypeMapping(
|
||||
target="type",
|
||||
source="TType",
|
||||
detection_method="keywords",
|
||||
income_keywords=["credit", "dep"],
|
||||
expense_keywords=["debit", "wdrl"]
|
||||
)
|
||||
self.assertEqual(service._coerce_type("Monthly Credit", mapping_config), Transaction.Type.INCOME)
|
||||
self.assertEqual(service._coerce_type("ATM WDRL", mapping_config), Transaction.Type.EXPENSE)
|
||||
self.assertIsNone(service._coerce_type("Unknown Type", mapping_config)) # No keyword match
|
||||
|
||||
@patch('apps.import_app.services.v1.os.remove')
|
||||
def test_process_file_simple_csv_transactions(self, mock_os_remove):
|
||||
simple_transactions_yaml = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
delimiter: ','
|
||||
skip_lines: 0
|
||||
mapping:
|
||||
date: {target: date, source: Date, format: '%Y-%m-%d'}
|
||||
amount: {target: amount, source: Amount}
|
||||
description: {target: description, source: Description}
|
||||
type: {target: type, source: Type, detection_method: always_income}
|
||||
account: {target: account, source: AccountName, type: name}
|
||||
"""
|
||||
self.profile.yaml_config = simple_transactions_yaml
|
||||
self.profile.save()
|
||||
self.import_run.refresh_from_db() # Ensure import_run has the latest profile reference if needed
|
||||
|
||||
csv_content = "Date,Amount,Description,Type,AccountName\n2023-01-01,100.00,Test Deposit,INCOME,TestAcc"
|
||||
|
||||
temp_file_path = None
|
||||
try:
|
||||
# Ensure TEMP_DIR exists if ImportService relies on it being pre-existing
|
||||
# For NamedTemporaryFile, dir just needs to be a valid directory path.
|
||||
# If ImportService.TEMP_DIR is a class variable pointing to a specific path,
|
||||
# it should be created or mocked if it doesn't exist by default.
|
||||
# For simplicity, let's assume it exists or tempfile handles it gracefully.
|
||||
# If ImportService.TEMP_DIR is not guaranteed, use default temp dir.
|
||||
temp_dir = getattr(ImportService, 'TEMP_DIR', None)
|
||||
if temp_dir and not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False, dir=temp_dir, suffix='.csv', encoding='utf-8') as tmp_file:
|
||||
tmp_file.write(csv_content)
|
||||
temp_file_path = tmp_file.name
|
||||
|
||||
self.addCleanup(lambda: os.remove(temp_file_path) if temp_file_path and os.path.exists(temp_file_path) else None)
|
||||
|
||||
service = ImportService(self.import_run)
|
||||
|
||||
with patch.object(service, '_create_transaction') as mock_create_transaction:
|
||||
service.process_file(temp_file_path)
|
||||
|
||||
self.import_run.refresh_from_db() # Refresh to get updated status and counts
|
||||
self.assertEqual(self.import_run.status, ImportRun.Status.FINISHED)
|
||||
self.assertEqual(self.import_run.total_rows, 1)
|
||||
self.assertEqual(self.import_run.successful_rows, 1)
|
||||
|
||||
mock_create_transaction.assert_called_once()
|
||||
|
||||
# The first argument to _create_transaction is the row_data dictionary
|
||||
args_dict = mock_create_transaction.call_args[0][0]
|
||||
|
||||
self.assertEqual(args_dict['date'], date(2023, 1, 1))
|
||||
self.assertEqual(args_dict['amount'], Decimal('100.00'))
|
||||
self.assertEqual(args_dict['description'], "Test Deposit")
|
||||
self.assertEqual(args_dict['type'], Transaction.Type.INCOME)
|
||||
|
||||
# Account 'TestAcc' does not exist, so _map_row should resolve 'account' to None.
|
||||
# This assumes the default behavior of AccountMapping(type='name') when an account is not found
|
||||
# and creation of new accounts from mapping is not enabled/implemented in _map_row for this test.
|
||||
self.assertIsNone(args_dict.get('account'),
|
||||
"Account should be None as 'TestAcc' is not created in this test setup.")
|
||||
|
||||
mock_os_remove.assert_called_once_with(temp_file_path)
|
||||
|
||||
finally:
|
||||
# This cleanup is now handled by self.addCleanup, but kept for safety if addCleanup fails early.
|
||||
if temp_file_path and os.path.exists(temp_file_path) and not mock_os_remove.called:
|
||||
# If mock_os_remove was not called (e.g., an error before service.process_file finished),
|
||||
# we might need to manually clean up if addCleanup didn't register or run.
|
||||
# However, addCleanup is generally robust.
|
||||
pass
|
||||
|
||||
@@ -2,7 +2,6 @@ from django.urls import path
|
||||
import apps.import_app.views as views
|
||||
|
||||
urlpatterns = [
|
||||
path("import/", views.import_view, name="import"),
|
||||
path(
|
||||
"import/presets/",
|
||||
views.import_presets_list,
|
||||
|
||||
@@ -13,22 +13,11 @@ from apps.import_app.forms import ImportRunFileUploadForm, ImportProfileForm
|
||||
from apps.import_app.models import ImportRun, ImportProfile
|
||||
from apps.import_app.services import PresetService
|
||||
from apps.import_app.tasks import process_import
|
||||
|
||||
|
||||
def import_view(request):
|
||||
import_profile = ImportProfile.objects.get(id=2)
|
||||
shutil.copyfile(
|
||||
"/usr/src/app/apps/import_app/teste2.csv", "/usr/src/app/temp/teste2.csv"
|
||||
)
|
||||
ir = ImportRun.objects.create(profile=import_profile, file_name="teste.csv")
|
||||
process_import.defer(
|
||||
import_run_id=ir.id,
|
||||
file_path="/usr/src/app/temp/teste2.csv",
|
||||
)
|
||||
return HttpResponse("Hello, world. You're at the polls page.")
|
||||
from apps.common.decorators.demo import disabled_on_demo
|
||||
|
||||
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET"])
|
||||
def import_presets_list(request):
|
||||
presets = PresetService.get_all_presets()
|
||||
@@ -40,6 +29,7 @@ def import_presets_list(request):
|
||||
|
||||
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def import_profile_index(request):
|
||||
return render(
|
||||
@@ -50,6 +40,7 @@ def import_profile_index(request):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def import_profile_list(request):
|
||||
profiles = ImportProfile.objects.all()
|
||||
@@ -63,6 +54,7 @@ def import_profile_list(request):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def import_profile_add(request):
|
||||
message = request.POST.get("message", None)
|
||||
@@ -98,6 +90,7 @@ def import_profile_add(request):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def import_profile_edit(request, profile_id):
|
||||
profile = get_object_or_404(ImportProfile, id=profile_id)
|
||||
@@ -127,6 +120,7 @@ def import_profile_edit(request, profile_id):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["DELETE"])
|
||||
def import_profile_delete(request, profile_id):
|
||||
profile = ImportProfile.objects.get(id=profile_id)
|
||||
@@ -145,6 +139,7 @@ def import_profile_delete(request, profile_id):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def import_runs_list(request, profile_id):
|
||||
profile = ImportProfile.objects.get(id=profile_id)
|
||||
@@ -160,6 +155,7 @@ def import_runs_list(request, profile_id):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def import_run_log(request, profile_id, run_id):
|
||||
run = ImportRun.objects.get(profile__id=profile_id, id=run_id)
|
||||
@@ -173,6 +169,7 @@ def import_run_log(request, profile_id, run_id):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["GET", "POST"])
|
||||
def import_run_add(request, profile_id):
|
||||
profile = ImportProfile.objects.get(id=profile_id)
|
||||
@@ -189,7 +186,11 @@ def import_run_add(request, profile_id):
|
||||
import_run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
|
||||
# Defer the procrastinate task
|
||||
process_import.defer(import_run_id=import_run.id, file_path=file_path)
|
||||
process_import.defer(
|
||||
import_run_id=import_run.id,
|
||||
file_path=file_path,
|
||||
user_id=request.user.id,
|
||||
)
|
||||
|
||||
messages.success(request, _("Import Run queued successfully"))
|
||||
|
||||
@@ -211,6 +212,7 @@ def import_run_add(request, profile_id):
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@disabled_on_demo
|
||||
@require_http_methods(["DELETE"])
|
||||
def import_run_delete(request, profile_id, run_id):
|
||||
run = ImportRun.objects.get(profile__id=profile_id, id=run_id)
|
||||
|
||||
0
app/apps/insights/__init__.py
Normal file
0
app/apps/insights/__init__.py
Normal file
3
app/apps/insights/admin.py
Normal file
3
app/apps/insights/admin.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from django.contrib import admin
|
||||
|
||||
# Register your models here.
|
||||
6
app/apps/insights/apps.py
Normal file
6
app/apps/insights/apps.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class InsightsConfig(AppConfig):
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "apps.insights"
|
||||
133
app/apps/insights/forms.py
Normal file
133
app/apps/insights/forms.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Row, Column
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.datepicker import (
|
||||
AirMonthYearPickerInput,
|
||||
AirYearPickerInput,
|
||||
AirDatePickerInput,
|
||||
)
|
||||
from apps.transactions.models import TransactionCategory
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
|
||||
|
||||
class SingleMonthForm(forms.Form):
|
||||
month = forms.DateField(
|
||||
widget=AirMonthYearPickerInput(clear_button=False), label="", required=True
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.disable_csrf = True
|
||||
|
||||
self.helper.layout = Layout(Field("month"))
|
||||
|
||||
|
||||
class SingleYearForm(forms.Form):
|
||||
year = forms.DateField(
|
||||
widget=AirYearPickerInput(clear_button=False), label="", required=True
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.disable_csrf = True
|
||||
|
||||
self.helper.layout = Layout(Field("year"))
|
||||
|
||||
|
||||
class MonthRangeForm(forms.Form):
|
||||
month_from = forms.DateField(
|
||||
widget=AirMonthYearPickerInput(clear_button=False), label="", required=True
|
||||
)
|
||||
month_to = forms.DateField(
|
||||
widget=AirMonthYearPickerInput(clear_button=False), label="", required=True
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.disable_csrf = True
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("month_from", css_class="form-group col-md-6"),
|
||||
Column("month_to", css_class="form-group col-md-6"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class YearRangeForm(forms.Form):
|
||||
year_from = forms.DateField(
|
||||
widget=AirYearPickerInput(clear_button=False), label="", required=True
|
||||
)
|
||||
year_to = forms.DateField(
|
||||
widget=AirYearPickerInput(clear_button=False), label="", required=True
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.disable_csrf = True
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("year_from", css_class="form-group col-md-6"),
|
||||
Column("year_to", css_class="form-group col-md-6"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DateRangeForm(forms.Form):
|
||||
date_from = forms.DateField(
|
||||
widget=AirDatePickerInput(clear_button=False), label="", required=True
|
||||
)
|
||||
date_to = forms.DateField(
|
||||
widget=AirDatePickerInput(clear_button=False), label="", required=True
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.disable_csrf = True
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("date_from", css_class="form-group col-md-6"),
|
||||
Column("date_to", css_class="form-group col-md-6"),
|
||||
css_class="mb-0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CategoryForm(forms.Form):
|
||||
category = forms.ModelChoiceField(
|
||||
required=False,
|
||||
label=_("Category"),
|
||||
empty_label=_("Uncategorized"),
|
||||
queryset=TransactionCategory.objects.all(),
|
||||
widget=TomSelect(clear_button=True),
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fields["category"].queryset = TransactionCategory.objects.all()
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.disable_csrf = True
|
||||
|
||||
self.helper.layout = Layout("category")
|
||||
0
app/apps/insights/migrations/__init__.py
Normal file
0
app/apps/insights/migrations/__init__.py
Normal file
3
app/apps/insights/models.py
Normal file
3
app/apps/insights/models.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from django.db import models
|
||||
|
||||
# Create your models here.
|
||||
303
app/apps/insights/tests.py
Normal file
303
app/apps/insights/tests.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import User
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from decimal import Decimal
|
||||
from datetime import date, timedelta
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, Transaction
|
||||
from apps.insights.utils.category_explorer import get_category_sums_by_account, get_category_sums_by_currency
|
||||
from apps.insights.utils.sankey import generate_sankey_data_by_account
|
||||
|
||||
class InsightsUtilsTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username='testinsightsuser', password='password')
|
||||
|
||||
self.currency_usd = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2)
|
||||
self.currency_eur = Currency.objects.create(code="EUR", name="Euro", decimal_places=2)
|
||||
|
||||
# It's good practice to have an AccountGroup for accounts
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group", owner=self.user)
|
||||
|
||||
self.category_food = TransactionCategory.objects.create(name="Food", owner=self.user, type=TransactionCategory.TransactionType.EXPENSE)
|
||||
self.category_salary = TransactionCategory.objects.create(name="Salary", owner=self.user, type=TransactionCategory.TransactionType.INCOME)
|
||||
|
||||
self.account_usd_1 = Account.objects.create(name="USD Account 1", owner=self.user, currency=self.currency_usd, group=self.account_group)
|
||||
self.account_usd_2 = Account.objects.create(name="USD Account 2", owner=self.user, currency=self.currency_usd, group=self.account_group)
|
||||
self.account_eur_1 = Account.objects.create(name="EUR Account 1", owner=self.user, currency=self.currency_eur, group=self.account_group)
|
||||
|
||||
today = date.today()
|
||||
|
||||
# T1: Acc USD1, Food, Expense 50 (paid)
|
||||
Transaction.objects.create(
|
||||
description="Groceries USD1 Food Paid", account=self.account_usd_1, category=self.category_food,
|
||||
type=Transaction.Type.EXPENSE, amount=Decimal('50.00'), date=today, is_paid=True, owner=self.user
|
||||
)
|
||||
# T2: Acc USD1, Food, Expense 20 (unpaid/projected)
|
||||
Transaction.objects.create(
|
||||
description="Restaurant USD1 Food Unpaid", account=self.account_usd_1, category=self.category_food,
|
||||
type=Transaction.Type.EXPENSE, amount=Decimal('20.00'), date=today, is_paid=False, owner=self.user
|
||||
)
|
||||
# T3: Acc USD2, Food, Expense 30 (paid)
|
||||
Transaction.objects.create(
|
||||
description="Snacks USD2 Food Paid", account=self.account_usd_2, category=self.category_food,
|
||||
type=Transaction.Type.EXPENSE, amount=Decimal('30.00'), date=today, is_paid=True, owner=self.user
|
||||
)
|
||||
# T4: Acc USD1, Salary, Income 1000 (paid)
|
||||
Transaction.objects.create(
|
||||
description="Salary USD1 Paid", account=self.account_usd_1, category=self.category_salary,
|
||||
type=Transaction.Type.INCOME, amount=Decimal('1000.00'), date=today, is_paid=True, owner=self.user
|
||||
)
|
||||
# T5: Acc EUR1, Food, Expense 40 (paid, different currency)
|
||||
Transaction.objects.create(
|
||||
description="Groceries EUR1 Food Paid", account=self.account_eur_1, category=self.category_food,
|
||||
type=Transaction.Type.EXPENSE, amount=Decimal('40.00'), date=today, is_paid=True, owner=self.user
|
||||
)
|
||||
# T6: Acc USD2, Salary, Income 200 (unpaid/projected)
|
||||
Transaction.objects.create(
|
||||
description="Bonus USD2 Salary Unpaid", account=self.account_usd_2, category=self.category_salary,
|
||||
type=Transaction.Type.INCOME, amount=Decimal('200.00'), date=today, is_paid=False, owner=self.user
|
||||
)
|
||||
|
||||
def test_get_category_sums_by_account_for_food(self):
|
||||
qs = Transaction.objects.filter(owner=self.user) # Filter by user for safety in shared DB environments
|
||||
result = get_category_sums_by_account(qs, category=self.category_food)
|
||||
|
||||
expected_labels = sorted([self.account_eur_1.name, self.account_usd_1.name, self.account_usd_2.name])
|
||||
self.assertEqual(result['labels'], expected_labels)
|
||||
|
||||
# Expected data structure: {account_name: {'current_income': D('0'), ...}, ...}
|
||||
# Then the util function transforms this.
|
||||
# Let's map labels to their expected index for easier assertion
|
||||
label_to_idx = {name: i for i, name in enumerate(expected_labels)}
|
||||
|
||||
# Initialize expected data arrays based on sorted labels length
|
||||
num_labels = len(expected_labels)
|
||||
expected_current_income = [Decimal('0.00')] * num_labels
|
||||
expected_current_expenses = [Decimal('0.00')] * num_labels
|
||||
expected_projected_income = [Decimal('0.00')] * num_labels
|
||||
expected_projected_expenses = [Decimal('0.00')] * num_labels
|
||||
|
||||
# Populate expected data based on transactions for FOOD category
|
||||
# T1: Acc USD1, Food, Expense 50 (paid) -> account_usd_1, current_expenses = -50
|
||||
expected_current_expenses[label_to_idx[self.account_usd_1.name]] = Decimal('-50.00')
|
||||
# T2: Acc USD1, Food, Expense 20 (unpaid/projected) -> account_usd_1, projected_expenses = -20
|
||||
expected_projected_expenses[label_to_idx[self.account_usd_1.name]] = Decimal('-20.00')
|
||||
# T3: Acc USD2, Food, Expense 30 (paid) -> account_usd_2, current_expenses = -30
|
||||
expected_current_expenses[label_to_idx[self.account_usd_2.name]] = Decimal('-30.00')
|
||||
# T5: Acc EUR1, Food, Expense 40 (paid) -> account_eur_1, current_expenses = -40
|
||||
expected_current_expenses[label_to_idx[self.account_eur_1.name]] = Decimal('-40.00')
|
||||
|
||||
self.assertEqual(result['datasets'][0]['data'], [float(x) for x in expected_current_income]) # Current Income
|
||||
self.assertEqual(result['datasets'][1]['data'], [float(x) for x in expected_current_expenses]) # Current Expenses
|
||||
self.assertEqual(result['datasets'][2]['data'], [float(x) for x in expected_projected_income]) # Projected Income
|
||||
self.assertEqual(result['datasets'][3]['data'], [float(x) for x in expected_projected_expenses]) # Projected Expenses
|
||||
|
||||
self.assertEqual(result['datasets'][0]['label'], "Current Income")
|
||||
self.assertEqual(result['datasets'][1]['label'], "Current Expenses")
|
||||
self.assertEqual(result['datasets'][2]['label'], "Projected Income")
|
||||
self.assertEqual(result['datasets'][3]['label'], "Projected Expenses")
|
||||
|
||||
def test_generate_sankey_data_by_account(self):
|
||||
qs = Transaction.objects.filter(owner=self.user)
|
||||
result = generate_sankey_data_by_account(qs)
|
||||
|
||||
nodes = result['nodes']
|
||||
flows = result['flows']
|
||||
|
||||
# Helper to find a node by a unique part of its ID
|
||||
def find_node_by_id_part(id_part):
|
||||
found_nodes = [n for n in nodes if id_part in n['id']]
|
||||
self.assertEqual(len(found_nodes), 1, f"Node with ID part '{id_part}' not found or not unique. Found: {found_nodes}")
|
||||
return found_nodes[0]
|
||||
|
||||
# Helper to find a flow by unique parts of its source and target node IDs
|
||||
def find_flow_by_node_id_parts(from_id_part, to_id_part):
|
||||
found_flows = [
|
||||
f for f in flows
|
||||
if from_id_part in f['from_node'] and to_id_part in f['to_node']
|
||||
]
|
||||
self.assertEqual(len(found_flows), 1, f"Flow from '{from_id_part}' to '{to_id_part}' not found or not unique. Found: {found_flows}")
|
||||
return found_flows[0]
|
||||
|
||||
# Calculate total volumes by currency (sum of absolute amounts of ALL transactions)
|
||||
total_volume_usd = sum(abs(t.amount) for t in qs if t.account.currency == self.currency_usd) # 50+20+30+1000+200 = 1300
|
||||
total_volume_eur = sum(abs(t.amount) for t in qs if t.account.currency == self.currency_eur) # 40
|
||||
|
||||
self.assertEqual(total_volume_usd, Decimal('1300.00'))
|
||||
self.assertEqual(total_volume_eur, Decimal('40.00'))
|
||||
|
||||
# --- Assertions for Account USD 1 ---
|
||||
acc_usd_1_id_part = f"_{self.account_usd_1.id}"
|
||||
|
||||
node_income_salary_usd1 = find_node_by_id_part(f"income_{self.category_salary.name.lower()}{acc_usd_1_id_part}")
|
||||
self.assertEqual(node_income_salary_usd1['name'], self.category_salary.name)
|
||||
|
||||
node_account_usd1 = find_node_by_id_part(f"account_{self.account_usd_1.name.lower().replace(' ', '_')}{acc_usd_1_id_part}")
|
||||
self.assertEqual(node_account_usd1['name'], self.account_usd_1.name)
|
||||
|
||||
node_expense_food_usd1 = find_node_by_id_part(f"expense_{self.category_food.name.lower()}{acc_usd_1_id_part}")
|
||||
self.assertEqual(node_expense_food_usd1['name'], self.category_food.name)
|
||||
|
||||
node_saved_usd1 = find_node_by_id_part(f"savings_saved{acc_usd_1_id_part}")
|
||||
self.assertEqual(node_saved_usd1['name'], _("Saved"))
|
||||
|
||||
# Flow 1: Salary (T4) to account_usd_1
|
||||
flow_salary_to_usd1 = find_flow_by_node_id_parts(node_income_salary_usd1['id'], node_account_usd1['id'])
|
||||
self.assertEqual(flow_salary_to_usd1['original_amount'], 1000.0)
|
||||
self.assertEqual(flow_salary_to_usd1['currency']['code'], self.currency_usd.code)
|
||||
self.assertAlmostEqual(flow_salary_to_usd1['percentage'], (1000.0 / float(total_volume_usd)) * 100, places=2)
|
||||
self.assertAlmostEqual(flow_salary_to_usd1['flow'], (1000.0 / float(total_volume_usd)), places=4)
|
||||
|
||||
# Flow 2: account_usd_1 to Food (T1)
|
||||
flow_usd1_to_food = find_flow_by_node_id_parts(node_account_usd1['id'], node_expense_food_usd1['id'])
|
||||
self.assertEqual(flow_usd1_to_food['original_amount'], 50.0) # T1 is 50
|
||||
self.assertEqual(flow_usd1_to_food['currency']['code'], self.currency_usd.code)
|
||||
self.assertAlmostEqual(flow_usd1_to_food['percentage'], (50.0 / float(total_volume_usd)) * 100, places=2)
|
||||
|
||||
# Flow 3: account_usd_1 to Saved
|
||||
# Net paid for account_usd_1: 1000 (T4 income) - 50 (T1 expense) = 950
|
||||
flow_usd1_to_saved = find_flow_by_node_id_parts(node_account_usd1['id'], node_saved_usd1['id'])
|
||||
self.assertEqual(flow_usd1_to_saved['original_amount'], 950.0)
|
||||
self.assertEqual(flow_usd1_to_saved['currency']['code'], self.currency_usd.code)
|
||||
self.assertAlmostEqual(flow_usd1_to_saved['percentage'], (950.0 / float(total_volume_usd)) * 100, places=2)
|
||||
|
||||
# --- Assertions for Account USD 2 ---
|
||||
acc_usd_2_id_part = f"_{self.account_usd_2.id}"
|
||||
node_account_usd2 = find_node_by_id_part(f"account_{self.account_usd_2.name.lower().replace(' ', '_')}{acc_usd_2_id_part}")
|
||||
node_expense_food_usd2 = find_node_by_id_part(f"expense_{self.category_food.name.lower()}{acc_usd_2_id_part}")
|
||||
# T6 (Salary for USD2) is unpaid, so no income node/flow for it.
|
||||
# Net paid for account_usd_2 is -30 (T3 expense). So no "Saved" node.
|
||||
|
||||
# Flow: account_usd_2 to Food (T3)
|
||||
flow_usd2_to_food = find_flow_by_node_id_parts(node_account_usd2['id'], node_expense_food_usd2['id'])
|
||||
self.assertEqual(flow_usd2_to_food['original_amount'], 30.0) # T3 is 30
|
||||
self.assertEqual(flow_usd2_to_food['currency']['code'], self.currency_usd.code)
|
||||
self.assertAlmostEqual(flow_usd2_to_food['percentage'], (30.0 / float(total_volume_usd)) * 100, places=2)
|
||||
|
||||
# Check no "Saved" node for account_usd_2
|
||||
saved_nodes_usd2 = [n for n in nodes if f"savings_saved{acc_usd_2_id_part}" in n['id']]
|
||||
self.assertEqual(len(saved_nodes_usd2), 0, "Should be no 'Saved' node for account_usd_2 as net is negative.")
|
||||
|
||||
# --- Assertions for Account EUR 1 ---
|
||||
acc_eur_1_id_part = f"_{self.account_eur_1.id}"
|
||||
node_account_eur1 = find_node_by_id_part(f"account_{self.account_eur_1.name.lower().replace(' ', '_')}{acc_eur_1_id_part}")
|
||||
node_expense_food_eur1 = find_node_by_id_part(f"expense_{self.category_food.name.lower()}{acc_eur_1_id_part}")
|
||||
# Net paid for account_eur_1 is -40 (T5 expense). No "Saved" node.
|
||||
|
||||
# Flow: account_eur_1 to Food (T5)
|
||||
flow_eur1_to_food = find_flow_by_node_id_parts(node_account_eur1['id'], node_expense_food_eur1['id'])
|
||||
self.assertEqual(flow_eur1_to_food['original_amount'], 40.0) # T5 is 40
|
||||
self.assertEqual(flow_eur1_to_food['currency']['code'], self.currency_eur.code)
|
||||
self.assertAlmostEqual(flow_eur1_to_food['percentage'], (40.0 / float(total_volume_eur)) * 100, places=2) # (40/40)*100 = 100%
|
||||
|
||||
# Check no "Saved" node for account_eur_1
|
||||
saved_nodes_eur1 = [n for n in nodes if f"savings_saved{acc_eur_1_id_part}" in n['id']]
|
||||
self.assertEqual(len(saved_nodes_eur1), 0, "Should be no 'Saved' node for account_eur_1 as net is negative.")
|
||||
|
||||
def test_get_category_sums_by_currency_for_food(self):
|
||||
qs = Transaction.objects.filter(owner=self.user)
|
||||
result = get_category_sums_by_currency(qs, category=self.category_food)
|
||||
|
||||
expected_labels = sorted([self.currency_eur.name, self.currency_usd.name])
|
||||
self.assertEqual(result['labels'], expected_labels)
|
||||
|
||||
label_to_idx = {name: i for i, name in enumerate(expected_labels)}
|
||||
num_labels = len(expected_labels)
|
||||
|
||||
expected_current_income = [Decimal('0.00')] * num_labels
|
||||
expected_current_expenses = [Decimal('0.00')] * num_labels
|
||||
expected_projected_income = [Decimal('0.00')] * num_labels
|
||||
expected_projected_expenses = [Decimal('0.00')] * num_labels
|
||||
|
||||
# Food Transactions:
|
||||
# T1: USD Account 1, Food, Expense 50 (paid)
|
||||
# T2: USD Account 1, Food, Expense 20 (unpaid/projected)
|
||||
# T3: USD Account 2, Food, Expense 30 (paid)
|
||||
# T5: EUR Account 1, Food, Expense 40 (paid)
|
||||
|
||||
# Current Expenses:
|
||||
expected_current_expenses[label_to_idx[self.currency_eur.name]] = Decimal('-40.00') # T5
|
||||
expected_current_expenses[label_to_idx[self.currency_usd.name]] = Decimal('-50.00') + Decimal('-30.00') # T1 + T3
|
||||
|
||||
# Projected Expenses:
|
||||
expected_projected_expenses[label_to_idx[self.currency_usd.name]] = Decimal('-20.00') # T2
|
||||
|
||||
self.assertEqual(result['datasets'][0]['data'], [float(x) for x in expected_current_income])
|
||||
self.assertEqual(result['datasets'][1]['data'], [float(x) for x in expected_current_expenses])
|
||||
self.assertEqual(result['datasets'][2]['data'], [float(x) for x in expected_projected_income])
|
||||
self.assertEqual(result['datasets'][3]['data'], [float(x) for x in expected_projected_expenses])
|
||||
|
||||
self.assertEqual(result['datasets'][0]['label'], "Current Income")
|
||||
self.assertEqual(result['datasets'][1]['label'], "Current Expenses")
|
||||
self.assertEqual(result['datasets'][2]['label'], "Projected Income")
|
||||
self.assertEqual(result['datasets'][3]['label'], "Projected Expenses")
|
||||
|
||||
def test_get_category_sums_by_currency_for_salary(self):
|
||||
qs = Transaction.objects.filter(owner=self.user)
|
||||
result = get_category_sums_by_currency(qs, category=self.category_salary)
|
||||
|
||||
# Salary Transactions:
|
||||
# T4: USD Account 1, Salary, Income 1000 (paid)
|
||||
# T6: USD Account 2, Salary, Income 200 (unpaid/projected)
|
||||
# All are USD
|
||||
expected_labels = [self.currency_usd.name] # Only USD has salary transactions
|
||||
self.assertEqual(result['labels'], expected_labels)
|
||||
|
||||
label_to_idx = {name: i for i, name in enumerate(expected_labels)}
|
||||
num_labels = len(expected_labels)
|
||||
|
||||
expected_current_income = [Decimal('0.00')] * num_labels
|
||||
expected_current_expenses = [Decimal('0.00')] * num_labels
|
||||
expected_projected_income = [Decimal('0.00')] * num_labels
|
||||
expected_projected_expenses = [Decimal('0.00')] * num_labels
|
||||
|
||||
# Current Income:
|
||||
expected_current_income[label_to_idx[self.currency_usd.name]] = Decimal('1000.00') # T4
|
||||
|
||||
# Projected Income:
|
||||
expected_projected_income[label_to_idx[self.currency_usd.name]] = Decimal('200.00') # T6
|
||||
|
||||
self.assertEqual(result['datasets'][0]['data'], [float(x) for x in expected_current_income])
|
||||
self.assertEqual(result['datasets'][1]['data'], [float(x) for x in expected_current_expenses])
|
||||
self.assertEqual(result['datasets'][2]['data'], [float(x) for x in expected_projected_income])
|
||||
self.assertEqual(result['datasets'][3]['data'], [float(x) for x in expected_projected_expenses])
|
||||
|
||||
self.assertEqual(result['datasets'][0]['label'], "Current Income")
|
||||
self.assertEqual(result['datasets'][1]['label'], "Current Expenses")
|
||||
self.assertEqual(result['datasets'][2]['label'], "Projected Income")
|
||||
self.assertEqual(result['datasets'][3]['label'], "Projected Expenses")
|
||||
|
||||
|
||||
def test_get_category_sums_by_account_for_salary(self):
|
||||
qs = Transaction.objects.filter(owner=self.user)
|
||||
result = get_category_sums_by_account(qs, category=self.category_salary)
|
||||
|
||||
# Only accounts with salary transactions should appear
|
||||
expected_labels = sorted([self.account_usd_1.name, self.account_usd_2.name])
|
||||
self.assertEqual(result['labels'], expected_labels)
|
||||
|
||||
label_to_idx = {name: i for i, name in enumerate(expected_labels)}
|
||||
num_labels = len(expected_labels)
|
||||
|
||||
expected_current_income = [Decimal('0.00')] * num_labels
|
||||
expected_current_expenses = [Decimal('0.00')] * num_labels
|
||||
expected_projected_income = [Decimal('0.00')] * num_labels
|
||||
expected_projected_expenses = [Decimal('0.00')] * num_labels
|
||||
|
||||
# Populate expected data based on transactions for SALARY category
|
||||
# T4: Acc USD1, Salary, Income 1000 (paid) -> account_usd_1, current_income = 1000
|
||||
expected_current_income[label_to_idx[self.account_usd_1.name]] = Decimal('1000.00')
|
||||
# T6: Acc USD2, Salary, Income 200 (unpaid/projected) -> account_usd_2, projected_income = 200
|
||||
expected_projected_income[label_to_idx[self.account_usd_2.name]] = Decimal('200.00')
|
||||
|
||||
self.assertEqual(result['datasets'][0]['data'], [float(x) for x in expected_current_income])
|
||||
self.assertEqual(result['datasets'][1]['data'], [float(x) for x in expected_current_expenses])
|
||||
self.assertEqual(result['datasets'][2]['data'], [float(x) for x in expected_projected_income])
|
||||
self.assertEqual(result['datasets'][3]['data'], [float(x) for x in expected_projected_expenses])
|
||||
|
||||
self.assertEqual(result['datasets'][0]['label'], "Current Income")
|
||||
self.assertEqual(result['datasets'][1]['label'], "Current Expenses")
|
||||
self.assertEqual(result['datasets'][2]['label'], "Projected Income")
|
||||
self.assertEqual(result['datasets'][3]['label'], "Projected Expenses")
|
||||
52
app/apps/insights/urls.py
Normal file
52
app/apps/insights/urls.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
urlpatterns = [
|
||||
path("insights/", views.index, name="insights_index"),
|
||||
path(
|
||||
"insights/sankey/account/",
|
||||
views.sankey_by_account,
|
||||
name="insights_sankey_by_account",
|
||||
),
|
||||
path(
|
||||
"insights/sankey/currency/",
|
||||
views.sankey_by_currency,
|
||||
name="insights_sankey_by_currency",
|
||||
),
|
||||
path(
|
||||
"insights/category-explorer/",
|
||||
views.category_explorer_index,
|
||||
name="category_explorer_index",
|
||||
),
|
||||
path(
|
||||
"insights/category-explorer/account/",
|
||||
views.category_sum_by_account,
|
||||
name="category_sum_by_account",
|
||||
),
|
||||
path(
|
||||
"insights/category-explorer/currency/",
|
||||
views.category_sum_by_currency,
|
||||
name="category_sum_by_currency",
|
||||
),
|
||||
path(
|
||||
"insights/category-overview/",
|
||||
views.category_overview,
|
||||
name="category_overview",
|
||||
),
|
||||
path(
|
||||
"insights/late-transactions/",
|
||||
views.late_transactions,
|
||||
name="insights_late_transactions",
|
||||
),
|
||||
path(
|
||||
"insights/latest-transactions/",
|
||||
views.latest_transactions,
|
||||
name="insights_latest_transactions",
|
||||
),
|
||||
path(
|
||||
"insights/emergency-fund/",
|
||||
views.emergency_fund,
|
||||
name="insights_emergency_fund",
|
||||
),
|
||||
]
|
||||
0
app/apps/insights/utils/__init__.py
Normal file
0
app/apps/insights/utils/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user