mirror of
https://github.com/eitchtee/WYGIWYH.git
synced 2026-05-20 06:37:24 +02:00
Compare commits
386 Commits
internal_port
...
weblate
| Author | SHA1 | Date | |
|---|---|---|---|
| b34c0557fa | |||
| 2af4066aab | |||
| d72ff3cdf5 | |||
| 63c69e5c6a | |||
| 78171183cc | |||
| 34a2b6bfd4 | |||
| 1dc24f855e | |||
| 1390aff07d | |||
| 8fc11b0acf | |||
| 9a30a0d3c0 | |||
| 10eecd09ff | |||
| 2cfb3fb12e | |||
| 47af8b135b | |||
| 39d0e63375 | |||
| 792154eba2 | |||
| dc76ed3156 | |||
| e627dd50be | |||
| 5527389196 | |||
| be24ca014e | |||
| 7c7056536e | |||
| 66d5d7a83b | |||
| 4d3ce087d6 | |||
| aeaf9fac43 | |||
| dcedf53b83 | |||
| 549648bd6b | |||
| 79149abdd2 | |||
| d66d1530bb | |||
| 4b9b6484d3 | |||
| a02944bdae | |||
| 3ede5304f1 | |||
| 27041695b8 | |||
| 0c927a2fe9 | |||
| c52db80c64 | |||
| 330ce8069c | |||
| 2989f11b01 | |||
| 738bb7fb74 | |||
| 79e50cd853 | |||
| 0a23c3ad5b | |||
| 43c7749102 | |||
| c1c4ccda8c | |||
| 615a689c61 | |||
| c5ccc42f99 | |||
| 2baa8b21e8 | |||
| 2e554141ba | |||
| 73ec6dc0fe | |||
| e19449ff99 | |||
| e81651119c | |||
| 55e9ef1b3f | |||
| c414179135 | |||
| 14c507de0f | |||
| 4722690fe9 | |||
| 493619a4ff | |||
| fb4aec88f1 | |||
| 4a35e770a4 | |||
| 83b81edbae | |||
| a7dc2c955e | |||
| ce2ae562c6 | |||
| de2881ffd4 | |||
| 838bf22498 | |||
| d3797ae4a5 | |||
| 0532397afd | |||
| 8106dc58e5 | |||
| 5986cf675b | |||
| 80da9142f1 | |||
| 766516d248 | |||
| 3fd0fba1b8 | |||
| c787565c04 | |||
| 0413921dbe | |||
| 9ecf8279b4 | |||
| 86cf625158 | |||
| ea097ab6f0 | |||
| b1201b51bb | |||
| 4c1d20215c | |||
| 27e85c4776 | |||
| 5a73cd20da | |||
| e305fab300 | |||
| c11f525373 | |||
| ea5d86dbf8 | |||
| a1d3539e3c | |||
| 1028a11c8b | |||
| e387a5e2a8 | |||
| 624dc382cf | |||
| f88699b333 | |||
| ca98dc073b | |||
| 63ba7af3c8 | |||
| 2d0dee4a9b | |||
| 0000a9ee03 | |||
| 41adb37fdb | |||
| 496651173e | |||
| 8836f06b80 | |||
| e98a48b3a7 | |||
| f9bc9f449b | |||
| 26eb1ae813 | |||
| 29a2cb9813 | |||
| be79e1b25a | |||
| 3fd08466a7 | |||
| 6896cdcdca | |||
| 2532930a64 | |||
| 24a1ef2d0a | |||
| 163f2f4e5b | |||
| ede63acf5f | |||
| a8ba3d8754 | |||
| e2f1156264 | |||
| d5bbad7887 | |||
| 7ebacff6e4 | |||
| df8ef5d04c | |||
| fa2a8b8c65 | |||
| e44ac5dab6 | |||
| f9261d1283 | |||
| 4c73c1cae5 | |||
| 0315a56f88 | |||
| 44d6b8b53c | |||
| 3e4d7c6b1f | |||
| 63868514f9 | |||
| 9055a24327 | |||
| 9dc963ed7b | |||
| 49cac0588e | |||
| 3b2b6d6473 | |||
| db30bcbeb7 | |||
| a122733a47 | |||
| 37f3e4d99a | |||
| d756286135 | |||
| 06a7378fd8 | |||
| ab4075c500 | |||
| 96318f003d | |||
| 1a0412264a | |||
| 2588404876 | |||
| fdc273103b | |||
| c015b78cd6 | |||
| 50e5492ea1 | |||
| 796089cdb3 | |||
| c83b1bf2d6 | |||
| b074ef7929 | |||
| ec7e33b3b0 | |||
| 72fedea0db | |||
| 0a03745ce6 | |||
| ff4bd79634 | |||
| 383b42e26d | |||
| 48e43ac031 | |||
| 21c60c4059 | |||
| dd6a390e6b | |||
| 0c961a8250 | |||
| e28c651973 | |||
| 7687ff81c3 | |||
| b2d78c9190 | |||
| b0815e00c7 | |||
| fbe9726338 | |||
| 0df3a57a33 | |||
| f86613b17a | |||
| ffa4644e1b | |||
| 6611559696 | |||
| b455a0251a | |||
| 9d7c3212f1 | |||
| 0da3185996 | |||
| 6c90e1bb7f | |||
| c6543c0841 | |||
| d4740b8406 | |||
| 5a51795e6a | |||
| 64d7765357 | |||
| 070e11ca77 | |||
| 39f66b620a | |||
| ad164866e0 | |||
| 05c465cb34 | |||
| 92cf526b76 | |||
| 639236b890 | |||
| 519a85d256 | |||
| 700d35b5d5 | |||
| 10e51971db | |||
| ec0d5fc121 | |||
| 01f91352d6 | |||
| 63ce57a315 | |||
| eadeb649a1 | |||
| a2871d5289 | |||
| f2a362bc0f | |||
| 2076903740 | |||
| c752c0b16e | |||
| 1674766253 | |||
| 7ea9d56132 | |||
| 3699c6c671 | |||
| d7c255aa14 | |||
| d17b9d5736 | |||
| c7ff6db0bf | |||
| a4c7753f69 | |||
| 7e08028557 | |||
| 5eaf5086d2 | |||
| c949c6cea0 | |||
| 71c0e9a271 | |||
| bc65980511 | |||
| ecdb1a52cc | |||
| afc06582b4 | |||
| 07cb0a2a0f | |||
| 05ede58c36 | |||
| 20b6366a18 | |||
| b0101dae1a | |||
| a3d38ff9e0 | |||
| 776e2117a0 | |||
| edcad37926 | |||
| 2d51d21035 | |||
| 94f5c25829 | |||
| 88a5c103e5 | |||
| 3dce9e1c55 | |||
| 41d8564e8b | |||
| 5ee2fd244f | |||
| 0545fb7651 | |||
| 7bd1d2d751 | |||
| 9a4ec449df | |||
| f918351303 | |||
| ef66b3a1e5 | |||
| 7486660223 | |||
| 00d5ccda34 | |||
| 1656eec601 | |||
| 64b96ed2f3 | |||
| 1f5e4f132d | |||
| edf056b68c | |||
| 35865ce21c | |||
| 8f06c06d32 | |||
| 15eaa2239a | |||
| fd7214df95 | |||
| e531c63de3 | |||
| 5a79dd5424 | |||
| 315dd1479a | |||
| 67f79effab | |||
| c168886968 | |||
| 272c34d3b3 | |||
| 43ce79ae65 | |||
| 4aa29545ec | |||
| fd1fcb832c | |||
| b5fd928a5d | |||
| 2dc398f82b | |||
| cf7d4b1404 | |||
| e9c3af1a85 | |||
| b121e8e982 | |||
| 606e6b3843 | |||
| 6e46b5abb8 | |||
| 5b4dab93a1 | |||
| 29b6ee3af3 | |||
| 484686b709 | |||
| 938c128d07 | |||
| 8123f7f3cb | |||
| 547dc90d9e | |||
| dc33fda5d3 | |||
| 92960d1b9a | |||
| 1978a467cb | |||
| 5bdafbba91 | |||
| 16de87376a | |||
| e8e1144fdd | |||
| 157f357a7a | |||
| d77eddbd26 | |||
| fb1b383962 | |||
| 11998475c5 | |||
| 21363e23a1 | |||
| d3a816d91b | |||
| 9c92bbd3cf | |||
| c55d688956 | |||
| 231b9065c9 | |||
| 01ea0de4b3 | |||
| c57fa1630b | |||
| 92f7bcfd9e | |||
| 58b855f55e | |||
| d4d51301b3 | |||
| aed3fb11fe | |||
| 70d427bec4 | |||
| b6f52458db | |||
| 8d76c40b7e | |||
| a43e3d158f | |||
| 588ae2de6e | |||
| 4b97ba681a | |||
| 1a903507ad | |||
| bf920df771 | |||
| 23ae6f3d54 | |||
| 49f28834e9 | |||
| 4351027b87 | |||
| c37aa6e059 | |||
| 8a5a54dcbd | |||
| 24ee8ecd68 | |||
| a14332bb80 | |||
| 32747071fe | |||
| 24fa9cde51 | |||
| 372ec2f30f | |||
| fffba037a6 | |||
| 43488147d8 | |||
| 31a31e9922 | |||
| 7af6280b29 | |||
| 40389396e3 | |||
| 21845d501e | |||
| 5f098e11a3 | |||
| d2de0684fb | |||
| eb4723e890 | |||
| 890cc90420 | |||
| 307af9e40a | |||
| 1eeb0b0f5e | |||
| 605ece705e | |||
| 2ae57e83cb | |||
| af72e3f44e | |||
| e2e1c5cff5 | |||
| ed3d58f1fd | |||
| b58f894dc6 | |||
| 2ed7fa44c0 | |||
| 7c3120cd43 | |||
| 2bc5e24e51 | |||
| d3f8a637bc | |||
| b02b6451d2 | |||
| 0b0d760bab | |||
| b38ed37bc5 | |||
| 7e37948616 | |||
| 2afb6b1f5f | |||
| cd54df6f2d | |||
| 3e4ace8993 | |||
| a878af28f1 | |||
| 0a4d4c12b9 | |||
| 9ade58a003 | |||
| 89b2d0118d | |||
| 232d5003b8 | |||
| 133d70d3d1 | |||
| e70608eaaf | |||
| a63367a772 | |||
| baef86b6cb | |||
| 3011b32fa6 | |||
| 910decfe00 | |||
| e600d87968 | |||
| dd82289488 | |||
| 1e816ec80a | |||
| 3b5626cbd1 | |||
| a819ceaa43 | |||
| de28dbb0f0 | |||
| cfb34a4dc3 | |||
| efdcfc192a | |||
| a7856a6671 | |||
| 7b8e3b528a | |||
| cc3244a034 | |||
| 2121a68c82 | |||
| f35002f862 | |||
| 73a992256d | |||
| 9f1098d6b9 | |||
| 2c0936b7e5 | |||
| 5fb717c3fe | |||
| c5f94fb34d | |||
| 29cdec4577 | |||
| 82efd48e53 | |||
| 5a3a0b7e5c | |||
| 41a5900f12 | |||
| 2dbdd02350 | |||
| fa0cde1a4e | |||
| 623d91d26f | |||
| 57200437dc | |||
| 6f4a2b687c | |||
| 8bb40be41c | |||
| 66c1cf2371 | |||
| 4b23836544 | |||
| 585af1270f | |||
| a0cc51b2ec | |||
| 6a5de7d94d | |||
| 6d9687de0b | |||
| e9acf1dd8f | |||
| 698e05bd06 | |||
| 90b3778e36 | |||
| 85a773bc01 | |||
| 355016a7a5 | |||
| f04fcf99b7 | |||
| 0fb389e7e8 | |||
| 63898aeef0 | |||
| 4fdf00d098 | |||
| 025cc585d5 | |||
| 17018d87cd | |||
| 1e5f4f6583 | |||
| a99851cf9b | |||
| 9fb1ad4861 | |||
| 66c3abfe37 | |||
| 8ca64f5820 | |||
| e743821570 | |||
| 5c698d8735 | |||
| 3e5aa90df0 | |||
| b2add14238 | |||
| a052c00aa8 | |||
| 7f343708e0 | |||
| 22e95c7f4a | |||
| 7645153f77 | |||
| 1abfed9abf | |||
| eea0ab009d | |||
| 29446def22 | |||
| 9dce5e9efe | |||
| 695e2cb322 | |||
| b135ec3b15 | |||
| bb3cc5da6c | |||
| ca7fe24a8a | |||
| 483ba74010 |
@@ -0,0 +1 @@
|
||||
__pycache__/
|
||||
@@ -12,7 +12,7 @@ on:
|
||||
required: true
|
||||
type: string
|
||||
ref:
|
||||
description: 'Git ref to checkout (branch, tag, or SHA)'
|
||||
description: 'Git ref to checkout'
|
||||
required: true
|
||||
default: 'main'
|
||||
type: string
|
||||
@@ -29,73 +29,57 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write # Needed if you switch to GHCR, good practice
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.ref }}
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
|
||||
- name: Checkout code (non-manual)
|
||||
uses: actions/checkout@v4
|
||||
if: github.event_name != 'workflow_dispatch'
|
||||
ref: ${{ inputs.ref || github.ref }}
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# This action handles all the logic for tags (nightly vs release vs custom)
|
||||
- name: Docker Metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
# Logic for Push to Main -> nightly
|
||||
type=raw,value=nightly,enable=${{ github.event_name == 'push' }}
|
||||
# Logic for Release -> semver and latest
|
||||
type=semver,pattern={{version}},enable=${{ github.event_name == 'release' }}
|
||||
type=raw,value=latest,enable=${{ github.event_name == 'release' }}
|
||||
# Logic for Manual Dispatch -> custom input
|
||||
type=raw,value=${{ inputs.tag }},enable=${{ github.event_name == 'workflow_dispatch' }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build and push nightly image
|
||||
if: github.event_name == 'push'
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/prod/django/Dockerfile
|
||||
push: true
|
||||
provenance: false
|
||||
# Pass the calculated tags from the meta step
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
VERSION=nightly
|
||||
tags: ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:nightly
|
||||
VERSION=${{ steps.meta.outputs.version }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Build and push release image
|
||||
if: github.event_name == 'release'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/prod/django/Dockerfile
|
||||
push: true
|
||||
provenance: false
|
||||
build-args: |
|
||||
VERSION=${{ github.event.release.tag_name }}
|
||||
tags: |
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:latest
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:${{ github.event.release.tag_name }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Build and push custom image
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/prod/django/Dockerfile
|
||||
push: true
|
||||
provenance: false
|
||||
build-args: |
|
||||
VERSION=${{ github.event.inputs.tag }}
|
||||
tags: ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:${{ github.event.inputs.tag }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
# --- CACHE CONFIGURATION ---
|
||||
# We set a specific 'scope' key.
|
||||
# This allows the Release tag to see the cache created by the Main branch.
|
||||
cache-from: type=gha,scope=build-cache
|
||||
cache-to: type=gha,mode=max,scope=build-cache
|
||||
|
||||
@@ -32,15 +32,16 @@ jobs:
|
||||
token: ${{ secrets.PAT }}
|
||||
ref: ${{ github.head_ref }}
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
enable-cache: true
|
||||
|
||||
- name: Set up Python 3.11
|
||||
run: uv python install 3.11
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
run: uv sync --frozen --no-dev
|
||||
|
||||
- name: Install gettext
|
||||
run: sudo apt-get install -y gettext
|
||||
@@ -48,7 +49,7 @@ jobs:
|
||||
- name: Run makemessages
|
||||
run: |
|
||||
cd app
|
||||
python manage.py makemessages -a
|
||||
uv run python manage.py makemessages -a
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
|
||||
@@ -123,6 +123,7 @@ celerybeat.pid
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.prod.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
@@ -161,5 +162,6 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
node_modules/
|
||||
postgres_data/
|
||||
.prod.env
|
||||
Vendored
+29
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Docker: Dev",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"command": "docker compose --env-file .env -f docker-compose.dev.yml up --build",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"postDebugTask": "Docker: Dev Down"
|
||||
},
|
||||
{
|
||||
"name": "Docker: Dev (no rebuild)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"command": "docker compose --env-file .env -f docker-compose.dev.yml up",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"postDebugTask": "Docker: Dev Down"
|
||||
},
|
||||
{
|
||||
"name": "Docker: Prod",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"command": "docker compose --env-file .prod.env -f docker-compose.prod.yml up --build",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"postDebugTask": "Docker: Prod Down"
|
||||
}
|
||||
]
|
||||
}
|
||||
Vendored
+8
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"djlint.showInstallError": false,
|
||||
"files.associations": {
|
||||
"*.css": "tailwindcss"
|
||||
},
|
||||
"tailwindCSS.experimental.configFile": "frontend/src/styles/tailwind.css",
|
||||
"djlint.profile": "django",
|
||||
}
|
||||
Vendored
+119
@@ -0,0 +1,119 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"label": "Docker: Dev",
|
||||
"type": "shell",
|
||||
"command": "docker",
|
||||
"args": [
|
||||
"compose",
|
||||
"--env-file",
|
||||
".env",
|
||||
"-f",
|
||||
"docker-compose.dev.yml",
|
||||
"up",
|
||||
"--build"
|
||||
],
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"group": "build",
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Docker: Dev (no rebuild)",
|
||||
"type": "shell",
|
||||
"command": "docker",
|
||||
"args": [
|
||||
"compose",
|
||||
"--env-file",
|
||||
".env",
|
||||
"-f",
|
||||
"docker-compose.dev.yml",
|
||||
"up"
|
||||
],
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Docker: Dev Refresh Vite Deps",
|
||||
"type": "shell",
|
||||
"command": "docker compose --env-file .env -f docker-compose.dev.yml rm -sfv vite; docker compose --env-file .env -f docker-compose.dev.yml up --build",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Docker: Dev Down",
|
||||
"type": "shell",
|
||||
"command": "docker",
|
||||
"args": [
|
||||
"compose",
|
||||
"--env-file",
|
||||
".env",
|
||||
"-f",
|
||||
"docker-compose.dev.yml",
|
||||
"down"
|
||||
],
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Docker: Prod",
|
||||
"type": "shell",
|
||||
"command": "docker",
|
||||
"args": [
|
||||
"compose",
|
||||
"--env-file",
|
||||
".prod.env",
|
||||
"-f",
|
||||
"docker-compose.prod.yml",
|
||||
"up",
|
||||
"--build"
|
||||
],
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Docker: Prod Down",
|
||||
"type": "shell",
|
||||
"command": "docker",
|
||||
"args": [
|
||||
"compose",
|
||||
"--env-file",
|
||||
".prod.env",
|
||||
"-f",
|
||||
"docker-compose.prod.yml",
|
||||
"down"
|
||||
],
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": []
|
||||
},
|
||||
{
|
||||
"label": "Django: Runserver localhost:8000",
|
||||
"type": "shell",
|
||||
"command": "${command:python.interpreterPath}",
|
||||
"args": [
|
||||
"manage.py",
|
||||
"runserver",
|
||||
"localhost:8000"
|
||||
],
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}/app",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
}
|
||||
},
|
||||
"problemMatcher": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -13,6 +13,7 @@
|
||||
<a href="#key-features">Features</a> •
|
||||
<a href="#how-to-use">Usage</a> •
|
||||
<a href="#how-it-works">How</a> •
|
||||
<a href="#mcp-server">MCP Server</a> •
|
||||
<a href="#help-us-translate-wygiwyh">Translate</a> •
|
||||
<a href="#caveats-and-warnings">Caveats and Warnings</a> •
|
||||
<a href="#built-with">Built with</a>
|
||||
@@ -144,7 +145,10 @@ To create the first user, open the container's console using Unraid's UI, by cli
|
||||
| DEMO | true\|false | false | If demo mode is enabled. |
|
||||
| ADMIN_EMAIL | string | None | Automatically creates an admin account with this email. Must have `ADMIN_PASSWORD` also set. |
|
||||
| ADMIN_PASSWORD | string | None | Automatically creates an admin account with this password. Must have `ADMIN_EMAIL` also set. |
|
||||
| CHECK_FOR_UPDATES | bool | true | Check and notify users about new versions. The check is done by doing a single query to Github's API every 12 hours. |
|
||||
| CHECK_FOR_UPDATES | true\|false | true | Check and notify users about new versions. The check is done by doing a single query to Github's API every 12 hours. |
|
||||
| DJANGO_VITE_DEV_MODE | true\|false | false | Enables Vite dev server mode for frontend development. When true, assets are served from Vite's dev server instead of the build manifest. For development only! |
|
||||
| DJANGO_VITE_DEV_SERVER_PORT | int | 5173 | The port where Vite's dev server is running. Only used when DJANGO_VITE_DEV_MODE is true. For development only! |
|
||||
| DJANGO_VITE_DEV_SERVER_HOST | string | localhost | The host where Vite's dev server is running. Only used when DJANGO_VITE_DEV_MODE is true. For development only! |
|
||||
|
||||
## OIDC Configuration
|
||||
|
||||
@@ -153,6 +157,13 @@ WYGIWYH supports login via OpenID Connect (OIDC) through `django-allauth`. This
|
||||
> [!NOTE]
|
||||
> Currently only OpenID Connect is supported as a provider, open an issue if you need something else.
|
||||
|
||||
> [!Caution]
|
||||
> WYGIWYH automatically connects OIDC accounts to existing local accounts with matching email addresses.
|
||||
> This means if a user already exists with email `user@example.com` and someone logs in via OIDC with the same email, the OIDC account will be automatically linked to the existing account without requiring user confirmation.
|
||||
> This is only recommended for trusted OIDC providers that verify email addresses and where you control who can create accounts.
|
||||
|
||||
### Configuration
|
||||
|
||||
To configure OIDC, you need to set the following environment variables:
|
||||
|
||||
| Variable | Description |
|
||||
@@ -183,6 +194,10 @@ Check out our [Wiki](https://github.com/eitchtee/WYGIWYH/wiki) for more informat
|
||||
> [!NOTE]
|
||||
> Login with your github account
|
||||
|
||||
# MCP Server
|
||||
|
||||
[IZIme07](https://github.com/IZIme07) has kindly created an MCP Server for WYGIWYH that you can self-host. [Check it out at MCP-WYGIWYH](https://github.com/ReNewator/MCP-WYGIWYH)!
|
||||
|
||||
# Caveats and Warnings
|
||||
|
||||
- I'm not an accountant, some terms and even calculations might be wrong. Make sure to open an issue if you see anything that could be improved.
|
||||
|
||||
+68
-31
@@ -11,6 +11,7 @@ https://docs.djangoproject.com/en/5.1/ref/settings/
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -46,7 +47,7 @@ INSTALLED_APPS = [
|
||||
"django.contrib.sites",
|
||||
"whitenoise.runserver_nostatic",
|
||||
"django.contrib.staticfiles",
|
||||
"webpack_boilerplate",
|
||||
"django_vite",
|
||||
"django.contrib.humanize",
|
||||
"django.contrib.postgres",
|
||||
"django_browser_reload",
|
||||
@@ -69,6 +70,7 @@ INSTALLED_APPS = [
|
||||
"apps.api.apps.ApiConfig",
|
||||
"cachalot",
|
||||
"rest_framework",
|
||||
"rest_framework.authtoken",
|
||||
"drf_spectacular",
|
||||
"django_cotton",
|
||||
"apps.rules.apps.RulesConfig",
|
||||
@@ -128,12 +130,23 @@ STORAGES = {
|
||||
|
||||
WHITENOISE_MANIFEST_STRICT = False
|
||||
|
||||
|
||||
def immutable_file_test(path, url):
|
||||
# Match vite (rollup)-generated hashes, à la, `some_file-CSliV9zW.js`
|
||||
return re.match(r"^.+[.-][0-9a-zA-Z_-]{8,12}\..+$", url)
|
||||
|
||||
|
||||
WHITENOISE_IMMUTABLE_FILE_TEST = immutable_file_test
|
||||
|
||||
WSGI_APPLICATION = "WYGIWYH.wsgi.application"
|
||||
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#databases
|
||||
|
||||
THREADS = int(os.getenv("GUNICORN_THREADS", 1))
|
||||
MAX_POOL_SIZE = THREADS + 1
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
@@ -142,6 +155,17 @@ DATABASES = {
|
||||
"PASSWORD": os.getenv("SQL_PASSWORD", "password"),
|
||||
"HOST": os.getenv("SQL_HOST", "localhost"),
|
||||
"PORT": os.getenv("SQL_PORT", "5432"),
|
||||
"CONN_MAX_AGE": 0,
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"OPTIONS": {
|
||||
"pool": {
|
||||
"min_size": 1,
|
||||
"max_size": MAX_POOL_SIZE,
|
||||
"timeout": 10,
|
||||
"max_lifetime": 600,
|
||||
"max_idle": 300,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,7 +313,7 @@ STATIC_URL = "static/"
|
||||
STATIC_ROOT = BASE_DIR / "static_files"
|
||||
|
||||
STATICFILES_DIRS = [
|
||||
ROOT_DIR / "frontend/build",
|
||||
ROOT_DIR / "frontend" / "build",
|
||||
BASE_DIR / "static",
|
||||
]
|
||||
|
||||
@@ -305,9 +329,11 @@ CACHES = {
|
||||
}
|
||||
}
|
||||
|
||||
WEBPACK_LOADER = {
|
||||
"MANIFEST_FILE": ROOT_DIR / "frontend/build/manifest.json",
|
||||
}
|
||||
DJANGO_VITE_ASSETS_PATH = STATIC_ROOT
|
||||
DJANGO_VITE_MANIFEST_PATH = DJANGO_VITE_ASSETS_PATH / "manifest.json"
|
||||
DJANGO_VITE_DEV_MODE = os.getenv("DJANGO_VITE_DEV_MODE", "false").lower() == "true"
|
||||
DJANGO_VITE_DEV_SERVER_PORT = int(os.getenv("DJANGO_VITE_DEV_SERVER_PORT", "5173"))
|
||||
DJANGO_VITE_DEV_SERVER_HOST = os.getenv("DJANGO_VITE_DEV_SERVER_HOST", "localhost")
|
||||
|
||||
# Default primary key field type
|
||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field
|
||||
@@ -350,17 +376,26 @@ ACCOUNT_EMAIL_VERIFICATION = "none"
|
||||
SOCIALACCOUNT_LOGIN_ON_GET = True
|
||||
SOCIALACCOUNT_ONLY = True
|
||||
SOCIALACCOUNT_AUTO_SIGNUP = os.getenv("OIDC_ALLOW_SIGNUP", "true").lower() == "true"
|
||||
SOCIALACCOUNT_EMAIL_AUTHENTICATION = True
|
||||
SOCIALACCOUNT_EMAIL_AUTHENTICATION_AUTO_CONNECT = True
|
||||
ACCOUNT_ADAPTER = "allauth.account.adapter.DefaultAccountAdapter"
|
||||
SOCIALACCOUNT_ADAPTER = "allauth.socialaccount.adapter.DefaultSocialAccountAdapter"
|
||||
SOCIALACCOUNT_ADAPTER = "apps.users.adapters.AutoConnectSocialAccountAdapter"
|
||||
|
||||
# CRISPY FORMS
|
||||
CRISPY_ALLOWED_TEMPLATE_PACKS = ["bootstrap5", "crispy_forms/pure_text"]
|
||||
CRISPY_TEMPLATE_PACK = "bootstrap5"
|
||||
CRISPY_ALLOWED_TEMPLATE_PACKS = [
|
||||
"crispy_forms/pure_text",
|
||||
"crispy-daisyui",
|
||||
]
|
||||
CRISPY_TEMPLATE_PACK = "crispy-daisyui"
|
||||
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE = False
|
||||
SESSION_COOKIE_AGE = int(os.getenv("SESSION_EXPIRY_TIME", 2678400)) # 31 days
|
||||
SESSION_COOKIE_SECURE = os.getenv("HTTPS_ENABLED", "false").lower() == "true"
|
||||
|
||||
HTTPS_ENABLED = os.getenv("HTTPS_ENABLED", "false").lower() == "true"
|
||||
ACCOUNT_DEFAULT_HTTP_PROTOCOL = "https" if HTTPS_ENABLED else "http"
|
||||
SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") if HTTPS_ENABLED else None
|
||||
|
||||
DEBUG_TOOLBAR_CONFIG = {
|
||||
"ROOT_TAG_EXTRA_ATTRS": "hx-preserve",
|
||||
# "SHOW_TOOLBAR_CALLBACK": lambda r: False, # disables it
|
||||
@@ -405,8 +440,16 @@ REST_FRAMEWORK = {
|
||||
"apps.api.permissions.NotInDemoMode",
|
||||
"rest_framework.permissions.DjangoModelPermissions",
|
||||
],
|
||||
"DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
|
||||
"PAGE_SIZE": 10,
|
||||
'DEFAULT_FILTER_BACKENDS': [
|
||||
'django_filters.rest_framework.DjangoFilterBackend',
|
||||
'rest_framework.filters.OrderingFilter',
|
||||
],
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': [
|
||||
'rest_framework.authentication.BasicAuthentication',
|
||||
'rest_framework.authentication.SessionAuthentication',
|
||||
'rest_framework.authentication.TokenAuthentication',
|
||||
],
|
||||
"DEFAULT_PAGINATION_CLASS": "apps.api.custom.pagination.CustomPageNumberPagination",
|
||||
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
|
||||
}
|
||||
|
||||
@@ -421,7 +464,7 @@ SPECTACULAR_SETTINGS = {
|
||||
if "procrastinate" in sys.argv:
|
||||
LOGGING = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"disable_existing_loggers": True,
|
||||
"formatters": {
|
||||
"standard": {
|
||||
"format": "[%(asctime)s] - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -429,26 +472,19 @@ if "procrastinate" in sys.argv:
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"procrastinate": {
|
||||
"level": "INFO",
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "standard",
|
||||
},
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "standard",
|
||||
"level": "INFO",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"procrastinate": {
|
||||
"handlers": ["procrastinate"],
|
||||
"propagate": False,
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"loggers": {
|
||||
"procrastinate": {
|
||||
"level": "INFO",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -468,20 +504,21 @@ else:
|
||||
"formatter": "standard",
|
||||
"level": "INFO",
|
||||
},
|
||||
"procrastinate": {
|
||||
"level": "INFO",
|
||||
"class": "logging.StreamHandler",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"procrastinate": {
|
||||
"handlers": None,
|
||||
"propagate": False,
|
||||
},
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": "INFO",
|
||||
},
|
||||
"loggers": {
|
||||
"procrastinate": {
|
||||
"handlers": [],
|
||||
"propagate": False,
|
||||
},
|
||||
"allauth": {
|
||||
"handlers": ["console"],
|
||||
"level": "DEBUG",
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+18
-29
@@ -1,23 +1,21 @@
|
||||
from crispy_bootstrap5.bootstrap5 import Switch
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.common.fields.forms.dynamic_select import (
|
||||
DynamicModelChoiceField,
|
||||
DynamicModelMultipleChoiceField,
|
||||
)
|
||||
from apps.common.widgets.crispy.daisyui import Switch
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, TransactionTag
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Column, Row
|
||||
from crispy_forms.layout import Column, Field, Layout, Row
|
||||
from django import forms
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.accounts.models import AccountGroup
|
||||
from apps.common.fields.forms.dynamic_select import (
|
||||
DynamicModelMultipleChoiceField,
|
||||
DynamicModelChoiceField,
|
||||
)
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.transactions.models import TransactionCategory, TransactionTag
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.currencies.models import Currency
|
||||
|
||||
|
||||
class AccountGroupForm(forms.ModelForm):
|
||||
class Meta:
|
||||
@@ -38,17 +36,13 @@ class AccountGroupForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -108,17 +102,13 @@ class AccountForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -156,9 +146,8 @@ class AccountBalanceForm(forms.Form):
|
||||
self.helper.layout = Layout(
|
||||
"new_balance",
|
||||
Row(
|
||||
Column("category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("category"),
|
||||
Column("tags"),
|
||||
),
|
||||
Field("account_id"),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
def get_account_balance(account: Account, paid_only: bool = True) -> Decimal:
|
||||
"""
|
||||
Calculate account balance (income - expense).
|
||||
|
||||
Args:
|
||||
account: Account instance to calculate balance for.
|
||||
paid_only: If True, only count paid transactions (current balance).
|
||||
If False, count all transactions (projected balance).
|
||||
|
||||
Returns:
|
||||
Decimal: The calculated balance (income - expense).
|
||||
"""
|
||||
filters = {"account": account}
|
||||
if paid_only:
|
||||
filters["is_paid"] = True
|
||||
|
||||
income = Transaction.objects.filter(
|
||||
type=Transaction.Type.INCOME, **filters
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
expense = Transaction.objects.filter(
|
||||
type=Transaction.Type.EXPENSE, **filters
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
return income - expense
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import date
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
@@ -39,3 +41,135 @@ class AccountTests(TestCase):
|
||||
exchange_currency=self.exchange_currency,
|
||||
)
|
||||
self.assertEqual(account.exchange_currency, self.exchange_currency)
|
||||
|
||||
|
||||
class GetAccountBalanceServiceTests(TestCase):
|
||||
"""Tests for the get_account_balance service function"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
from apps.transactions.models import Transaction
|
||||
self.Transaction = Transaction
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="BRL", name="Brazilian Real", decimal_places=2, prefix="R$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Service Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Service Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
|
||||
def test_balance_with_no_transactions(self):
|
||||
"""Test balance is 0 when no transactions exist"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
balance = get_account_balance(self.account, paid_only=True)
|
||||
self.assertEqual(balance, Decimal("0"))
|
||||
|
||||
def test_current_balance_only_counts_paid(self):
|
||||
"""Test current balance only counts paid transactions"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
# Paid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
# Unpaid income (should not count)
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid income",
|
||||
)
|
||||
# Paid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("30.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid expense",
|
||||
)
|
||||
|
||||
balance = get_account_balance(self.account, paid_only=True)
|
||||
self.assertEqual(balance, Decimal("70.00")) # 100 - 30
|
||||
|
||||
def test_projected_balance_counts_all(self):
|
||||
"""Test projected balance counts all transactions"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
# Paid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
# Unpaid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid income",
|
||||
)
|
||||
# Paid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("30.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid expense",
|
||||
)
|
||||
# Unpaid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("20.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid expense",
|
||||
)
|
||||
|
||||
balance = get_account_balance(self.account, paid_only=False)
|
||||
self.assertEqual(balance, Decimal("100.00")) # (100 + 50) - (30 + 20)
|
||||
|
||||
def test_balance_defaults_to_paid_only(self):
|
||||
"""Test that paid_only defaults to True"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid",
|
||||
)
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid",
|
||||
)
|
||||
|
||||
balance = get_account_balance(self.account) # defaults to paid_only=True
|
||||
self.assertEqual(balance, Decimal("100.00"))
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def account_groups_index(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def account_groups_list(request):
|
||||
account_groups = AccountGroup.objects.all().order_by("id")
|
||||
account_groups = AccountGroup.objects.all().order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"account_groups/fragments/list.html",
|
||||
|
||||
@@ -25,7 +25,7 @@ def accounts_index(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def accounts_list(request):
|
||||
accounts = Account.objects.all().order_by("id")
|
||||
accounts = Account.objects.all().order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"accounts/fragments/list.html",
|
||||
|
||||
@@ -11,23 +11,13 @@ from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.forms import AccountBalanceFormSet
|
||||
from apps.accounts.models import Account, Transaction
|
||||
from apps.accounts.services import get_account_balance
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
def account_reconciliation(request):
|
||||
def get_account_balance(account):
|
||||
income = Transaction.objects.filter(
|
||||
account=account, type=Transaction.Type.INCOME, is_paid=True
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
expense = Transaction.objects.filter(
|
||||
account=account, type=Transaction.Type.EXPENSE, is_paid=True
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
return income - expense
|
||||
|
||||
initial_data = [
|
||||
{
|
||||
"account_id": account.id,
|
||||
|
||||
@@ -10,15 +10,19 @@ from apps.transactions.models import (
|
||||
|
||||
@extend_schema_field(
|
||||
{
|
||||
"oneOf": [{"type": "string"}, {"type": "integer"}],
|
||||
"description": "TransactionCategory ID or name. If the name doesn't exist, a new one will be created",
|
||||
"oneOf": [{"type": "string"}, {"type": "integer"}, {"type": "null"}],
|
||||
"description": "TransactionCategory ID or name. If the name doesn't exist, a new one will be created. Can be null if no category is assigned.",
|
||||
}
|
||||
)
|
||||
class TransactionCategoryField(serializers.Field):
|
||||
def to_representation(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
return {"id": value.id, "name": value.name}
|
||||
|
||||
def to_internal_value(self, data):
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, int):
|
||||
try:
|
||||
return TransactionCategory.objects.get(pk=data)
|
||||
|
||||
@@ -2,3 +2,5 @@ from .transactions import *
|
||||
from .accounts import *
|
||||
from .currencies import *
|
||||
from .dca import *
|
||||
from .imports import *
|
||||
|
||||
|
||||
@@ -67,3 +67,12 @@ class AccountSerializer(serializers.ModelSerializer):
|
||||
setattr(instance, attr, value)
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
|
||||
class AccountBalanceSerializer(serializers.Serializer):
|
||||
"""Serializer for account balance response."""
|
||||
|
||||
current_balance = serializers.DecimalField(max_digits=20, decimal_places=10)
|
||||
projected_balance = serializers.DecimalField(max_digits=20, decimal_places=10)
|
||||
currency = CurrencySerializer()
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
|
||||
|
||||
class ImportProfileSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for listing import profiles."""
|
||||
|
||||
class Meta:
|
||||
model = ImportProfile
|
||||
fields = ["id", "name", "version", "yaml_config"]
|
||||
|
||||
|
||||
class ImportRunSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for listing import runs."""
|
||||
|
||||
class Meta:
|
||||
model = ImportRun
|
||||
fields = [
|
||||
"id",
|
||||
"status",
|
||||
"profile",
|
||||
"file_name",
|
||||
"logs",
|
||||
"processed_rows",
|
||||
"total_rows",
|
||||
"successful_rows",
|
||||
"skipped_rows",
|
||||
"failed_rows",
|
||||
"started_at",
|
||||
"finished_at",
|
||||
]
|
||||
|
||||
|
||||
class ImportFileSerializer(serializers.Serializer):
|
||||
"""Serializer for uploading a file to import using an existing profile."""
|
||||
|
||||
profile_id = serializers.PrimaryKeyRelatedField(
|
||||
queryset=ImportProfile.objects.all(), source="profile"
|
||||
)
|
||||
file = serializers.FileField()
|
||||
@@ -0,0 +1,5 @@
|
||||
# Import all test classes for Django test discovery
|
||||
from .test_imports import *
|
||||
from .test_accounts import *
|
||||
from .test_data_isolation import *
|
||||
from .test_shared_access import *
|
||||
@@ -0,0 +1,99 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class AccountBalanceAPITests(TestCase):
|
||||
"""Tests for the Account Balance API endpoint"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
|
||||
# Create some transactions
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("500.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("200.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 15),
|
||||
description="Unpaid income",
|
||||
)
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 10),
|
||||
description="Paid expense",
|
||||
)
|
||||
|
||||
def test_get_balance_success(self):
|
||||
"""Test successful balance retrieval"""
|
||||
response = self.client.get(f"/api/accounts/{self.account.id}/balance/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("current_balance", response.data)
|
||||
self.assertIn("projected_balance", response.data)
|
||||
self.assertIn("currency", response.data)
|
||||
|
||||
# Current: 500 - 100 = 400
|
||||
self.assertEqual(Decimal(response.data["current_balance"]), Decimal("400.00"))
|
||||
# Projected: (500 + 200) - 100 = 600
|
||||
self.assertEqual(Decimal(response.data["projected_balance"]), Decimal("600.00"))
|
||||
|
||||
# Check currency data
|
||||
self.assertEqual(response.data["currency"]["code"], "USD")
|
||||
|
||||
def test_get_balance_nonexistent_account(self):
|
||||
"""Test balance for non-existent account returns 404"""
|
||||
response = self.client.get("/api/accounts/99999/balance/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_get_balance_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 401"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get(
|
||||
f"/api/accounts/{self.account.id}/balance/"
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
@@ -0,0 +1,719 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
TransactionEntity,
|
||||
InstallmentPlan,
|
||||
RecurringTransaction,
|
||||
)
|
||||
|
||||
|
||||
ACCESS_DENIED_CODES = [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND]
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class AccountDataIsolationTests(TestCase):
|
||||
"""Tests to ensure users cannot access other users' accounts."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with two distinct users."""
|
||||
User = get_user_model()
|
||||
|
||||
# User 1 - the requester
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
# User 2 - owner of data that user1 should NOT access
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
# Shared currency
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's account
|
||||
self.user1_account_group = AccountGroup.all_objects.create(
|
||||
name="User1 Group", owner=self.user1
|
||||
)
|
||||
self.user1_account = Account.all_objects.create(
|
||||
name="User1 Account",
|
||||
group=self.user1_account_group,
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
# User 2's account (private, should be invisible to user1)
|
||||
self.user2_account_group = AccountGroup.all_objects.create(
|
||||
name="User2 Group", owner=self.user2
|
||||
)
|
||||
self.user2_account = Account.all_objects.create(
|
||||
name="User2 Account",
|
||||
group=self.user2_account_group,
|
||||
currency=self.currency,
|
||||
owner=self.user2,
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_accounts_in_list(self):
|
||||
"""GET /api/accounts/ should only return user's own accounts."""
|
||||
response = self.client1.get("/api/accounts/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
# User1 should only see their own account
|
||||
account_ids = [acc["id"] for acc in response.data["results"]]
|
||||
self.assertIn(self.user1_account.id, account_ids)
|
||||
self.assertNotIn(self.user2_account.id, account_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_account_detail(self):
|
||||
"""GET /api/accounts/{id}/ should deny access to other user's account."""
|
||||
response = self.client1.get(f"/api/accounts/{self.user2_account.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_account(self):
|
||||
"""PATCH on other user's account should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/accounts/{self.user2_account.id}/",
|
||||
{"name": "Hacked Account"},
|
||||
)
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
# Verify account name wasn't changed
|
||||
self.user2_account.refresh_from_db()
|
||||
self.assertEqual(self.user2_account.name, "User2 Account")
|
||||
|
||||
def test_user_cannot_delete_other_users_account(self):
|
||||
"""DELETE on other user's account should deny access."""
|
||||
response = self.client1.delete(f"/api/accounts/{self.user2_account.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
# Verify account still exists
|
||||
self.assertTrue(Account.all_objects.filter(id=self.user2_account.id).exists())
|
||||
|
||||
def test_user_cannot_get_balance_of_other_users_account(self):
|
||||
"""Balance action on other user's account should deny access."""
|
||||
response = self.client1.get(f"/api/accounts/{self.user2_account.id}/balance/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_can_access_own_account(self):
|
||||
"""User can access their own account normally."""
|
||||
response = self.client1.get(f"/api/accounts/{self.user1_account.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "User1 Account")
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class AccountGroupDataIsolationTests(TestCase):
|
||||
"""Tests to ensure users cannot access other users' account groups."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with two distinct users."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
# User 1's account group
|
||||
self.user1_group = AccountGroup.all_objects.create(
|
||||
name="User1 Group", owner=self.user1
|
||||
)
|
||||
|
||||
# User 2's account group
|
||||
self.user2_group = AccountGroup.all_objects.create(
|
||||
name="User2 Group", owner=self.user2
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_account_groups(self):
|
||||
"""GET /api/account-groups/ should only return user's own groups."""
|
||||
response = self.client1.get("/api/account-groups/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
group_ids = [grp["id"] for grp in response.data["results"]]
|
||||
self.assertIn(self.user1_group.id, group_ids)
|
||||
self.assertNotIn(self.user2_group.id, group_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_account_group_detail(self):
|
||||
"""GET /api/account-groups/{id}/ should deny access to other user's group."""
|
||||
response = self.client1.get(f"/api/account-groups/{self.user2_group.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_account_group(self):
|
||||
"""PATCH on other user's account group should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/account-groups/{self.user2_group.id}/",
|
||||
{"name": "Hacked Group"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.user2_group.refresh_from_db()
|
||||
self.assertEqual(self.user2_group.name, "User2 Group")
|
||||
|
||||
def test_user_cannot_delete_other_users_account_group(self):
|
||||
"""DELETE on other user's account group should deny access."""
|
||||
response = self.client1.delete(f"/api/account-groups/{self.user2_group.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(
|
||||
AccountGroup.all_objects.filter(id=self.user2_group.id).exists()
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class TransactionDataIsolationTests(TestCase):
|
||||
"""Tests to ensure users cannot access other users' transactions."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with transactions for two distinct users."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's account and transaction
|
||||
self.user1_account = Account.all_objects.create(
|
||||
name="User1 Account", currency=self.currency, owner=self.user1
|
||||
)
|
||||
self.user1_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.user1_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="User1 Income",
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
# User 2's account and transaction
|
||||
self.user2_account = Account.all_objects.create(
|
||||
name="User2 Account", currency=self.currency, owner=self.user2
|
||||
)
|
||||
self.user2_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.user2_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="User2 Expense",
|
||||
owner=self.user2,
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_transactions_in_list(self):
|
||||
"""GET /api/transactions/ should only return user's own transactions."""
|
||||
response = self.client1.get("/api/transactions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
transaction_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.user1_transaction.id, transaction_ids)
|
||||
self.assertNotIn(self.user2_transaction.id, transaction_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_transaction_detail(self):
|
||||
"""GET /api/transactions/{id}/ should deny access to other user's transaction."""
|
||||
response = self.client1.get(f"/api/transactions/{self.user2_transaction.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_transaction(self):
|
||||
"""PATCH on other user's transaction should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/transactions/{self.user2_transaction.id}/",
|
||||
{"description": "Hacked Transaction"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.user2_transaction.refresh_from_db()
|
||||
self.assertEqual(self.user2_transaction.description, "User2 Expense")
|
||||
|
||||
def test_user_cannot_delete_other_users_transaction(self):
|
||||
"""DELETE on other user's transaction should deny access."""
|
||||
response = self.client1.delete(
|
||||
f"/api/transactions/{self.user2_transaction.id}/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(
|
||||
Transaction.userless_all_objects.filter(
|
||||
id=self.user2_transaction.id
|
||||
).exists()
|
||||
)
|
||||
|
||||
def test_user_cannot_create_transaction_in_other_users_account(self):
|
||||
"""POST /api/transactions/ with other user's account should fail."""
|
||||
response = self.client1.post(
|
||||
"/api/transactions/",
|
||||
{
|
||||
"account": self.user2_account.id,
|
||||
"type": "IN",
|
||||
"amount": "100.00",
|
||||
"date": "2025-01-15",
|
||||
"description": "Sneaky transaction",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
# Should deny access - 400 (validation error), 403, or 404
|
||||
self.assertIn(
|
||||
response.status_code,
|
||||
ACCESS_DENIED_CODES + [status.HTTP_400_BAD_REQUEST],
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class CategoryTagEntityIsolationTests(TestCase):
|
||||
"""Tests for isolation of categories, tags, and entities between users."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
# User 1's categories, tags, entities
|
||||
self.user1_category = TransactionCategory.all_objects.create(
|
||||
name="User1 Category", owner=self.user1
|
||||
)
|
||||
self.user1_tag = TransactionTag.all_objects.create(
|
||||
name="User1 Tag", owner=self.user1
|
||||
)
|
||||
self.user1_entity = TransactionEntity.all_objects.create(
|
||||
name="User1 Entity", owner=self.user1
|
||||
)
|
||||
|
||||
# User 2's categories, tags, entities
|
||||
self.user2_category = TransactionCategory.all_objects.create(
|
||||
name="User2 Category", owner=self.user2
|
||||
)
|
||||
self.user2_tag = TransactionTag.all_objects.create(
|
||||
name="User2 Tag", owner=self.user2
|
||||
)
|
||||
self.user2_entity = TransactionEntity.all_objects.create(
|
||||
name="User2 Entity", owner=self.user2
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_categories(self):
|
||||
"""GET /api/categories/ should only return user's own categories."""
|
||||
response = self.client1.get("/api/categories/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
category_ids = [c["id"] for c in response.data["results"]]
|
||||
self.assertIn(self.user1_category.id, category_ids)
|
||||
self.assertNotIn(self.user2_category.id, category_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_category_detail(self):
|
||||
"""GET /api/categories/{id}/ should deny access to other user's category."""
|
||||
response = self.client1.get(f"/api/categories/{self.user2_category.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_see_other_users_tags(self):
|
||||
"""GET /api/tags/ should only return user's own tags."""
|
||||
response = self.client1.get("/api/tags/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
tag_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.user1_tag.id, tag_ids)
|
||||
self.assertNotIn(self.user2_tag.id, tag_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_tag_detail(self):
|
||||
"""GET /api/tags/{id}/ should deny access to other user's tag."""
|
||||
response = self.client1.get(f"/api/tags/{self.user2_tag.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_see_other_users_entities(self):
|
||||
"""GET /api/entities/ should only return user's own entities."""
|
||||
response = self.client1.get("/api/entities/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
entity_ids = [e["id"] for e in response.data["results"]]
|
||||
self.assertIn(self.user1_entity.id, entity_ids)
|
||||
self.assertNotIn(self.user2_entity.id, entity_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_entity_detail(self):
|
||||
"""GET /api/entities/{id}/ should deny access to other user's entity."""
|
||||
response = self.client1.get(f"/api/entities/{self.user2_entity.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_category(self):
|
||||
"""PATCH on other user's category should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/categories/{self.user2_category.id}/",
|
||||
{"name": "Hacked Category"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_delete_other_users_tag(self):
|
||||
"""DELETE on other user's tag should deny access."""
|
||||
response = self.client1.delete(f"/api/tags/{self.user2_tag.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(
|
||||
TransactionTag.all_objects.filter(id=self.user2_tag.id).exists()
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class DCADataIsolationTests(TestCase):
|
||||
"""Tests to ensure users cannot access other users' DCA strategies and entries."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
self.currency1 = Currency.objects.create(
|
||||
code="BTC", name="Bitcoin", decimal_places=8, prefix=""
|
||||
)
|
||||
self.currency2 = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's DCA strategy and entry
|
||||
self.user1_strategy = DCAStrategy.all_objects.create(
|
||||
name="User1 BTC Strategy",
|
||||
target_currency=self.currency1,
|
||||
payment_currency=self.currency2,
|
||||
owner=self.user1,
|
||||
)
|
||||
self.user1_entry = DCAEntry.objects.create(
|
||||
strategy=self.user1_strategy,
|
||||
date=date(2025, 1, 1),
|
||||
amount_paid=Decimal("100.00"),
|
||||
amount_received=Decimal("0.001"),
|
||||
)
|
||||
|
||||
# User 2's DCA strategy and entry
|
||||
self.user2_strategy = DCAStrategy.all_objects.create(
|
||||
name="User2 BTC Strategy",
|
||||
target_currency=self.currency1,
|
||||
payment_currency=self.currency2,
|
||||
owner=self.user2,
|
||||
)
|
||||
self.user2_entry = DCAEntry.objects.create(
|
||||
strategy=self.user2_strategy,
|
||||
date=date(2025, 1, 1),
|
||||
amount_paid=Decimal("200.00"),
|
||||
amount_received=Decimal("0.002"),
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_dca_strategies(self):
|
||||
"""GET /api/dca/strategies/ should only return user's own strategies."""
|
||||
response = self.client1.get("/api/dca/strategies/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
strategy_ids = [s["id"] for s in response.data["results"]]
|
||||
self.assertIn(self.user1_strategy.id, strategy_ids)
|
||||
self.assertNotIn(self.user2_strategy.id, strategy_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_dca_strategy_detail(self):
|
||||
"""GET /api/dca/strategies/{id}/ should deny access to other user's strategy."""
|
||||
response = self.client1.get(f"/api/dca/strategies/{self.user2_strategy.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_access_other_users_dca_entries(self):
|
||||
"""GET /api/dca/entries/ filtered by other user's strategy should return empty."""
|
||||
response = self.client1.get(
|
||||
f"/api/dca/entries/?strategy={self.user2_strategy.id}"
|
||||
)
|
||||
|
||||
# Either OK with empty results or error
|
||||
if response.status_code == status.HTTP_200_OK:
|
||||
entry_ids = [e["id"] for e in response.data["results"]]
|
||||
self.assertNotIn(self.user2_entry.id, entry_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_dca_entry_detail(self):
|
||||
"""GET /api/dca/entries/{id}/ should deny access to other user's entry."""
|
||||
response = self.client1.get(f"/api/dca/entries/{self.user2_entry.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_access_other_users_strategy_investment_frequency(self):
|
||||
"""investment_frequency action on other user's strategy should deny access."""
|
||||
response = self.client1.get(
|
||||
f"/api/dca/strategies/{self.user2_strategy.id}/investment_frequency/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_access_other_users_strategy_price_comparison(self):
|
||||
"""price_comparison action on other user's strategy should deny access."""
|
||||
response = self.client1.get(
|
||||
f"/api/dca/strategies/{self.user2_strategy.id}/price_comparison/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_access_other_users_strategy_current_price(self):
|
||||
"""current_price action on other user's strategy should deny access."""
|
||||
response = self.client1.get(
|
||||
f"/api/dca/strategies/{self.user2_strategy.id}/current_price/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_dca_strategy(self):
|
||||
"""PATCH on other user's DCA strategy should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/dca/strategies/{self.user2_strategy.id}/",
|
||||
{"name": "Hacked Strategy"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_delete_other_users_dca_entry(self):
|
||||
"""DELETE on other user's DCA entry should deny access."""
|
||||
response = self.client1.delete(f"/api/dca/entries/{self.user2_entry.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(DCAEntry.objects.filter(id=self.user2_entry.id).exists())
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class InstallmentRecurringIsolationTests(TestCase):
|
||||
"""Tests for isolation of installment plans and recurring transactions."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's account
|
||||
self.user1_account = Account.all_objects.create(
|
||||
name="User1 Account", currency=self.currency, owner=self.user1
|
||||
)
|
||||
|
||||
# User 2's account
|
||||
self.user2_account = Account.all_objects.create(
|
||||
name="User2 Account", currency=self.currency, owner=self.user2
|
||||
)
|
||||
|
||||
# User 1's installment plan
|
||||
self.user1_installment = InstallmentPlan.all_objects.create(
|
||||
account=self.user1_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
description="User1 Installment",
|
||||
number_of_installments=12,
|
||||
start_date=date(2025, 1, 1),
|
||||
installment_amount=Decimal("100.00"),
|
||||
)
|
||||
|
||||
# User 2's installment plan
|
||||
self.user2_installment = InstallmentPlan.all_objects.create(
|
||||
account=self.user2_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
description="User2 Installment",
|
||||
number_of_installments=6,
|
||||
start_date=date(2025, 1, 1),
|
||||
installment_amount=Decimal("200.00"),
|
||||
)
|
||||
|
||||
# User 1's recurring transaction
|
||||
self.user1_recurring = RecurringTransaction.all_objects.create(
|
||||
account=self.user1_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("50.00"),
|
||||
description="User1 Recurring",
|
||||
start_date=date(2025, 1, 1),
|
||||
recurrence_type=RecurringTransaction.RecurrenceType.MONTH,
|
||||
recurrence_interval=1,
|
||||
)
|
||||
|
||||
# User 2's recurring transaction
|
||||
self.user2_recurring = RecurringTransaction.all_objects.create(
|
||||
account=self.user2_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("1000.00"),
|
||||
description="User2 Recurring",
|
||||
start_date=date(2025, 1, 1),
|
||||
recurrence_type=RecurringTransaction.RecurrenceType.MONTH,
|
||||
recurrence_interval=1,
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_installment_plans(self):
|
||||
"""GET /api/installment-plans/ should only return user's own plans."""
|
||||
response = self.client1.get("/api/installment-plans/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
plan_ids = [p["id"] for p in response.data["results"]]
|
||||
self.assertIn(self.user1_installment.id, plan_ids)
|
||||
self.assertNotIn(self.user2_installment.id, plan_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_installment_plan_detail(self):
|
||||
"""GET /api/installment-plans/{id}/ should deny access to other user's plan."""
|
||||
response = self.client1.get(
|
||||
f"/api/installment-plans/{self.user2_installment.id}/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_see_other_users_recurring_transactions(self):
|
||||
"""GET /api/recurring-transactions/ should only return user's own recurring."""
|
||||
response = self.client1.get("/api/recurring-transactions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
recurring_ids = [r["id"] for r in response.data["results"]]
|
||||
self.assertIn(self.user1_recurring.id, recurring_ids)
|
||||
self.assertNotIn(self.user2_recurring.id, recurring_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_recurring_transaction_detail(self):
|
||||
"""GET /api/recurring-transactions/{id}/ should deny access to other user's recurring."""
|
||||
response = self.client1.get(
|
||||
f"/api/recurring-transactions/{self.user2_recurring.id}/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_installment_plan(self):
|
||||
"""PATCH on other user's installment plan should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/installment-plans/{self.user2_installment.id}/",
|
||||
{"description": "Hacked Installment"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_delete_other_users_recurring_transaction(self):
|
||||
"""DELETE on other user's recurring transaction should deny access."""
|
||||
response = self.client1.delete(
|
||||
f"/api/recurring-transactions/{self.user2_recurring.id}/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(
|
||||
RecurringTransaction.all_objects.filter(id=self.user2_recurring.id).exists()
|
||||
)
|
||||
@@ -0,0 +1,404 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportAPITests(TestCase):
|
||||
"""Tests for the Import API endpoint"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
# Create a basic import profile with minimal valid YAML config
|
||||
self.profile = ImportProfile.objects.create(
|
||||
name="Test Profile",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
@patch("apps.import_app.tasks.process_import.defer")
|
||||
@patch("django.core.files.storage.FileSystemStorage.save")
|
||||
@patch("django.core.files.storage.FileSystemStorage.path")
|
||||
def test_create_import_success(self, mock_path, mock_save, mock_defer):
|
||||
"""Test successful file upload creates ImportRun and queues task"""
|
||||
mock_save.return_value = "test_file.csv"
|
||||
mock_path.return_value = "/usr/src/app/temp/test_file.csv"
|
||||
|
||||
csv_content = b"date,description,amount,account\n2025-01-01,Test,100,Main"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
self.assertIn("import_run_id", response.data)
|
||||
self.assertEqual(response.data["status"], "queued")
|
||||
|
||||
# Verify ImportRun was created
|
||||
import_run = ImportRun.objects.get(id=response.data["import_run_id"])
|
||||
self.assertEqual(import_run.profile, self.profile)
|
||||
self.assertEqual(import_run.file_name, "test_file.csv")
|
||||
|
||||
# Verify task was deferred
|
||||
mock_defer.assert_called_once_with(
|
||||
import_run_id=import_run.id,
|
||||
file_path="/usr/src/app/temp/test_file.csv",
|
||||
user_id=self.user.id,
|
||||
)
|
||||
|
||||
def test_create_import_missing_profile(self):
|
||||
"""Test request without profile_id returns 400"""
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("profile_id", response.data)
|
||||
|
||||
def test_create_import_missing_file(self):
|
||||
"""Test request without file returns 400"""
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("file", response.data)
|
||||
|
||||
def test_create_import_invalid_profile(self):
|
||||
"""Test request with non-existent profile returns 400"""
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": 99999, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("profile_id", response.data)
|
||||
|
||||
@patch("apps.import_app.tasks.process_import.defer")
|
||||
@patch("django.core.files.storage.FileSystemStorage.save")
|
||||
@patch("django.core.files.storage.FileSystemStorage.path")
|
||||
def test_create_import_xlsx(self, mock_path, mock_save, mock_defer):
|
||||
"""Test successful XLSX file upload"""
|
||||
mock_save.return_value = "test_file.xlsx"
|
||||
mock_path.return_value = "/usr/src/app/temp/test_file.xlsx"
|
||||
|
||||
# Create a simple XLSX-like content (just for the upload test)
|
||||
xlsx_content = BytesIO(b"PK\x03\x04") # XLSX files start with PK header
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.xlsx",
|
||||
xlsx_content.getvalue(),
|
||||
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
self.assertIn("import_run_id", response.data)
|
||||
|
||||
def test_unauthenticated_request(self):
|
||||
"""Test unauthenticated request returns 401"""
|
||||
unauthenticated_client = APIClient()
|
||||
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = unauthenticated_client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportProfileAPITests(TestCase):
|
||||
"""Tests for the Import Profile API endpoints"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.profile1 = ImportProfile.objects.create(
|
||||
name="Profile 1",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
self.profile2 = ImportProfile.objects.create(
|
||||
name="Profile 2",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_income
|
||||
is_paid:
|
||||
detection_method: always_unpaid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
def test_list_profiles(self):
|
||||
"""Test listing all profiles"""
|
||||
response = self.client.get("/api/import/profiles/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 2)
|
||||
self.assertEqual(len(response.data["results"]), 2)
|
||||
|
||||
def test_retrieve_profile(self):
|
||||
"""Test retrieving a specific profile"""
|
||||
response = self.client.get(f"/api/import/profiles/{self.profile1.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["id"], self.profile1.id)
|
||||
self.assertEqual(response.data["name"], "Profile 1")
|
||||
self.assertIn("yaml_config", response.data)
|
||||
|
||||
def test_retrieve_nonexistent_profile(self):
|
||||
"""Test retrieving a non-existent profile returns 404"""
|
||||
response = self.client.get("/api/import/profiles/99999/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_profiles_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 401"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get("/api/import/profiles/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportRunAPITests(TestCase):
|
||||
"""Tests for the Import Run API endpoints"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.profile1 = ImportProfile.objects.create(
|
||||
name="Profile 1",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
self.profile2 = ImportProfile.objects.create(
|
||||
name="Profile 2",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_income
|
||||
is_paid:
|
||||
detection_method: always_unpaid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
# Create import runs
|
||||
self.run1 = ImportRun.objects.create(
|
||||
profile=self.profile1,
|
||||
file_name="file1.csv",
|
||||
status=ImportRun.Status.FINISHED,
|
||||
)
|
||||
self.run2 = ImportRun.objects.create(
|
||||
profile=self.profile1,
|
||||
file_name="file2.csv",
|
||||
status=ImportRun.Status.QUEUED,
|
||||
)
|
||||
self.run3 = ImportRun.objects.create(
|
||||
profile=self.profile2,
|
||||
file_name="file3.csv",
|
||||
status=ImportRun.Status.FINISHED,
|
||||
)
|
||||
|
||||
def test_list_all_runs(self):
|
||||
"""Test listing all runs"""
|
||||
response = self.client.get("/api/import/runs/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 3)
|
||||
self.assertEqual(len(response.data["results"]), 3)
|
||||
|
||||
def test_list_runs_by_profile(self):
|
||||
"""Test filtering runs by profile_id"""
|
||||
response = self.client.get(f"/api/import/runs/?profile_id={self.profile1.id}")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 2)
|
||||
for run in response.data["results"]:
|
||||
self.assertEqual(run["profile"], self.profile1.id)
|
||||
|
||||
def test_list_runs_by_other_profile(self):
|
||||
"""Test filtering runs by another profile_id"""
|
||||
response = self.client.get(f"/api/import/runs/?profile_id={self.profile2.id}")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 1)
|
||||
self.assertEqual(response.data["results"][0]["profile"], self.profile2.id)
|
||||
|
||||
def test_retrieve_run(self):
|
||||
"""Test retrieving a specific run"""
|
||||
response = self.client.get(f"/api/import/runs/{self.run1.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["id"], self.run1.id)
|
||||
self.assertEqual(response.data["file_name"], "file1.csv")
|
||||
self.assertEqual(response.data["status"], "FINISHED")
|
||||
|
||||
def test_retrieve_nonexistent_run(self):
|
||||
"""Test retrieving a non-existent run returns 404"""
|
||||
response = self.client.get("/api/import/runs/99999/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_runs_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 401"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get("/api/import/runs/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
@@ -0,0 +1,587 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
TransactionEntity,
|
||||
)
|
||||
|
||||
|
||||
ACCESS_DENIED_CODES = [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND]
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class SharedAccountAccessTests(TestCase):
|
||||
"""Tests for shared account access via shared_with field."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with shared accounts."""
|
||||
User = get_user_model()
|
||||
|
||||
# User 1 - owner
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
# User 2 - will have shared access
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
# User 3 - no shared access
|
||||
self.user3 = User.objects.create_user(
|
||||
email="user3@test.com", password="testpass123"
|
||||
)
|
||||
self.client3 = APIClient()
|
||||
self.client3.force_authenticate(user=self.user3)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's account shared with user 2
|
||||
self.shared_account = Account.all_objects.create(
|
||||
name="Shared Account",
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
visibility="private",
|
||||
)
|
||||
self.shared_account.shared_with.add(self.user2)
|
||||
|
||||
# User 1's private account (not shared)
|
||||
self.private_account = Account.all_objects.create(
|
||||
name="Private Account",
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
visibility="private",
|
||||
)
|
||||
|
||||
# Transaction in shared account
|
||||
self.shared_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.shared_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Shared Transaction",
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
# Transaction in private account
|
||||
self.private_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.private_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Private Transaction",
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
def test_user_can_see_accounts_shared_with_them(self):
|
||||
"""User2 should see the account shared with them."""
|
||||
response = self.client2.get("/api/accounts/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
account_ids = [acc["id"] for acc in response.data["results"]]
|
||||
self.assertIn(self.shared_account.id, account_ids)
|
||||
|
||||
def test_user_cannot_see_accounts_not_shared_with_them(self):
|
||||
"""User2 should NOT see user1's private (non-shared) account."""
|
||||
response = self.client2.get("/api/accounts/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
account_ids = [acc["id"] for acc in response.data["results"]]
|
||||
self.assertNotIn(self.private_account.id, account_ids)
|
||||
|
||||
def test_user_can_access_shared_account_detail(self):
|
||||
"""User2 should be able to access shared account details."""
|
||||
response = self.client2.get(f"/api/accounts/{self.shared_account.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Account")
|
||||
|
||||
def test_user_without_share_cannot_access_shared_account(self):
|
||||
"""User3 should NOT be able to access the shared account."""
|
||||
response = self.client3.get(f"/api/accounts/{self.shared_account.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_can_see_transactions_in_shared_account(self):
|
||||
"""User2 should see transactions in the shared account."""
|
||||
response = self.client2.get("/api/transactions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
transaction_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.shared_transaction.id, transaction_ids)
|
||||
self.assertNotIn(self.private_transaction.id, transaction_ids)
|
||||
|
||||
def test_user_can_access_transaction_in_shared_account(self):
|
||||
"""User2 should be able to access transaction details in shared account."""
|
||||
response = self.client2.get(f"/api/transactions/{self.shared_transaction.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["description"], "Shared Transaction")
|
||||
|
||||
def test_user_cannot_access_transaction_in_non_shared_account(self):
|
||||
"""User2 should NOT access transactions in user1's private account."""
|
||||
response = self.client2.get(f"/api/transactions/{self.private_transaction.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_can_get_balance_of_shared_account(self):
|
||||
"""User2 should be able to get balance of shared account."""
|
||||
response = self.client2.get(f"/api/accounts/{self.shared_account.id}/balance/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("current_balance", response.data)
|
||||
|
||||
def test_sharing_works_with_multiple_users(self):
|
||||
"""Account shared with multiple users should be accessible by all."""
|
||||
# Add user3 to shared_with
|
||||
self.shared_account.shared_with.add(self.user3)
|
||||
|
||||
# User2 still has access
|
||||
response2 = self.client2.get(f"/api/accounts/{self.shared_account.id}/")
|
||||
self.assertEqual(response2.status_code, status.HTTP_200_OK)
|
||||
|
||||
# User3 now has access
|
||||
response3 = self.client3.get(f"/api/accounts/{self.shared_account.id}/")
|
||||
self.assertEqual(response3.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class PublicVisibilityTests(TestCase):
|
||||
"""Tests for public visibility access."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with public accounts."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's public account
|
||||
self.public_account = Account.all_objects.create(
|
||||
name="Public Account",
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
visibility="public",
|
||||
)
|
||||
|
||||
# User 1's private account
|
||||
self.private_account = Account.all_objects.create(
|
||||
name="Private Account",
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
visibility="private",
|
||||
)
|
||||
|
||||
# Transaction in public account
|
||||
self.public_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.public_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Public Transaction",
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
def test_user_can_see_public_accounts(self):
|
||||
"""User2 should see user1's public account."""
|
||||
response = self.client2.get("/api/accounts/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
account_ids = [acc["id"] for acc in response.data["results"]]
|
||||
self.assertIn(self.public_account.id, account_ids)
|
||||
self.assertNotIn(self.private_account.id, account_ids)
|
||||
|
||||
def test_user_can_access_public_account_detail(self):
|
||||
"""User2 should be able to access public account details."""
|
||||
response = self.client2.get(f"/api/accounts/{self.public_account.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Public Account")
|
||||
|
||||
def test_user_can_see_transactions_in_public_accounts(self):
|
||||
"""User2 should see transactions in public accounts."""
|
||||
response = self.client2.get("/api/transactions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
transaction_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.public_transaction.id, transaction_ids)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class SharedCategoryTagEntityTests(TestCase):
|
||||
"""Tests for shared categories, tags, and entities."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with shared categories/tags/entities."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
self.user3 = User.objects.create_user(
|
||||
email="user3@test.com", password="testpass123"
|
||||
)
|
||||
self.client3 = APIClient()
|
||||
self.client3.force_authenticate(user=self.user3)
|
||||
|
||||
# User 1's category shared with user 2
|
||||
self.shared_category = TransactionCategory.all_objects.create(
|
||||
name="Shared Category", owner=self.user1
|
||||
)
|
||||
self.shared_category.shared_with.add(self.user2)
|
||||
|
||||
# User 1's private category
|
||||
self.private_category = TransactionCategory.all_objects.create(
|
||||
name="Private Category", owner=self.user1
|
||||
)
|
||||
|
||||
# User 1's public category
|
||||
self.public_category = TransactionCategory.all_objects.create(
|
||||
name="Public Category", owner=self.user1, visibility="public"
|
||||
)
|
||||
|
||||
# User 1's tag shared with user 2
|
||||
self.shared_tag = TransactionTag.all_objects.create(
|
||||
name="Shared Tag", owner=self.user1
|
||||
)
|
||||
self.shared_tag.shared_with.add(self.user2)
|
||||
|
||||
# User 1's entity shared with user 2
|
||||
self.shared_entity = TransactionEntity.all_objects.create(
|
||||
name="Shared Entity", owner=self.user1
|
||||
)
|
||||
self.shared_entity.shared_with.add(self.user2)
|
||||
|
||||
def test_user_can_see_shared_categories(self):
|
||||
"""User2 should see categories shared with them."""
|
||||
response = self.client2.get("/api/categories/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
category_ids = [c["id"] for c in response.data["results"]]
|
||||
self.assertIn(self.shared_category.id, category_ids)
|
||||
self.assertNotIn(self.private_category.id, category_ids)
|
||||
|
||||
def test_user_can_access_shared_category_detail(self):
|
||||
"""User2 should be able to access shared category details."""
|
||||
response = self.client2.get(f"/api/categories/{self.shared_category.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Category")
|
||||
|
||||
def test_user_can_see_public_categories(self):
|
||||
"""User3 should see public categories."""
|
||||
response = self.client3.get("/api/categories/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
category_ids = [c["id"] for c in response.data["results"]]
|
||||
self.assertIn(self.public_category.id, category_ids)
|
||||
|
||||
def test_user_without_share_cannot_see_shared_category(self):
|
||||
"""User3 should NOT see category shared only with user2."""
|
||||
response = self.client3.get("/api/categories/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
category_ids = [c["id"] for c in response.data["results"]]
|
||||
self.assertNotIn(self.shared_category.id, category_ids)
|
||||
|
||||
def test_user_can_see_shared_tags(self):
|
||||
"""User2 should see tags shared with them."""
|
||||
response = self.client2.get("/api/tags/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
tag_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.shared_tag.id, tag_ids)
|
||||
|
||||
def test_user_can_access_shared_tag_detail(self):
|
||||
"""User2 should be able to access shared tag details."""
|
||||
response = self.client2.get(f"/api/tags/{self.shared_tag.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Tag")
|
||||
|
||||
def test_user_can_see_shared_entities(self):
|
||||
"""User2 should see entities shared with them."""
|
||||
response = self.client2.get("/api/entities/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
entity_ids = [e["id"] for e in response.data["results"]]
|
||||
self.assertIn(self.shared_entity.id, entity_ids)
|
||||
|
||||
def test_user_can_access_shared_entity_detail(self):
|
||||
"""User2 should be able to access shared entity details."""
|
||||
response = self.client2.get(f"/api/entities/{self.shared_entity.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Entity")
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class SharedDCAAccessTests(TestCase):
|
||||
"""Tests for shared DCA strategy access."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with shared DCA strategies."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
self.user3 = User.objects.create_user(
|
||||
email="user3@test.com", password="testpass123"
|
||||
)
|
||||
self.client3 = APIClient()
|
||||
self.client3.force_authenticate(user=self.user3)
|
||||
|
||||
self.currency1 = Currency.objects.create(
|
||||
code="BTC", name="Bitcoin", decimal_places=8, prefix=""
|
||||
)
|
||||
self.currency2 = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's DCA strategy shared with user 2
|
||||
self.shared_strategy = DCAStrategy.all_objects.create(
|
||||
name="Shared BTC Strategy",
|
||||
target_currency=self.currency1,
|
||||
payment_currency=self.currency2,
|
||||
owner=self.user1,
|
||||
)
|
||||
self.shared_strategy.shared_with.add(self.user2)
|
||||
|
||||
# Entry in shared strategy
|
||||
self.shared_entry = DCAEntry.objects.create(
|
||||
strategy=self.shared_strategy,
|
||||
date=date(2025, 1, 1),
|
||||
amount_paid=Decimal("100.00"),
|
||||
amount_received=Decimal("0.001"),
|
||||
)
|
||||
|
||||
# User 1's private strategy
|
||||
self.private_strategy = DCAStrategy.all_objects.create(
|
||||
name="Private BTC Strategy",
|
||||
target_currency=self.currency1,
|
||||
payment_currency=self.currency2,
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
def test_user_can_see_shared_dca_strategies(self):
|
||||
"""User2 should see DCA strategies shared with them."""
|
||||
response = self.client2.get("/api/dca/strategies/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
strategy_ids = [s["id"] for s in response.data["results"]]
|
||||
self.assertIn(self.shared_strategy.id, strategy_ids)
|
||||
self.assertNotIn(self.private_strategy.id, strategy_ids)
|
||||
|
||||
def test_user_can_access_shared_dca_strategy_detail(self):
|
||||
"""User2 should be able to access shared strategy details."""
|
||||
response = self.client2.get(f"/api/dca/strategies/{self.shared_strategy.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared BTC Strategy")
|
||||
|
||||
def test_user_without_share_cannot_see_shared_strategy(self):
|
||||
"""User3 should NOT see strategy shared only with user2."""
|
||||
response = self.client3.get("/api/dca/strategies/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
strategy_ids = [s["id"] for s in response.data["results"]]
|
||||
self.assertNotIn(self.shared_strategy.id, strategy_ids)
|
||||
|
||||
def test_user_can_access_shared_strategy_actions(self):
|
||||
"""User2 should be able to access actions on shared strategy."""
|
||||
# investment_frequency
|
||||
response1 = self.client2.get(
|
||||
f"/api/dca/strategies/{self.shared_strategy.id}/investment_frequency/"
|
||||
)
|
||||
self.assertEqual(response1.status_code, status.HTTP_200_OK)
|
||||
|
||||
# price_comparison
|
||||
response2 = self.client2.get(
|
||||
f"/api/dca/strategies/{self.shared_strategy.id}/price_comparison/"
|
||||
)
|
||||
self.assertEqual(response2.status_code, status.HTTP_200_OK)
|
||||
|
||||
# current_price
|
||||
response3 = self.client2.get(
|
||||
f"/api/dca/strategies/{self.shared_strategy.id}/current_price/"
|
||||
)
|
||||
self.assertEqual(response3.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class SharedAccountGroupTests(TestCase):
|
||||
"""Tests for shared account group access."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with shared account groups."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
self.user3 = User.objects.create_user(
|
||||
email="user3@test.com", password="testpass123"
|
||||
)
|
||||
self.client3 = APIClient()
|
||||
self.client3.force_authenticate(user=self.user3)
|
||||
|
||||
# User 1's account group shared with user 2
|
||||
self.shared_group = AccountGroup.all_objects.create(
|
||||
name="Shared Group", owner=self.user1
|
||||
)
|
||||
self.shared_group.shared_with.add(self.user2)
|
||||
|
||||
# User 1's private account group
|
||||
self.private_group = AccountGroup.all_objects.create(
|
||||
name="Private Group", owner=self.user1
|
||||
)
|
||||
|
||||
# User 1's public account group
|
||||
self.public_group = AccountGroup.all_objects.create(
|
||||
name="Public Group", owner=self.user1, visibility="public"
|
||||
)
|
||||
|
||||
def test_user_can_see_shared_account_groups(self):
|
||||
"""User2 should see account groups shared with them."""
|
||||
response = self.client2.get("/api/account-groups/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
group_ids = [g["id"] for g in response.data["results"]]
|
||||
self.assertIn(self.shared_group.id, group_ids)
|
||||
self.assertNotIn(self.private_group.id, group_ids)
|
||||
|
||||
def test_user_can_access_shared_account_group_detail(self):
|
||||
"""User2 should be able to access shared account group details."""
|
||||
response = self.client2.get(f"/api/account-groups/{self.shared_group.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Group")
|
||||
|
||||
def test_user_can_see_public_account_groups(self):
|
||||
"""User3 should see public account groups."""
|
||||
response = self.client3.get("/api/account-groups/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
group_ids = [g["id"] for g in response.data["results"]]
|
||||
self.assertIn(self.public_group.id, group_ids)
|
||||
|
||||
def test_user_without_share_cannot_access_shared_group(self):
|
||||
"""User3 should NOT be able to access shared account group."""
|
||||
response = self.client3.get(f"/api/account-groups/{self.shared_group.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
@@ -16,7 +16,11 @@ router.register(r"currencies", views.CurrencyViewSet)
|
||||
router.register(r"exchange-rates", views.ExchangeRateViewSet)
|
||||
router.register(r"dca/strategies", views.DCAStrategyViewSet)
|
||||
router.register(r"dca/entries", views.DCAEntryViewSet)
|
||||
router.register(r"import/profiles", views.ImportProfileViewSet, basename="import-profiles")
|
||||
router.register(r"import/runs", views.ImportRunViewSet, basename="import-runs")
|
||||
router.register(r"import/import", views.ImportViewSet, basename="import-import")
|
||||
|
||||
urlpatterns = [
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
|
||||
|
||||
@@ -2,3 +2,5 @@ from .transactions import *
|
||||
from .accounts import *
|
||||
from .currencies import *
|
||||
from .dca import *
|
||||
from .imports import *
|
||||
|
||||
|
||||
@@ -1,27 +1,79 @@
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||
from apps.accounts.models import AccountGroup, Account
|
||||
from apps.api.serializers import AccountGroupSerializer, AccountSerializer
|
||||
from apps.accounts.services import get_account_balance
|
||||
from apps.api.serializers import (
|
||||
AccountGroupSerializer,
|
||||
AccountSerializer,
|
||||
AccountBalanceSerializer,
|
||||
)
|
||||
|
||||
|
||||
class AccountGroupViewSet(viewsets.ModelViewSet):
|
||||
"""ViewSet for managing account groups."""
|
||||
|
||||
queryset = AccountGroup.objects.all()
|
||||
serializer_class = AccountGroupSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return AccountGroup.objects.all().order_by("id")
|
||||
return AccountGroup.objects.all()
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
balance=extend_schema(
|
||||
summary="Get account balance",
|
||||
description="Returns the current and projected balance for the account, along with currency data.",
|
||||
responses={200: AccountBalanceSerializer},
|
||||
),
|
||||
)
|
||||
class AccountViewSet(viewsets.ModelViewSet):
|
||||
"""ViewSet for managing accounts."""
|
||||
|
||||
queryset = Account.objects.all()
|
||||
serializer_class = AccountSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"group": ["exact", "isnull"],
|
||||
"currency": ["exact"],
|
||||
"exchange_currency": ["exact", "isnull"],
|
||||
"is_asset": ["exact"],
|
||||
"is_archived": ["exact"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return (
|
||||
Account.objects.all()
|
||||
.order_by("id")
|
||||
.select_related("group", "currency", "exchange_currency")
|
||||
return Account.objects.all().select_related(
|
||||
"group", "currency", "exchange_currency"
|
||||
)
|
||||
|
||||
@action(detail=True, methods=["get"], permission_classes=[IsAuthenticated])
|
||||
def balance(self, request, pk=None):
|
||||
"""Get current and projected balance for an account."""
|
||||
account = self.get_object()
|
||||
|
||||
current_balance = get_account_balance(account, paid_only=True)
|
||||
projected_balance = get_account_balance(account, paid_only=False)
|
||||
|
||||
serializer = AccountBalanceSerializer(
|
||||
{
|
||||
"current_balance": current_balance,
|
||||
"projected_balance": projected_balance,
|
||||
"currency": account.currency,
|
||||
}
|
||||
)
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
@@ -9,8 +9,28 @@ from apps.currencies.models import ExchangeRate
|
||||
class CurrencyViewSet(viewsets.ModelViewSet):
|
||||
queryset = Currency.objects.all()
|
||||
serializer_class = CurrencySerializer
|
||||
filterset_fields = {
|
||||
'name': ['exact', 'icontains'],
|
||||
'code': ['exact', 'icontains'],
|
||||
'decimal_places': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'prefix': ['exact', 'icontains'],
|
||||
'suffix': ['exact', 'icontains'],
|
||||
'exchange_currency': ['exact'],
|
||||
'is_archived': ['exact'],
|
||||
}
|
||||
search_fields = '__all__'
|
||||
ordering_fields = '__all__'
|
||||
|
||||
|
||||
class ExchangeRateViewSet(viewsets.ModelViewSet):
|
||||
queryset = ExchangeRate.objects.all()
|
||||
serializer_class = ExchangeRateSerializer
|
||||
filterset_fields = {
|
||||
'from_currency': ['exact'],
|
||||
'to_currency': ['exact'],
|
||||
'rate': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'date': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'automatic': ['exact'],
|
||||
}
|
||||
search_fields = '__all__'
|
||||
ordering_fields = '__all__'
|
||||
|
||||
@@ -8,6 +8,19 @@ from apps.api.serializers import DCAStrategySerializer, DCAEntrySerializer
|
||||
class DCAStrategyViewSet(viewsets.ModelViewSet):
|
||||
queryset = DCAStrategy.objects.all()
|
||||
serializer_class = DCAStrategySerializer
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"target_currency": ["exact"],
|
||||
"payment_currency": ["exact"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"created_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"updated_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
}
|
||||
search_fields = ["name", "notes"]
|
||||
ordering_fields = "__all__"
|
||||
|
||||
def get_queryset(self):
|
||||
return DCAStrategy.objects.all()
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def investment_frequency(self, request, pk=None):
|
||||
@@ -32,10 +45,22 @@ class DCAStrategyViewSet(viewsets.ModelViewSet):
|
||||
class DCAEntryViewSet(viewsets.ModelViewSet):
|
||||
queryset = DCAEntry.objects.all()
|
||||
serializer_class = DCAEntrySerializer
|
||||
filterset_fields = {
|
||||
"strategy": ["exact"],
|
||||
"date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"amount_paid": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"amount_received": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"expense_transaction": ["exact", "isnull"],
|
||||
"income_transaction": ["exact", "isnull"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"created_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"updated_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
}
|
||||
search_fields = ["notes"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["-date"]
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = DCAEntry.objects.all()
|
||||
strategy_id = self.request.query_params.get("strategy", None)
|
||||
if strategy_id is not None:
|
||||
queryset = queryset.filter(strategy_id=strategy_id)
|
||||
return queryset
|
||||
# Filter entries by strategies the user has access to
|
||||
accessible_strategies = DCAStrategy.objects.all()
|
||||
return DCAEntry.objects.filter(strategy__in=accessible_strategies)
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
from django.core.files.storage import FileSystemStorage
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view, inline_serializer
|
||||
from rest_framework import serializers as drf_serializers
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.parsers import MultiPartParser
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.api.serializers import ImportFileSerializer, ImportProfileSerializer, ImportRunSerializer
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
from apps.import_app.tasks import process_import
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
summary="List import profiles",
|
||||
description="Returns a paginated list of all available import profiles.",
|
||||
),
|
||||
retrieve=extend_schema(
|
||||
summary="Get import profile",
|
||||
description="Returns the details of a specific import profile by ID.",
|
||||
),
|
||||
)
|
||||
class ImportProfileViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
"""ViewSet for listing and retrieving import profiles."""
|
||||
|
||||
queryset = ImportProfile.objects.all()
|
||||
serializer_class = ImportProfileSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
filterset_fields = {
|
||||
'name': ['exact', 'icontains'],
|
||||
'yaml_config': ['exact', 'icontains'],
|
||||
'version': ['exact'],
|
||||
}
|
||||
search_fields = ['name', 'yaml_config']
|
||||
ordering_fields = '__all__'
|
||||
ordering = ['name']
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
summary="List import runs",
|
||||
description="Returns a paginated list of import runs. Optionally filter by profile_id.",
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="profile_id",
|
||||
type=int,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter runs by profile ID",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
retrieve=extend_schema(
|
||||
summary="Get import run",
|
||||
description="Returns the details of a specific import run by ID, including status and logs.",
|
||||
),
|
||||
)
|
||||
class ImportRunViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
"""ViewSet for listing and retrieving import runs."""
|
||||
|
||||
queryset = ImportRun.objects.all().order_by("-id")
|
||||
serializer_class = ImportRunSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
filterset_fields = {
|
||||
'status': ['exact'],
|
||||
'profile': ['exact'],
|
||||
'file_name': ['exact', 'icontains'],
|
||||
'logs': ['exact', 'icontains'],
|
||||
'processed_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'total_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'successful_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'skipped_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'failed_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'started_at': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'],
|
||||
'finished_at': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'],
|
||||
}
|
||||
search_fields = ['file_name', 'logs']
|
||||
ordering_fields = '__all__'
|
||||
ordering = ['-id']
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
profile_id = self.request.query_params.get("profile_id")
|
||||
if profile_id:
|
||||
queryset = queryset.filter(profile_id=profile_id)
|
||||
return queryset
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
summary="Import file",
|
||||
description="Upload a CSV or XLSX file to import using an existing import profile. The import is queued and processed asynchronously.",
|
||||
request={
|
||||
"multipart/form-data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"profile_id": {"type": "integer", "description": "ID of the ImportProfile to use"},
|
||||
"file": {"type": "string", "format": "binary", "description": "CSV or XLSX file to import"},
|
||||
},
|
||||
"required": ["profile_id", "file"],
|
||||
},
|
||||
},
|
||||
responses={
|
||||
202: inline_serializer(
|
||||
name="ImportResponse",
|
||||
fields={
|
||||
"import_run_id": drf_serializers.IntegerField(),
|
||||
"status": drf_serializers.CharField(),
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
class ImportViewSet(viewsets.ViewSet):
|
||||
"""ViewSet for importing data via file upload."""
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
parser_classes = [MultiPartParser]
|
||||
|
||||
def create(self, request):
|
||||
serializer = ImportFileSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
profile = serializer.validated_data["profile"]
|
||||
uploaded_file = serializer.validated_data["file"]
|
||||
|
||||
# Save file to temp location
|
||||
fs = FileSystemStorage(location="/usr/src/app/temp")
|
||||
filename = fs.save(uploaded_file.name, uploaded_file)
|
||||
file_path = fs.path(filename)
|
||||
|
||||
# Create ImportRun record
|
||||
import_run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
|
||||
# Queue import task
|
||||
process_import.defer(
|
||||
import_run_id=import_run.id,
|
||||
file_path=file_path,
|
||||
user_id=request.user.id,
|
||||
)
|
||||
|
||||
return Response(
|
||||
{"import_run_id": import_run.id, "status": "queued"},
|
||||
status=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
@@ -2,7 +2,6 @@ from copy import deepcopy
|
||||
|
||||
from rest_framework import viewsets
|
||||
|
||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||
from apps.api.serializers import (
|
||||
TransactionSerializer,
|
||||
TransactionCategorySerializer,
|
||||
@@ -25,14 +24,41 @@ from apps.rules.signals import transaction_updated, transaction_created
|
||||
class TransactionViewSet(viewsets.ModelViewSet):
|
||||
queryset = Transaction.objects.all()
|
||||
serializer_class = TransactionSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"account": ["exact"],
|
||||
"type": ["exact"],
|
||||
"is_paid": ["exact"],
|
||||
"date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"reference_date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"mute": ["exact"],
|
||||
"amount": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"description": ["exact", "icontains"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"category": ["exact", "isnull"],
|
||||
"installment_plan": ["exact", "isnull"],
|
||||
"installment_id": ["exact", "gte", "lte"],
|
||||
"recurring_transaction": ["exact", "isnull"],
|
||||
"internal_note": ["exact", "icontains"],
|
||||
"internal_id": ["exact"],
|
||||
"deleted": ["exact"],
|
||||
"created_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"updated_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"deleted_at": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["description", "notes", "internal_note"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["-id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return Transaction.objects.all()
|
||||
|
||||
def perform_create(self, serializer):
|
||||
instance = serializer.save()
|
||||
transaction_created.send(sender=instance)
|
||||
|
||||
def perform_update(self, serializer):
|
||||
old_data = deepcopy(Transaction.objects.get(pk=serializer.data["pk"]))
|
||||
old_data = deepcopy(self.get_object())
|
||||
instance = serializer.save()
|
||||
transaction_updated.send(sender=instance, old_data=old_data)
|
||||
|
||||
@@ -40,50 +66,109 @@ class TransactionViewSet(viewsets.ModelViewSet):
|
||||
kwargs["partial"] = True
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
return Transaction.objects.all().order_by("-id")
|
||||
|
||||
|
||||
class TransactionCategoryViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionCategory.objects.all()
|
||||
serializer_class = TransactionCategorySerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"mute": ["exact"],
|
||||
"active": ["exact"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionCategory.objects.all().order_by("id")
|
||||
return TransactionCategory.objects.all()
|
||||
|
||||
|
||||
class TransactionTagViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionTag.objects.all()
|
||||
serializer_class = TransactionTagSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"active": ["exact"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionTag.objects.all().order_by("id")
|
||||
return TransactionTag.objects.all()
|
||||
|
||||
|
||||
class TransactionEntityViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionEntity.objects.all()
|
||||
serializer_class = TransactionEntitySerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"active": ["exact"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionEntity.objects.all().order_by("id")
|
||||
return TransactionEntity.objects.all()
|
||||
|
||||
|
||||
class InstallmentPlanViewSet(viewsets.ModelViewSet):
|
||||
queryset = InstallmentPlan.objects.all()
|
||||
serializer_class = InstallmentPlanSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"account": ["exact"],
|
||||
"type": ["exact"],
|
||||
"description": ["exact", "icontains"],
|
||||
"number_of_installments": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"installment_start": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"installment_total_number": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"start_date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"end_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"recurrence": ["exact"],
|
||||
"installment_amount": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"category": ["exact", "isnull"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"add_description_to_transaction": ["exact"],
|
||||
"add_notes_to_transaction": ["exact"],
|
||||
}
|
||||
search_fields = ["description", "notes"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["-id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return InstallmentPlan.objects.all().order_by("-id")
|
||||
return InstallmentPlan.objects.all()
|
||||
|
||||
|
||||
class RecurringTransactionViewSet(viewsets.ModelViewSet):
|
||||
queryset = RecurringTransaction.objects.all()
|
||||
serializer_class = RecurringTransactionSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"is_paused": ["exact"],
|
||||
"account": ["exact"],
|
||||
"type": ["exact"],
|
||||
"amount": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"description": ["exact", "icontains"],
|
||||
"category": ["exact", "isnull"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"start_date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"end_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"recurrence_type": ["exact"],
|
||||
"recurrence_interval": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"keep_at_most": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"last_generated_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"last_generated_reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"add_description_to_transaction": ["exact"],
|
||||
"add_notes_to_transaction": ["exact"],
|
||||
}
|
||||
search_fields = ["description", "notes"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["-id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return RecurringTransaction.objects.all().order_by("-id")
|
||||
return RecurringTransaction.objects.all()
|
||||
|
||||
@@ -23,3 +23,6 @@ class CommonConfig(AppConfig):
|
||||
# Delete the cache for update checks to prevent false-positives when the app is restarted
|
||||
# this will be recreated by the check_for_updates task
|
||||
cache.delete("update_check")
|
||||
|
||||
# Register system checks for required environment variables
|
||||
from apps.common import checks # noqa: F401
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Django System Checks for required environment variables.
|
||||
|
||||
This module validates that required environment variables (those without defaults)
|
||||
are present before the application starts.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from django.core.checks import Error, register
|
||||
|
||||
|
||||
# List of environment variables that are required (no default values)
|
||||
# Based on the README.md documentation
|
||||
REQUIRED_ENV_VARS = [
|
||||
("SECRET_KEY", "This is used to provide cryptographic signing."),
|
||||
("SQL_DATABASE", "The name of your postgres database."),
|
||||
]
|
||||
|
||||
# List of environment variables that must be valid integers if set
|
||||
INT_ENV_VARS = [
|
||||
("TASK_WORKERS", "How many workers to have for async tasks."),
|
||||
("SESSION_EXPIRY_TIME", "The age of session cookies, in seconds."),
|
||||
("INTERNAL_PORT", "The port on which the app listens on."),
|
||||
("DJANGO_VITE_DEV_SERVER_PORT", "The port where Vite's dev server is running"),
|
||||
]
|
||||
|
||||
|
||||
@register()
|
||||
def check_required_env_vars(app_configs, **kwargs):
|
||||
"""
|
||||
Check that all required environment variables are set.
|
||||
|
||||
Returns a list of Error objects for any missing required variables.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for var_name, description in REQUIRED_ENV_VARS:
|
||||
value = os.getenv(var_name)
|
||||
if not value:
|
||||
errors.append(
|
||||
Error(
|
||||
f"Required environment variable '{var_name}' is not set.",
|
||||
hint=f"{description} Please set this variable in your .env file or environment.",
|
||||
id="wygiwyh.E001",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@register()
|
||||
def check_int_env_vars(app_configs, **kwargs):
|
||||
"""
|
||||
Check that environment variables that should be integers are valid.
|
||||
|
||||
Returns a list of Error objects for any invalid integer variables.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for var_name, description in INT_ENV_VARS:
|
||||
value = os.getenv(var_name)
|
||||
if value is not None:
|
||||
try:
|
||||
int(value)
|
||||
except ValueError:
|
||||
errors.append(
|
||||
Error(
|
||||
f"Environment variable '{var_name}' must be a valid integer, got '{value}'.",
|
||||
hint=f"{description}",
|
||||
id="wygiwyh.E002",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@register()
|
||||
def check_soft_delete_config(app_configs, **kwargs):
|
||||
"""
|
||||
Check that KEEP_DELETED_TRANSACTIONS_FOR is a valid integer when ENABLE_SOFT_DELETE is enabled.
|
||||
|
||||
Returns a list of Error objects if the configuration is invalid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
enable_soft_delete = os.getenv("ENABLE_SOFT_DELETE", "false").lower() == "true"
|
||||
|
||||
if enable_soft_delete:
|
||||
keep_deleted_for = os.getenv("KEEP_DELETED_TRANSACTIONS_FOR")
|
||||
if keep_deleted_for is not None:
|
||||
try:
|
||||
int(keep_deleted_for)
|
||||
except ValueError:
|
||||
errors.append(
|
||||
Error(
|
||||
f"Environment variable 'KEEP_DELETED_TRANSACTIONS_FOR' must be a valid integer when ENABLE_SOFT_DELETE is enabled, got '{keep_deleted_for}'.",
|
||||
hint="Time in days to keep soft deleted transactions for. Set to 0 to keep all transactions indefinitely.",
|
||||
id="wygiwyh.E003",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
+11
-16
@@ -1,14 +1,13 @@
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from django import forms
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.core.exceptions import ValidationError
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Submit, Div, HTML
|
||||
|
||||
from apps.common.widgets.tom_select import TomSelect, TomSelectMultiple
|
||||
from apps.common.models import SharedObject
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.tom_select import TomSelect, TomSelectMultiple
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import HTML, Div, Field, Layout, Submit
|
||||
from django import forms
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
@@ -39,6 +38,7 @@ class SharedObjectForm(forms.Form):
|
||||
choices=SharedObject.Visibility.choices,
|
||||
required=True,
|
||||
label=_("Visibility"),
|
||||
widget=TomSelect(clear_button=False),
|
||||
help_text=_(
|
||||
"Private: Only shown for the owner and shared users. Only editable by the owner."
|
||||
"<br/>"
|
||||
@@ -48,9 +48,6 @@ class SharedObjectForm(forms.Form):
|
||||
|
||||
class Meta:
|
||||
fields = ["visibility", "shared_with_users"]
|
||||
widgets = {
|
||||
"visibility": TomSelect(clear_button=False),
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get the current user to filter available sharing options
|
||||
@@ -73,12 +70,10 @@ class SharedObjectForm(forms.Form):
|
||||
self.helper.layout = Layout(
|
||||
Field("owner"),
|
||||
Field("visibility"),
|
||||
HTML("<hr>"),
|
||||
HTML('<hr class="hr my-3">'),
|
||||
Field("shared_with_users"),
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Save"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Save"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,47 @@
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
import procrastinate
|
||||
from django.db import close_old_connections
|
||||
|
||||
|
||||
_CONNECTION_CLEANUP_WRAPPED = "_wygiwyh_connection_cleanup_wrapped"
|
||||
|
||||
|
||||
def _wrap_task_with_django_connection_cleanup(task):
|
||||
if getattr(task.func, _CONNECTION_CLEANUP_WRAPPED, False):
|
||||
return
|
||||
|
||||
func = task.func
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapped(*args, **kwargs):
|
||||
close_old_connections()
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
finally:
|
||||
close_old_connections()
|
||||
|
||||
wrapped = async_wrapped
|
||||
else:
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapped(*args, **kwargs):
|
||||
close_old_connections()
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
close_old_connections()
|
||||
|
||||
wrapped = sync_wrapped
|
||||
|
||||
setattr(wrapped, _CONNECTION_CLEANUP_WRAPPED, True)
|
||||
task.func = wrapped
|
||||
|
||||
|
||||
def on_app_ready(app: procrastinate.App):
|
||||
"""This function is ran upon procrastinate initialization."""
|
||||
...
|
||||
for task in set(app.tasks.values()):
|
||||
_wrap_task_with_django_connection_cleanup(task)
|
||||
|
||||
@@ -17,7 +17,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.periodic(cron="0 4 * * *")
|
||||
@app.task(queueing_lock="remove_old_jobs", pass_context=True, name="remove_old_jobs")
|
||||
@app.task(
|
||||
lock="remove_old_jobs",
|
||||
queueing_lock="remove_old_jobs",
|
||||
pass_context=True,
|
||||
name="remove_old_jobs",
|
||||
)
|
||||
async def remove_old_jobs(context, timestamp):
|
||||
try:
|
||||
return await builtin_tasks.remove_old_jobs(
|
||||
@@ -36,7 +41,11 @@ async def remove_old_jobs(context, timestamp):
|
||||
|
||||
|
||||
@app.periodic(cron="0 6 1 * *")
|
||||
@app.task(queueing_lock="remove_expired_sessions", name="remove_expired_sessions")
|
||||
@app.task(
|
||||
lock="remove_expired_sessions",
|
||||
queueing_lock="remove_expired_sessions",
|
||||
name="remove_expired_sessions",
|
||||
)
|
||||
async def remove_expired_sessions(timestamp=None):
|
||||
"""Cleanup expired sessions by using Django management command."""
|
||||
try:
|
||||
@@ -49,7 +58,7 @@ async def remove_expired_sessions(timestamp=None):
|
||||
|
||||
|
||||
@app.periodic(cron="0 8 * * *")
|
||||
@app.task(name="reset_demo_data")
|
||||
@app.task(lock="reset_demo_data", name="reset_demo_data")
|
||||
def reset_demo_data(timestamp=None):
|
||||
"""
|
||||
Wipes the database and loads fresh demo data if DEMO mode is active.
|
||||
@@ -86,9 +95,7 @@ def reset_demo_data(timestamp=None):
|
||||
|
||||
|
||||
@app.periodic(cron="0 */12 * * *") # Every 12 hours
|
||||
@app.task(
|
||||
name="check_for_updates",
|
||||
)
|
||||
@app.task(lock="check_for_updates", name="check_for_updates")
|
||||
def check_for_updates(timestamp=None):
|
||||
if not settings.CHECK_FOR_UPDATES:
|
||||
return "CHECK_FOR_UPDATES is disabled"
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from django import forms, template
|
||||
|
||||
register = template.Library()
|
||||
|
||||
|
||||
@register.filter
|
||||
def is_input(field):
|
||||
return isinstance(field.field.widget, forms.TextInput)
|
||||
|
||||
|
||||
@register.filter
|
||||
def is_textarea(field):
|
||||
return isinstance(field.field.widget, forms.Textarea)
|
||||
@@ -11,7 +11,7 @@ def toast_bg(tags):
|
||||
elif "warning" in tags:
|
||||
return "warning"
|
||||
elif "error" in tags:
|
||||
return "danger"
|
||||
return "error"
|
||||
elif "info" in tags:
|
||||
return "info"
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import procrastinate
|
||||
from django.db import connection
|
||||
from django.test import SimpleTestCase, TransactionTestCase
|
||||
from procrastinate.testing import InMemoryConnector
|
||||
|
||||
from apps.common.procrastinate import on_app_ready
|
||||
|
||||
|
||||
def make_app_with_task(func):
|
||||
app = procrastinate.App(connector=InMemoryConnector())
|
||||
task = app.task(name="sample_task")(func)
|
||||
|
||||
return app, task
|
||||
|
||||
|
||||
class ProcrastinateConnectionCleanupTests(SimpleTestCase):
|
||||
def test_app_ready_closes_old_connections_around_sync_tasks(self):
|
||||
calls = []
|
||||
|
||||
def sample_task(value):
|
||||
calls.append(("task", value))
|
||||
return value * 2
|
||||
|
||||
app, task = make_app_with_task(sample_task)
|
||||
|
||||
with patch(
|
||||
"apps.common.procrastinate.close_old_connections",
|
||||
create=True,
|
||||
side_effect=lambda: calls.append(("cleanup", None)),
|
||||
):
|
||||
on_app_ready(app)
|
||||
|
||||
result = task.func(3)
|
||||
|
||||
self.assertEqual(result, 6)
|
||||
self.assertEqual(
|
||||
calls,
|
||||
[
|
||||
("cleanup", None),
|
||||
("task", 3),
|
||||
("cleanup", None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_app_ready_closes_old_connections_when_sync_task_raises(self):
|
||||
calls = []
|
||||
|
||||
def sample_task():
|
||||
calls.append(("task", None))
|
||||
raise RuntimeError("boom")
|
||||
|
||||
app, task = make_app_with_task(sample_task)
|
||||
|
||||
with patch(
|
||||
"apps.common.procrastinate.close_old_connections",
|
||||
create=True,
|
||||
side_effect=lambda: calls.append(("cleanup", None)),
|
||||
):
|
||||
on_app_ready(app)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
task.func()
|
||||
|
||||
self.assertEqual(
|
||||
calls,
|
||||
[
|
||||
("cleanup", None),
|
||||
("task", None),
|
||||
("cleanup", None),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ProcrastinateConnectionRecoveryTests(TransactionTestCase):
|
||||
def test_wrapped_task_recovers_from_closed_django_connection(self):
|
||||
def sample_task():
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
app, task = make_app_with_task(sample_task)
|
||||
on_app_ready(app)
|
||||
|
||||
connection.ensure_connection()
|
||||
connection.connection.close()
|
||||
|
||||
self.assertEqual(task.func(), 1)
|
||||
@@ -0,0 +1,5 @@
|
||||
from crispy_forms.layout import Field
|
||||
|
||||
|
||||
class Switch(Field):
|
||||
template = "crispy-daisyui/layout/switch.html"
|
||||
@@ -1,15 +1,14 @@
|
||||
import datetime
|
||||
|
||||
from django.forms import widgets
|
||||
from django.utils import formats, translation, dates
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.functions.format import get_format
|
||||
from apps.common.utils.django import (
|
||||
django_to_python_datetime,
|
||||
django_to_airdatepicker_datetime,
|
||||
django_to_airdatepicker_datetime_separated,
|
||||
django_to_python_datetime,
|
||||
)
|
||||
from apps.common.functions.format import get_format
|
||||
from django.forms import widgets
|
||||
from django.utils import dates, formats, translation
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class AirDatePickerInput(widgets.DateInput):
|
||||
@@ -52,6 +51,8 @@ class AirDatePickerInput(widgets.DateInput):
|
||||
def build_attrs(self, base_attrs, extra_attrs=None):
|
||||
attrs = super().build_attrs(base_attrs, extra_attrs)
|
||||
|
||||
attrs["class"] = attrs.get("class", "") + " input"
|
||||
|
||||
attrs["data-now-button-txt"] = _("Today")
|
||||
attrs["data-auto-close"] = str(self.auto_close).lower()
|
||||
attrs["data-clear-button"] = str(self.clear_button).lower()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from django.forms import widgets, SelectMultiple
|
||||
from django.forms import SelectMultiple, widgets
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
@@ -17,7 +17,7 @@ class TomSelect(widgets.Select):
|
||||
checkboxes=False,
|
||||
group_by=None,
|
||||
*args,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(attrs, *args, **kwargs)
|
||||
self.remove_button = remove_button
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
@@ -258,7 +257,10 @@ class ExchangeRateFetcher:
|
||||
processed_pairs.add((from_currency.id, to_currency.id))
|
||||
|
||||
service.last_fetch = timezone.now()
|
||||
service.failure_count = 0
|
||||
service.save()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching rates for {service.name}: {e}")
|
||||
service.failure_count += 1
|
||||
service.save()
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from crispy_bootstrap5.bootstrap5 import Switch
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Row, Column
|
||||
from django import forms
|
||||
from django.forms import CharField
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.crispy.daisyui import Switch
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.datepicker import AirDateTimePickerInput
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.currencies.models import Currency, ExchangeRate, ExchangeRateService
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Column, Layout, Row
|
||||
from django import forms
|
||||
from django.forms import CharField
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class CurrencyForm(forms.ModelForm):
|
||||
@@ -51,17 +50,13 @@ class CurrencyForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -89,17 +84,13 @@ class ExchangeRateForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -132,8 +123,8 @@ class ExchangeRateServiceForm(forms.ModelForm):
|
||||
Switch("singleton"),
|
||||
"api_key",
|
||||
Row(
|
||||
Column("interval_type", css_class="form-group col-md-6"),
|
||||
Column("fetch_interval", css_class="form-group col-md-6"),
|
||||
Column("interval_type"),
|
||||
Column("fetch_interval"),
|
||||
),
|
||||
"target_currencies",
|
||||
"target_accounts",
|
||||
@@ -142,16 +133,12 @@ class ExchangeRateServiceForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.10 on 2026-01-10 06:08
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0022_currency_is_archived'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='exchangerateservice',
|
||||
name='failure_count',
|
||||
field=models.PositiveIntegerField(default=0),
|
||||
),
|
||||
]
|
||||
@@ -136,6 +136,8 @@ class ExchangeRateService(models.Model):
|
||||
null=True, blank=True, verbose_name=_("Last Successful Fetch")
|
||||
)
|
||||
|
||||
failure_count = models.PositiveIntegerField(default=0)
|
||||
|
||||
target_currencies = models.ManyToManyField(
|
||||
Currency,
|
||||
verbose_name=_("Target Currencies"),
|
||||
@@ -237,7 +239,7 @@ class ExchangeRateService(models.Model):
|
||||
hours = self._parse_hour_ranges(self.fetch_interval)
|
||||
# Store in normalized format (optional)
|
||||
self.fetch_interval = ",".join(str(h) for h in sorted(hours))
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
raise ValidationError(
|
||||
{
|
||||
"fetch_interval": _(
|
||||
@@ -248,7 +250,7 @@ class ExchangeRateService(models.Model):
|
||||
)
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValidationError(
|
||||
{
|
||||
"fetch_interval": _(
|
||||
|
||||
@@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.periodic(cron="0 * * * *") # Run every hour
|
||||
@app.task(name="automatic_fetch_exchange_rates")
|
||||
@app.task(lock="automatic_fetch_exchange_rates", name="automatic_fetch_exchange_rates")
|
||||
def automatic_fetch_exchange_rates(timestamp=None):
|
||||
"""Fetch exchange rates for all due services"""
|
||||
fetcher = ExchangeRateFetcher()
|
||||
@@ -19,7 +19,7 @@ def automatic_fetch_exchange_rates(timestamp=None):
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
|
||||
@app.task(name="manual_fetch_exchange_rates")
|
||||
@app.task(lock="manual_fetch_exchange_rates", name="manual_fetch_exchange_rates")
|
||||
def manual_fetch_exchange_rates(timestamp=None):
|
||||
"""Fetch exchange rates for all due services"""
|
||||
fetcher = ExchangeRateFetcher()
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Tests package for currencies app
|
||||
@@ -0,0 +1,109 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.currencies.models import Currency, ExchangeRateService
|
||||
from apps.currencies.exchange_rates.fetcher import ExchangeRateFetcher
|
||||
|
||||
|
||||
class ExchangeRateServiceFailureTrackingTests(TestCase):
|
||||
"""Tests for the failure count tracking functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
self.usd = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.eur = Currency.objects.create(
|
||||
code="EUR", name="Euro", decimal_places=2, prefix="€ "
|
||||
)
|
||||
self.eur.exchange_currency = self.usd
|
||||
self.eur.save()
|
||||
|
||||
self.service = ExchangeRateService.objects.create(
|
||||
name="Test Service",
|
||||
service_type=ExchangeRateService.ServiceType.FRANKFURTER,
|
||||
is_active=True,
|
||||
)
|
||||
self.service.target_currencies.add(self.eur)
|
||||
|
||||
def test_failure_count_increments_on_provider_error(self):
|
||||
"""Test that failure_count increments when provider raises an exception."""
|
||||
self.assertEqual(self.service.failure_count, 0)
|
||||
|
||||
with patch.object(
|
||||
self.service, "get_provider", side_effect=Exception("API Error")
|
||||
):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 1)
|
||||
|
||||
def test_failure_count_resets_on_success(self):
|
||||
"""Test that failure_count resets to 0 on successful fetch."""
|
||||
# Set initial failure count
|
||||
self.service.failure_count = 5
|
||||
self.service.save()
|
||||
|
||||
# Mock a successful provider
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.requires_api_key.return_value = False
|
||||
mock_provider.get_rates.return_value = [(self.usd, self.eur, Decimal("0.85"))]
|
||||
mock_provider.rates_inverted = False
|
||||
|
||||
with patch.object(self.service, "get_provider", return_value=mock_provider):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 0)
|
||||
|
||||
def test_failure_count_accumulates_across_fetches(self):
|
||||
"""Test that failure_count accumulates with consecutive failures."""
|
||||
self.assertEqual(self.service.failure_count, 0)
|
||||
|
||||
with patch.object(
|
||||
self.service, "get_provider", side_effect=Exception("API Error")
|
||||
):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 1)
|
||||
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 2)
|
||||
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 3)
|
||||
|
||||
def test_last_fetch_not_updated_on_failure(self):
|
||||
"""Test that last_fetch is NOT updated when a failure occurs."""
|
||||
original_last_fetch = self.service.last_fetch
|
||||
self.assertIsNone(original_last_fetch)
|
||||
|
||||
with patch.object(
|
||||
self.service, "get_provider", side_effect=Exception("API Error")
|
||||
):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
|
||||
self.service.refresh_from_db()
|
||||
self.assertIsNone(self.service.last_fetch)
|
||||
self.assertEqual(self.service.failure_count, 1)
|
||||
|
||||
def test_last_fetch_updated_on_success(self):
|
||||
"""Test that last_fetch IS updated when fetch succeeds."""
|
||||
self.assertIsNone(self.service.last_fetch)
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.requires_api_key.return_value = False
|
||||
mock_provider.get_rates.return_value = [(self.usd, self.eur, Decimal("0.85"))]
|
||||
mock_provider.rates_inverted = False
|
||||
|
||||
with patch.object(self.service, "get_provider", return_value=mock_provider):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
|
||||
self.service.refresh_from_db()
|
||||
self.assertIsNotNone(self.service.last_fetch)
|
||||
self.assertEqual(self.service.failure_count, 0)
|
||||
@@ -23,7 +23,7 @@ def currencies_index(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def currencies_list(request):
|
||||
currencies = Currency.objects.all().order_by("id")
|
||||
currencies = Currency.objects.all().order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"currencies/fragments/list.html",
|
||||
|
||||
+25
-47
@@ -1,22 +1,20 @@
|
||||
from crispy_bootstrap5.bootstrap5 import Switch, BS5Accordion
|
||||
from crispy_forms.bootstrap import FormActions, AccordionGroup
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Row, Column, HTML
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.datepicker import AirDatePickerInput
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.common.widgets.tom_select import TransactionSelect
|
||||
from apps.transactions.models import Transaction, TransactionTag, TransactionCategory
|
||||
from apps.common.fields.forms.dynamic_select import (
|
||||
DynamicModelChoiceField,
|
||||
DynamicModelMultipleChoiceField,
|
||||
)
|
||||
from apps.common.widgets.crispy.daisyui import Switch
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.datepicker import AirDatePickerInput
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect, TransactionSelect
|
||||
from apps.dca.models import DCAEntry, DCAStrategy
|
||||
from apps.transactions.models import Transaction, TransactionCategory, TransactionTag
|
||||
from crispy_forms.bootstrap import AccordionGroup, FormActions, Accordion
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import HTML, Column, Layout, Row
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class DCAStrategyForm(forms.ModelForm):
|
||||
@@ -36,8 +34,8 @@ class DCAStrategyForm(forms.ModelForm):
|
||||
self.helper.layout = Layout(
|
||||
"name",
|
||||
Row(
|
||||
Column("payment_currency", css_class="form-group col-md-6"),
|
||||
Column("target_currency", css_class="form-group col-md-6"),
|
||||
Column("payment_currency"),
|
||||
Column("target_currency"),
|
||||
),
|
||||
"notes",
|
||||
)
|
||||
@@ -45,17 +43,13 @@ class DCAStrategyForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -155,11 +149,11 @@ class DCAEntryForm(forms.ModelForm):
|
||||
self.helper.layout = Layout(
|
||||
"date",
|
||||
Row(
|
||||
Column("amount_paid", css_class="form-group col-md-6"),
|
||||
Column("amount_received", css_class="form-group col-md-6"),
|
||||
Column("amount_paid"),
|
||||
Column("amount_received"),
|
||||
),
|
||||
"notes",
|
||||
BS5Accordion(
|
||||
Accordion(
|
||||
AccordionGroup(
|
||||
_("Create transaction"),
|
||||
Switch("create_transaction"),
|
||||
@@ -168,19 +162,11 @@ class DCAEntryForm(forms.ModelForm):
|
||||
Row(
|
||||
Column(
|
||||
"from_account",
|
||||
css_class="form-group",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
"from_category",
|
||||
css_class="form-group col-md-6 mb-0",
|
||||
),
|
||||
Column(
|
||||
"from_tags", css_class="form-group col-md-6 mb-0"
|
||||
),
|
||||
css_class="form-row",
|
||||
Column("from_category"),
|
||||
Column("from_tags"),
|
||||
),
|
||||
),
|
||||
css_class="p-1 mx-1 my-3 border rounded-3",
|
||||
@@ -192,14 +178,10 @@ class DCAEntryForm(forms.ModelForm):
|
||||
"to_account",
|
||||
css_class="form-group",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
"to_category", css_class="form-group col-md-6 mb-0"
|
||||
),
|
||||
Column("to_tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("to_category"),
|
||||
Column("to_tags"),
|
||||
),
|
||||
),
|
||||
css_class="p-1 mx-1 my-3 border rounded-3",
|
||||
@@ -220,17 +202,13 @@ class DCAEntryForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# apps/dca_tracker/views.py
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.db.models import Sum, Avg
|
||||
@@ -23,7 +22,7 @@ def strategy_index(request):
|
||||
@only_htmx
|
||||
@login_required
|
||||
def strategy_list(request):
|
||||
strategies = DCAStrategy.objects.all().order_by("created_at")
|
||||
strategies = DCAStrategy.objects.all().order_by("name")
|
||||
return render(
|
||||
request, "dca/fragments/strategy/list.html", {"strategies": strategies}
|
||||
)
|
||||
@@ -234,7 +233,7 @@ def strategy_entry_add(request, strategy_id):
|
||||
if request.method == "POST":
|
||||
form = DCAEntryForm(request.POST, strategy=strategy)
|
||||
if form.is_valid():
|
||||
entry = form.save()
|
||||
form.save()
|
||||
messages.success(request, _("Entry added successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, HTML
|
||||
from crispy_forms.layout import HTML, Layout
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
|
||||
|
||||
class ExportForm(forms.Form):
|
||||
users = forms.BooleanField(
|
||||
@@ -115,9 +114,7 @@ class ExportForm(forms.Form):
|
||||
"dca",
|
||||
"import_profiles",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Export"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Export"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -162,7 +159,7 @@ class RestoreForm(forms.Form):
|
||||
self.helper.form_method = "post"
|
||||
self.helper.layout = Layout(
|
||||
"zip_file",
|
||||
HTML("<hr />"),
|
||||
HTML('<hr class="hr my-3"/>'),
|
||||
"users",
|
||||
"accounts",
|
||||
"currencies",
|
||||
@@ -181,9 +178,7 @@ class RestoreForm(forms.Form):
|
||||
"dca_entries",
|
||||
"import_profiles",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Restore"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Restore"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from import_export import fields, resources
|
||||
from import_export.widgets import ForeignKeyWidget
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.export_app.widgets.foreign_key import AutoCreateForeignKeyWidget
|
||||
from apps.export_app.widgets.foreign_key import (
|
||||
AllObjectsForeignKeyWidget,
|
||||
AutoCreateForeignKeyWidget,
|
||||
)
|
||||
from apps.export_app.widgets.many_to_many import AutoCreateManyToManyWidget
|
||||
from apps.export_app.widgets.string import EmptyStringToNoneField
|
||||
from apps.transactions.models import (
|
||||
@@ -20,7 +22,7 @@ class TransactionResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
widget=AllObjectsForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
@@ -86,7 +88,7 @@ class RecurringTransactionResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
widget=AllObjectsForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
@@ -119,12 +121,16 @@ class RecurringTransactionResource(resources.ModelResource):
|
||||
def get_queryset(self):
|
||||
return RecurringTransaction.all_objects.all()
|
||||
|
||||
def dehydrate_account_owner(self, obj):
|
||||
"""Export the account's owner ID for proper import matching."""
|
||||
return obj.account.owner_id if obj.account else None
|
||||
|
||||
|
||||
class InstallmentPlanResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
widget=AllObjectsForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
@@ -156,3 +162,7 @@ class InstallmentPlanResource(resources.ModelResource):
|
||||
|
||||
def get_queryset(self):
|
||||
return InstallmentPlan.all_objects.all()
|
||||
|
||||
def dehydrate_account_owner(self, obj):
|
||||
"""Export the account's owner ID for proper import matching."""
|
||||
return obj.account.owner_id if obj.account else None
|
||||
|
||||
@@ -1,6 +1,60 @@
|
||||
from import_export.widgets import ForeignKeyWidget
|
||||
|
||||
|
||||
class AllObjectsForeignKeyWidget(ForeignKeyWidget):
|
||||
"""
|
||||
ForeignKeyWidget that uses 'all_objects' manager for lookups,
|
||||
bypassing user-filtered managers like SharedObjectManager.
|
||||
Also filters by owner if available in the row data.
|
||||
"""
|
||||
|
||||
def get_queryset(self, value, row, *args, **kwargs):
|
||||
# Use all_objects manager if available, otherwise fall back to default
|
||||
if hasattr(self.model, "all_objects"):
|
||||
qs = self.model.all_objects.all()
|
||||
# Filter by owner if the row has an owner field and the model has owner
|
||||
if row:
|
||||
# Check for direct owner field first
|
||||
owner_id = row.get("owner") if "owner" in row else None
|
||||
# Fall back to account_owner for models like InstallmentPlan
|
||||
if not owner_id and "account_owner" in row:
|
||||
owner_id = row.get("account_owner")
|
||||
# If still no owner, try to get it from the existing record's account
|
||||
# This handles backward compatibility with older exports
|
||||
if not owner_id and "id" in row and row.get("id"):
|
||||
try:
|
||||
# Try to find the existing record and get owner from its account
|
||||
from apps.transactions.models import (
|
||||
InstallmentPlan,
|
||||
RecurringTransaction,
|
||||
)
|
||||
|
||||
record_id = row.get("id")
|
||||
# Try to find the existing InstallmentPlan or RecurringTransaction
|
||||
for model_class in [InstallmentPlan, RecurringTransaction]:
|
||||
try:
|
||||
existing = model_class.all_objects.get(id=record_id)
|
||||
if existing.account:
|
||||
owner_id = existing.account.owner_id
|
||||
break
|
||||
except model_class.DoesNotExist:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
# Final fallback: use the current logged-in user
|
||||
# This handles restoring to a fresh database with older exports
|
||||
if not owner_id:
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
|
||||
user = get_current_user()
|
||||
if user and user.is_authenticated:
|
||||
owner_id = user.id
|
||||
if owner_id:
|
||||
qs = qs.filter(owner_id=owner_id)
|
||||
return qs
|
||||
return super().get_queryset(value, row, *args, **kwargs)
|
||||
|
||||
|
||||
class AutoCreateForeignKeyWidget(ForeignKeyWidget):
|
||||
def clean(self, value, row=None, *args, **kwargs):
|
||||
if value:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.import_app.models import ImportProfile
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import (
|
||||
@@ -6,9 +8,6 @@ from crispy_forms.layout import (
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.import_app.models import ImportProfile
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
|
||||
|
||||
class ImportProfileForm(forms.ModelForm):
|
||||
class Meta:
|
||||
@@ -30,17 +29,13 @@ class ImportProfileForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -57,8 +52,6 @@ class ImportRunFileUploadForm(forms.Form):
|
||||
self.helper.layout = Layout(
|
||||
"file",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Import"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Import"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -106,6 +106,17 @@ class ExcelImportSettings(BaseModel):
|
||||
sheets: list[str] | str = "*"
|
||||
|
||||
|
||||
class QIFImportSettings(BaseModel):
|
||||
skip_errors: bool = Field(
|
||||
default=False,
|
||||
description="If True, errors during import will be logged and skipped",
|
||||
)
|
||||
file_type: Literal["qif"] = "qif"
|
||||
importing: Literal["transactions"] = "transactions"
|
||||
encoding: str = Field(default="utf-8", description="File encoding")
|
||||
date_format: str = Field(..., description="Date format (e.g. %d/%m/%Y)")
|
||||
|
||||
|
||||
class ColumnMapping(BaseModel):
|
||||
source: Optional[str] | Optional[list[str]] = Field(
|
||||
default=None,
|
||||
@@ -342,7 +353,7 @@ class CurrencyExchangeMapping(ColumnMapping):
|
||||
|
||||
|
||||
class ImportProfileSchema(BaseModel):
|
||||
settings: CSVImportSettings | ExcelImportSettings
|
||||
settings: CSVImportSettings | ExcelImportSettings | QIFImportSettings
|
||||
mapping: Dict[
|
||||
str,
|
||||
TransactionAccountMapping
|
||||
|
||||
@@ -3,6 +3,8 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from django.db import transaction
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Dict, Any, Literal, Union
|
||||
@@ -11,6 +13,7 @@ import openpyxl
|
||||
import xlrd
|
||||
import yaml
|
||||
from cachalot.api import cachalot_disabled
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.utils import timezone
|
||||
from openpyxl.utils.exceptions import InvalidFileException
|
||||
|
||||
@@ -363,7 +366,7 @@ class ImportService:
|
||||
try:
|
||||
if entities_mapping:
|
||||
if entities_mapping.type == "id":
|
||||
entity = TransactionTag.objects.filter(
|
||||
entity = TransactionEntity.objects.filter(
|
||||
id=entity_name
|
||||
).first()
|
||||
else: # name
|
||||
@@ -459,11 +462,12 @@ class ImportService:
|
||||
# Build query conditions for each field in the rule
|
||||
for field in rule.fields:
|
||||
if field in transaction_data:
|
||||
if rule.match_type == "strict":
|
||||
query = query.filter(**{field: transaction_data[field]})
|
||||
else: # lax matching
|
||||
query = query.filter(
|
||||
**{f"{field}__iexact": transaction_data[field]}
|
||||
value = transaction_data[field]
|
||||
query = self._apply_deduplication_filter(
|
||||
query=query,
|
||||
field=field,
|
||||
value=value,
|
||||
match_type=rule.match_type,
|
||||
)
|
||||
|
||||
# If we found any matching transaction, it's a duplicate
|
||||
@@ -472,14 +476,95 @@ class ImportService:
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_int_like(value: Any) -> bool:
|
||||
try:
|
||||
int(value)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _apply_deduplication_filter(
|
||||
self,
|
||||
query,
|
||||
field: str,
|
||||
value: Any,
|
||||
match_type: Literal["lax", "strict"],
|
||||
):
|
||||
if isinstance(value, list):
|
||||
return self._apply_list_deduplication_filter(
|
||||
query=query,
|
||||
field=field,
|
||||
values=value,
|
||||
match_type=match_type,
|
||||
)
|
||||
|
||||
# Use __iexact only for string fields; non-string types
|
||||
# (date, Decimal, bool, int, etc.) don't support UPPER()
|
||||
if match_type == "strict" or not isinstance(value, str):
|
||||
return query.filter(**{field: value})
|
||||
|
||||
return query.filter(**{f"{field}__iexact": value})
|
||||
|
||||
def _apply_list_deduplication_filter(
|
||||
self,
|
||||
query,
|
||||
field: str,
|
||||
values: list[Any],
|
||||
match_type: Literal["lax", "strict"],
|
||||
):
|
||||
clean_values = [v for v in values if v not in (None, "")]
|
||||
if not clean_values:
|
||||
return query
|
||||
|
||||
try:
|
||||
model_field = Transaction._meta.get_field(field)
|
||||
except FieldDoesNotExist:
|
||||
return query.filter(**{f"{field}__in": clean_values})
|
||||
|
||||
if getattr(model_field, "many_to_many", False):
|
||||
# For m2m fields (e.g., entities/tags), apply one filter per value so
|
||||
# all provided values must be present in the matched transaction.
|
||||
if all(self._is_int_like(v) for v in clean_values):
|
||||
for value in clean_values:
|
||||
query = query.filter(**{f"{field}__id": int(value)})
|
||||
else:
|
||||
for value in clean_values:
|
||||
lookup = (
|
||||
f"{field}__name"
|
||||
if match_type == "strict"
|
||||
else f"{field}__name__iexact"
|
||||
)
|
||||
query = query.filter(**{lookup: str(value).strip()})
|
||||
|
||||
return query.distinct()
|
||||
|
||||
return query.filter(**{f"{field}__in": clean_values})
|
||||
|
||||
def _coerce_type(
|
||||
self, value: str, mapping: version_1.ColumnMapping
|
||||
) -> Union[str, int, bool, Decimal, datetime, list, None]:
|
||||
coerce_to = mapping.coerce_to
|
||||
|
||||
# Handle detection methods that don't require a source value
|
||||
if coerce_to == "transaction_type" and isinstance(
|
||||
mapping, version_1.TransactionTypeMapping
|
||||
):
|
||||
if mapping.detection_method == "always_income":
|
||||
return Transaction.Type.INCOME
|
||||
elif mapping.detection_method == "always_expense":
|
||||
return Transaction.Type.EXPENSE
|
||||
elif coerce_to == "is_paid" and isinstance(
|
||||
mapping, version_1.TransactionIsPaidMapping
|
||||
):
|
||||
if mapping.detection_method == "always_paid":
|
||||
return True
|
||||
elif mapping.detection_method == "always_unpaid":
|
||||
return False
|
||||
|
||||
if not value:
|
||||
return None
|
||||
|
||||
coerce_to = mapping.coerce_to
|
||||
|
||||
return self._coerce_single_type(value, coerce_to, mapping)
|
||||
|
||||
@staticmethod
|
||||
@@ -828,6 +913,219 @@ class ImportService:
|
||||
f"Invalid {self.settings.file_type.upper()} file format: {str(e)}"
|
||||
)
|
||||
|
||||
def _parse_and_import_qif(self, content_lines: list[str], filename: str) -> None:
|
||||
# Infer account from filename (remove extension)
|
||||
account_name = os.path.splitext(os.path.basename(filename))[0]
|
||||
|
||||
current_transaction = {}
|
||||
raw_lines_buffer = []
|
||||
|
||||
account = Account.objects.filter(name=account_name).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account '{account_name}' not found.")
|
||||
|
||||
row_number = 0
|
||||
for line in content_lines:
|
||||
row_number += 1
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
raw_lines_buffer.append(line)
|
||||
|
||||
if line == "^":
|
||||
if current_transaction:
|
||||
# Deduplication using hash of raw lines
|
||||
raw_content = "".join(raw_lines_buffer)
|
||||
internal_id = hashlib.sha256(
|
||||
raw_content.encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
# Reset buffer for next transaction
|
||||
raw_lines_buffer = []
|
||||
|
||||
try:
|
||||
with transaction.atomic():
|
||||
if Transaction.objects.filter(
|
||||
internal_id=internal_id
|
||||
).exists():
|
||||
self._increment_totals("skipped", 1)
|
||||
self._log(
|
||||
"info",
|
||||
f"Skipped duplicate transaction from {filename}",
|
||||
)
|
||||
current_transaction = {}
|
||||
continue
|
||||
|
||||
# Handle Account
|
||||
if account:
|
||||
current_transaction["account"] = account
|
||||
else:
|
||||
acc = Account.objects.filter(name=account_name).first()
|
||||
if acc:
|
||||
current_transaction["account"] = acc
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Account '{account_name}' not found."
|
||||
)
|
||||
|
||||
current_transaction["internal_id"] = internal_id
|
||||
|
||||
# Handle Description/Memo mapping
|
||||
if "memo" in current_transaction:
|
||||
current_transaction["description"] = (
|
||||
current_transaction.pop("memo")
|
||||
)
|
||||
|
||||
# Handle Payee mapping
|
||||
entities = []
|
||||
if "payee" in current_transaction:
|
||||
payee_name = current_transaction.pop("payee")
|
||||
# "Treat the payee (P) as the entity. Use existing or create"
|
||||
entity, _ = TransactionEntity.objects.get_or_create(
|
||||
name=payee_name
|
||||
)
|
||||
entities.append(entity)
|
||||
|
||||
# Handle Label/Category
|
||||
category = None
|
||||
tags = []
|
||||
if "label" in current_transaction:
|
||||
label = current_transaction.pop("label")
|
||||
if label.startswith("[") and label.endswith("]"):
|
||||
# Transfer: set label as description, ignore category/tags
|
||||
clean_label = label[1:-1]
|
||||
current_transaction["description"] = clean_label
|
||||
else:
|
||||
parts = label.split(":")
|
||||
if parts:
|
||||
cat_name = parts[0].strip()
|
||||
if cat_name:
|
||||
category, _ = (
|
||||
TransactionCategory.objects.get_or_create(
|
||||
name=cat_name
|
||||
)
|
||||
)
|
||||
|
||||
if len(parts) > 1:
|
||||
for tag_name in parts[1:]:
|
||||
tag_name = tag_name.strip()
|
||||
if tag_name:
|
||||
tag, _ = (
|
||||
TransactionTag.objects.get_or_create(
|
||||
name=tag_name
|
||||
)
|
||||
)
|
||||
tags.append(tag)
|
||||
|
||||
current_transaction["category"] = category
|
||||
|
||||
# Create transaction
|
||||
new_trans = Transaction.objects.create(
|
||||
**current_transaction
|
||||
)
|
||||
if entities:
|
||||
new_trans.entities.set(entities)
|
||||
if tags:
|
||||
new_trans.tags.set(tags)
|
||||
|
||||
self.import_run.transactions.add(new_trans)
|
||||
self._increment_totals("successful", 1)
|
||||
|
||||
except Exception as e:
|
||||
if not self.settings.skip_errors:
|
||||
raise e
|
||||
self._log(
|
||||
"warning",
|
||||
f"Error processing transaction in {filename}: {str(e)}",
|
||||
)
|
||||
self._increment_totals("failed", 1)
|
||||
|
||||
# Reset for next transaction
|
||||
current_transaction = {}
|
||||
else:
|
||||
# Empty transaction record (orphaned ^)
|
||||
raw_lines_buffer = []
|
||||
pass
|
||||
self._increment_totals("processed", 1)
|
||||
continue
|
||||
|
||||
if line.startswith("!"):
|
||||
continue
|
||||
|
||||
code = line[0]
|
||||
value = line[1:]
|
||||
|
||||
if code == "D":
|
||||
try:
|
||||
current_transaction["date"] = datetime.strptime(
|
||||
value, self.settings.date_format
|
||||
).date()
|
||||
except ValueError:
|
||||
self._log(
|
||||
"warning",
|
||||
f"Could not parse date '{value}' using format '{self.settings.date_format}' in {filename}",
|
||||
)
|
||||
if not self.settings.skip_errors:
|
||||
raise ValueError(f"Invalid date format '{value}'")
|
||||
|
||||
elif code == "T":
|
||||
try:
|
||||
cleaned_value = value.replace(",", "")
|
||||
amount = Decimal(cleaned_value)
|
||||
if amount < 0:
|
||||
current_transaction["type"] = Transaction.Type.EXPENSE
|
||||
current_transaction["amount"] = abs(amount)
|
||||
else:
|
||||
current_transaction["type"] = Transaction.Type.INCOME
|
||||
current_transaction["amount"] = amount
|
||||
except InvalidOperation:
|
||||
self._log(
|
||||
"warning", f"Could not parse amount '{value}' in {filename}"
|
||||
)
|
||||
if not self.settings.skip_errors:
|
||||
raise ValueError(f"Invalid amount format '{value}'")
|
||||
|
||||
elif code == "P":
|
||||
current_transaction["payee"] = value
|
||||
elif code == "M":
|
||||
current_transaction["memo"] = value
|
||||
elif code == "L":
|
||||
current_transaction["label"] = value
|
||||
elif code == "N":
|
||||
pass
|
||||
|
||||
def _process_qif(self, file_path):
|
||||
def process_logic():
|
||||
if zipfile.is_zipfile(file_path):
|
||||
try:
|
||||
with zipfile.ZipFile(file_path, "r") as zf:
|
||||
for filename in zf.namelist():
|
||||
if filename.lower().endswith(
|
||||
".qif"
|
||||
) and not filename.startswith("__MACOSX"):
|
||||
self._log(
|
||||
"info", f"Processing QIF from ZIP: {filename}"
|
||||
)
|
||||
with zf.open(filename) as f:
|
||||
content = f.read().decode(self.settings.encoding)
|
||||
self._parse_and_import_qif(
|
||||
content.splitlines(), filename
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error processing ZIP file: {str(e)}")
|
||||
else:
|
||||
with open(file_path, "r", encoding=self.settings.encoding) as f:
|
||||
self._parse_and_import_qif(
|
||||
f.readlines(), os.path.basename(file_path)
|
||||
)
|
||||
|
||||
if not self.settings.skip_errors:
|
||||
with transaction.atomic():
|
||||
process_logic()
|
||||
else:
|
||||
process_logic()
|
||||
|
||||
def _validate_file_path(self, file_path: str) -> str:
|
||||
"""
|
||||
Validates that the file path is within the allowed temporary directory.
|
||||
@@ -854,6 +1152,8 @@ class ImportService:
|
||||
self._process_csv(file_path)
|
||||
elif isinstance(self.settings, version_1.ExcelImportSettings):
|
||||
self._process_excel(file_path)
|
||||
elif isinstance(self.settings, version_1.QIFImportSettings):
|
||||
self._process_qif(file_path)
|
||||
|
||||
self._update_status("FINISHED")
|
||||
self._log(
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from django.test import TestCase
|
||||
|
||||
# Create your tests here.
|
||||
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Tests for ImportService v1, specifically for deduplication logic.
|
||||
|
||||
These tests verify that the _check_duplicate_transaction method handles
|
||||
different field types correctly, particularly ensuring that __iexact
|
||||
is only used for string fields (not dates, decimals, etc.).
|
||||
"""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
from apps.import_app.services.v1 import ImportService
|
||||
from apps.transactions.models import Transaction, TransactionEntity
|
||||
|
||||
|
||||
class DeduplicationTests(TestCase):
|
||||
"""Tests for transaction deduplication during import."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
|
||||
# Create an existing transaction for deduplication tests
|
||||
self.existing_transaction = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=date(2024, 1, 15),
|
||||
amount=Decimal("100.00"),
|
||||
description="Existing Transaction",
|
||||
internal_id="ABC123",
|
||||
)
|
||||
|
||||
def _create_import_service_with_deduplication(
|
||||
self, fields: list[str], match_type: str = "lax"
|
||||
) -> ImportService:
|
||||
"""Helper to create an ImportService with specific deduplication rules."""
|
||||
yaml_config = f"""
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
trigger_transaction_rules: false
|
||||
mapping:
|
||||
date_field:
|
||||
source: date
|
||||
target: date
|
||||
format: "%Y-%m-%d"
|
||||
amount_field:
|
||||
source: amount
|
||||
target: amount
|
||||
description_field:
|
||||
source: description
|
||||
target: description
|
||||
account_field:
|
||||
source: account
|
||||
target: account
|
||||
type: id
|
||||
deduplication:
|
||||
- type: compare
|
||||
fields: {fields}
|
||||
match_type: {match_type}
|
||||
"""
|
||||
profile = ImportProfile.objects.create(
|
||||
name=f"Test Profile {match_type} {'_'.join(fields)}",
|
||||
yaml_config=yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
)
|
||||
import_run = ImportRun.objects.create(
|
||||
profile=profile,
|
||||
file_name="test.csv",
|
||||
)
|
||||
return ImportService(import_run)
|
||||
|
||||
def test_deduplication_with_date_field_strict_match(self):
|
||||
"""Test that date fields work with strict matching."""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date"], match_type="strict"
|
||||
)
|
||||
|
||||
# Should find duplicate when date matches
|
||||
is_duplicate = service._check_duplicate_transaction({"date": date(2024, 1, 15)})
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when date differs
|
||||
is_duplicate = service._check_duplicate_transaction({"date": date(2024, 2, 20)})
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_date_field_lax_match(self):
|
||||
"""
|
||||
Test that date fields use strict matching even when match_type is 'lax'.
|
||||
|
||||
This is the fix for the UPPER(date) PostgreSQL error. Date fields
|
||||
cannot use __iexact, so they should fall back to strict matching.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate when date matches (using strict comparison)
|
||||
is_duplicate = service._check_duplicate_transaction({"date": date(2024, 1, 15)})
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when date differs
|
||||
is_duplicate = service._check_duplicate_transaction({"date": date(2024, 2, 20)})
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_amount_field_lax_match(self):
|
||||
"""
|
||||
Test that Decimal fields use strict matching even when match_type is 'lax'.
|
||||
|
||||
Decimal fields cannot use __iexact, so they should fall back to strict matching.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["amount"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate when amount matches
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"amount": Decimal("100.00")}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when amount differs
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"amount": Decimal("200.00")}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_string_field_lax_match(self):
|
||||
"""
|
||||
Test that string fields use case-insensitive matching with match_type 'lax'.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["description"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate with case-insensitive match
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "EXISTING TRANSACTION"}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should find duplicate with exact case match
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "Existing Transaction"}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when description differs
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "Different Transaction"}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_string_field_strict_match(self):
|
||||
"""
|
||||
Test that string fields use case-sensitive matching with match_type 'strict'.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["description"], match_type="strict"
|
||||
)
|
||||
|
||||
# Should NOT find duplicate with different case (strict matching)
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "EXISTING TRANSACTION"}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
# Should find duplicate with exact case match
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "Existing Transaction"}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
def test_deduplication_with_multiple_fields_mixed_types(self):
|
||||
"""
|
||||
Test deduplication with multiple fields of different types.
|
||||
|
||||
Verifies that string fields use __iexact while non-string fields
|
||||
use strict matching, all in the same deduplication rule.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date", "amount", "description"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate when all fields match (with case-insensitive description)
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
"amount": Decimal("100.00"),
|
||||
"description": "existing transaction", # lowercase should match
|
||||
}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should NOT find duplicate when date differs
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 2, 20),
|
||||
"amount": Decimal("100.00"),
|
||||
"description": "existing transaction",
|
||||
}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
# Should NOT find duplicate when amount differs
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
"amount": Decimal("999.99"),
|
||||
"description": "existing transaction",
|
||||
}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_internal_id_lax_match(self):
|
||||
"""Test deduplication with internal_id field using lax matching."""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["internal_id"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate with case-insensitive match
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"internal_id": "abc123"} # lowercase should match ABC123
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should find duplicate with exact match
|
||||
is_duplicate = service._check_duplicate_transaction({"internal_id": "ABC123"})
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when internal_id differs
|
||||
is_duplicate = service._check_duplicate_transaction({"internal_id": "XYZ789"})
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_no_duplicate_when_no_transactions_exist(self):
|
||||
"""Test that no duplicate is found when there are no matching transactions."""
|
||||
# Hard delete to bypass signals that require user context
|
||||
self.existing_transaction.hard_delete()
|
||||
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date", "amount"], match_type="lax"
|
||||
)
|
||||
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
"amount": Decimal("100.00"),
|
||||
}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_missing_field_in_data(self):
|
||||
"""Test that missing fields in transaction_data are handled gracefully."""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date", "nonexistent_field"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should still work, only checking the fields that exist
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
def test_deduplication_with_entities_list_value(self):
|
||||
"""Test that list values for m2m entities deduplicate correctly."""
|
||||
entity = TransactionEntity.objects.create(name="DB Vertrieb GmbH")
|
||||
self.existing_transaction.entities.add(entity)
|
||||
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date", "amount", "entities"], match_type="strict"
|
||||
)
|
||||
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
"amount": Decimal("100.00"),
|
||||
"entities": ["DB Vertrieb GmbH"],
|
||||
}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
def test_deduplication_with_entities_list_value_not_matching(self):
|
||||
"""Test that non-matching entity list values are not marked duplicate."""
|
||||
entity = TransactionEntity.objects.create(name="DB Vertrieb GmbH")
|
||||
self.existing_transaction.entities.add(entity)
|
||||
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date", "amount", "entities"], match_type="strict"
|
||||
)
|
||||
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
"amount": Decimal("100.00"),
|
||||
"entities": ["Different Entity"],
|
||||
}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
@@ -0,0 +1,259 @@
|
||||
from decimal import Decimal
|
||||
import os
|
||||
import shutil
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth import get_user_model
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.common.middleware.thread_local import write_current_user, delete_current_user
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
from apps.import_app.services.v1 import ImportService
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
)
|
||||
|
||||
|
||||
class QIFImportTests(TestCase):
|
||||
def setUp(self):
|
||||
# Patch TEMP_DIR for testing
|
||||
self.original_temp_dir = ImportService.TEMP_DIR
|
||||
self.test_dir = os.path.abspath("temp_test_import")
|
||||
ImportService.TEMP_DIR = self.test_dir
|
||||
os.makedirs(self.test_dir, exist_ok=True)
|
||||
|
||||
# Create user and set context
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="test@example.com", password="password"
|
||||
)
|
||||
write_current_user(self.user)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="BRL", name="Real", decimal_places=2, prefix="R$ "
|
||||
)
|
||||
self.group = AccountGroup.objects.create(name="Test Group", owner=self.user)
|
||||
self.account = Account.objects.create(
|
||||
name="bradesco-checking",
|
||||
group=self.group,
|
||||
currency=self.currency,
|
||||
owner=self.user,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
delete_current_user()
|
||||
ImportService.TEMP_DIR = self.original_temp_dir
|
||||
if os.path.exists(self.test_dir):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def test_import_single_qif_valid_mapping(self):
|
||||
content = """!Type:Bank
|
||||
D04/01/2015
|
||||
T8069.46
|
||||
PMy Payee -> Entity
|
||||
MNote -> Desc
|
||||
LOld Cat:New Tag
|
||||
^
|
||||
D05/01/2015
|
||||
T-100.00
|
||||
PSupermarket
|
||||
MWeekly shopping
|
||||
L[Transfer]
|
||||
^
|
||||
"""
|
||||
filename = "bradesco-checking.qif"
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
yaml_config = """
|
||||
settings:
|
||||
file_type: qif
|
||||
importing: transactions
|
||||
date_format: "%d/%m/%Y"
|
||||
mapping: {}
|
||||
"""
|
||||
profile = ImportProfile.objects.create(
|
||||
name="QIF Profile",
|
||||
yaml_config=yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
)
|
||||
run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
service = ImportService(run)
|
||||
|
||||
service.process_file(file_path)
|
||||
|
||||
self.assertEqual(Transaction.objects.count(), 2)
|
||||
|
||||
# Transaction 1: Income, Category+Tag
|
||||
t1 = Transaction.objects.get(description="Note -> Desc")
|
||||
self.assertEqual(t1.amount, Decimal("8069.46"))
|
||||
self.assertEqual(t1.type, Transaction.Type.INCOME)
|
||||
self.assertEqual(t1.category.name, "Old Cat")
|
||||
self.assertTrue(t1.tags.filter(name="New Tag").exists())
|
||||
self.assertTrue(t1.entities.filter(name="My Payee -> Entity").exists())
|
||||
self.assertEqual(t1.account, self.account)
|
||||
|
||||
# Transaction 2: Expense, Transfer ([Transfer] -> Description)
|
||||
t2 = Transaction.objects.get(description="Transfer")
|
||||
self.assertEqual(t2.amount, Decimal("100.00"))
|
||||
self.assertEqual(t2.type, Transaction.Type.EXPENSE)
|
||||
self.assertIsNone(t2.category)
|
||||
self.assertFalse(t2.tags.exists())
|
||||
self.assertTrue(t2.entities.filter(name="Supermarket").exists())
|
||||
self.assertEqual(t2.description, "Transfer")
|
||||
|
||||
def test_import_deduplication_hash(self):
|
||||
# Same content twice. Should result in only 1 transaction due to hash deduplication.
|
||||
content = """!Type:Bank
|
||||
D04/01/2015
|
||||
T100.00
|
||||
POK
|
||||
^
|
||||
"""
|
||||
filename = "bradesco-checking.qif"
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
yaml_config = """
|
||||
settings:
|
||||
file_type: qif
|
||||
importing: transactions
|
||||
date_format: "%d/%m/%Y"
|
||||
mapping: {}
|
||||
"""
|
||||
profile = ImportProfile.objects.create(
|
||||
name="QIF Profile",
|
||||
yaml_config=yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
)
|
||||
run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
service = ImportService(run)
|
||||
|
||||
# First run
|
||||
service.process_file(file_path)
|
||||
self.assertEqual(Transaction.objects.count(), 1)
|
||||
|
||||
# Service deletes file after processing, so recreate it for second run
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
# Second run - Duplicate content
|
||||
service.process_file(file_path)
|
||||
self.assertEqual(Transaction.objects.count(), 1)
|
||||
|
||||
def test_import_strict_error_rollback(self):
|
||||
# atomic check.
|
||||
# Transaction 1 valid, Transaction 2 invalid date.
|
||||
content = """!Type:Bank
|
||||
D04/01/2015
|
||||
T100.00
|
||||
POK
|
||||
^
|
||||
DINVALID
|
||||
T100.00
|
||||
PBad
|
||||
^
|
||||
"""
|
||||
filename = "bradesco-checking.qif"
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
yaml_config = """
|
||||
settings:
|
||||
file_type: qif
|
||||
importing: transactions
|
||||
date_format: "%d/%m/%Y"
|
||||
skip_errors: false
|
||||
mapping: {}
|
||||
"""
|
||||
profile = ImportProfile.objects.create(
|
||||
name="QIF Profile",
|
||||
yaml_config=yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
)
|
||||
run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
service = ImportService(run)
|
||||
|
||||
with self.assertRaises(Exception) as cm:
|
||||
service.process_file(file_path)
|
||||
self.assertEqual(str(cm.exception), "Import failed")
|
||||
|
||||
# Should be 0 transactions because of atomic rollback
|
||||
self.assertEqual(Transaction.objects.count(), 0)
|
||||
|
||||
def test_import_missing_account(self):
|
||||
# File with account name that doesn't exist
|
||||
content = """!Type:Bank
|
||||
D04/01/2015
|
||||
T100.00
|
||||
POK
|
||||
^
|
||||
"""
|
||||
filename = "missing-account.qif"
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
yaml_config = """
|
||||
settings:
|
||||
file_type: qif
|
||||
importing: transactions
|
||||
date_format: "%d/%m/%Y"
|
||||
mapping: {}
|
||||
"""
|
||||
profile = ImportProfile.objects.create(
|
||||
name="QIF Profile",
|
||||
yaml_config=yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
)
|
||||
run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
service = ImportService(run)
|
||||
|
||||
# Should fail because account doesn't exist
|
||||
with self.assertRaises(Exception) as cm:
|
||||
service.process_file(file_path)
|
||||
self.assertEqual(str(cm.exception), "Import failed")
|
||||
|
||||
def test_import_skip_errors(self):
|
||||
# skip_errors: true.
|
||||
# Transaction 1 valid, Transaction 2 invalid date.
|
||||
content = """!Type:Bank
|
||||
D04/01/2015
|
||||
T100.00
|
||||
POK
|
||||
^
|
||||
DINVALID
|
||||
T100.00
|
||||
PBad
|
||||
^
|
||||
"""
|
||||
filename = "bradesco-checking.qif"
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
yaml_config = """
|
||||
settings:
|
||||
file_type: qif
|
||||
importing: transactions
|
||||
date_format: "%d/%m/%Y"
|
||||
skip_errors: true
|
||||
mapping: {}
|
||||
"""
|
||||
profile = ImportProfile.objects.create(
|
||||
name="QIF Profile",
|
||||
yaml_config=yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
)
|
||||
run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
service = ImportService(run)
|
||||
|
||||
service.process_file(file_path)
|
||||
|
||||
# Should be 1 transaction (valid one)
|
||||
self.assertEqual(Transaction.objects.count(), 1)
|
||||
self.assertEqual(
|
||||
Transaction.objects.first().description, ""
|
||||
) # empty desc if no memo
|
||||
+12
-13
@@ -1,15 +1,14 @@
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Row, Column
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.datepicker import (
|
||||
AirDatePickerInput,
|
||||
AirMonthYearPickerInput,
|
||||
AirYearPickerInput,
|
||||
AirDatePickerInput,
|
||||
)
|
||||
from apps.transactions.models import TransactionCategory
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.transactions.models import TransactionCategory
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Column, Field, Layout, Row
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class SingleMonthForm(forms.Form):
|
||||
@@ -59,8 +58,8 @@ class MonthRangeForm(forms.Form):
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("month_from", css_class="form-group col-md-6"),
|
||||
Column("month_to", css_class="form-group col-md-6"),
|
||||
Column("month_from"),
|
||||
Column("month_to"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -82,8 +81,8 @@ class YearRangeForm(forms.Form):
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("year_from", css_class="form-group col-md-6"),
|
||||
Column("year_to", css_class="form-group col-md-6"),
|
||||
Column("year_from"),
|
||||
Column("year_to"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -105,8 +104,8 @@ class DateRangeForm(forms.Form):
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("date_from", css_class="form-group col-md-6"),
|
||||
Column("date_to", css_class="form-group col-md-6"),
|
||||
Column("date_from"),
|
||||
Column("date_to"),
|
||||
css_class="mb-0",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -49,4 +49,14 @@ urlpatterns = [
|
||||
views.emergency_fund,
|
||||
name="insights_emergency_fund",
|
||||
),
|
||||
path(
|
||||
"insights/year-by-year/",
|
||||
views.year_by_year,
|
||||
name="insights_year_by_year",
|
||||
),
|
||||
path(
|
||||
"insights/month-by-month/",
|
||||
views.month_by_month,
|
||||
name="insights_month_by_month",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import Sum, Case, When, Value
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.currencies.models import Currency
|
||||
from apps.currencies.utils.convert import convert
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
def get_month_by_month_data(year=None, group_by="categories"):
|
||||
"""
|
||||
Aggregate transaction totals by month for a specific year, grouped by categories, tags, or entities.
|
||||
|
||||
Args:
|
||||
year: The year to filter transactions (defaults to current year)
|
||||
group_by: One of "categories", "tags", or "entities"
|
||||
|
||||
Returns:
|
||||
{
|
||||
"year": 2025,
|
||||
"available_years": [2025, 2024, ...],
|
||||
"months": [1, 2, 3, ..., 12],
|
||||
"items": {
|
||||
item_id: {
|
||||
"name": "Item Name",
|
||||
"month_totals": {
|
||||
1: {"currencies": {...}},
|
||||
...
|
||||
},
|
||||
"total": {"currencies": {...}}
|
||||
},
|
||||
...
|
||||
},
|
||||
"month_totals": {...},
|
||||
"grand_total": {"currencies": {...}}
|
||||
}
|
||||
"""
|
||||
if year is None:
|
||||
year = timezone.localdate(timezone.now()).year
|
||||
|
||||
# Base queryset - all paid transactions, non-muted
|
||||
transactions = Transaction.objects.filter(
|
||||
is_paid=True,
|
||||
account__is_archived=False,
|
||||
).exclude(account__currency__is_archived=True)
|
||||
|
||||
# Get available years for the selector
|
||||
available_years = list(
|
||||
transactions.values_list("reference_date__year", flat=True)
|
||||
.distinct()
|
||||
.order_by("-reference_date__year")
|
||||
)
|
||||
|
||||
# Filter by the selected year
|
||||
transactions = transactions.filter(reference_date__year=year)
|
||||
|
||||
# Define grouping fields based on group_by parameter
|
||||
if group_by == "tags":
|
||||
group_field = "tags"
|
||||
name_field = "tags__name"
|
||||
elif group_by == "entities":
|
||||
group_field = "entities"
|
||||
name_field = "entities__name"
|
||||
else: # Default to categories
|
||||
group_field = "category"
|
||||
name_field = "category__name"
|
||||
|
||||
# Months 1-12
|
||||
months = list(range(1, 13))
|
||||
|
||||
if not available_years:
|
||||
return {
|
||||
"year": year,
|
||||
"available_years": [],
|
||||
"months": months,
|
||||
"items": {},
|
||||
"month_totals": {},
|
||||
"grand_total": {"currencies": {}},
|
||||
}
|
||||
|
||||
# Aggregate by group, month, and currency
|
||||
metrics = (
|
||||
transactions.values(
|
||||
group_field,
|
||||
name_field,
|
||||
"reference_date__month",
|
||||
"account__currency",
|
||||
"account__currency__code",
|
||||
"account__currency__name",
|
||||
"account__currency__decimal_places",
|
||||
"account__currency__prefix",
|
||||
"account__currency__suffix",
|
||||
"account__currency__exchange_currency",
|
||||
)
|
||||
.annotate(
|
||||
expense_total=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.EXPENSE, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
income_total=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.INCOME, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
)
|
||||
.order_by(name_field, "reference_date__month")
|
||||
)
|
||||
|
||||
# Build result structure
|
||||
result = {
|
||||
"year": year,
|
||||
"available_years": available_years,
|
||||
"months": months,
|
||||
"items": OrderedDict(),
|
||||
"month_totals": {},
|
||||
"grand_total": {"currencies": {}},
|
||||
}
|
||||
|
||||
# Store currency info for later use in totals
|
||||
currency_info = {}
|
||||
|
||||
for metric in metrics:
|
||||
item_id = metric[group_field]
|
||||
item_name = metric[name_field]
|
||||
month = metric["reference_date__month"]
|
||||
currency_id = metric["account__currency"]
|
||||
|
||||
# Use a consistent key for None (uncategorized/untagged/no entity)
|
||||
item_key = item_id if item_id is not None else "__none__"
|
||||
|
||||
if item_key not in result["items"]:
|
||||
result["items"][item_key] = {
|
||||
"name": item_name,
|
||||
"month_totals": {},
|
||||
"total": {"currencies": {}},
|
||||
}
|
||||
|
||||
if month not in result["items"][item_key]["month_totals"]:
|
||||
result["items"][item_key]["month_totals"][month] = {"currencies": {}}
|
||||
|
||||
# Calculate final total (income - expense)
|
||||
final_total = metric["income_total"] - metric["expense_total"]
|
||||
|
||||
# Store currency info for totals calculation
|
||||
if currency_id not in currency_info:
|
||||
currency_info[currency_id] = {
|
||||
"code": metric["account__currency__code"],
|
||||
"name": metric["account__currency__name"],
|
||||
"decimal_places": metric["account__currency__decimal_places"],
|
||||
"prefix": metric["account__currency__prefix"],
|
||||
"suffix": metric["account__currency__suffix"],
|
||||
"exchange_currency_id": metric["account__currency__exchange_currency"],
|
||||
}
|
||||
|
||||
currency_data = {
|
||||
"currency": {
|
||||
"code": metric["account__currency__code"],
|
||||
"name": metric["account__currency__name"],
|
||||
"decimal_places": metric["account__currency__decimal_places"],
|
||||
"prefix": metric["account__currency__prefix"],
|
||||
"suffix": metric["account__currency__suffix"],
|
||||
},
|
||||
"final_total": final_total,
|
||||
"income_total": metric["income_total"],
|
||||
"expense_total": metric["expense_total"],
|
||||
}
|
||||
|
||||
# Handle currency conversion if exchange currency is set
|
||||
if metric["account__currency__exchange_currency"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=metric["account__currency__exchange_currency"]
|
||||
)
|
||||
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=final_total,
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
|
||||
if converted_amount is not None:
|
||||
currency_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
result["items"][item_key]["month_totals"][month]["currencies"][currency_id] = (
|
||||
currency_data
|
||||
)
|
||||
|
||||
# Accumulate item total (across all months for this item)
|
||||
if currency_id not in result["items"][item_key]["total"]["currencies"]:
|
||||
result["items"][item_key]["total"]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["items"][item_key]["total"]["currencies"][currency_id][
|
||||
"final_total"
|
||||
] += final_total
|
||||
|
||||
# Accumulate month total (across all items for this month)
|
||||
if month not in result["month_totals"]:
|
||||
result["month_totals"][month] = {"currencies": {}}
|
||||
if currency_id not in result["month_totals"][month]["currencies"]:
|
||||
result["month_totals"][month]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["month_totals"][month]["currencies"][currency_id]["final_total"] += (
|
||||
final_total
|
||||
)
|
||||
|
||||
# Accumulate grand total
|
||||
if currency_id not in result["grand_total"]["currencies"]:
|
||||
result["grand_total"]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["grand_total"]["currencies"][currency_id]["final_total"] += final_total
|
||||
|
||||
# Add currency conversion for item totals
|
||||
for item_key, item_data in result["items"].items():
|
||||
for currency_id, total_data in item_data["total"]["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Add currency conversion for month totals
|
||||
for month, month_data in result["month_totals"].items():
|
||||
for currency_id, total_data in month_data["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Add currency conversion for grand total
|
||||
for currency_id, total_data in result["grand_total"]["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,303 @@
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import Sum, Case, When, Value
|
||||
from django.db.models.functions import Coalesce
|
||||
|
||||
from apps.currencies.models import Currency
|
||||
from apps.currencies.utils.convert import convert
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
def get_year_by_year_data(group_by="categories"):
|
||||
"""
|
||||
Aggregate transaction totals by year for categories, tags, or entities.
|
||||
|
||||
Args:
|
||||
group_by: One of "categories", "tags", or "entities"
|
||||
|
||||
Returns:
|
||||
{
|
||||
"years": [2025, 2024, ...], # Sorted descending
|
||||
"items": {
|
||||
item_id: {
|
||||
"name": "Item Name",
|
||||
"year_totals": {
|
||||
2025: {"currencies": {...}},
|
||||
...
|
||||
},
|
||||
"total": {"currencies": {...}} # Sum across all years
|
||||
},
|
||||
...
|
||||
},
|
||||
"year_totals": { # Sum across all items for each year
|
||||
2025: {"currencies": {...}},
|
||||
...
|
||||
},
|
||||
"grand_total": {"currencies": {...}} # Sum of everything
|
||||
}
|
||||
"""
|
||||
# Base queryset - all paid transactions, non-muted
|
||||
transactions = Transaction.objects.filter(
|
||||
is_paid=True,
|
||||
account__is_archived=False,
|
||||
).exclude(account__currency__is_archived=True)
|
||||
|
||||
# Define grouping fields based on group_by parameter
|
||||
if group_by == "tags":
|
||||
group_field = "tags"
|
||||
name_field = "tags__name"
|
||||
elif group_by == "entities":
|
||||
group_field = "entities"
|
||||
name_field = "entities__name"
|
||||
else: # Default to categories
|
||||
group_field = "category"
|
||||
name_field = "category__name"
|
||||
|
||||
# Get all unique years with transactions
|
||||
years = (
|
||||
transactions.values_list("reference_date__year", flat=True)
|
||||
.distinct()
|
||||
.order_by("-reference_date__year")
|
||||
)
|
||||
years = list(years)
|
||||
|
||||
if not years:
|
||||
return {
|
||||
"years": [],
|
||||
"items": {},
|
||||
"year_totals": {},
|
||||
"grand_total": {"currencies": {}},
|
||||
}
|
||||
|
||||
# Aggregate by group, year, and currency
|
||||
metrics = (
|
||||
transactions.values(
|
||||
group_field,
|
||||
name_field,
|
||||
"reference_date__year",
|
||||
"account__currency",
|
||||
"account__currency__code",
|
||||
"account__currency__name",
|
||||
"account__currency__decimal_places",
|
||||
"account__currency__prefix",
|
||||
"account__currency__suffix",
|
||||
"account__currency__exchange_currency",
|
||||
)
|
||||
.annotate(
|
||||
expense_total=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.EXPENSE, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
income_total=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.INCOME, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
)
|
||||
.order_by(name_field, "-reference_date__year")
|
||||
)
|
||||
|
||||
# Build result structure
|
||||
result = {
|
||||
"years": years,
|
||||
"items": OrderedDict(),
|
||||
"year_totals": {}, # Totals per year across all items
|
||||
"grand_total": {"currencies": {}}, # Grand total across everything
|
||||
}
|
||||
|
||||
# Store currency info for later use in totals
|
||||
currency_info = {}
|
||||
|
||||
for metric in metrics:
|
||||
item_id = metric[group_field]
|
||||
item_name = metric[name_field]
|
||||
year = metric["reference_date__year"]
|
||||
currency_id = metric["account__currency"]
|
||||
|
||||
# Use a consistent key for None (uncategorized/untagged/no entity)
|
||||
item_key = item_id if item_id is not None else "__none__"
|
||||
|
||||
if item_key not in result["items"]:
|
||||
result["items"][item_key] = {
|
||||
"name": item_name,
|
||||
"year_totals": {},
|
||||
"total": {"currencies": {}}, # Total for this item across all years
|
||||
}
|
||||
|
||||
if year not in result["items"][item_key]["year_totals"]:
|
||||
result["items"][item_key]["year_totals"][year] = {"currencies": {}}
|
||||
|
||||
# Calculate final total (income - expense)
|
||||
final_total = metric["income_total"] - metric["expense_total"]
|
||||
|
||||
# Store currency info for totals calculation
|
||||
if currency_id not in currency_info:
|
||||
currency_info[currency_id] = {
|
||||
"code": metric["account__currency__code"],
|
||||
"name": metric["account__currency__name"],
|
||||
"decimal_places": metric["account__currency__decimal_places"],
|
||||
"prefix": metric["account__currency__prefix"],
|
||||
"suffix": metric["account__currency__suffix"],
|
||||
"exchange_currency_id": metric["account__currency__exchange_currency"],
|
||||
}
|
||||
|
||||
currency_data = {
|
||||
"currency": {
|
||||
"code": metric["account__currency__code"],
|
||||
"name": metric["account__currency__name"],
|
||||
"decimal_places": metric["account__currency__decimal_places"],
|
||||
"prefix": metric["account__currency__prefix"],
|
||||
"suffix": metric["account__currency__suffix"],
|
||||
},
|
||||
"final_total": final_total,
|
||||
"income_total": metric["income_total"],
|
||||
"expense_total": metric["expense_total"],
|
||||
}
|
||||
|
||||
# Handle currency conversion if exchange currency is set
|
||||
if metric["account__currency__exchange_currency"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=metric["account__currency__exchange_currency"]
|
||||
)
|
||||
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=final_total,
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
|
||||
if converted_amount is not None:
|
||||
currency_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
result["items"][item_key]["year_totals"][year]["currencies"][currency_id] = (
|
||||
currency_data
|
||||
)
|
||||
|
||||
# Accumulate item total (across all years for this item)
|
||||
if currency_id not in result["items"][item_key]["total"]["currencies"]:
|
||||
result["items"][item_key]["total"]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["items"][item_key]["total"]["currencies"][currency_id][
|
||||
"final_total"
|
||||
] += final_total
|
||||
|
||||
# Accumulate year total (across all items for this year)
|
||||
if year not in result["year_totals"]:
|
||||
result["year_totals"][year] = {"currencies": {}}
|
||||
if currency_id not in result["year_totals"][year]["currencies"]:
|
||||
result["year_totals"][year]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["year_totals"][year]["currencies"][currency_id]["final_total"] += (
|
||||
final_total
|
||||
)
|
||||
|
||||
# Accumulate grand total
|
||||
if currency_id not in result["grand_total"]["currencies"]:
|
||||
result["grand_total"]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["grand_total"]["currencies"][currency_id]["final_total"] += final_total
|
||||
|
||||
# Add currency conversion for item totals
|
||||
for item_key, item_data in result["items"].items():
|
||||
for currency_id, total_data in item_data["total"]["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Add currency conversion for year totals
|
||||
for year, year_data in result["year_totals"].items():
|
||||
for currency_id, total_data in year_data["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Add currency conversion for grand total
|
||||
for currency_id, total_data in result["grand_total"]["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -26,6 +26,8 @@ from apps.insights.utils.sankey import (
|
||||
generate_sankey_data_by_currency,
|
||||
)
|
||||
from apps.insights.utils.transactions import get_transactions
|
||||
from apps.insights.utils.year_by_year import get_year_by_year_data
|
||||
from apps.insights.utils.month_by_month import get_month_by_month_data
|
||||
from apps.transactions.models import TransactionCategory, Transaction
|
||||
from apps.transactions.utils.calculations import calculate_currency_totals
|
||||
|
||||
@@ -74,7 +76,9 @@ def index(request):
|
||||
def sankey_by_account(request):
|
||||
# Get filtered transactions
|
||||
|
||||
transactions = get_transactions(request, include_untracked_accounts=True)
|
||||
transactions = get_transactions(
|
||||
request, include_untracked_accounts=True, include_silent=True
|
||||
)
|
||||
|
||||
# Generate Sankey data
|
||||
sankey_data = generate_sankey_data_by_account(transactions)
|
||||
@@ -91,7 +95,9 @@ def sankey_by_account(request):
|
||||
@require_http_methods(["GET"])
|
||||
def sankey_by_currency(request):
|
||||
# Get filtered transactions
|
||||
transactions = get_transactions(request)
|
||||
transactions = get_transactions(
|
||||
request, include_silent=True, include_untracked_accounts=True
|
||||
)
|
||||
|
||||
# Generate Sankey data
|
||||
sankey_data = generate_sankey_data_by_currency(transactions)
|
||||
@@ -302,3 +308,71 @@ def emergency_fund(request):
|
||||
"insights/fragments/emergency_fund.html",
|
||||
{"data": currency_net_worth},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def year_by_year(request):
|
||||
if "group_by" in request.GET:
|
||||
group_by = request.GET["group_by"]
|
||||
request.session["insights_year_by_year_group_by"] = group_by
|
||||
else:
|
||||
group_by = request.session.get("insights_year_by_year_group_by", "categories")
|
||||
|
||||
# Validate group_by value
|
||||
if group_by not in ("categories", "tags", "entities"):
|
||||
group_by = "categories"
|
||||
|
||||
data = get_year_by_year_data(group_by=group_by)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"insights/fragments/year_by_year.html",
|
||||
{
|
||||
"data": data,
|
||||
"group_by": group_by,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def month_by_month(request):
|
||||
# Handle year selection
|
||||
if "year" in request.GET:
|
||||
try:
|
||||
year = int(request.GET["year"])
|
||||
request.session["insights_month_by_month_year"] = year
|
||||
except (ValueError, TypeError):
|
||||
year = request.session.get(
|
||||
"insights_month_by_month_year", timezone.localdate(timezone.now()).year
|
||||
)
|
||||
else:
|
||||
year = request.session.get(
|
||||
"insights_month_by_month_year", timezone.localdate(timezone.now()).year
|
||||
)
|
||||
|
||||
# Handle group_by selection
|
||||
if "group_by" in request.GET:
|
||||
group_by = request.GET["group_by"]
|
||||
request.session["insights_month_by_month_group_by"] = group_by
|
||||
else:
|
||||
group_by = request.session.get("insights_month_by_month_group_by", "categories")
|
||||
|
||||
# Validate group_by value
|
||||
if group_by not in ("categories", "tags", "entities"):
|
||||
group_by = "categories"
|
||||
|
||||
data = get_month_by_month_data(year=year, group_by=group_by)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"insights/fragments/month_by_month.html",
|
||||
{
|
||||
"data": data,
|
||||
"group_by": group_by,
|
||||
"selected_year": year,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,331 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class MonthlySummaryFilterBehaviorTests(TestCase):
|
||||
"""Tests for monthly summary views filter behavior.
|
||||
|
||||
These tests verify that:
|
||||
1. Views work correctly without any filters
|
||||
2. Views work correctly with filters applied
|
||||
3. The filter detection logic properly uses different querysets
|
||||
4. Calculated values reflect the applied filters
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client.login(username="testuser@test.com", password="testpass123")
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Test Account",
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
is_asset=False,
|
||||
)
|
||||
self.category = TransactionCategory.objects.create(
|
||||
name="Test Category", owner=self.user
|
||||
)
|
||||
self.tag = TransactionTag.objects.create(name="TestTag", owner=self.user)
|
||||
|
||||
# Create test transactions for December 2025
|
||||
# Income: 1000 (paid)
|
||||
self.income_transaction = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.INCOME,
|
||||
is_paid=True,
|
||||
date=date(2025, 12, 10),
|
||||
reference_date=date(2025, 12, 1),
|
||||
amount=Decimal("1000.00"),
|
||||
description="December Income",
|
||||
owner=self.user,
|
||||
)
|
||||
|
||||
# Expense: 200 (paid)
|
||||
self.expense_transaction = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
is_paid=True,
|
||||
date=date(2025, 12, 15),
|
||||
reference_date=date(2025, 12, 1),
|
||||
amount=Decimal("200.00"),
|
||||
description="December Expense",
|
||||
category=self.category,
|
||||
owner=self.user,
|
||||
)
|
||||
self.expense_transaction.tags.add(self.tag)
|
||||
|
||||
# Expense: 150 (projected/unpaid)
|
||||
self.projected_expense = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
is_paid=False,
|
||||
date=date(2025, 12, 20),
|
||||
reference_date=date(2025, 12, 1),
|
||||
amount=Decimal("150.00"),
|
||||
description="Projected Expense",
|
||||
owner=self.user,
|
||||
)
|
||||
|
||||
def _get_currency_data(self, context_dict):
|
||||
"""Helper to extract data for our test currency from context dict.
|
||||
|
||||
The context dict is keyed by currency ID, so we need to find
|
||||
the entry for our currency.
|
||||
"""
|
||||
if not context_dict:
|
||||
return None
|
||||
for currency_id, data in context_dict.items():
|
||||
if data.get("currency", {}).get("code") == "USD":
|
||||
return data
|
||||
return None
|
||||
|
||||
# --- monthly_summary view tests ---
|
||||
|
||||
def test_monthly_summary_no_filter_returns_200(self):
|
||||
"""Test that monthly_summary returns 200 without filters"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_monthly_summary_no_filter_includes_all_transactions(self):
|
||||
"""Without filters, summary should include all transactions"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should have the income: 1000
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["income_current"], Decimal("1000.00"))
|
||||
|
||||
# expense_current should have paid expense: 200
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_current"], Decimal("200.00"))
|
||||
|
||||
# expense_projected should have unpaid expense: 150
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_projected"], Decimal("150.00"))
|
||||
|
||||
def test_monthly_summary_type_filter_only_income(self):
|
||||
"""With type=IN filter, summary should only include income"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?type=IN",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should still have 1000
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["income_current"], Decimal("1000.00"))
|
||||
|
||||
# expense_current should be empty/zero (filtered out)
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_current", 0), Decimal("0"))
|
||||
|
||||
# expense_projected should be empty/zero (filtered out)
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_projected", 0), Decimal("0"))
|
||||
|
||||
def test_monthly_summary_type_filter_only_expenses(self):
|
||||
"""With type=EX filter, summary should only include expenses"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?type=EX",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should be empty/zero (filtered out)
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("income_current", 0), Decimal("0"))
|
||||
|
||||
# expense_current should have 200
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_current"], Decimal("200.00"))
|
||||
|
||||
# expense_projected should have 150
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_projected"], Decimal("150.00"))
|
||||
|
||||
def test_monthly_summary_is_paid_filter_only_paid(self):
|
||||
"""With is_paid=1 filter, summary should only include paid transactions"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?is_paid=1",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should have 1000 (paid)
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["income_current"], Decimal("1000.00"))
|
||||
|
||||
# expense_current should have 200 (paid)
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_current"], Decimal("200.00"))
|
||||
|
||||
# expense_projected should be empty/zero (filtered out - unpaid)
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_projected", 0), Decimal("0"))
|
||||
|
||||
def test_monthly_summary_is_paid_filter_only_unpaid(self):
|
||||
"""With is_paid=0 filter, summary should only include unpaid transactions"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?is_paid=0",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should be empty/zero (filtered out - paid)
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("income_current", 0), Decimal("0"))
|
||||
|
||||
# expense_current should be empty/zero (filtered out - paid)
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_current", 0), Decimal("0"))
|
||||
|
||||
# expense_projected should have 150 (unpaid)
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_projected"], Decimal("150.00"))
|
||||
|
||||
def test_monthly_summary_description_filter(self):
|
||||
"""With description filter, summary should only include matching transactions"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?description=Income",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# Only income matches "Income" description
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["income_current"], Decimal("1000.00"))
|
||||
|
||||
# Expenses should be filtered out
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_current", 0), Decimal("0"))
|
||||
|
||||
def test_monthly_summary_amount_filter(self):
|
||||
"""With amount filter, summary should only include transactions in range"""
|
||||
# Filter to only get transactions between 100 and 250 (should get 200 and 150)
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?from_amount=100&to_amount=250",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# Income (1000) should be filtered out
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("income_current", 0), Decimal("0"))
|
||||
|
||||
# expense_current should have 200
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_current"], Decimal("200.00"))
|
||||
|
||||
# expense_projected should have 150
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_projected"], Decimal("150.00"))
|
||||
|
||||
# --- monthly_account_summary view tests ---
|
||||
|
||||
def test_monthly_account_summary_no_filter_returns_200(self):
|
||||
"""Test that monthly_account_summary returns 200 without filters"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/accounts/",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_monthly_account_summary_with_filter_returns_200(self):
|
||||
"""Test that monthly_account_summary returns 200 with filter"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/accounts/?type=IN",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# --- monthly_currency_summary view tests ---
|
||||
|
||||
def test_monthly_currency_summary_no_filter_returns_200(self):
|
||||
"""Test that monthly_currency_summary returns 200 without filters"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/currencies/",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_monthly_currency_summary_with_filter_returns_200(self):
|
||||
"""Test that monthly_currency_summary returns 200 with filter"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/currencies/?type=EX",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
@@ -2,7 +2,8 @@ from django.contrib.auth.decorators import login_required
|
||||
from django.db.models import (
|
||||
Q,
|
||||
)
|
||||
from django.http import HttpResponse
|
||||
from django.http import HttpResponse, Http404
|
||||
|
||||
from django.shortcuts import render, redirect
|
||||
from django.utils import timezone
|
||||
from django.views.decorators.http import require_http_methods
|
||||
@@ -36,8 +37,6 @@ def monthly_overview(request, month: int, year: int):
|
||||
summary_tab = request.session.get("monthly_summary_tab", "summary")
|
||||
|
||||
if month < 1 or month > 12:
|
||||
from django.http import Http404
|
||||
|
||||
raise Http404("Month is out of range")
|
||||
|
||||
next_month = 1 if month == 12 else month + 1
|
||||
@@ -76,6 +75,8 @@ def transactions_list(request, month: int, year: int):
|
||||
if order != request.session.get("monthly_transactions_order", "default"):
|
||||
request.session["monthly_transactions_order"] = order
|
||||
|
||||
today = timezone.localdate(timezone.now())
|
||||
|
||||
f = TransactionsFilter(request.GET)
|
||||
transactions_filtered = f.qs.filter(
|
||||
reference_date__year=year,
|
||||
@@ -93,12 +94,28 @@ def transactions_list(request, month: int, year: int):
|
||||
"dca_income_entries",
|
||||
)
|
||||
|
||||
# Late transactions: date < today and is_paid = False (only shown for default ordering)
|
||||
late_transactions = None
|
||||
if order == "default":
|
||||
late_transactions = transactions_filtered.filter(
|
||||
date__lt=today,
|
||||
is_paid=False,
|
||||
).order_by("date", "id")
|
||||
# Exclude late transactions from the main list
|
||||
transactions_filtered = transactions_filtered.exclude(
|
||||
date__lt=today,
|
||||
is_paid=False,
|
||||
)
|
||||
|
||||
transactions_filtered = default_order(transactions_filtered, order=order)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"monthly_overview/fragments/list.html",
|
||||
context={"transactions": transactions_filtered},
|
||||
context={
|
||||
"transactions": transactions_filtered,
|
||||
"late_transactions": late_transactions,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -107,17 +124,48 @@ def transactions_list(request, month: int, year: int):
|
||||
@require_http_methods(["GET"])
|
||||
def monthly_summary(request, month: int, year: int):
|
||||
# Base queryset with all required filters
|
||||
base_queryset = (
|
||||
Transaction.objects.filter(
|
||||
base_queryset = Transaction.objects.filter(
|
||||
reference_date__year=year,
|
||||
reference_date__month=month,
|
||||
account__is_asset=False,
|
||||
)
|
||||
.exclude(Q(Q(category__mute=True) & ~Q(category=None)) | Q(mute=True))
|
||||
.exclude(account__in=request.user.untracked_accounts.all())
|
||||
)
|
||||
|
||||
data = calculate_currency_totals(base_queryset, ignore_empty=True)
|
||||
# Apply filters and check if any are active
|
||||
f = TransactionsFilter(request.GET, queryset=base_queryset)
|
||||
|
||||
# Check if any filter has a non-default value
|
||||
# Default values are: type=['IN', 'EX'], is_paid=['1', '0'], everything else empty
|
||||
has_active_filter = False
|
||||
if f.form.is_valid():
|
||||
for name, value in f.form.cleaned_data.items():
|
||||
# Skip fields with default/empty values
|
||||
if not value:
|
||||
continue
|
||||
# Skip type if it has both default values
|
||||
if name == "type" and set(value) == {"IN", "EX"}:
|
||||
continue
|
||||
# Skip is_paid if it has both default values (values are strings)
|
||||
if name == "is_paid" and set(value) == {"1", "0"}:
|
||||
continue
|
||||
# Skip mute_status if it has both default values
|
||||
if name == "mute_status" and set(value) == {"active", "muted"}:
|
||||
continue
|
||||
# If we get here, there's an active filter
|
||||
has_active_filter = True
|
||||
break
|
||||
|
||||
if has_active_filter:
|
||||
queryset = f.qs
|
||||
else:
|
||||
queryset = (
|
||||
base_queryset.exclude(
|
||||
Q(Q(category__mute=True) & ~Q(category=None)) | Q(mute=True)
|
||||
)
|
||||
.exclude(account__in=request.user.untracked_accounts.all())
|
||||
.exclude(account__is_asset=True)
|
||||
)
|
||||
|
||||
data = calculate_currency_totals(queryset, ignore_empty=True)
|
||||
|
||||
percentages = calculate_percentage_distribution(data)
|
||||
|
||||
context = {
|
||||
@@ -132,6 +180,7 @@ def monthly_summary(request, month: int, year: int):
|
||||
currency_totals=data, month=month, year=year
|
||||
),
|
||||
"percentages": percentages,
|
||||
"has_active_filter": has_active_filter,
|
||||
}
|
||||
|
||||
return render(
|
||||
@@ -149,9 +198,38 @@ def monthly_account_summary(request, month: int, year: int):
|
||||
base_queryset = Transaction.objects.filter(
|
||||
reference_date__year=year,
|
||||
reference_date__month=month,
|
||||
).exclude(Q(Q(category__mute=True) & ~Q(category=None)) | Q(mute=True))
|
||||
)
|
||||
|
||||
account_data = calculate_account_totals(transactions_queryset=base_queryset.all())
|
||||
# Apply filters and check if any are active
|
||||
f = TransactionsFilter(request.GET, queryset=base_queryset)
|
||||
|
||||
# Check if any filter has a non-default value
|
||||
has_active_filter = False
|
||||
if f.form.is_valid():
|
||||
for name, value in f.form.cleaned_data.items():
|
||||
if not value:
|
||||
continue
|
||||
if name == "type" and set(value) == {"IN", "EX"}:
|
||||
continue
|
||||
if name == "is_paid" and set(value) == {"1", "0"}:
|
||||
continue
|
||||
if name == "mute_status" and set(value) == {"active", "muted"}:
|
||||
continue
|
||||
has_active_filter = True
|
||||
break
|
||||
|
||||
if has_active_filter:
|
||||
queryset = f.qs
|
||||
else:
|
||||
queryset = (
|
||||
base_queryset.exclude(
|
||||
Q(Q(category__mute=True) & ~Q(category=None)) | Q(mute=True)
|
||||
)
|
||||
.exclude(account__in=request.user.untracked_accounts.all())
|
||||
.exclude(account__is_asset=True)
|
||||
)
|
||||
|
||||
account_data = calculate_account_totals(transactions_queryset=queryset.all())
|
||||
account_percentages = calculate_percentage_distribution(account_data)
|
||||
|
||||
context = {
|
||||
@@ -171,16 +249,41 @@ def monthly_account_summary(request, month: int, year: int):
|
||||
@require_http_methods(["GET"])
|
||||
def monthly_currency_summary(request, month: int, year: int):
|
||||
# Base queryset with all required filters
|
||||
base_queryset = (
|
||||
Transaction.objects.filter(
|
||||
base_queryset = Transaction.objects.filter(
|
||||
reference_date__year=year,
|
||||
reference_date__month=month,
|
||||
)
|
||||
.exclude(Q(Q(category__mute=True) & ~Q(category=None)) | Q(mute=True))
|
||||
|
||||
# Apply filters and check if any are active
|
||||
f = TransactionsFilter(request.GET, queryset=base_queryset)
|
||||
|
||||
# Check if any filter has a non-default value
|
||||
has_active_filter = False
|
||||
if f.form.is_valid():
|
||||
for name, value in f.form.cleaned_data.items():
|
||||
if not value:
|
||||
continue
|
||||
if name == "type" and set(value) == {"IN", "EX"}:
|
||||
continue
|
||||
if name == "is_paid" and set(value) == {"1", "0"}:
|
||||
continue
|
||||
if name == "mute_status" and set(value) == {"active", "muted"}:
|
||||
continue
|
||||
has_active_filter = True
|
||||
break
|
||||
|
||||
if has_active_filter:
|
||||
queryset = f.qs
|
||||
else:
|
||||
queryset = (
|
||||
base_queryset.exclude(
|
||||
Q(Q(category__mute=True) & ~Q(category=None)) | Q(mute=True)
|
||||
)
|
||||
.exclude(account__in=request.user.untracked_accounts.all())
|
||||
.exclude(account__is_asset=True)
|
||||
)
|
||||
|
||||
currency_data = calculate_currency_totals(base_queryset.all(), ignore_empty=True)
|
||||
currency_data = calculate_currency_totals(queryset.all(), ignore_empty=True)
|
||||
currency_percentages = calculate_percentage_distribution(currency_data)
|
||||
|
||||
context = {
|
||||
|
||||
@@ -182,3 +182,29 @@ def calculate_historical_account_balance(queryset):
|
||||
historical_account_balance[date_filter(end_date, "b Y")] = month_data
|
||||
|
||||
return historical_account_balance
|
||||
|
||||
|
||||
def calculate_monthly_net_worth_difference(historical_net_worth):
|
||||
diff_dict = OrderedDict()
|
||||
if not historical_net_worth:
|
||||
return diff_dict
|
||||
|
||||
# Get all currencies
|
||||
currencies = set()
|
||||
for data in historical_net_worth.values():
|
||||
currencies.update(data.keys())
|
||||
|
||||
# Initialize prev_values for all currencies
|
||||
prev_values = {currency: Decimal("0.00") for currency in currencies}
|
||||
|
||||
for month, values in historical_net_worth.items():
|
||||
diff_values = {}
|
||||
for currency in sorted(list(currencies)):
|
||||
current_val = values.get(currency, Decimal("0.00"))
|
||||
prev_val = prev_values.get(currency, Decimal("0.00"))
|
||||
diff_values[currency] = current_val - prev_val
|
||||
|
||||
diff_dict[month] = diff_values
|
||||
prev_values = values.copy()
|
||||
|
||||
return diff_dict
|
||||
|
||||
@@ -8,6 +8,7 @@ from django.views.decorators.http import require_http_methods
|
||||
from apps.net_worth.utils.calculate_net_worth import (
|
||||
calculate_historical_currency_net_worth,
|
||||
calculate_historical_account_balance,
|
||||
calculate_monthly_net_worth_difference,
|
||||
)
|
||||
from apps.transactions.models import Transaction
|
||||
from apps.transactions.utils.calculations import (
|
||||
@@ -96,6 +97,38 @@ def net_worth(request):
|
||||
|
||||
chart_data_currency_json = json.dumps(chart_data_currency, cls=DjangoJSONEncoder)
|
||||
|
||||
monthly_difference_data = calculate_monthly_net_worth_difference(
|
||||
historical_net_worth=historical_currency_net_worth
|
||||
)
|
||||
|
||||
diff_labels = (
|
||||
list(monthly_difference_data.keys()) if monthly_difference_data else []
|
||||
)
|
||||
diff_currencies = (
|
||||
list(monthly_difference_data[diff_labels[0]].keys())
|
||||
if monthly_difference_data and diff_labels
|
||||
else []
|
||||
)
|
||||
|
||||
diff_datasets = []
|
||||
for i, currency in enumerate(diff_currencies):
|
||||
data = [
|
||||
float(month_data.get(currency, 0))
|
||||
for month_data in monthly_difference_data.values()
|
||||
]
|
||||
diff_datasets.append(
|
||||
{
|
||||
"label": currency,
|
||||
"data": data,
|
||||
"borderWidth": 3,
|
||||
}
|
||||
)
|
||||
|
||||
chart_data_monthly_difference = {"labels": diff_labels, "datasets": diff_datasets}
|
||||
chart_data_monthly_difference_json = json.dumps(
|
||||
chart_data_monthly_difference, cls=DjangoJSONEncoder
|
||||
)
|
||||
|
||||
historical_account_balance = calculate_historical_account_balance(
|
||||
queryset=transactions_account_queryset
|
||||
)
|
||||
@@ -140,6 +173,7 @@ def net_worth(request):
|
||||
"chart_data_accounts_json": chart_data_accounts_json,
|
||||
"accounts": accounts,
|
||||
"type": view_type,
|
||||
"chart_data_monthly_difference_json": chart_data_monthly_difference_json,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
+56
-70
@@ -1,20 +1,22 @@
|
||||
from crispy_bootstrap5.bootstrap5 import Switch, BS5Accordion
|
||||
from crispy_forms.bootstrap import FormActions, AccordionGroup
|
||||
from crispy_forms.bootstrap import Alert
|
||||
from apps.common.fields.forms.dynamic_select import DynamicModelChoiceField
|
||||
from apps.common.widgets.crispy.daisyui import Switch
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.tom_select import TomSelect, TransactionSelect
|
||||
from apps.rules.models import (
|
||||
TransactionRule,
|
||||
TransactionRuleAction,
|
||||
UpdateOrCreateTransactionRuleAction,
|
||||
)
|
||||
from apps.transactions.forms import BulkEditTransactionForm
|
||||
from apps.transactions.models import Transaction
|
||||
from crispy_forms.bootstrap import AccordionGroup, FormActions, Accordion
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Row, Column, HTML
|
||||
from crispy_forms.layout import HTML, Column, Field, Layout, Row
|
||||
from django import forms
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.tom_select import TomSelect, TransactionSelect
|
||||
from apps.rules.models import TransactionRule, UpdateOrCreateTransactionRuleAction
|
||||
from apps.rules.models import TransactionRuleAction
|
||||
from apps.common.fields.forms.dynamic_select import DynamicModelChoiceField
|
||||
from apps.transactions.forms import BulkEditTransactionForm
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
class TransactionRuleForm(forms.ModelForm):
|
||||
class Meta:
|
||||
@@ -35,7 +37,6 @@ class TransactionRuleForm(forms.ModelForm):
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.form_method = "post"
|
||||
# TO-DO: Add helper with available commands
|
||||
self.helper.layout = Layout(
|
||||
Switch("active"),
|
||||
"name",
|
||||
@@ -48,22 +49,21 @@ class TransactionRuleForm(forms.ModelForm):
|
||||
Switch("sequenced"),
|
||||
"description",
|
||||
"trigger",
|
||||
Alert(
|
||||
_("You can add actions to this rule in the next screen."), dismiss=False
|
||||
),
|
||||
)
|
||||
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -97,17 +97,13 @@ class TransactionRuleActionForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -214,148 +210,148 @@ class UpdateOrCreateTransactionRuleActionForm(forms.ModelForm):
|
||||
|
||||
self.helper.layout = Layout(
|
||||
"order",
|
||||
BS5Accordion(
|
||||
Accordion(
|
||||
AccordionGroup(
|
||||
_("Search Criteria"),
|
||||
Field("filter", rows=1),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_type_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_type", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_is_paid_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_is_paid", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_mute_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_mute", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_account_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_account", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_entities_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_entities", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_date_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_date", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_reference_date_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_reference_date", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_description_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_description", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_amount_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_amount", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_category_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_category", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_tags_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_tags", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_notes_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_notes", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_internal_note_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_internal_note", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Field("search_internal_id_operator"),
|
||||
css_class="form-group col-md-4",
|
||||
css_class="col-span-12 md:col-span-4",
|
||||
),
|
||||
Column(
|
||||
Field("search_internal_id", rows=1),
|
||||
css_class="form-group col-md-8",
|
||||
css_class="col-span-12 md:col-span-8",
|
||||
),
|
||||
),
|
||||
active=True,
|
||||
@@ -386,17 +382,13 @@ class UpdateOrCreateTransactionRuleActionForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -427,9 +419,7 @@ class DryRunCreatedTransacion(forms.Form):
|
||||
self.helper.layout = Layout(
|
||||
"transaction",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Test"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Test"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -464,9 +454,7 @@ class DryRunDeletedTransacion(forms.Form):
|
||||
self.helper.layout = Layout(
|
||||
"transaction",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Test"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Test"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -496,13 +484,11 @@ class DryRunUpdatedTransactionForm(BulkEditTransactionForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.helper.layout.insert(0, "transaction")
|
||||
self.helper.layout.insert(1, HTML("<hr/>"))
|
||||
self.helper.layout.insert(1, HTML('<hr class="hr my-3" />'))
|
||||
|
||||
# Change submit button
|
||||
self.helper.layout[-1] = FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Test"), css_class="btn btn-outline-primary w-100"
|
||||
)
|
||||
NoClassSubmit("submit", _("Test"), css_class="btn btn-primary")
|
||||
)
|
||||
|
||||
if self.data.get("transaction"):
|
||||
|
||||
@@ -365,7 +365,9 @@ def check_for_transaction_rules(
|
||||
|
||||
if processed_action.set_category:
|
||||
value = simple.eval(processed_action.set_category)
|
||||
if isinstance(value, int):
|
||||
if value is None:
|
||||
transaction.category = None
|
||||
elif isinstance(value, int):
|
||||
transaction.category = TransactionCategory.objects.get(id=value)
|
||||
else:
|
||||
transaction.category = TransactionCategory.objects.get(name=value)
|
||||
@@ -458,7 +460,9 @@ def check_for_transaction_rules(
|
||||
transaction.account = account
|
||||
|
||||
elif field == TransactionRuleAction.Field.category:
|
||||
if isinstance(new_value, int):
|
||||
if new_value is None:
|
||||
transaction.category = None
|
||||
elif isinstance(new_value, int):
|
||||
category = TransactionCategory.objects.get(id=new_value)
|
||||
transaction.category = category
|
||||
elif isinstance(new_value, str):
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.currencies.models import Currency
|
||||
from apps.rules.models import TransactionRule, UpdateOrCreateTransactionRuleAction
|
||||
from apps.rules.tasks import check_for_transaction_rules
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
def run_check_for_transaction_rules_without_worker_wrapper(**kwargs):
|
||||
task_func = check_for_transaction_rules.func
|
||||
task_func = getattr(task_func, "__wrapped__", task_func)
|
||||
|
||||
return task_func(**kwargs)
|
||||
|
||||
|
||||
class CheckForTransactionRulesTests(TransactionTestCase):
|
||||
def setUp(self):
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="rules@example.com",
|
||||
password="testpass123",
|
||||
)
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD",
|
||||
name="US Dollar",
|
||||
decimal_places=2,
|
||||
)
|
||||
self.account = Account.objects.create(
|
||||
name="Main Account",
|
||||
currency=self.currency,
|
||||
owner=self.user,
|
||||
)
|
||||
|
||||
@patch("apps.rules.signals.check_for_transaction_rules.defer")
|
||||
def test_update_or_create_action_can_clear_category_from_none_expression(
|
||||
self, mock_defer
|
||||
):
|
||||
source_transaction = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("10.00"),
|
||||
date=date(2026, 5, 4),
|
||||
reference_date=date(2026, 5, 1),
|
||||
description="Source without category",
|
||||
category=None,
|
||||
owner=self.user,
|
||||
)
|
||||
rule = TransactionRule.objects.create(
|
||||
active=True,
|
||||
on_create=False,
|
||||
on_update=True,
|
||||
name="Copy transaction",
|
||||
trigger="True",
|
||||
owner=self.user,
|
||||
)
|
||||
UpdateOrCreateTransactionRuleAction.objects.create(
|
||||
rule=rule,
|
||||
set_account="account_id",
|
||||
set_type="'EX'",
|
||||
set_date="date",
|
||||
set_reference_date="reference_date",
|
||||
set_amount="amount",
|
||||
set_description="'Generated transaction'",
|
||||
set_category="category_name",
|
||||
)
|
||||
|
||||
run_check_for_transaction_rules_without_worker_wrapper(
|
||||
instance_id=source_transaction.id,
|
||||
user_id=self.user.id,
|
||||
signal="transaction_updated",
|
||||
)
|
||||
|
||||
generated_transaction = Transaction.objects.get(
|
||||
description="Generated transaction"
|
||||
)
|
||||
self.assertIsNone(generated_transaction.category)
|
||||
@@ -564,7 +564,7 @@ def dry_run_rule_updated(request, pk):
|
||||
|
||||
response = render(
|
||||
request,
|
||||
"rules/fragments/transaction_rule/dry_run/created.html",
|
||||
"rules/fragments/transaction_rule/dry_run/updated.html",
|
||||
{"form": form, "rule": rule, "logs": logs, "results": results},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
import django_filters
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Row, Column
|
||||
from django import forms
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_filters import Filter
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.common.fields.month_year import MonthYearFormField
|
||||
from apps.common.widgets.datepicker import AirDatePickerInput
|
||||
@@ -15,15 +8,26 @@ from apps.currencies.models import Currency
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
TransactionEntity,
|
||||
TransactionTag,
|
||||
)
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Column, Field, Layout, Row
|
||||
from django import forms
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_filters import Filter
|
||||
|
||||
SITUACAO_CHOICES = (
|
||||
("1", _("Paid")),
|
||||
("0", _("Projected")),
|
||||
)
|
||||
|
||||
MUTE_STATUS_CHOICES = (
|
||||
("active", _("Active")),
|
||||
("muted", _("Muted")),
|
||||
)
|
||||
|
||||
|
||||
def content_filter(queryset, name, value):
|
||||
queryset = queryset.filter(
|
||||
@@ -79,6 +83,11 @@ class TransactionsFilter(django_filters.FilterSet):
|
||||
choices=SITUACAO_CHOICES,
|
||||
field_name="is_paid",
|
||||
)
|
||||
mute_status = django_filters.MultipleChoiceFilter(
|
||||
choices=MUTE_STATUS_CHOICES,
|
||||
method="filter_mute_status",
|
||||
label=_("Mute Status"),
|
||||
)
|
||||
date_start = django_filters.DateFilter(
|
||||
field_name="date",
|
||||
lookup_expr="gte",
|
||||
@@ -141,6 +150,9 @@ class TransactionsFilter(django_filters.FilterSet):
|
||||
if data.get("is_paid") is None:
|
||||
data.setlist("is_paid", ["1", "0"])
|
||||
|
||||
if data.get("mute_status") is None:
|
||||
data.setlist("mute_status", ["active", "muted"])
|
||||
|
||||
super().__init__(data, *args, **kwargs)
|
||||
|
||||
self.form.helper = FormHelper()
|
||||
@@ -156,17 +168,19 @@ class TransactionsFilter(django_filters.FilterSet):
|
||||
"is_paid",
|
||||
template="transactions/widgets/transaction_type_filter_buttons.html",
|
||||
),
|
||||
Field(
|
||||
"mute_status",
|
||||
template="transactions/widgets/transaction_type_filter_buttons.html",
|
||||
),
|
||||
Field("description"),
|
||||
Row(Column("date_start"), Column("date_end")),
|
||||
Row(
|
||||
Column("reference_date_start", css_class="form-group col-md-6 mb-0"),
|
||||
Column("reference_date_end", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("reference_date_start"),
|
||||
Column("reference_date_end"),
|
||||
),
|
||||
Row(
|
||||
Column("from_amount", css_class="form-group col-md-6 mb-0"),
|
||||
Column("to_amount", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("from_amount"),
|
||||
Column("to_amount"),
|
||||
),
|
||||
Field("account", size=1),
|
||||
Field("currency", size=1),
|
||||
@@ -271,3 +285,36 @@ class TransactionsFilter(django_filters.FilterSet):
|
||||
return queryset.filter(q).distinct()
|
||||
|
||||
return queryset
|
||||
|
||||
@staticmethod
|
||||
def filter_mute_status(queryset, name, value):
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
|
||||
if not value:
|
||||
return queryset
|
||||
|
||||
value = list(value)
|
||||
|
||||
# If both are selected, return all
|
||||
if "active" in value and "muted" in value:
|
||||
return queryset
|
||||
|
||||
user = get_current_user()
|
||||
|
||||
# Only Active selected: exclude muted transactions
|
||||
if "active" in value:
|
||||
return (
|
||||
queryset.exclude(account__untracked_by=user)
|
||||
.filter(
|
||||
mute=False,
|
||||
)
|
||||
.filter(Q(category__mute=False) | Q(category__isnull=True))
|
||||
)
|
||||
|
||||
# Only Muted selected: include only muted transactions
|
||||
if "muted" in value:
|
||||
return queryset.filter(
|
||||
Q(account__untracked_by=user) | Q(category__mute=True) | Q(mute=True)
|
||||
)
|
||||
|
||||
return queryset
|
||||
|
||||
+99
-168
@@ -1,39 +1,39 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from crispy_bootstrap5.bootstrap5 import Switch, BS5Accordion
|
||||
from crispy_forms.bootstrap import FormActions, AccordionGroup, AppendedText
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import (
|
||||
Layout,
|
||||
Row,
|
||||
Column,
|
||||
Field,
|
||||
Div,
|
||||
HTML,
|
||||
)
|
||||
from django import forms
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.common.fields.forms.dynamic_select import (
|
||||
DynamicModelChoiceField,
|
||||
DynamicModelMultipleChoiceField,
|
||||
)
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
from apps.common.widgets.crispy.daisyui import Switch
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.datepicker import AirDatePickerInput, AirMonthYearPickerInput
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.rules.signals import transaction_created, transaction_updated
|
||||
from apps.transactions.models import (
|
||||
InstallmentPlan,
|
||||
QuickTransaction,
|
||||
RecurringTransaction,
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
InstallmentPlan,
|
||||
RecurringTransaction,
|
||||
TransactionEntity,
|
||||
QuickTransaction,
|
||||
TransactionTag,
|
||||
)
|
||||
from crispy_forms.bootstrap import AccordionGroup, AppendedText, FormActions, Accordion
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import (
|
||||
HTML,
|
||||
Column,
|
||||
Div,
|
||||
Field,
|
||||
Layout,
|
||||
Row,
|
||||
)
|
||||
from django import forms
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class TransactionForm(forms.ModelForm):
|
||||
@@ -117,6 +117,9 @@ class TransactionForm(forms.ModelForm):
|
||||
self.fields["account"].queryset = Account.objects.filter(
|
||||
is_archived=False,
|
||||
)
|
||||
user_settings = get_current_user().settings
|
||||
if user_settings.default_account:
|
||||
self.fields["account"].initial = user_settings.default_account
|
||||
|
||||
self.fields["category"].queryset = TransactionCategory.objects.filter(
|
||||
active=True
|
||||
@@ -134,21 +137,18 @@ class TransactionForm(forms.ModelForm):
|
||||
),
|
||||
Field("is_paid", template="transactions/widgets/paid_toggle_button.html"),
|
||||
Row(
|
||||
Column("account", css_class="form-group col-md-6 mb-0"),
|
||||
Column("entities", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("account"),
|
||||
Column("entities"),
|
||||
),
|
||||
Row(
|
||||
Column(Field("date"), css_class="form-group col-md-6 mb-0"),
|
||||
Column(Field("reference_date"), css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column(Field("date")),
|
||||
Column(Field("reference_date")),
|
||||
),
|
||||
"description",
|
||||
Field("amount", inputmode="decimal"),
|
||||
Row(
|
||||
Column("category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("category"),
|
||||
Column("tags"),
|
||||
),
|
||||
"notes",
|
||||
)
|
||||
@@ -164,20 +164,18 @@ class TransactionForm(forms.ModelForm):
|
||||
Field("is_paid", template="transactions/widgets/paid_toggle_button.html"),
|
||||
"account",
|
||||
Row(
|
||||
Column(Field("date"), css_class="form-group col-md-6 mb-0"),
|
||||
Column(Field("reference_date"), css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column(Field("date")),
|
||||
Column(Field("reference_date")),
|
||||
),
|
||||
"description",
|
||||
Field("amount", inputmode="decimal"),
|
||||
BS5Accordion(
|
||||
Accordion(
|
||||
AccordionGroup(
|
||||
_("More"),
|
||||
"entities",
|
||||
Row(
|
||||
Column("category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("category"),
|
||||
Column("tags"),
|
||||
),
|
||||
"notes",
|
||||
active=False,
|
||||
@@ -187,9 +185,7 @@ class TransactionForm(forms.ModelForm):
|
||||
css_class="mb-3",
|
||||
),
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -202,29 +198,25 @@ class TransactionForm(forms.ModelForm):
|
||||
)
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.fields["amount"].widget = ArbitraryDecimalDisplayNumberInput()
|
||||
self.helper.layout.append(
|
||||
Div(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
NoClassSubmit(
|
||||
"submit_and_similar",
|
||||
_("Save and add similar"),
|
||||
css_class="btn btn-outline-primary",
|
||||
css_class="btn btn-primary btn-soft",
|
||||
),
|
||||
NoClassSubmit(
|
||||
"submit_and_another",
|
||||
_("Save and add another"),
|
||||
css_class="btn btn-outline-primary",
|
||||
css_class="btn btn-primary btn-soft",
|
||||
),
|
||||
css_class="d-grid gap-2",
|
||||
css_class="flex flex-col gap-2 mt-3",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -348,18 +340,16 @@ class QuickTransactionForm(forms.ModelForm):
|
||||
),
|
||||
Field("is_paid", template="transactions/widgets/paid_toggle_button.html"),
|
||||
"name",
|
||||
HTML("<hr />"),
|
||||
HTML('<hr class="hr my-3" />'),
|
||||
Row(
|
||||
Column("account", css_class="form-group col-md-6 mb-0"),
|
||||
Column("entities", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("account"),
|
||||
Column("entities"),
|
||||
),
|
||||
"description",
|
||||
Field("amount", inputmode="decimal"),
|
||||
Row(
|
||||
Column("category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("category"),
|
||||
Column("tags"),
|
||||
),
|
||||
"notes",
|
||||
Switch("mute"),
|
||||
@@ -372,19 +362,14 @@ class QuickTransactionForm(forms.ModelForm):
|
||||
)
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.fields["amount"].widget = ArbitraryDecimalDisplayNumberInput()
|
||||
self.helper.layout.append(
|
||||
Div(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary"
|
||||
),
|
||||
css_class="d-grid gap-2",
|
||||
FormActions(
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -481,27 +466,22 @@ class BulkEditTransactionForm(forms.Form):
|
||||
template="transactions/widgets/unselectable_paid_toggle_button.html",
|
||||
),
|
||||
Row(
|
||||
Column("account", css_class="form-group col-md-6 mb-0"),
|
||||
Column("entities", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("account"),
|
||||
Column("entities"),
|
||||
),
|
||||
Row(
|
||||
Column(Field("date"), css_class="form-group col-md-6 mb-0"),
|
||||
Column(Field("reference_date"), css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column(Field("date")),
|
||||
Column(Field("reference_date")),
|
||||
),
|
||||
"description",
|
||||
Field("amount", inputmode="decimal"),
|
||||
Row(
|
||||
Column("category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("category"),
|
||||
Column("tags"),
|
||||
),
|
||||
"notes",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -600,62 +580,34 @@ class TransferForm(forms.Form):
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column(Field("date"), css_class="form-group col-md-6 mb-0"),
|
||||
Column(Field("date")),
|
||||
Column(
|
||||
Field("reference_date"),
|
||||
css_class="form-group col-md-6 mb-0",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Field("description"),
|
||||
Field("notes"),
|
||||
Switch("mute"),
|
||||
Row(
|
||||
Column(
|
||||
Row(
|
||||
Column(
|
||||
"from_account",
|
||||
css_class="form-group col-md-6 mb-0",
|
||||
Column("from_account"),
|
||||
Column(Field("from_amount")),
|
||||
Column("from_category"),
|
||||
Column("from_tags"),
|
||||
css_class="bg-base-100 rounded-box p-4 border-base-content/60 border my-3",
|
||||
),
|
||||
Column(
|
||||
Field("from_amount"),
|
||||
css_class="form-group col-md-6 mb-0",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Row(
|
||||
Column("from_category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("from_tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
),
|
||||
),
|
||||
css_class="p-1 mx-1 my-3 border rounded-3",
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
Row(
|
||||
Column(
|
||||
"to_account",
|
||||
css_class="form-group col-md-6 mb-0",
|
||||
),
|
||||
Column(
|
||||
Field("to_amount"),
|
||||
css_class="form-group col-md-6 mb-0",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Row(
|
||||
Column("to_category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("to_tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
),
|
||||
),
|
||||
css_class="p-1 mx-1 my-3 border rounded-3",
|
||||
Column("to_category"),
|
||||
Column("to_tags"),
|
||||
css_class="bg-base-100 rounded-box p-4 border-base-content/60 border",
|
||||
),
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Transfer"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Transfer"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -820,6 +772,9 @@ class InstallmentPlanForm(forms.ModelForm):
|
||||
).distinct()
|
||||
else:
|
||||
self.fields["account"].queryset = Account.objects.filter(is_archived=False)
|
||||
user_settings = get_current_user().settings
|
||||
if user_settings.default_account:
|
||||
self.fields["account"].initial = user_settings.default_account
|
||||
|
||||
self.fields["category"].queryset = TransactionCategory.objects.filter(
|
||||
active=True
|
||||
@@ -841,30 +796,26 @@ class InstallmentPlanForm(forms.ModelForm):
|
||||
template="transactions/widgets/income_expense_toggle_buttons.html",
|
||||
),
|
||||
Row(
|
||||
Column("account", css_class="form-group col-md-6 mb-0"),
|
||||
Column("entities", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("account"),
|
||||
Column("entities"),
|
||||
),
|
||||
"description",
|
||||
Switch("add_description_to_transaction"),
|
||||
"notes",
|
||||
Switch("add_notes_to_transaction"),
|
||||
Row(
|
||||
Column("number_of_installments", css_class="form-group col-md-6 mb-0"),
|
||||
Column("installment_start", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("number_of_installments"),
|
||||
Column("installment_start"),
|
||||
),
|
||||
Row(
|
||||
Column("start_date", css_class="form-group col-md-4 mb-0"),
|
||||
Column("reference_date", css_class="form-group col-md-4 mb-0"),
|
||||
Column("recurrence", css_class="form-group col-md-4 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("start_date", css_class="col-span-12 md:col-span-4"),
|
||||
Column("reference_date", css_class="col-span-12 md:col-span-4"),
|
||||
Column("recurrence", css_class="col-span-12 md:col-span-4"),
|
||||
),
|
||||
"installment_amount",
|
||||
Row(
|
||||
Column("category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("category"),
|
||||
Column("tags"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -874,17 +825,13 @@ class InstallmentPlanForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -917,17 +864,13 @@ class TransactionTagForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -949,17 +892,13 @@ class TransactionEntityForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -984,17 +923,13 @@ class TransactionCategoryForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1083,6 +1018,10 @@ class RecurringTransactionForm(forms.ModelForm):
|
||||
else:
|
||||
self.fields["account"].queryset = Account.objects.filter(is_archived=False)
|
||||
|
||||
user_settings = get_current_user().settings
|
||||
if user_settings.default_account:
|
||||
self.fields["account"].initial = user_settings.default_account
|
||||
|
||||
self.fields["category"].queryset = TransactionCategory.objects.filter(
|
||||
active=True
|
||||
)
|
||||
@@ -1103,30 +1042,26 @@ class RecurringTransactionForm(forms.ModelForm):
|
||||
template="transactions/widgets/income_expense_toggle_buttons.html",
|
||||
),
|
||||
Row(
|
||||
Column("account", css_class="form-group col-md-6 mb-0"),
|
||||
Column("entities", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("account"),
|
||||
Column("entities"),
|
||||
),
|
||||
"description",
|
||||
Switch("add_description_to_transaction"),
|
||||
"amount",
|
||||
Row(
|
||||
Column("category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("category"),
|
||||
Column("tags"),
|
||||
),
|
||||
"notes",
|
||||
Switch("add_notes_to_transaction"),
|
||||
Row(
|
||||
Column("start_date", css_class="form-group col-md-6 mb-0"),
|
||||
Column("reference_date", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("start_date"),
|
||||
Column("reference_date"),
|
||||
),
|
||||
Row(
|
||||
Column("recurrence_interval", css_class="form-group col-md-4 mb-0"),
|
||||
Column("recurrence_type", css_class="form-group col-md-4 mb-0"),
|
||||
Column("end_date", css_class="form-group col-md-4 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("recurrence_interval", css_class="col-span-12 md:col-span-4"),
|
||||
Column("recurrence_type", css_class="col-span-12 md:col-span-4"),
|
||||
Column("end_date", css_class="col-span-12 md:col-span-4"),
|
||||
),
|
||||
AppendedText("keep_at_most", _("future transactions")),
|
||||
)
|
||||
@@ -1138,17 +1073,13 @@ class RecurringTransactionForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,18 @@ import decimal
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
from apps.common.fields.month_year import MonthYearModelField
|
||||
from apps.common.functions.decimals import truncate_decimal
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
from apps.common.models import (
|
||||
OwnedObject,
|
||||
OwnedObjectManager,
|
||||
SharedObject,
|
||||
SharedObjectManager,
|
||||
)
|
||||
from apps.common.templatetags.decimal import drop_trailing_zeros, localize_number
|
||||
from apps.currencies.utils.convert import convert
|
||||
from apps.transactions.validators import validate_decimal_places, validate_non_negative
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from django.conf import settings
|
||||
from django.core.validators import MinValueValidator
|
||||
@@ -12,19 +24,6 @@ from django.template.defaultfilters import date
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.fields.month_year import MonthYearModelField
|
||||
from apps.common.functions.decimals import truncate_decimal
|
||||
from apps.common.templatetags.decimal import localize_number, drop_trailing_zeros
|
||||
from apps.currencies.utils.convert import convert
|
||||
from apps.transactions.validators import validate_decimal_places, validate_non_negative
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
from apps.common.models import (
|
||||
SharedObject,
|
||||
SharedObjectManager,
|
||||
OwnedObject,
|
||||
OwnedObjectManager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@@ -381,22 +380,41 @@ class Transaction(OwnedObject):
|
||||
db_table = "transactions"
|
||||
default_manager_name = "objects"
|
||||
|
||||
def clean_fields(self, *args, **kwargs):
|
||||
def clean(self):
|
||||
super().clean()
|
||||
|
||||
# Convert empty internal_id to None to allow multiple "empty" values with unique constraint
|
||||
if self.internal_id == "":
|
||||
self.internal_id = None
|
||||
|
||||
# Only process amount and reference_date if account exists
|
||||
# If account is missing, Django's required field validation will handle it
|
||||
try:
|
||||
account = self.account
|
||||
except Transaction.account.RelatedObjectDoesNotExist:
|
||||
# Account doesn't exist, skip processing that depends on it
|
||||
# Django will add the required field error
|
||||
return
|
||||
|
||||
# Validate and normalize amount
|
||||
if isinstance(self.amount, (str, int, float)):
|
||||
self.amount = decimal.Decimal(str(self.amount))
|
||||
|
||||
self.amount = truncate_decimal(
|
||||
value=self.amount, decimal_places=self.account.currency.decimal_places
|
||||
value=self.amount, decimal_places=account.currency.decimal_places
|
||||
)
|
||||
|
||||
# Normalize reference_date
|
||||
if self.reference_date:
|
||||
self.reference_date = self.reference_date.replace(day=1)
|
||||
elif not self.reference_date and self.date:
|
||||
self.reference_date = self.date.replace(day=1)
|
||||
|
||||
super().clean_fields(*args, **kwargs)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
# This is here so Django validation doesn't trigger an error before clean() is ran
|
||||
if not self.reference_date and self.date:
|
||||
self.reference_date = self.date.replace(day=1)
|
||||
|
||||
# This is not recommended as it will run twice on some cases like form and API saves.
|
||||
# We only do this here because we forgot to independently call it on multiple places.
|
||||
self.full_clean()
|
||||
@@ -856,9 +874,7 @@ class RecurringTransaction(models.Model):
|
||||
notes=self.notes if self.add_notes_to_transaction else "",
|
||||
owner=self.account.owner,
|
||||
)
|
||||
if self.tags.exists():
|
||||
created_transaction.tags.set(self.tags.all())
|
||||
if self.entities.exists():
|
||||
created_transaction.entities.set(self.entities.all())
|
||||
|
||||
def get_recurrence_delta(self):
|
||||
|
||||
@@ -13,7 +13,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.periodic(cron="0 0 * * *")
|
||||
@app.task(name="generate_recurring_transactions")
|
||||
@app.task(
|
||||
lock="generate_recurring_transactions", name="generate_recurring_transactions"
|
||||
)
|
||||
def generate_recurring_transactions(timestamp=None):
|
||||
try:
|
||||
RecurringTransaction.generate_upcoming_transactions()
|
||||
@@ -26,7 +28,7 @@ def generate_recurring_transactions(timestamp=None):
|
||||
|
||||
|
||||
@app.periodic(cron="10 1 * * *")
|
||||
@app.task(name="cleanup_deleted_transactions")
|
||||
@app.task(lock="cleanup_deleted_transactions", name="cleanup_deleted_transactions")
|
||||
def cleanup_deleted_transactions(timestamp=None):
|
||||
if settings.ENABLE_SOFT_DELETE and settings.KEEP_DELETED_TRANSACTIONS_FOR == 0:
|
||||
return "KEEP_DELETED_TRANSACTIONS_FOR is 0, no cleanup performed."
|
||||
|
||||
@@ -3,7 +3,6 @@ from decimal import Decimal
|
||||
from django import template
|
||||
from django.utils.formats import number_format
|
||||
|
||||
|
||||
register = template.Library()
|
||||
|
||||
|
||||
@@ -13,13 +12,27 @@ def _format_string(prefix, amount, decimal_places, suffix):
|
||||
value=abs(amount), decimal_pos=decimal_places, force_grouping=True
|
||||
)
|
||||
if amount < 0:
|
||||
return "-", prefix, formatted_amount, suffix
|
||||
return f"-{prefix}{formatted_amount}{suffix}"
|
||||
else:
|
||||
return "", prefix, formatted_amount, suffix
|
||||
return f"{prefix}{formatted_amount}{suffix}"
|
||||
else:
|
||||
return "ERR"
|
||||
return "", "", "ERR", ""
|
||||
|
||||
|
||||
@register.simple_tag(name="currency_display")
|
||||
def currency_display(amount, prefix, suffix, decimal_places):
|
||||
return _format_string(prefix, amount, decimal_places, suffix)
|
||||
def currency_display(amount, prefix, suffix, decimal_places, string=False):
|
||||
sign, prefix, amount, suffix = _format_string(
|
||||
prefix, amount, decimal_places, suffix
|
||||
)
|
||||
|
||||
if string:
|
||||
return f"{sign}{prefix}{amount}{suffix}"
|
||||
|
||||
return {
|
||||
"sign": sign,
|
||||
"prefix": prefix,
|
||||
"amount": amount,
|
||||
"suffix": suffix,
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
from datetime import date, timedelta
|
||||
|
||||
from django.test import TestCase
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.transactions.models import (
|
||||
@@ -127,6 +125,70 @@ class TransactionTests(TestCase):
|
||||
datetime.datetime(day=1, month=2, year=2000).date(),
|
||||
)
|
||||
|
||||
def test_empty_internal_id_converts_to_none(self):
|
||||
"""Test that empty string internal_id is converted to None"""
|
||||
transaction = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=timezone.now().date(),
|
||||
amount=Decimal("100.00"),
|
||||
description="Test transaction",
|
||||
internal_id="", # Empty string should become None
|
||||
)
|
||||
self.assertIsNone(transaction.internal_id)
|
||||
|
||||
def test_unique_internal_id_works(self):
|
||||
"""Test that unique non-empty internal_id values work correctly"""
|
||||
transaction1 = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=timezone.now().date(),
|
||||
amount=Decimal("100.00"),
|
||||
description="Test transaction 1",
|
||||
internal_id="unique-id-123",
|
||||
)
|
||||
transaction2 = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=timezone.now().date(),
|
||||
amount=Decimal("100.00"),
|
||||
description="Test transaction 2",
|
||||
internal_id="unique-id-456",
|
||||
)
|
||||
self.assertEqual(transaction1.internal_id, "unique-id-123")
|
||||
self.assertEqual(transaction2.internal_id, "unique-id-456")
|
||||
|
||||
def test_multiple_transactions_with_empty_internal_id(self):
|
||||
"""Test that multiple transactions can have empty/None internal_id"""
|
||||
transaction1 = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=timezone.now().date(),
|
||||
amount=Decimal("100.00"),
|
||||
description="Test transaction 1",
|
||||
internal_id="",
|
||||
)
|
||||
transaction2 = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=timezone.now().date(),
|
||||
amount=Decimal("100.00"),
|
||||
description="Test transaction 2",
|
||||
internal_id="",
|
||||
)
|
||||
transaction3 = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=timezone.now().date(),
|
||||
amount=Decimal("100.00"),
|
||||
description="Test transaction 3",
|
||||
internal_id=None,
|
||||
)
|
||||
# All should be saved successfully with None internal_id
|
||||
self.assertIsNone(transaction1.internal_id)
|
||||
self.assertIsNone(transaction2.internal_id)
|
||||
self.assertIsNone(transaction3.internal_id)
|
||||
|
||||
|
||||
class InstallmentPlanTests(TestCase):
|
||||
def setUp(self):
|
||||
@@ -0,0 +1,174 @@
|
||||
from datetime import date
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class TransactionSimpleAddViewTests(TestCase):
|
||||
"""Tests for the transaction_simple_add view with query parameters"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client.login(username="testuser@test.com", password="testpass123")
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
self.category = TransactionCategory.objects.create(name="Test Category")
|
||||
self.tag = TransactionTag.objects.create(name="TestTag")
|
||||
|
||||
def test_get_returns_form_with_default_values(self):
|
||||
"""Test GET request returns 200 and form with defaults"""
|
||||
response = self.client.get("/add/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn("form", response.context)
|
||||
|
||||
def test_get_with_type_param(self):
|
||||
"""Test type param sets form initial value"""
|
||||
response = self.client.get("/add/?type=EX")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("type"), Transaction.Type.EXPENSE)
|
||||
|
||||
def test_get_with_account_param(self):
|
||||
"""Test account param sets form initial value"""
|
||||
response = self.client.get(f"/add/?account={self.account.id}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("account"), self.account.id)
|
||||
|
||||
def test_get_with_is_paid_param_true(self):
|
||||
"""Test is_paid param with true value"""
|
||||
response = self.client.get("/add/?is_paid=true")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertTrue(form.initial.get("is_paid"))
|
||||
|
||||
def test_get_with_is_paid_param_false(self):
|
||||
"""Test is_paid param with false value"""
|
||||
response = self.client.get("/add/?is_paid=false")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertFalse(form.initial.get("is_paid"))
|
||||
|
||||
def test_get_with_amount_param(self):
|
||||
"""Test amount param sets form initial value"""
|
||||
response = self.client.get("/add/?amount=150.50")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("amount"), "150.50")
|
||||
|
||||
def test_get_with_description_param(self):
|
||||
"""Test description param sets form initial value"""
|
||||
response = self.client.get("/add/?description=Test%20Transaction")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("description"), "Test Transaction")
|
||||
|
||||
def test_get_with_notes_param(self):
|
||||
"""Test notes param sets form initial value"""
|
||||
response = self.client.get("/add/?notes=Some%20notes")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("notes"), "Some notes")
|
||||
|
||||
def test_get_with_category_param(self):
|
||||
"""Test category param sets form initial value"""
|
||||
response = self.client.get(f"/add/?category={self.category.id}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("category"), self.category.id)
|
||||
|
||||
def test_get_with_tags_param(self):
|
||||
"""Test tags param as comma-separated names"""
|
||||
response = self.client.get("/add/?tags=TestTag,AnotherTag")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("tags"), ["TestTag", "AnotherTag"])
|
||||
|
||||
def test_get_with_all_params(self):
|
||||
"""Test all params together work correctly"""
|
||||
url = (
|
||||
f"/add/?type=EX&account={self.account.id}&is_paid=true"
|
||||
f"&amount=200.00&description=Full%20Test¬es=Test%20notes"
|
||||
f"&category={self.category.id}&tags=TestTag"
|
||||
)
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("type"), Transaction.Type.EXPENSE)
|
||||
self.assertEqual(form.initial.get("account"), self.account.id)
|
||||
self.assertTrue(form.initial.get("is_paid"))
|
||||
self.assertEqual(form.initial.get("amount"), "200.00")
|
||||
self.assertEqual(form.initial.get("description"), "Full Test")
|
||||
self.assertEqual(form.initial.get("notes"), "Test notes")
|
||||
self.assertEqual(form.initial.get("category"), self.category.id)
|
||||
self.assertEqual(form.initial.get("tags"), ["TestTag"])
|
||||
|
||||
def test_post_creates_transaction(self):
|
||||
"""Test form submission creates transaction"""
|
||||
data = {
|
||||
"account": self.account.id,
|
||||
"type": "EX",
|
||||
"is_paid": True,
|
||||
"date": timezone.now().date().isoformat(),
|
||||
"amount": "100.00",
|
||||
"description": "Test Transaction",
|
||||
}
|
||||
response = self.client.post("/add/", data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(
|
||||
Transaction.objects.filter(description="Test Transaction").exists()
|
||||
)
|
||||
|
||||
def test_get_with_date_param(self):
|
||||
"""Test date param overrides expected date"""
|
||||
response = self.client.get("/add/?date=2025-06-15")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("date"), date(2025, 6, 15))
|
||||
|
||||
def test_get_with_reference_date_param(self):
|
||||
"""Test reference_date param sets form initial value"""
|
||||
response = self.client.get("/add/?reference_date=2025-07-01")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("reference_date"), date(2025, 7, 1))
|
||||
|
||||
def test_get_with_account_name_param(self):
|
||||
"""Test account param by name (case-insensitive)"""
|
||||
response = self.client.get("/add/?account=Test%20Account")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("account"), self.account.id)
|
||||
|
||||
def test_get_with_category_name_param(self):
|
||||
"""Test category param by name (case-insensitive)"""
|
||||
response = self.client.get("/add/?category=Test%20Category")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
form = response.context["form"]
|
||||
self.assertEqual(form.initial.get("category"), self.category.id)
|
||||
@@ -35,7 +35,7 @@ def categories_list(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def categories_table_active(request):
|
||||
categories = TransactionCategory.objects.filter(active=True).order_by("id")
|
||||
categories = TransactionCategory.objects.filter(active=True).order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"categories/fragments/table.html",
|
||||
@@ -47,7 +47,7 @@ def categories_table_active(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def categories_table_archived(request):
|
||||
categories = TransactionCategory.objects.filter(active=False).order_by("id")
|
||||
categories = TransactionCategory.objects.filter(active=False).order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"categories/fragments/table.html",
|
||||
|
||||
@@ -35,7 +35,7 @@ def entities_list(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def entities_table_active(request):
|
||||
entities = TransactionEntity.objects.filter(active=True).order_by("id")
|
||||
entities = TransactionEntity.objects.filter(active=True).order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"entities/fragments/table.html",
|
||||
@@ -47,7 +47,7 @@ def entities_table_active(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def entities_table_archived(request):
|
||||
entities = TransactionEntity.objects.filter(active=False).order_by("id")
|
||||
entities = TransactionEntity.objects.filter(active=False).order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"entities/fragments/table.html",
|
||||
|
||||
@@ -137,6 +137,7 @@ def quick_transaction_add_as_transaction(request, quick_transaction_id):
|
||||
"category",
|
||||
"tags",
|
||||
"entities",
|
||||
"internal_id",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -206,6 +207,7 @@ def quick_transaction_add_as_quick_transaction(request, transaction_id):
|
||||
"recurring_transaction",
|
||||
"deleted",
|
||||
"deleted_at",
|
||||
"internal_id",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ def tags_list(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def tags_table_active(request):
|
||||
tags = TransactionTag.objects.filter(active=True).order_by("id")
|
||||
tags = TransactionTag.objects.filter(active=True).order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"tags/fragments/table.html",
|
||||
@@ -47,7 +47,7 @@ def tags_table_active(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def tags_table_archived(request):
|
||||
tags = TransactionTag.objects.filter(active=False).order_by("id")
|
||||
tags = TransactionTag.objects.filter(active=False).order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"tags/fragments/table.html",
|
||||
|
||||
@@ -142,26 +142,105 @@ def transaction_simple_add(request):
|
||||
year=year,
|
||||
).date()
|
||||
|
||||
# Build initial data from query parameters
|
||||
initial_data = {
|
||||
"date": expected_date,
|
||||
"type": transaction_type,
|
||||
}
|
||||
|
||||
# Handle date param (ISO format: YYYY-MM-DD) - overrides expected_date
|
||||
date_param = request.GET.get("date")
|
||||
if date_param:
|
||||
try:
|
||||
initial_data["date"] = datetime.datetime.strptime(
|
||||
date_param, "%Y-%m-%d"
|
||||
).date()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Handle reference_date param (ISO format: YYYY-MM-DD)
|
||||
reference_date_param = request.GET.get("reference_date")
|
||||
if reference_date_param:
|
||||
try:
|
||||
initial_data["reference_date"] = datetime.datetime.strptime(
|
||||
reference_date_param, "%Y-%m-%d"
|
||||
).date()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Handle account param (by ID or name)
|
||||
account_param = request.GET.get("account")
|
||||
if account_param:
|
||||
try:
|
||||
initial_data["account"] = int(account_param)
|
||||
except (ValueError, TypeError):
|
||||
# Try to find by name
|
||||
from apps.accounts.models import Account
|
||||
|
||||
account = Account.objects.filter(
|
||||
name__iexact=account_param, is_archived=False
|
||||
).first()
|
||||
if account:
|
||||
initial_data["account"] = account.pk
|
||||
|
||||
# Handle is_paid param (boolean)
|
||||
is_paid = request.GET.get("is_paid")
|
||||
if is_paid is not None:
|
||||
initial_data["is_paid"] = is_paid.lower() in ("true", "1", "yes")
|
||||
|
||||
# Handle amount param (decimal)
|
||||
amount = request.GET.get("amount")
|
||||
if amount:
|
||||
try:
|
||||
initial_data["amount"] = amount
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Handle description param (string)
|
||||
description = request.GET.get("description")
|
||||
if description:
|
||||
initial_data["description"] = description
|
||||
|
||||
# Handle notes param (string)
|
||||
notes = request.GET.get("notes")
|
||||
if notes:
|
||||
initial_data["notes"] = notes
|
||||
|
||||
# Handle category param (by ID or name)
|
||||
category_param = request.GET.get("category")
|
||||
if category_param:
|
||||
try:
|
||||
initial_data["category"] = int(category_param)
|
||||
except (ValueError, TypeError):
|
||||
# Try to find by name
|
||||
from apps.transactions.models import TransactionCategory
|
||||
|
||||
category = TransactionCategory.objects.filter(
|
||||
name__iexact=category_param, active=True
|
||||
).first()
|
||||
if category:
|
||||
initial_data["category"] = category.pk
|
||||
|
||||
# Handle tags param (comma-separated names)
|
||||
tags = request.GET.get("tags")
|
||||
if tags:
|
||||
initial_data["tags"] = [t.strip() for t in tags.split(",") if t.strip()]
|
||||
|
||||
# Handle entities param (comma-separated names)
|
||||
entities = request.GET.get("entities")
|
||||
if entities:
|
||||
initial_data["entities"] = [e.strip() for e in entities.split(",") if e.strip()]
|
||||
|
||||
if request.method == "POST":
|
||||
form = TransactionForm(request.POST)
|
||||
if form.is_valid():
|
||||
form.save()
|
||||
messages.success(request, _("Transaction added successfully"))
|
||||
|
||||
form = TransactionForm(
|
||||
initial={
|
||||
"date": expected_date,
|
||||
"type": transaction_type,
|
||||
},
|
||||
)
|
||||
# Only reset form after successful save
|
||||
form = TransactionForm(initial=initial_data)
|
||||
|
||||
else:
|
||||
form = TransactionForm(
|
||||
initial={
|
||||
"date": expected_date,
|
||||
"type": transaction_type,
|
||||
},
|
||||
)
|
||||
form = TransactionForm(initial=initial_data)
|
||||
|
||||
return render(
|
||||
request,
|
||||
@@ -388,7 +467,7 @@ def transaction_pay(request, transaction_id):
|
||||
context={"transaction": transaction, **request.GET},
|
||||
)
|
||||
response.headers["HX-Trigger"] = (
|
||||
f'{"paid" if new_is_paid else "unpaid"}, selective_update'
|
||||
f"{'paid' if new_is_paid else 'unpaid'}, selective_update"
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -483,6 +562,8 @@ def transaction_all_list(request):
|
||||
if order != request.session.get("all_transactions_order", "default"):
|
||||
request.session["all_transactions_order"] = order
|
||||
|
||||
today = timezone.localdate(timezone.now())
|
||||
|
||||
transactions = Transaction.objects.prefetch_related(
|
||||
"account",
|
||||
"account__group",
|
||||
@@ -496,12 +577,27 @@ def transaction_all_list(request):
|
||||
"dca_income_entries",
|
||||
).all()
|
||||
|
||||
transactions = default_order(transactions, order=order)
|
||||
|
||||
f = TransactionsFilter(request.GET, queryset=transactions)
|
||||
|
||||
# Late transactions: date < today and is_paid = False (only shown for default ordering on first page)
|
||||
late_transactions = None
|
||||
page_number = request.GET.get("page", 1)
|
||||
paginator = Paginator(f.qs, 100)
|
||||
if order == "default" and str(page_number) == "1":
|
||||
late_transactions = f.qs.filter(
|
||||
date__lt=today,
|
||||
is_paid=False,
|
||||
).order_by("date", "id")
|
||||
# Exclude late transactions from the main paginated list
|
||||
main_transactions = f.qs.exclude(
|
||||
date__lt=today,
|
||||
is_paid=False,
|
||||
)
|
||||
else:
|
||||
main_transactions = f.qs
|
||||
|
||||
main_transactions = default_order(main_transactions, order=order)
|
||||
|
||||
paginator = Paginator(main_transactions, 100)
|
||||
page_obj = paginator.get_page(page_number)
|
||||
|
||||
return render(
|
||||
@@ -510,6 +606,7 @@ def transaction_all_list(request):
|
||||
{
|
||||
"page_obj": page_obj,
|
||||
"paginator": paginator,
|
||||
"late_transactions": late_transactions,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import logging
|
||||
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoConnectSocialAccountAdapter(DefaultSocialAccountAdapter):
|
||||
"""
|
||||
Custom adapter to automatically connect social accounts to existing users
|
||||
with the same email address.
|
||||
|
||||
SECURITY WARNING:
|
||||
This adapter automatically connects OIDC accounts to existing local accounts
|
||||
based on email matching.
|
||||
|
||||
If your OIDC provider allows unverified emails, this could lead to
|
||||
ACCOUNT TAKEOVER attacks where an attacker creates an OIDC account
|
||||
with someone else's email and gains access to their account.
|
||||
"""
|
||||
|
||||
def pre_social_login(self, request, sociallogin):
|
||||
"""
|
||||
Invoked just after a user successfully authenticates via a
|
||||
social provider, but before the login is actually processed.
|
||||
|
||||
If a user with the same email already exists, connect the social
|
||||
account to that existing user instead of creating a new account.
|
||||
"""
|
||||
# If the social account is already connected to a user, do nothing
|
||||
if sociallogin.is_existing:
|
||||
return
|
||||
|
||||
# Check if we have an email from the social provider
|
||||
if not sociallogin.email_addresses:
|
||||
logger.warning(
|
||||
"OIDC login attempted without email address. "
|
||||
f"Provider: {sociallogin.account.provider}"
|
||||
)
|
||||
return
|
||||
|
||||
# Get the email from the social login
|
||||
email = sociallogin.email_addresses[0].email.lower()
|
||||
|
||||
# Try to find an existing user with this email
|
||||
try:
|
||||
user = User.objects.get(email__iexact=email)
|
||||
|
||||
# Log this connection for security audit trail
|
||||
logger.info(
|
||||
f"Auto-connecting OIDC account to existing user. "
|
||||
f"Email: {email}, Provider: {sociallogin.account.provider}, "
|
||||
f"User ID: {user.id}"
|
||||
)
|
||||
|
||||
# Connect the social account to the existing user
|
||||
sociallogin.connect(request, user)
|
||||
|
||||
except User.DoesNotExist:
|
||||
# No user with this email exists, proceed with normal signup flow
|
||||
logger.debug(
|
||||
f"No existing user found for email {email}. "
|
||||
"Proceeding with new account creation."
|
||||
)
|
||||
pass
|
||||
except User.MultipleObjectsReturned:
|
||||
# Multiple users with the same email (shouldn't happen with unique constraint)
|
||||
logger.error(
|
||||
f"Multiple users found with email {email}. "
|
||||
"This should not happen with unique constraint. "
|
||||
"Blocking auto-connect."
|
||||
)
|
||||
# Let the default behavior handle this
|
||||
pass
|
||||
+49
-31
@@ -1,35 +1,45 @@
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.users.models import UserSettings
|
||||
from apps.accounts.models import Account
|
||||
from crispy_forms.bootstrap import (
|
||||
FormActions,
|
||||
)
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Submit, Row, Column, Field, Div, HTML
|
||||
from crispy_forms.layout import HTML, Column, Div, Field, Layout, Row, Submit
|
||||
from django import forms
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.forms import (
|
||||
UsernameField,
|
||||
AuthenticationForm,
|
||||
UserCreationForm,
|
||||
UsernameField,
|
||||
)
|
||||
from django.db import transaction
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.users.models import UserSettings
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
|
||||
|
||||
class LoginForm(AuthenticationForm):
|
||||
username = UsernameField(
|
||||
label=_("E-mail"),
|
||||
widget=forms.EmailInput(
|
||||
attrs={"class": "form-control", "placeholder": "E-mail", "name": "email"}
|
||||
attrs={
|
||||
"class": "input",
|
||||
"placeholder": _("E-mail"),
|
||||
"name": "email",
|
||||
"autocomplete": "email",
|
||||
}
|
||||
),
|
||||
)
|
||||
password = forms.CharField(
|
||||
label=_("Password"),
|
||||
strip=False,
|
||||
widget=forms.PasswordInput(
|
||||
attrs={"class": "form-control", "placeholder": "Senha"}
|
||||
attrs={
|
||||
"class": "input",
|
||||
"placeholder": _("Password"),
|
||||
"autocomplete": "current-password",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@@ -45,7 +55,7 @@ class LoginForm(AuthenticationForm):
|
||||
self.helper.layout = Layout(
|
||||
"username",
|
||||
"password",
|
||||
Submit("Submit", "Login", css_class="btn btn-primary w-100"),
|
||||
Submit("Submit", "Login", css_class="w-full mt-3"),
|
||||
)
|
||||
|
||||
|
||||
@@ -108,6 +118,15 @@ class UserSettingsForm(forms.ModelForm):
|
||||
label=_("Number Format"),
|
||||
)
|
||||
|
||||
default_account = forms.ModelChoiceField(
|
||||
queryset=Account.objects.filter(
|
||||
is_archived=False,
|
||||
),
|
||||
label=_("Default Account"),
|
||||
widget=TomSelect(clear_button=False, group_by="group"),
|
||||
required=False,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = UserSettings
|
||||
fields = [
|
||||
@@ -118,29 +137,36 @@ class UserSettingsForm(forms.ModelForm):
|
||||
"datetime_format",
|
||||
"number_format",
|
||||
"volume",
|
||||
"default_account",
|
||||
]
|
||||
widgets = {
|
||||
"default_account": TomSelect(clear_button=False, group_by="group"),
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fields["default_account"].queryset = Account.objects.filter(
|
||||
is_archived=False,
|
||||
)
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.form_method = "post"
|
||||
self.helper.layout = Layout(
|
||||
"language",
|
||||
"timezone",
|
||||
HTML("<hr />"),
|
||||
HTML('<hr class="hr my-3" />'),
|
||||
"date_format",
|
||||
"datetime_format",
|
||||
"number_format",
|
||||
HTML("<hr />"),
|
||||
HTML('<hr class="hr my-3" />'),
|
||||
"start_page",
|
||||
HTML("<hr />"),
|
||||
"default_account",
|
||||
HTML('<hr class="hr my-3" />'),
|
||||
"volume",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Save"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Save"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -191,8 +217,8 @@ class UserUpdateForm(forms.ModelForm):
|
||||
# Define the layout using Crispy Forms, including the new fields
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("first_name", css_class="form-group col-md-6"),
|
||||
Column("last_name", css_class="form-group col-md-6"),
|
||||
Column("first_name"),
|
||||
Column("last_name"),
|
||||
css_class="row",
|
||||
),
|
||||
Field("email"),
|
||||
@@ -213,17 +239,13 @@ class UserUpdateForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -354,8 +376,8 @@ class UserAddForm(UserCreationForm):
|
||||
self.helper.layout = Layout(
|
||||
Field("email"),
|
||||
Row(
|
||||
Column("first_name", css_class="form-group col-md-6"),
|
||||
Column("last_name", css_class="form-group col-md-6"),
|
||||
Column("first_name"),
|
||||
Column("last_name"),
|
||||
css_class="row",
|
||||
),
|
||||
# UserCreationForm provides 'password1' and 'password2' fields
|
||||
@@ -375,17 +397,13 @@ class UserAddForm(UserCreationForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# Generated by Django 5.2.9 on 2026-02-15 21:35
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("accounts", "0016_account_untracked_by"),
|
||||
("users", "0023_alter_usersettings_timezone"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="usersettings",
|
||||
name="default_account",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="accounts.account",
|
||||
verbose_name="Default account",
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,20 @@
|
||||
# Generated by Django 5.2.9 on 2026-02-16 01:32
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0016_account_untracked_by'),
|
||||
('users', '0024_usersettings_default_account'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='usersettings',
|
||||
name='default_account',
|
||||
field=models.ForeignKey(blank=True, help_text='Selects the account by default when creating new transactions', null=True, on_delete=django.db.models.deletion.SET_NULL, to='accounts.account', verbose_name='Default account'),
|
||||
),
|
||||
]
|
||||
@@ -510,6 +510,14 @@ class UserSettings(models.Model):
|
||||
default=StartPage.MONTHLY,
|
||||
verbose_name=_("Start page"),
|
||||
)
|
||||
default_account = models.ForeignKey(
|
||||
"accounts.Account",
|
||||
on_delete=models.SET_NULL,
|
||||
verbose_name=_("Default account"),
|
||||
help_text=_("Selects the account by default when creating new transactions"),
|
||||
blank=True,
|
||||
null=True,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.user.email}'s settings"
|
||||
|
||||
@@ -18,10 +18,15 @@ urlpatterns = [
|
||||
name="toggle_sound_playing",
|
||||
),
|
||||
path(
|
||||
"user/toggle-sidebar/",
|
||||
"user/session/toggle-sidebar/",
|
||||
views.toggle_sidebar_status,
|
||||
name="toggle_sidebar_status",
|
||||
),
|
||||
path(
|
||||
"user/session/toggle-theme/",
|
||||
views.toggle_theme,
|
||||
name="toggle_theme",
|
||||
),
|
||||
path(
|
||||
"user/settings/",
|
||||
views.update_settings,
|
||||
|
||||
+31
-13
@@ -1,27 +1,26 @@
|
||||
from apps.common.decorators.demo import disabled_on_demo
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
from apps.common.decorators.user import htmx_login_required, is_superuser
|
||||
from apps.users.forms import (
|
||||
LoginForm,
|
||||
UserAddForm,
|
||||
UserSettingsForm,
|
||||
UserUpdateForm,
|
||||
)
|
||||
from apps.users.models import UserSettings
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth import logout, get_user_model
|
||||
from django.contrib.auth import get_user_model, logout
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.contrib.auth.views import (
|
||||
LoginView,
|
||||
)
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.http import HttpResponse
|
||||
from django.shortcuts import redirect, render, get_object_or_404
|
||||
from django.shortcuts import get_object_or_404, redirect, render
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.views.decorators.http import require_http_methods
|
||||
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
from apps.common.decorators.user import is_superuser, htmx_login_required
|
||||
from apps.users.forms import (
|
||||
LoginForm,
|
||||
UserSettingsForm,
|
||||
UserUpdateForm,
|
||||
UserAddForm,
|
||||
)
|
||||
from apps.users.models import UserSettings
|
||||
from apps.common.decorators.demo import disabled_on_demo
|
||||
|
||||
|
||||
def logout_view(request):
|
||||
logout(request)
|
||||
@@ -118,6 +117,7 @@ def update_settings(request):
|
||||
|
||||
@only_htmx
|
||||
@htmx_login_required
|
||||
@require_http_methods(["GET"])
|
||||
def toggle_sidebar_status(request):
|
||||
if not request.session.get("sidebar_status"):
|
||||
request.session["sidebar_status"] = "floating"
|
||||
@@ -134,6 +134,24 @@ def toggle_sidebar_status(request):
|
||||
)
|
||||
|
||||
|
||||
@htmx_login_required
|
||||
@require_http_methods(["GET"])
|
||||
def toggle_theme(request):
|
||||
if not request.session.get("theme"):
|
||||
request.session["theme"] = "wygiwyh_dark"
|
||||
|
||||
if request.session["theme"] == "wygiwyh_dark":
|
||||
request.session["theme"] = "wygiwyh_light"
|
||||
elif request.session["theme"] == "wygiwyh_light":
|
||||
request.session["theme"] = "wygiwyh_dark"
|
||||
else:
|
||||
request.session["theme"] = "wygiwyh_light"
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
)
|
||||
|
||||
|
||||
@htmx_login_required
|
||||
@is_superuser
|
||||
@require_http_methods(["GET"])
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
settings:
|
||||
file_type: qif
|
||||
importing: transactions
|
||||
encoding: cp1252
|
||||
date_format: "%d/%m/%Y"
|
||||
skip_errors: true
|
||||
|
||||
mapping: {}
|
||||
|
||||
deduplicate: []
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user