mirror of
https://github.com/eitchtee/WYGIWYH.git
synced 2026-01-12 12:20:26 +01:00
Compare commits
582 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e4d7c6b1f | ||
|
|
63868514f9 | ||
|
|
9055a24327 | ||
|
|
9dc963ed7b | ||
|
|
49cac0588e | ||
|
|
3b2b6d6473 | ||
|
|
db30bcbeb7 | ||
|
|
a122733a47 | ||
|
|
37f3e4d99a | ||
|
|
d756286135 | ||
|
|
06a7378fd8 | ||
|
|
ab4075c500 | ||
|
|
96318f003d | ||
|
|
1a0412264a | ||
|
|
2588404876 | ||
|
|
fdc273103b | ||
|
|
c015b78cd6 | ||
|
|
50e5492ea1 | ||
|
|
796089cdb3 | ||
|
|
c83b1bf2d6 | ||
|
|
b074ef7929 | ||
|
|
ec7e33b3b0 | ||
|
|
72fedea0db | ||
|
|
0a03745ce6 | ||
|
|
ff4bd79634 | ||
|
|
383b42e26d | ||
|
|
48e43ac031 | ||
|
|
21c60c4059 | ||
|
|
dd6a390e6b | ||
|
|
0c961a8250 | ||
|
|
e28c651973 | ||
|
|
7687ff81c3 | ||
|
|
b2d78c9190 | ||
|
|
b0815e00c7 | ||
|
|
fbe9726338 | ||
|
|
0df3a57a33 | ||
|
|
f86613b17a | ||
|
|
ffa4644e1b | ||
|
|
6611559696 | ||
|
|
b455a0251a | ||
|
|
9d7c3212f1 | ||
|
|
0da3185996 | ||
|
|
6c90e1bb7f | ||
|
|
c6543c0841 | ||
|
|
d4740b8406 | ||
|
|
5a51795e6a | ||
|
|
64d7765357 | ||
|
|
070e11ca77 | ||
|
|
39f66b620a | ||
|
|
ad164866e0 | ||
|
|
05c465cb34 | ||
|
|
92cf526b76 | ||
|
|
639236b890 | ||
|
|
519a85d256 | ||
|
|
700d35b5d5 | ||
|
|
10e51971db | ||
|
|
ec0d5fc121 | ||
|
|
01f91352d6 | ||
|
|
63ce57a315 | ||
|
|
eadeb649a1 | ||
|
|
a2871d5289 | ||
|
|
f2a362bc0f | ||
|
|
2076903740 | ||
|
|
c752c0b16e | ||
|
|
1674766253 | ||
|
|
7ea9d56132 | ||
|
|
3699c6c671 | ||
|
|
d7c255aa14 | ||
|
|
d17b9d5736 | ||
|
|
c7ff6db0bf | ||
|
|
a4c7753f69 | ||
|
|
7e08028557 | ||
|
|
5eaf5086d2 | ||
|
|
c949c6cea0 | ||
|
|
71c0e9a271 | ||
|
|
bc65980511 | ||
|
|
ecdb1a52cc | ||
|
|
afc06582b4 | ||
|
|
07cb0a2a0f | ||
|
|
05ede58c36 | ||
|
|
20b6366a18 | ||
|
|
b0101dae1a | ||
|
|
a3d38ff9e0 | ||
|
|
776e2117a0 | ||
|
|
edcad37926 | ||
|
|
2d51d21035 | ||
|
|
94f5c25829 | ||
|
|
88a5c103e5 | ||
|
|
3dce9e1c55 | ||
|
|
41d8564e8b | ||
|
|
5ee2fd244f | ||
|
|
0545fb7651 | ||
|
|
7bd1d2d751 | ||
|
|
9a4ec449df | ||
|
|
f918351303 | ||
|
|
ef66b3a1e5 | ||
|
|
7486660223 | ||
|
|
00d5ccda34 | ||
|
|
1656eec601 | ||
|
|
64b96ed2f3 | ||
|
|
1f5e4f132d | ||
|
|
edf056b68c | ||
|
|
35865ce21c | ||
|
|
8f06c06d32 | ||
|
|
15eaa2239a | ||
|
|
fd7214df95 | ||
|
|
e531c63de3 | ||
|
|
5a79dd5424 | ||
|
|
315dd1479a | ||
|
|
67f79effab | ||
|
|
c168886968 | ||
|
|
272c34d3b3 | ||
|
|
43ce79ae65 | ||
|
|
4aa29545ec | ||
|
|
fd1fcb832c | ||
|
|
b5fd928a5d | ||
|
|
2dc398f82b | ||
|
|
cf7d4b1404 | ||
|
|
e9c3af1a85 | ||
|
|
b121e8e982 | ||
|
|
606e6b3843 | ||
|
|
6e46b5abb8 | ||
|
|
5b4dab93a1 | ||
|
|
29b6ee3af3 | ||
|
|
484686b709 | ||
|
|
938c128d07 | ||
|
|
8123f7f3cb | ||
|
|
547dc90d9e | ||
|
|
dc33fda5d3 | ||
|
|
92960d1b9a | ||
|
|
1978a467cb | ||
|
|
5bdafbba91 | ||
|
|
16de87376a | ||
|
|
e8e1144fdd | ||
|
|
157f357a7a | ||
|
|
d77eddbd26 | ||
|
|
fb1b383962 | ||
|
|
11998475c5 | ||
|
|
21363e23a1 | ||
|
|
d3a816d91b | ||
|
|
9c92bbd3cf | ||
|
|
c55d688956 | ||
|
|
231b9065c9 | ||
|
|
01ea0de4b3 | ||
|
|
c57fa1630b | ||
|
|
92f7bcfd9e | ||
|
|
58b855f55e | ||
|
|
d4d51301b3 | ||
|
|
aed3fb11fe | ||
|
|
70d427bec4 | ||
|
|
b6f52458db | ||
|
|
8d76c40b7e | ||
|
|
a43e3d158f | ||
|
|
588ae2de6e | ||
|
|
4b97ba681a | ||
|
|
1a903507ad | ||
|
|
bf920df771 | ||
|
|
23ae6f3d54 | ||
|
|
49f28834e9 | ||
|
|
4351027b87 | ||
|
|
c37aa6e059 | ||
|
|
8a5a54dcbd | ||
|
|
24ee8ecd68 | ||
|
|
a14332bb80 | ||
|
|
32747071fe | ||
|
|
24fa9cde51 | ||
|
|
372ec2f30f | ||
|
|
fffba037a6 | ||
|
|
43488147d8 | ||
|
|
31a31e9922 | ||
|
|
7af6280b29 | ||
|
|
40389396e3 | ||
|
|
21845d501e | ||
|
|
5f098e11a3 | ||
|
|
d2de0684fb | ||
|
|
eb4723e890 | ||
|
|
890cc90420 | ||
|
|
307af9e40a | ||
|
|
1eeb0b0f5e | ||
|
|
605ece705e | ||
|
|
2ae57e83cb | ||
|
|
af72e3f44e | ||
|
|
e2e1c5cff5 | ||
|
|
ed3d58f1fd | ||
|
|
b58f894dc6 | ||
|
|
2ed7fa44c0 | ||
|
|
7c3120cd43 | ||
|
|
2bc5e24e51 | ||
|
|
d3f8a637bc | ||
|
|
b02b6451d2 | ||
|
|
0b0d760bab | ||
|
|
b38ed37bc5 | ||
|
|
5d7dd622f5 | ||
|
|
7e37948616 | ||
|
|
2afb6b1f5f | ||
|
|
cd54df6f2d | ||
|
|
3e4ace8993 | ||
|
|
a878af28f1 | ||
|
|
0a4d4c12b9 | ||
|
|
9ade58a003 | ||
|
|
89b2d0118d | ||
|
|
232d5003b8 | ||
|
|
133d70d3d1 | ||
|
|
e70608eaaf | ||
|
|
a63367a772 | ||
|
|
baef86b6cb | ||
|
|
3011b32fa6 | ||
|
|
910decfe00 | ||
|
|
e600d87968 | ||
|
|
dd82289488 | ||
|
|
1e816ec80a | ||
|
|
3b5626cbd1 | ||
|
|
a819ceaa43 | ||
|
|
de28dbb0f0 | ||
|
|
cfb34a4dc3 | ||
|
|
efdcfc192a | ||
|
|
a7856a6671 | ||
|
|
7b8e3b528a | ||
|
|
cc3244a034 | ||
|
|
2121a68c82 | ||
|
|
f35002f862 | ||
|
|
73a992256d | ||
|
|
9f1098d6b9 | ||
|
|
2c0936b7e5 | ||
|
|
5fb717c3fe | ||
|
|
c5f94fb34d | ||
|
|
29cdec4577 | ||
|
|
82efd48e53 | ||
|
|
5a3a0b7e5c | ||
|
|
41a5900f12 | ||
|
|
2dbdd02350 | ||
|
|
fa0cde1a4e | ||
|
|
623d91d26f | ||
|
|
57200437dc | ||
|
|
6f4a2b687c | ||
|
|
8bb40be41c | ||
|
|
66c1cf2371 | ||
|
|
4b23836544 | ||
|
|
585af1270f | ||
|
|
a0cc51b2ec | ||
|
|
6a5de7d94d | ||
|
|
6d9687de0b | ||
|
|
e9acf1dd8f | ||
|
|
698e05bd06 | ||
|
|
90b3778e36 | ||
|
|
85a773bc01 | ||
|
|
355016a7a5 | ||
|
|
f04fcf99b7 | ||
|
|
0fb389e7e8 | ||
|
|
63898aeef0 | ||
|
|
4fdf00d098 | ||
|
|
025cc585d5 | ||
|
|
17018d87cd | ||
|
|
1e5f4f6583 | ||
|
|
a99851cf9b | ||
|
|
9fb1ad4861 | ||
|
|
66c3abfe37 | ||
|
|
8ca64f5820 | ||
|
|
e743821570 | ||
|
|
5c698d8735 | ||
|
|
3e5aa90df0 | ||
|
|
b2add14238 | ||
|
|
a052c00aa8 | ||
|
|
7f343708e0 | ||
|
|
22e95c7f4a | ||
|
|
7645153f77 | ||
|
|
1abfed9abf | ||
|
|
eea0ab009d | ||
|
|
29446def22 | ||
|
|
9dce5e9efe | ||
|
|
695e2cb322 | ||
|
|
b135ec3b15 | ||
|
|
bb3cc5da6c | ||
|
|
ca7fe24a8a | ||
|
|
483ba74010 | ||
|
|
f2abeff31a | ||
|
|
666eaff167 | ||
|
|
d72454f854 | ||
|
|
333aa81923 | ||
|
|
41b8cfd1e7 | ||
|
|
1fa7985b01 | ||
|
|
38392a6322 | ||
|
|
637c62319b | ||
|
|
f91fe67629 | ||
|
|
9eb1818a20 | ||
|
|
50ac679e33 | ||
|
|
2a463c63b8 | ||
|
|
dce65f2faf | ||
|
|
a053cb3947 | ||
|
|
2d43072120 | ||
|
|
70bdee065e | ||
|
|
95db27a32f | ||
|
|
d6d4e6a102 | ||
|
|
bc0f30fead | ||
|
|
a9a86fc491 | ||
|
|
c3b5f2bf39 | ||
|
|
19128e5aed | ||
|
|
9b5c6d3413 | ||
|
|
73c873a2ad | ||
|
|
9d2be22a77 | ||
|
|
6a3d31f37d | ||
|
|
3be3a3c14b | ||
|
|
a5b0f4efb7 | ||
|
|
6da50db417 | ||
|
|
a6c1daf902 | ||
|
|
6a271fb3d7 | ||
|
|
2cf9a9dd0f | ||
|
|
64b32316ca | ||
|
|
0deaabe719 | ||
|
|
b14342af2e | ||
|
|
efe020efb3 | ||
|
|
2c14ce6366 | ||
|
|
8c133f92ce | ||
|
|
2dd887b0d9 | ||
|
|
f3c9d8faea | ||
|
|
8be7758dc0 | ||
|
|
8f5204a17b | ||
|
|
05dd782df5 | ||
|
|
187fe43283 | ||
|
|
cae73376db | ||
|
|
7225454a6e | ||
|
|
70c8c1e07c | ||
|
|
2235bdeabb | ||
|
|
d724300513 | ||
|
|
eacafa1def | ||
|
|
c738f5ee29 | ||
|
|
c392a2c988 | ||
|
|
17ea859fd2 | ||
|
|
8aae6f928f | ||
|
|
7c43b06b9f | ||
|
|
72904266bf | ||
|
|
e16e279911 | ||
|
|
670bee4325 | ||
|
|
3e2c1184ce | ||
|
|
731f351eef | ||
|
|
b7056e7aa1 | ||
|
|
accceed630 | ||
|
|
76346cb503 | ||
|
|
3df8952ea2 | ||
|
|
9bd067da96 | ||
|
|
1abe9e9f62 | ||
|
|
1a86b5dea4 | ||
|
|
8f2f5a16c2 | ||
|
|
4565dc770b | ||
|
|
23673def09 | ||
|
|
dd2b9ead7e | ||
|
|
2078e9f3e4 | ||
|
|
e6bab57ab4 | ||
|
|
38d50a78f4 | ||
|
|
0d947f9ba6 | ||
|
|
99c85a56bb | ||
|
|
ab1c074f27 | ||
|
|
abf3a148cc | ||
|
|
2733c92da5 | ||
|
|
9bfbe54ed5 | ||
|
|
5b27dea07c | ||
|
|
791e1000a3 | ||
|
|
7301d9f475 | ||
|
|
47a44e96f8 | ||
|
|
7d247eb737 | ||
|
|
373616e7bb | ||
|
|
bf3c11d582 | ||
|
|
b4b1c10db9 | ||
|
|
5ca531f47d | ||
|
|
07673cb528 | ||
|
|
67c6d81897 | ||
|
|
3c85da46b0 | ||
|
|
d263936be7 | ||
|
|
1524063d5a | ||
|
|
c3a403b8f0 | ||
|
|
1c1018adae | ||
|
|
350d5adf25 | ||
|
|
7e4defb9cc | ||
|
|
7121e4bc09 | ||
|
|
4540e48fe5 | ||
|
|
d06b51421f | ||
|
|
e096912e41 | ||
|
|
f0ad6e16fe | ||
|
|
734a302fa7 | ||
|
|
89b1b7bcb7 | ||
|
|
37b40f89bb | ||
|
|
0c63552d1b | ||
|
|
7db517e848 | ||
|
|
7e3ed6cf94 | ||
|
|
e10a88c00e | ||
|
|
b912a33b93 | ||
|
|
d9fb3627cc | ||
|
|
78ffa68ba4 | ||
|
|
37f4ead058 | ||
|
|
61630fca5b | ||
|
|
910d4c84a3 | ||
|
|
be1f29d8c1 | ||
|
|
9784d840cc | ||
|
|
db5ce13ff3 | ||
|
|
a2b943d949 | ||
|
|
d8098b4486 | ||
|
|
f8cff6623f | ||
|
|
65c61f76ff | ||
|
|
74899f63ab | ||
|
|
66a5e6d613 | ||
|
|
e0ab32ec03 | ||
|
|
a912e4a511 | ||
|
|
57ba672c91 | ||
|
|
20c6989ffb | ||
|
|
c6cd525c49 | ||
|
|
55c4b920ee | ||
|
|
7f8261b9cc | ||
|
|
9102654eab | ||
|
|
1ff49a8a04 | ||
|
|
846dd1fd73 | ||
|
|
9eed3b6692 | ||
|
|
b7c53a3c2d | ||
|
|
b378c8f6f7 | ||
|
|
ccc4deb1d8 | ||
|
|
d3ecf55375 | ||
|
|
580f3e7345 | ||
|
|
0e5843094b | ||
|
|
ed65945d19 | ||
|
|
18d8837c64 | ||
|
|
067d819077 | ||
|
|
bbaae4746a | ||
|
|
d2e5c1d6b3 | ||
|
|
ffef61d514 | ||
|
|
9020f6f972 | ||
|
|
540235c1b0 | ||
|
|
9070bc5705 | ||
|
|
ba5a6c9772 | ||
|
|
2f33b5989f | ||
|
|
5f24d05540 | ||
|
|
31cf62e277 | ||
|
|
15d990007e | ||
|
|
3d5bc9cd3f | ||
|
|
a544dc4943 | ||
|
|
b1178198e9 | ||
|
|
02a488bfff | ||
|
|
b05285947b | ||
|
|
d7b7dd28c7 | ||
|
|
9353d498ef | ||
|
|
4f6903e8e4 | ||
|
|
7d3d6ea2fc | ||
|
|
cce9c7a7a5 | ||
|
|
f80f74a01a | ||
|
|
df47ffc49c | ||
|
|
4f35647a22 | ||
|
|
368342853f | ||
|
|
5a675f674d | ||
|
|
9ef8fdec49 | ||
|
|
f29a8d8bc0 | ||
|
|
8c43365ec0 | ||
|
|
2cdcc4ee26 | ||
|
|
84852012f9 | ||
|
|
edf0e2c66f | ||
|
|
f90a31f2b9 | ||
|
|
dd1f6a6ef2 | ||
|
|
57f98ba171 | ||
|
|
f2e93f7df9 | ||
|
|
26cfa493b3 | ||
|
|
c6e003ed86 | ||
|
|
c09ad0e49d | ||
|
|
9250131396 | ||
|
|
5f503149ce | ||
|
|
d45b4f2942 | ||
|
|
4a8493c7d9 | ||
|
|
c39c3ccacb | ||
|
|
4bb8ef6582 | ||
|
|
d711ccca69 | ||
|
|
76d59f1038 | ||
|
|
5b6c123fa1 | ||
|
|
782ab11ae4 | ||
|
|
8db885f47d | ||
|
|
01bd8710d8 | ||
|
|
569d08711c | ||
|
|
a285f055e4 | ||
|
|
6aae9b1207 | ||
|
|
9d2206f8a4 | ||
|
|
d7e3c50c2c | ||
|
|
789fd4eb80 | ||
|
|
586b3a5d44 | ||
|
|
9248e8bd77 | ||
|
|
c44247f6a5 | ||
|
|
8ba89434f8 | ||
|
|
f2f41981a3 | ||
|
|
1153fd6b0a | ||
|
|
76822224a0 | ||
|
|
31b2b98eb9 | ||
|
|
d7a4e79321 | ||
|
|
985f07e792 | ||
|
|
5465bb1eeb | ||
|
|
451a85a998 | ||
|
|
54c74e7c07 | ||
|
|
d6e9e123b7 | ||
|
|
80c9c43a02 | ||
|
|
3e34f088fc | ||
|
|
5b9e5c6003 | ||
|
|
c266b8809f | ||
|
|
8cda4116bc | ||
|
|
c2510b2261 | ||
|
|
dcdaf756f9 | ||
|
|
50ca08165a | ||
|
|
f85618fa01 | ||
|
|
635f87a8ad | ||
|
|
1a073ba53d | ||
|
|
5412e5b12c | ||
|
|
2103ba1b38 | ||
|
|
04fb15224c | ||
|
|
2fc526beac | ||
|
|
cc3ca4f4a3 | ||
|
|
8d3844c431 | ||
|
|
5e7e918085 | ||
|
|
c3f02320b5 | ||
|
|
da8bbbfb0b | ||
|
|
e3f74538d2 | ||
|
|
d8234950c6 | ||
|
|
58f19ce1ca | ||
|
|
ef5f3580a0 | ||
|
|
efe0f99cb4 | ||
|
|
dccb5079ad | ||
|
|
6c90150661 | ||
|
|
c33d6fab69 | ||
|
|
c0c57a6d77 | ||
|
|
f19d58a2bd | ||
|
|
dfe99093e9 | ||
|
|
d737e573cc | ||
|
|
805d3f419e | ||
|
|
9777aac746 | ||
|
|
61b782104d | ||
|
|
79dec2b515 | ||
|
|
db23e162c4 | ||
|
|
d81d89d9f6 | ||
|
|
6826cfe79a | ||
|
|
0832ec75ca | ||
|
|
3090f632de | ||
|
|
6b4fbee7a6 | ||
|
|
e7fe6622cd | ||
|
|
3017593ed5 | ||
|
|
ceb8e9ea97 | ||
|
|
9b5b7683dd | ||
|
|
514600e34a | ||
|
|
07dd805b07 | ||
|
|
905e9b4c54 | ||
|
|
60d367dec5 | ||
|
|
6e0842a697 | ||
|
|
858934b7c5 | ||
|
|
47d9e4078c | ||
|
|
fa6f3e87c0 | ||
|
|
5f101af879 | ||
|
|
b27633a28e | ||
|
|
7716eee0b3 | ||
|
|
37c447ae0a | ||
|
|
e544d7068b | ||
|
|
8d93da99c1 | ||
|
|
cc87477a2e | ||
|
|
e86e0b8c08 | ||
|
|
eb0c872c50 | ||
|
|
b4578df242 | ||
|
|
756de12835 | ||
|
|
d573d02657 | ||
|
|
250b352217 | ||
|
|
b4e9446cf6 | ||
|
|
90944f0179 | ||
|
|
008d34b1d0 | ||
|
|
46dfc7dad4 | ||
|
|
22900b5d9e | ||
|
|
0c48e9fe3d | ||
|
|
b2e100d1b0 | ||
|
|
e49b38a442 | ||
|
|
1f2902eea9 | ||
|
|
7d60db8716 | ||
|
|
873b0baed7 | ||
|
|
2313c97761 | ||
|
|
9cd7337153 | ||
|
|
d3b354e2b8 | ||
|
|
e137666e99 | ||
|
|
4291a5b97d | ||
|
|
c8d316857f | ||
|
|
3395a96949 | ||
|
|
8ab9624619 | ||
|
|
f9056c3a45 | ||
|
|
a9df684ee2 | ||
|
|
e4d07c94d4 | ||
|
|
5d5d172b3b | ||
|
|
99f746b6be |
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@@ -0,0 +1 @@
|
||||
__pycache__/
|
||||
@@ -31,3 +31,10 @@ ENABLE_SOFT_DELETE=false
|
||||
KEEP_DELETED_TRANSACTIONS_FOR=365
|
||||
|
||||
TASK_WORKERS=1 # This only work if you're using the single container option. Increase to have more open queues via procrastinate, you probably don't need to increase this.
|
||||
|
||||
# OIDC Configuration. Uncomment the lines below if you want to add OIDC login to your instance
|
||||
#OIDC_CLIENT_NAME=""
|
||||
#OIDC_CLIENT_ID=""
|
||||
#OIDC_CLIENT_SECRET=""
|
||||
#OIDC_SERVER_URL=""
|
||||
#OIDC_ALLOW_SIGNUP=true
|
||||
|
||||
74
.github/workflows/release.yml
vendored
74
.github/workflows/release.yml
vendored
@@ -12,7 +12,7 @@ on:
|
||||
required: true
|
||||
type: string
|
||||
ref:
|
||||
description: 'Git ref to checkout (branch, tag, or SHA)'
|
||||
description: 'Git ref to checkout'
|
||||
required: true
|
||||
default: 'main'
|
||||
type: string
|
||||
@@ -29,73 +29,57 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write # Needed if you switch to GHCR, good practice
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.inputs.ref }}
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
|
||||
- name: Checkout code (non-manual)
|
||||
uses: actions/checkout@v4
|
||||
if: github.event_name != 'workflow_dispatch'
|
||||
ref: ${{ inputs.ref || github.ref }}
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# This action handles all the logic for tags (nightly vs release vs custom)
|
||||
- name: Docker Metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
# Logic for Push to Main -> nightly
|
||||
type=raw,value=nightly,enable=${{ github.event_name == 'push' }}
|
||||
# Logic for Release -> semver and latest
|
||||
type=semver,pattern={{version}},enable=${{ github.event_name == 'release' }}
|
||||
type=raw,value=latest,enable=${{ github.event_name == 'release' }}
|
||||
# Logic for Manual Dispatch -> custom input
|
||||
type=raw,value=${{ inputs.tag }},enable=${{ github.event_name == 'workflow_dispatch' }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build and push nightly image
|
||||
if: github.event_name == 'push'
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/prod/django/Dockerfile
|
||||
push: true
|
||||
provenance: false
|
||||
# Pass the calculated tags from the meta step
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
VERSION=nightly
|
||||
tags: ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:nightly
|
||||
VERSION=${{ steps.meta.outputs.version }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Build and push release image
|
||||
if: github.event_name == 'release'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/prod/django/Dockerfile
|
||||
push: true
|
||||
provenance: false
|
||||
build-args: |
|
||||
VERSION=${{ github.event.release.tag_name }}
|
||||
tags: |
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:latest
|
||||
${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:${{ github.event.release.tag_name }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Build and push custom image
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/prod/django/Dockerfile
|
||||
push: true
|
||||
provenance: false
|
||||
build-args: |
|
||||
VERSION=${{ github.event.inputs.tag }}
|
||||
tags: ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:${{ github.event.inputs.tag }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
# --- CACHE CONFIGURATION ---
|
||||
# We set a specific 'scope' key.
|
||||
# This allows the Release tag to see the cache created by the Main branch.
|
||||
cache-from: type=gha,scope=build-cache
|
||||
cache-to: type=gha,mode=max,scope=build-cache
|
||||
|
||||
15
.github/workflows/translations.yml
vendored
15
.github/workflows/translations.yml
vendored
@@ -32,15 +32,16 @@ jobs:
|
||||
token: ${{ secrets.PAT }}
|
||||
ref: ${{ github.head_ref }}
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
enable-cache: true
|
||||
|
||||
- name: Set up Python 3.11
|
||||
run: uv python install 3.11
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
run: uv sync --frozen --no-dev
|
||||
|
||||
- name: Install gettext
|
||||
run: sudo apt-get install -y gettext
|
||||
@@ -48,7 +49,7 @@ jobs:
|
||||
- name: Run makemessages
|
||||
run: |
|
||||
cd app
|
||||
python manage.py makemessages -a
|
||||
uv run python manage.py makemessages -a
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -123,6 +123,7 @@ celerybeat.pid
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.prod.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
@@ -160,3 +161,7 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
node_modules/
|
||||
postgres_data/
|
||||
.prod.env
|
||||
|
||||
8
.vscode/settings.json
vendored
Normal file
8
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"djlint.showInstallError": false,
|
||||
"files.associations": {
|
||||
"*.css": "tailwindcss"
|
||||
},
|
||||
"tailwindCSS.experimental.configFile": "frontend/src/styles/tailwind.css",
|
||||
"djlint.profile": "django",
|
||||
}
|
||||
47
README.md
47
README.md
@@ -13,6 +13,7 @@
|
||||
<a href="#key-features">Features</a> •
|
||||
<a href="#how-to-use">Usage</a> •
|
||||
<a href="#how-it-works">How</a> •
|
||||
<a href="#mcp-server">MCP Server</a> •
|
||||
<a href="#help-us-translate-wygiwyh">Translate</a> •
|
||||
<a href="#caveats-and-warnings">Caveats and Warnings</a> •
|
||||
<a href="#built-with">Built with</a>
|
||||
@@ -29,15 +30,15 @@ Managing money can feel unnecessarily complex, but it doesn’t have to be. WYGI
|
||||
|
||||
By sticking to this straightforward approach, you avoid dipping into your savings while still keeping tabs on where your money goes.
|
||||
|
||||
While this philosophy is simple, finding tools to make it work wasn’t. I initially used a spreadsheet, which served me well for years—until it became unwieldy as I started managing multiple currencies, accounts, and investments. I tried various financial management apps, but none met my key requirements:
|
||||
While this philosophy is simple, finding tools to make it work wasn’t. I initially used a spreadsheet, which served me well for years, until it became unwieldy as I started managing multiple currencies, accounts, and investments. I tried various financial management apps, but none met my key requirements:
|
||||
|
||||
1. **Multi-currency support** to track income and expenses in different currencies.
|
||||
2. **Not a budgeting app** — as I dislike budgeting constraints.
|
||||
2. **Not a budgeting app** as I dislike budgeting constraints.
|
||||
3. **Web app usability** (ideally with mobile support, though optional).
|
||||
4. **Automation-ready API** to integrate with other tools and services.
|
||||
5. **Custom transaction rules** for credit card billing cycles or similar quirks.
|
||||
|
||||
Frustrated by the lack of comprehensive options, I set out to build **WYGIWYH** — an opinionated yet powerful tool that I believe will resonate with like-minded users.
|
||||
Frustrated by the lack of comprehensive options, I set out to build **WYGIWYH**, an opinionated yet powerful tool that I believe will resonate with like-minded users.
|
||||
|
||||
# Key Features
|
||||
|
||||
@@ -126,6 +127,7 @@ To create the first user, open the container's console using Unraid's UI, by cli
|
||||
|
||||
| variable | type | default | explanation |
|
||||
|-------------------------------|-------------|-----------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| INTERNAL_PORT | int | 8000 | The port on which the app listens on. Defaults to 8000 if not set. |
|
||||
| DJANGO_ALLOWED_HOSTS | string | localhost 127.0.0.1 | A list of space separated domains and IPs representing the host/domain names that WYGIWYH site can serve. [Click here](https://docs.djangoproject.com/en/5.1/ref/settings/#allowed-hosts) for more details |
|
||||
| HTTPS_ENABLED | true\|false | false | Whether to use secure cookies. If this is set to true, the cookie will be marked as “secure”, which means browsers may ensure that the cookie is only sent under an HTTPS connection |
|
||||
| URL | string | http://localhost http://127.0.0.1 | A list of space separated domains and IPs (with the protocol) representing the trusted origins for unsafe requests (e.g. POST). [Click here](https://docs.djangoproject.com/en/5.1/ref/settings/#csrf-trusted-origins ) for more details |
|
||||
@@ -140,9 +142,38 @@ To create the first user, open the container's console using Unraid's UI, by cli
|
||||
| ENABLE_SOFT_DELETE | true\|false | false | Whether to enable transactions soft delete, if enabled, deleted transactions will remain in the database. Useful for imports and avoiding duplicate entries. |
|
||||
| KEEP_DELETED_TRANSACTIONS_FOR | int | 365 | Time in days to keep soft deleted transactions for. If 0, will keep all transactions indefinitely. Only works if ENABLE_SOFT_DELETE is true. |
|
||||
| TASK_WORKERS | int | 1 | How many workers to have for async tasks. One should be enough for most use cases |
|
||||
| DEMO | true\|false | false | If demo mode is enabled. |
|
||||
| ADMIN_EMAIL | string | None | Automatically creates an admin account with this email. Must have `ADMIN_PASSWORD` also set. |
|
||||
| ADMIN_PASSWORD | string | None | Automatically creates an admin account with this password. Must have `ADMIN_EMAIL` also set. |
|
||||
| DEMO | true\|false | false | If demo mode is enabled. |
|
||||
| ADMIN_EMAIL | string | None | Automatically creates an admin account with this email. Must have `ADMIN_PASSWORD` also set. |
|
||||
| ADMIN_PASSWORD | string | None | Automatically creates an admin account with this password. Must have `ADMIN_EMAIL` also set. |
|
||||
| CHECK_FOR_UPDATES | true\|false | true | Check and notify users about new versions. The check is done by doing a single query to Github's API every 12 hours. |
|
||||
| DJANGO_VITE_DEV_MODE | true\|false | false | Enables Vite dev server mode for frontend development. When true, assets are served from Vite's dev server instead of the build manifest. For development only! |
|
||||
| DJANGO_VITE_DEV_SERVER_PORT | int | 5173 | The port where Vite's dev server is running. Only used when DJANGO_VITE_DEV_MODE is true. For development only! |
|
||||
| DJANGO_VITE_DEV_SERVER_HOST | string | localhost | The host where Vite's dev server is running. Only used when DJANGO_VITE_DEV_MODE is true. For development only! |
|
||||
|
||||
## OIDC Configuration
|
||||
|
||||
WYGIWYH supports login via OpenID Connect (OIDC) through `django-allauth`. This allows users to authenticate using an external OIDC provider.
|
||||
|
||||
> [!NOTE]
|
||||
> Currently only OpenID Connect is supported as a provider, open an issue if you need something else.
|
||||
|
||||
To configure OIDC, you need to set the following environment variables:
|
||||
|
||||
| Variable | Description |
|
||||
|----------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `OIDC_CLIENT_NAME` | The name of the provider. will be displayed in the login page. Defaults to `OpenID Connect` |
|
||||
| `OIDC_CLIENT_ID` | The Client ID provided by your OIDC provider. |
|
||||
| `OIDC_CLIENT_SECRET` | The Client Secret provided by your OIDC provider. |
|
||||
| `OIDC_SERVER_URL` | The base URL of your OIDC provider's discovery document or authorization server (e.g., `https://your-provider.com/auth/realms/your-realm`). `django-allauth` will use this to discover the necessary endpoints (authorization, token, userinfo, etc.). |
|
||||
| `OIDC_ALLOW_SIGNUP` | Allow the automatic creation of inexistent accounts on a successfull authentication. Defaults to `true`. |
|
||||
|
||||
**Callback URL (Redirect URI):**
|
||||
|
||||
When configuring your OIDC provider, you will need to provide a callback URL (also known as a Redirect URI). For WYGIWYH, the default callback URL is:
|
||||
|
||||
`https://your.wygiwyh.domain/auth/oidc/<OIDC_CLIENT_NAME>/login/callback/`
|
||||
|
||||
Replace `https://your.wygiwyh.domain` with the actual URL where your WYGIWYH instance is accessible. And `<OIDC_CLIENT_NAME>` with the slugfied value set in OIDC_CLIENT_NAME or the default `openid-connect` if you haven't set this variable.
|
||||
|
||||
# How it works
|
||||
|
||||
@@ -156,6 +187,10 @@ Check out our [Wiki](https://github.com/eitchtee/WYGIWYH/wiki) for more informat
|
||||
> [!NOTE]
|
||||
> Login with your github account
|
||||
|
||||
# MCP Server
|
||||
|
||||
[IZIme07](https://github.com/IZIme07) has kindly created an MCP Server for WYGIWYH that you can self-host. [Check it out at MCP-WYGIWYH](https://github.com/ReNewator/MCP-WYGIWYH)!
|
||||
|
||||
# Caveats and Warnings
|
||||
|
||||
- I'm not an accountant, some terms and even calculations might be wrong. Make sure to open an issue if you see anything that could be improved.
|
||||
|
||||
@@ -11,9 +11,11 @@ https://docs.djangoproject.com/en/5.1/ref/settings/
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from django.utils.text import slugify
|
||||
|
||||
SITE_TITLE = "WYGIWYH"
|
||||
TITLE_SEPARATOR = "::"
|
||||
@@ -42,9 +44,10 @@ INSTALLED_APPS = [
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
"django.contrib.messages",
|
||||
"django.contrib.sites",
|
||||
"whitenoise.runserver_nostatic",
|
||||
"django.contrib.staticfiles",
|
||||
"webpack_boilerplate",
|
||||
"django_vite",
|
||||
"django.contrib.humanize",
|
||||
"django.contrib.postgres",
|
||||
"django_browser_reload",
|
||||
@@ -61,21 +64,28 @@ INSTALLED_APPS = [
|
||||
"apps.transactions.apps.TransactionsConfig",
|
||||
"apps.currencies.apps.CurrenciesConfig",
|
||||
"apps.accounts.apps.AccountsConfig",
|
||||
"apps.common.apps.CommonConfig",
|
||||
"apps.net_worth.apps.NetWorthConfig",
|
||||
"apps.import_app.apps.ImportConfig",
|
||||
"apps.export_app.apps.ExportConfig",
|
||||
"apps.api.apps.ApiConfig",
|
||||
"cachalot",
|
||||
"rest_framework",
|
||||
"rest_framework.authtoken",
|
||||
"drf_spectacular",
|
||||
"django_cotton",
|
||||
"apps.rules.apps.RulesConfig",
|
||||
"apps.calendar_view.apps.CalendarViewConfig",
|
||||
"apps.dca.apps.DcaConfig",
|
||||
"pwa",
|
||||
"allauth",
|
||||
"allauth.account",
|
||||
"allauth.socialaccount",
|
||||
"allauth.socialaccount.providers.openid_connect",
|
||||
"apps.common.apps.CommonConfig",
|
||||
]
|
||||
|
||||
SITE_ID = 1
|
||||
|
||||
MIDDLEWARE = [
|
||||
"django_browser_reload.middleware.BrowserReloadMiddleware",
|
||||
"apps.common.middleware.thread_local.ThreadLocalMiddleware",
|
||||
@@ -91,6 +101,7 @@ MIDDLEWARE = [
|
||||
"django.contrib.messages.middleware.MessageMiddleware",
|
||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||
"hijack.middleware.HijackUserMiddleware",
|
||||
"allauth.account.middleware.AccountMiddleware",
|
||||
]
|
||||
|
||||
ROOT_URLCONF = "WYGIWYH.urls"
|
||||
@@ -119,12 +130,23 @@ STORAGES = {
|
||||
|
||||
WHITENOISE_MANIFEST_STRICT = False
|
||||
|
||||
|
||||
def immutable_file_test(path, url):
|
||||
# Match vite (rollup)-generated hashes, à la, `some_file-CSliV9zW.js`
|
||||
return re.match(r"^.+[.-][0-9a-zA-Z_-]{8,12}\..+$", url)
|
||||
|
||||
|
||||
WHITENOISE_IMMUTABLE_FILE_TEST = immutable_file_test
|
||||
|
||||
WSGI_APPLICATION = "WYGIWYH.wsgi.application"
|
||||
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#databases
|
||||
|
||||
THREADS = int(os.getenv("GUNICORN_THREADS", 1))
|
||||
MAX_POOL_SIZE = THREADS + 1
|
||||
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
@@ -133,6 +155,17 @@ DATABASES = {
|
||||
"PASSWORD": os.getenv("SQL_PASSWORD", "password"),
|
||||
"HOST": os.getenv("SQL_HOST", "localhost"),
|
||||
"PORT": os.getenv("SQL_PORT", "5432"),
|
||||
"CONN_MAX_AGE": 0,
|
||||
"CONN_HEALTH_CHECKS": True,
|
||||
"OPTIONS": {
|
||||
"pool": {
|
||||
"min_size": 1,
|
||||
"max_size": MAX_POOL_SIZE,
|
||||
"timeout": 10,
|
||||
"max_lifetime": 600,
|
||||
"max_idle": 300,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +313,7 @@ STATIC_URL = "static/"
|
||||
STATIC_ROOT = BASE_DIR / "static_files"
|
||||
|
||||
STATICFILES_DIRS = [
|
||||
ROOT_DIR / "frontend/build",
|
||||
ROOT_DIR / "frontend" / "build",
|
||||
BASE_DIR / "static",
|
||||
]
|
||||
|
||||
@@ -296,9 +329,11 @@ CACHES = {
|
||||
}
|
||||
}
|
||||
|
||||
WEBPACK_LOADER = {
|
||||
"MANIFEST_FILE": ROOT_DIR / "frontend/build/manifest.json",
|
||||
}
|
||||
DJANGO_VITE_ASSETS_PATH = STATIC_ROOT
|
||||
DJANGO_VITE_MANIFEST_PATH = DJANGO_VITE_ASSETS_PATH / "manifest.json"
|
||||
DJANGO_VITE_DEV_MODE = os.getenv("DJANGO_VITE_DEV_MODE", "false").lower() == "true"
|
||||
DJANGO_VITE_DEV_SERVER_PORT = int(os.getenv("DJANGO_VITE_DEV_SERVER_PORT", "5173"))
|
||||
DJANGO_VITE_DEV_SERVER_HOST = os.getenv("DJANGO_VITE_DEV_SERVER_HOST", "localhost")
|
||||
|
||||
# Default primary key field type
|
||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field
|
||||
@@ -307,10 +342,49 @@ DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||
|
||||
LOGIN_REDIRECT_URL = "/"
|
||||
LOGIN_URL = "/login/"
|
||||
LOGOUT_REDIRECT_URL = "/login/"
|
||||
|
||||
# Allauth settings
|
||||
AUTHENTICATION_BACKENDS = [
|
||||
"django.contrib.auth.backends.ModelBackend", # Keep default
|
||||
"allauth.account.auth_backends.AuthenticationBackend",
|
||||
]
|
||||
|
||||
SOCIALACCOUNT_PROVIDERS = {"openid_connect": {"APPS": []}}
|
||||
|
||||
if (
|
||||
os.getenv("OIDC_CLIENT_ID")
|
||||
and os.getenv("OIDC_CLIENT_SECRET")
|
||||
and os.getenv("OIDC_SERVER_URL")
|
||||
):
|
||||
SOCIALACCOUNT_PROVIDERS["openid_connect"]["APPS"].append(
|
||||
{
|
||||
"provider_id": slugify(os.getenv("OIDC_CLIENT_NAME", "OpenID Connect")),
|
||||
"name": os.getenv("OIDC_CLIENT_NAME", "OpenID Connect"),
|
||||
"client_id": os.getenv("OIDC_CLIENT_ID"),
|
||||
"secret": os.getenv("OIDC_CLIENT_SECRET"),
|
||||
"settings": {
|
||||
"server_url": os.getenv("OIDC_SERVER_URL"),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
ACCOUNT_LOGIN_METHODS = {"email"}
|
||||
ACCOUNT_SIGNUP_FIELDS = ["email*", "password1*", "password2*"]
|
||||
ACCOUNT_USER_MODEL_USERNAME_FIELD = None
|
||||
ACCOUNT_EMAIL_VERIFICATION = "none"
|
||||
SOCIALACCOUNT_LOGIN_ON_GET = True
|
||||
SOCIALACCOUNT_ONLY = True
|
||||
SOCIALACCOUNT_AUTO_SIGNUP = os.getenv("OIDC_ALLOW_SIGNUP", "true").lower() == "true"
|
||||
ACCOUNT_ADAPTER = "allauth.account.adapter.DefaultAccountAdapter"
|
||||
SOCIALACCOUNT_ADAPTER = "allauth.socialaccount.adapter.DefaultSocialAccountAdapter"
|
||||
|
||||
# CRISPY FORMS
|
||||
CRISPY_ALLOWED_TEMPLATE_PACKS = ["bootstrap5", "crispy_forms/pure_text"]
|
||||
CRISPY_TEMPLATE_PACK = "bootstrap5"
|
||||
CRISPY_ALLOWED_TEMPLATE_PACKS = [
|
||||
"crispy_forms/pure_text",
|
||||
"crispy-daisyui",
|
||||
]
|
||||
CRISPY_TEMPLATE_PACK = "crispy-daisyui"
|
||||
|
||||
SESSION_EXPIRE_AT_BROWSER_CLOSE = False
|
||||
SESSION_COOKIE_AGE = int(os.getenv("SESSION_EXPIRY_TIME", 2678400)) # 31 days
|
||||
@@ -334,7 +408,7 @@ DEBUG_TOOLBAR_PANELS = [
|
||||
"debug_toolbar.panels.signals.SignalsPanel",
|
||||
"debug_toolbar.panels.redirects.RedirectsPanel",
|
||||
"debug_toolbar.panels.profiling.ProfilingPanel",
|
||||
"cachalot.panels.CachalotPanel",
|
||||
# "cachalot.panels.CachalotPanel",
|
||||
]
|
||||
INTERNAL_IPS = [
|
||||
"127.0.0.1",
|
||||
@@ -360,8 +434,16 @@ REST_FRAMEWORK = {
|
||||
"apps.api.permissions.NotInDemoMode",
|
||||
"rest_framework.permissions.DjangoModelPermissions",
|
||||
],
|
||||
"DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
|
||||
"PAGE_SIZE": 10,
|
||||
'DEFAULT_FILTER_BACKENDS': [
|
||||
'django_filters.rest_framework.DjangoFilterBackend',
|
||||
'rest_framework.filters.OrderingFilter',
|
||||
],
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': [
|
||||
'rest_framework.authentication.BasicAuthentication',
|
||||
'rest_framework.authentication.SessionAuthentication',
|
||||
'rest_framework.authentication.TokenAuthentication',
|
||||
],
|
||||
"DEFAULT_PAGINATION_CLASS": "apps.api.custom.pagination.CustomPageNumberPagination",
|
||||
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
|
||||
}
|
||||
|
||||
@@ -442,6 +524,8 @@ else:
|
||||
|
||||
CACHALOT_UNCACHABLE_TABLES = ("django_migrations", "procrastinate_jobs")
|
||||
|
||||
# Procrastinate
|
||||
PROCRASTINATE_ON_APP_READY = "apps.common.procrastinate.on_app_ready"
|
||||
|
||||
# PWA
|
||||
PWA_APP_NAME = SITE_TITLE
|
||||
@@ -490,6 +574,7 @@ PWA_APP_SCREENSHOTS = [
|
||||
PWA_SERVICE_WORKER_PATH = BASE_DIR / "templates" / "pwa" / "serviceworker.js"
|
||||
|
||||
ENABLE_SOFT_DELETE = os.getenv("ENABLE_SOFT_DELETE", "false").lower() == "true"
|
||||
CHECK_FOR_UPDATES = os.getenv("CHECK_FOR_UPDATES", "true").lower() == "true"
|
||||
KEEP_DELETED_TRANSACTIONS_FOR = int(os.getenv("KEEP_DELETED_ENTRIES_FOR", "365"))
|
||||
APP_VERSION = os.getenv("APP_VERSION", "unknown")
|
||||
DEMO = os.getenv("DEMO", "false").lower() == "true"
|
||||
|
||||
@@ -21,6 +21,8 @@ from drf_spectacular.views import (
|
||||
SpectacularAPIView,
|
||||
SpectacularSwaggerView,
|
||||
)
|
||||
from allauth.socialaccount.providers.openid_connect.views import login, callback
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path("admin/", admin.site.urls),
|
||||
@@ -36,6 +38,13 @@ urlpatterns = [
|
||||
SpectacularSwaggerView.as_view(url_name="schema"),
|
||||
name="swagger-ui",
|
||||
),
|
||||
path("auth/", include("allauth.urls")), # allauth urls
|
||||
# path("auth/oidc/<str:provider_id>/login/", login, name="openid_connect_login"),
|
||||
# path(
|
||||
# "auth/oidc/<str:provider_id>/login/callback/",
|
||||
# callback,
|
||||
# name="openid_connect_callback",
|
||||
# ),
|
||||
path("", include("apps.transactions.urls")),
|
||||
path("", include("apps.common.urls")),
|
||||
path("", include("apps.users.urls")),
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
from crispy_bootstrap5.bootstrap5 import Switch
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.common.fields.forms.dynamic_select import (
|
||||
DynamicModelChoiceField,
|
||||
DynamicModelMultipleChoiceField,
|
||||
)
|
||||
from apps.common.widgets.crispy.daisyui import Switch
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, TransactionTag
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Column, Row
|
||||
from crispy_forms.layout import Column, Field, Layout, Row
|
||||
from django import forms
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.accounts.models import AccountGroup
|
||||
from apps.common.fields.forms.dynamic_select import (
|
||||
DynamicModelMultipleChoiceField,
|
||||
DynamicModelChoiceField,
|
||||
)
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.transactions.models import TransactionCategory, TransactionTag
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
|
||||
|
||||
class AccountGroupForm(forms.ModelForm):
|
||||
class Meta:
|
||||
@@ -36,17 +36,13 @@ class AccountGroupForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -79,6 +75,18 @@ class AccountForm(forms.ModelForm):
|
||||
|
||||
self.fields["group"].queryset = AccountGroup.objects.all()
|
||||
|
||||
if self.instance.id:
|
||||
qs = Currency.objects.filter(
|
||||
Q(is_archived=False) | Q(accounts=self.instance.id)
|
||||
).distinct()
|
||||
self.fields["currency"].queryset = qs
|
||||
self.fields["exchange_currency"].queryset = qs
|
||||
|
||||
else:
|
||||
qs = Currency.objects.filter(Q(is_archived=False))
|
||||
self.fields["currency"].queryset = qs
|
||||
self.fields["exchange_currency"].queryset = qs
|
||||
|
||||
self.helper = FormHelper()
|
||||
self.helper.form_tag = False
|
||||
self.helper.form_method = "post"
|
||||
@@ -94,17 +102,13 @@ class AccountForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -142,9 +146,8 @@ class AccountBalanceForm(forms.Form):
|
||||
self.helper.layout = Layout(
|
||||
"new_balance",
|
||||
Row(
|
||||
Column("category", css_class="form-group col-md-6 mb-0"),
|
||||
Column("tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("category"),
|
||||
Column("tags"),
|
||||
),
|
||||
Field("account_id"),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
# Generated by Django 5.2.4 on 2025-07-28 02:15
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0014_alter_account_options_alter_accountgroup_options'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='account',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(class)s_owned', to=settings.AUTH_USER_MODEL, verbose_name='Owner'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='account',
|
||||
name='shared_with',
|
||||
field=models.ManyToManyField(blank=True, related_name='%(class)s_shared', to=settings.AUTH_USER_MODEL, verbose_name='Shared with users'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='account',
|
||||
name='visibility',
|
||||
field=models.CharField(choices=[('private', 'Private'), ('public', 'Public')], default='private', max_length=10, verbose_name='Visibility'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='accountgroup',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(class)s_owned', to=settings.AUTH_USER_MODEL, verbose_name='Owner'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='accountgroup',
|
||||
name='shared_with',
|
||||
field=models.ManyToManyField(blank=True, related_name='%(class)s_shared', to=settings.AUTH_USER_MODEL, verbose_name='Shared with users'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='accountgroup',
|
||||
name='visibility',
|
||||
field=models.CharField(choices=[('private', 'Private'), ('public', 'Public')], default='private', max_length=10, verbose_name='Visibility'),
|
||||
),
|
||||
]
|
||||
20
app/apps/accounts/migrations/0016_account_untracked_by.py
Normal file
20
app/apps/accounts/migrations/0016_account_untracked_by.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Generated by Django 5.2.4 on 2025-08-09 05:52
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('accounts', '0015_alter_account_owner_alter_account_shared_with_and_more'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='account',
|
||||
name='untracked_by',
|
||||
field=models.ManyToManyField(blank=True, related_name='untracked_accounts', to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
]
|
||||
@@ -1,11 +1,11 @@
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.transactions.models import Transaction
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
from apps.common.models import SharedObject, SharedObjectManager
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
class AccountGroup(SharedObject):
|
||||
@@ -62,6 +62,11 @@ class Account(SharedObject):
|
||||
verbose_name=_("Archived"),
|
||||
help_text=_("Archived accounts don't show up nor count towards your net worth"),
|
||||
)
|
||||
untracked_by = models.ManyToManyField(
|
||||
settings.AUTH_USER_MODEL,
|
||||
blank=True,
|
||||
related_name="untracked_accounts",
|
||||
)
|
||||
|
||||
objects = SharedObjectManager()
|
||||
all_objects = models.Manager() # Unfiltered manager
|
||||
@@ -75,6 +80,10 @@ class Account(SharedObject):
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def is_untracked_by(self):
|
||||
user = get_current_user()
|
||||
return self.untracked_by.filter(pk=user.pk).exists()
|
||||
|
||||
def clean(self):
|
||||
super().clean()
|
||||
if self.exchange_currency == self.currency:
|
||||
|
||||
33
app/apps/accounts/services.py
Normal file
33
app/apps/accounts/services.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
def get_account_balance(account: Account, paid_only: bool = True) -> Decimal:
|
||||
"""
|
||||
Calculate account balance (income - expense).
|
||||
|
||||
Args:
|
||||
account: Account instance to calculate balance for.
|
||||
paid_only: If True, only count paid transactions (current balance).
|
||||
If False, count all transactions (projected balance).
|
||||
|
||||
Returns:
|
||||
Decimal: The calculated balance (income - expense).
|
||||
"""
|
||||
filters = {"account": account}
|
||||
if paid_only:
|
||||
filters["is_paid"] = True
|
||||
|
||||
income = Transaction.objects.filter(
|
||||
type=Transaction.Type.INCOME, **filters
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
expense = Transaction.objects.filter(
|
||||
type=Transaction.Type.EXPENSE, **filters
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
return income - expense
|
||||
@@ -1,33 +1,21 @@
|
||||
from django.test import TestCase, Client
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.contrib.auth.models import User
|
||||
from django.db import IntegrityError, models
|
||||
from django.utils import timezone
|
||||
from django.urls import reverse
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.accounts.forms import AccountForm
|
||||
from apps.transactions.models import Transaction, TransactionCategory
|
||||
|
||||
|
||||
class AccountTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
self.owner1 = User.objects.create_user(username='testowner', password='password123')
|
||||
self.client = Client()
|
||||
self.client.login(username='testowner', password='password123')
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.eur = Currency.objects.create(
|
||||
self.exchange_currency = Currency.objects.create(
|
||||
code="EUR", name="Euro", decimal_places=2, prefix="€ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group", owner=self.owner1)
|
||||
self.reconciliation_category = TransactionCategory.objects.create(name='Reconciliation', owner=self.owner1, type='INFO')
|
||||
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
|
||||
def test_account_creation(self):
|
||||
"""Test basic account creation"""
|
||||
@@ -49,262 +37,139 @@ class AccountTests(TestCase):
|
||||
"""Test account creation with exchange currency"""
|
||||
account = Account.objects.create(
|
||||
name="Exchange Account",
|
||||
owner=self.owner1, # Added owner
|
||||
group=self.account_group, # Added group
|
||||
currency=self.currency,
|
||||
exchange_currency=self.eur, # Changed to self.eur
|
||||
exchange_currency=self.exchange_currency,
|
||||
)
|
||||
self.assertEqual(account.exchange_currency, self.eur) # Changed to self.eur
|
||||
|
||||
def test_account_archiving(self):
|
||||
"""Test archiving and unarchiving an account"""
|
||||
account = Account.objects.create(
|
||||
name="Archivable Account",
|
||||
owner=self.owner1, # Added owner
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
is_asset=True, # Assuming default, can be anything for this test
|
||||
is_archived=False,
|
||||
)
|
||||
self.assertFalse(account.is_archived, "Account should initially be unarchived")
|
||||
|
||||
# Archive the account
|
||||
account.is_archived = True
|
||||
account.save()
|
||||
|
||||
archived_account = Account.objects.get(pk=account.pk)
|
||||
self.assertTrue(archived_account.is_archived, "Account should be archived")
|
||||
|
||||
# Unarchive the account
|
||||
archived_account.is_archived = False
|
||||
archived_account.save()
|
||||
|
||||
unarchived_account = Account.objects.get(pk=account.pk)
|
||||
self.assertFalse(unarchived_account.is_archived, "Account should be unarchived")
|
||||
|
||||
def test_account_exchange_currency_cannot_be_same_as_currency(self):
|
||||
"""Test that exchange_currency cannot be the same as currency."""
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
account = Account(
|
||||
name="Same Currency Account",
|
||||
owner=self.owner1, # Added owner
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
exchange_currency=self.currency, # Same as currency
|
||||
)
|
||||
account.full_clean()
|
||||
self.assertIn('exchange_currency', cm.exception.error_dict)
|
||||
# To check for a specific message (optional, might make test brittle):
|
||||
# self.assertTrue(any("cannot be the same as the main currency" in e.message
|
||||
# for e in cm.exception.error_dict['exchange_currency']))
|
||||
|
||||
def test_account_name_unique_per_owner(self):
|
||||
"""Test that account name is unique per owner."""
|
||||
owner1 = User.objects.create_user(username='owner1', password='password123')
|
||||
owner2 = User.objects.create_user(username='owner2', password='password123')
|
||||
|
||||
# Initial account for self.owner1 (owner1 from setUp)
|
||||
Account.objects.create(
|
||||
name="Unique Name Test",
|
||||
owner=self.owner1, # Changed to self.owner1
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
)
|
||||
|
||||
# Attempt to create another account with the same name and self.owner1 - should fail
|
||||
with self.assertRaises(IntegrityError):
|
||||
Account.objects.create(
|
||||
name="Unique Name Test",
|
||||
owner=self.owner1, # Changed to self.owner1
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
)
|
||||
|
||||
# Create account with the same name but for owner2 - should succeed
|
||||
try:
|
||||
Account.objects.create(
|
||||
name="Unique Name Test",
|
||||
owner=owner2, # owner2 is locally defined here, that's fine for this test
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
)
|
||||
except IntegrityError:
|
||||
self.fail("Creating account with same name but different owner failed unexpectedly.")
|
||||
|
||||
# Create account with a different name for self.owner1 - should succeed
|
||||
try:
|
||||
Account.objects.create(
|
||||
name="Another Name Test",
|
||||
owner=self.owner1, # Changed to self.owner1
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
)
|
||||
except IntegrityError:
|
||||
self.fail("Creating account with different name for the same owner failed unexpectedly.")
|
||||
|
||||
def test_account_form_valid_data(self):
|
||||
"""Test AccountForm with valid data."""
|
||||
form_data = {
|
||||
'name': 'Form Test Account',
|
||||
'group': self.account_group.pk,
|
||||
'currency': self.currency.pk,
|
||||
'exchange_currency': self.eur.pk,
|
||||
'is_asset': True,
|
||||
'is_archived': False,
|
||||
'description': 'A valid test account from form.'
|
||||
}
|
||||
form = AccountForm(data=form_data)
|
||||
self.assertTrue(form.is_valid(), form.errors.as_text())
|
||||
|
||||
account = form.save(commit=False)
|
||||
account.owner = self.owner1
|
||||
account.save()
|
||||
|
||||
self.assertEqual(account.name, 'Form Test Account')
|
||||
self.assertEqual(account.owner, self.owner1)
|
||||
self.assertEqual(account.group, self.account_group)
|
||||
self.assertEqual(account.currency, self.currency)
|
||||
self.assertEqual(account.exchange_currency, self.eur)
|
||||
self.assertTrue(account.is_asset)
|
||||
self.assertFalse(account.is_archived)
|
||||
|
||||
def test_account_form_missing_name(self):
|
||||
"""Test AccountForm with missing name."""
|
||||
form_data = {
|
||||
'group': self.account_group.pk,
|
||||
'currency': self.currency.pk,
|
||||
}
|
||||
form = AccountForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('name', form.errors)
|
||||
|
||||
def test_account_form_exchange_currency_same_as_currency(self):
|
||||
"""Test AccountForm where exchange_currency is the same as currency."""
|
||||
form_data = {
|
||||
'name': 'Same Currency Form Account',
|
||||
'group': self.account_group.pk,
|
||||
'currency': self.currency.pk,
|
||||
'exchange_currency': self.currency.pk, # Same as currency
|
||||
}
|
||||
form = AccountForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('exchange_currency', form.errors)
|
||||
self.assertEqual(account.exchange_currency, self.exchange_currency)
|
||||
|
||||
|
||||
class AccountGroupTests(TestCase):
|
||||
class GetAccountBalanceServiceTests(TestCase):
|
||||
"""Tests for the get_account_balance service function"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data for AccountGroup tests."""
|
||||
self.owner1 = User.objects.create_user(username='groupowner1', password='password123')
|
||||
self.owner2 = User.objects.create_user(username='groupowner2', password='password123')
|
||||
|
||||
def test_account_group_creation(self):
|
||||
"""Test basic AccountGroup creation."""
|
||||
group = AccountGroup.objects.create(name="Test Group", owner=self.owner1)
|
||||
self.assertEqual(group.name, "Test Group")
|
||||
self.assertEqual(group.owner, self.owner1)
|
||||
self.assertEqual(str(group), "Test Group") # Assuming __str__ returns the name
|
||||
|
||||
def test_account_group_name_unique_per_owner(self):
|
||||
"""Test that AccountGroup name is unique per owner."""
|
||||
# Initial group for owner1
|
||||
AccountGroup.objects.create(name="Unique Group Name", owner=self.owner1)
|
||||
|
||||
# Attempt to create another group with the same name and owner1 - should fail
|
||||
with self.assertRaises(IntegrityError):
|
||||
AccountGroup.objects.create(name="Unique Group Name", owner=self.owner1)
|
||||
|
||||
# Create group with the same name but for owner2 - should succeed
|
||||
try:
|
||||
AccountGroup.objects.create(name="Unique Group Name", owner=self.owner2)
|
||||
except IntegrityError:
|
||||
self.fail("Creating group with same name but different owner failed unexpectedly.")
|
||||
|
||||
# Create group with a different name for owner1 - should succeed
|
||||
try:
|
||||
AccountGroup.objects.create(name="Another Group Name", owner=self.owner1)
|
||||
except IntegrityError:
|
||||
self.fail("Creating group with different name for the same owner failed unexpectedly.")
|
||||
|
||||
def test_account_reconciliation_creates_transaction(self):
|
||||
"""Test that account_reconciliation view creates a transaction for the difference."""
|
||||
|
||||
# Helper function to get balance
|
||||
def get_balance(account):
|
||||
balance = account.transactions.filter(is_paid=True).aggregate(
|
||||
total_income=models.Sum('amount', filter=models.Q(type=Transaction.Type.INCOME)),
|
||||
total_expense=models.Sum('amount', filter=models.Q(type=Transaction.Type.EXPENSE)),
|
||||
total_transfer_in=models.Sum('amount', filter=models.Q(type=Transaction.Type.TRANSFER, transfer_to_account=account)),
|
||||
total_transfer_out=models.Sum('amount', filter=models.Q(type=Transaction.Type.TRANSFER, account=account))
|
||||
)['total_income'] or Decimal('0.00')
|
||||
balance -= account.transactions.filter(is_paid=True).aggregate(
|
||||
total_expense=models.Sum('amount', filter=models.Q(type=Transaction.Type.EXPENSE))
|
||||
)['total_expense'] or Decimal('0.00')
|
||||
# For transfers, a more complete logic might be needed if transfers are involved in reconciliation scope
|
||||
return balance
|
||||
|
||||
account_usd = Account.objects.create(
|
||||
name="USD Account for Recon",
|
||||
owner=self.owner1,
|
||||
currency=self.currency,
|
||||
group=self.account_group
|
||||
"""Set up test data"""
|
||||
from apps.transactions.models import Transaction
|
||||
self.Transaction = Transaction
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="BRL", name="Brazilian Real", decimal_places=2, prefix="R$ "
|
||||
)
|
||||
account_eur = Account.objects.create(
|
||||
name="EUR Account for Recon",
|
||||
owner=self.owner1,
|
||||
currency=self.eur,
|
||||
group=self.account_group
|
||||
self.account_group = AccountGroup.objects.create(name="Service Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Service Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
|
||||
# Initial transactions
|
||||
Transaction.objects.create(account=account_usd, type=Transaction.Type.INCOME, amount=Decimal('100.00'), date=timezone.localdate(timezone.now()), description='Initial USD', category=self.reconciliation_category, owner=self.owner1, is_paid=True)
|
||||
Transaction.objects.create(account=account_eur, type=Transaction.Type.INCOME, amount=Decimal('200.00'), date=timezone.localdate(timezone.now()), description='Initial EUR', category=self.reconciliation_category, owner=self.owner1, is_paid=True)
|
||||
Transaction.objects.create(account=account_eur, type=Transaction.Type.EXPENSE, amount=Decimal('50.00'), date=timezone.localdate(timezone.now()), description='EUR Expense', category=self.reconciliation_category, owner=self.owner1, is_paid=True)
|
||||
def test_balance_with_no_transactions(self):
|
||||
"""Test balance is 0 when no transactions exist"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
balance = get_account_balance(self.account, paid_only=True)
|
||||
self.assertEqual(balance, Decimal("0"))
|
||||
|
||||
initial_usd_balance = get_balance(account_usd) # Should be 100.00
|
||||
initial_eur_balance = get_balance(account_eur) # Should be 150.00
|
||||
self.assertEqual(initial_usd_balance, Decimal('100.00'))
|
||||
self.assertEqual(initial_eur_balance, Decimal('150.00'))
|
||||
|
||||
initial_transaction_count = Transaction.objects.filter(owner=self.owner1).count() # Should be 3
|
||||
|
||||
formset_data = {
|
||||
'form-TOTAL_FORMS': '2',
|
||||
'form-INITIAL_FORMS': '2', # Based on view logic, it builds initial data for all accounts
|
||||
'form-MAX_NUM_FORMS': '', # Can be empty or a number >= TOTAL_FORMS
|
||||
'form-0-account_id': account_usd.id,
|
||||
'form-0-new_balance': '120.00', # New balance for USD account (implies +20 adjustment)
|
||||
'form-0-category': self.reconciliation_category.id,
|
||||
'form-1-account_id': account_eur.id,
|
||||
'form-1-new_balance': '150.00', # Same as current balance for EUR account (no adjustment)
|
||||
'form-1-category': self.reconciliation_category.id,
|
||||
}
|
||||
|
||||
response = self.client.post(
|
||||
reverse('accounts:account_reconciliation'),
|
||||
data=formset_data,
|
||||
HTTP_HX_REQUEST='true' # Required if view uses @only_htmx
|
||||
def test_current_balance_only_counts_paid(self):
|
||||
"""Test current balance only counts paid transactions"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
# Paid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
# Unpaid income (should not count)
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid income",
|
||||
)
|
||||
# Paid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("30.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid expense",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 204, response.content.decode()) # 204 No Content for successful HTMX POST
|
||||
balance = get_account_balance(self.account, paid_only=True)
|
||||
self.assertEqual(balance, Decimal("70.00")) # 100 - 30
|
||||
|
||||
# Check that only one new transaction was created
|
||||
self.assertEqual(Transaction.objects.filter(owner=self.owner1).count(), initial_transaction_count + 1)
|
||||
def test_projected_balance_counts_all(self):
|
||||
"""Test projected balance counts all transactions"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
# Paid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
# Unpaid income
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid income",
|
||||
)
|
||||
# Paid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("30.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid expense",
|
||||
)
|
||||
# Unpaid expense
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.EXPENSE,
|
||||
amount=Decimal("20.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid expense",
|
||||
)
|
||||
|
||||
# Get the newly created transaction
|
||||
new_transaction = Transaction.objects.filter(
|
||||
account=account_usd,
|
||||
description="Balance reconciliation"
|
||||
).first()
|
||||
balance = get_account_balance(self.account, paid_only=False)
|
||||
self.assertEqual(balance, Decimal("100.00")) # (100 + 50) - (30 + 20)
|
||||
|
||||
self.assertIsNotNone(new_transaction)
|
||||
self.assertEqual(new_transaction.type, Transaction.Type.INCOME)
|
||||
self.assertEqual(new_transaction.amount, Decimal('20.00'))
|
||||
self.assertEqual(new_transaction.category, self.reconciliation_category)
|
||||
self.assertEqual(new_transaction.owner, self.owner1)
|
||||
self.assertTrue(new_transaction.is_paid)
|
||||
self.assertEqual(new_transaction.date, timezone.localdate(timezone.now()))
|
||||
def test_balance_defaults_to_paid_only(self):
|
||||
"""Test that paid_only defaults to True"""
|
||||
from apps.accounts.services import get_account_balance
|
||||
from decimal import Decimal
|
||||
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid",
|
||||
)
|
||||
self.Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=self.Transaction.Type.INCOME,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 1),
|
||||
description="Unpaid",
|
||||
)
|
||||
|
||||
balance = get_account_balance(self.account) # defaults to paid_only=True
|
||||
self.assertEqual(balance, Decimal("100.00"))
|
||||
|
||||
# Verify final balances
|
||||
self.assertEqual(get_balance(account_usd), Decimal('120.00'))
|
||||
self.assertEqual(get_balance(account_eur), Decimal('150.00'))
|
||||
|
||||
@@ -31,6 +31,11 @@ urlpatterns = [
|
||||
views.account_take_ownership,
|
||||
name="account_take_ownership",
|
||||
),
|
||||
path(
|
||||
"account/<int:pk>/toggle-untracked/",
|
||||
views.account_toggle_untracked,
|
||||
name="account_toggle_untracked",
|
||||
),
|
||||
path("account-groups/", views.account_groups_index, name="account_groups_index"),
|
||||
path("account-groups/list/", views.account_groups_list, name="account_groups_list"),
|
||||
path("account-groups/add/", views.account_group_add, name="account_group_add"),
|
||||
|
||||
@@ -25,7 +25,7 @@ def account_groups_index(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def account_groups_list(request):
|
||||
account_groups = AccountGroup.objects.all().order_by("id")
|
||||
account_groups = AccountGroup.objects.all().order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"account_groups/fragments/list.html",
|
||||
|
||||
@@ -25,7 +25,7 @@ def accounts_index(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def accounts_list(request):
|
||||
accounts = Account.objects.all().order_by("id")
|
||||
accounts = Account.objects.all().order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"accounts/fragments/list.html",
|
||||
@@ -155,6 +155,26 @@ def account_delete(request, pk):
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def account_toggle_untracked(request, pk):
|
||||
account = get_object_or_404(Account, id=pk)
|
||||
if account.is_untracked_by():
|
||||
account.untracked_by.remove(request.user)
|
||||
messages.success(request, _("Account is now tracked"))
|
||||
else:
|
||||
account.untracked_by.add(request.user)
|
||||
messages.success(request, _("Account is now untracked"))
|
||||
|
||||
return HttpResponse(
|
||||
status=204,
|
||||
headers={
|
||||
"HX-Trigger": "updated",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
|
||||
@@ -11,23 +11,13 @@ from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.forms import AccountBalanceFormSet
|
||||
from apps.accounts.models import Account, Transaction
|
||||
from apps.accounts.services import get_account_balance
|
||||
from apps.common.decorators.htmx import only_htmx
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
def account_reconciliation(request):
|
||||
def get_account_balance(account):
|
||||
income = Transaction.objects.filter(
|
||||
account=account, type=Transaction.Type.INCOME, is_paid=True
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
expense = Transaction.objects.filter(
|
||||
account=account, type=Transaction.Type.EXPENSE, is_paid=True
|
||||
).aggregate(total=models.Sum("amount"))["total"] or Decimal("0")
|
||||
|
||||
return income - expense
|
||||
|
||||
initial_data = [
|
||||
{
|
||||
"account_id": account.id,
|
||||
|
||||
@@ -10,15 +10,19 @@ from apps.transactions.models import (
|
||||
|
||||
@extend_schema_field(
|
||||
{
|
||||
"oneOf": [{"type": "string"}, {"type": "integer"}],
|
||||
"description": "TransactionCategory ID or name. If the name doesn't exist, a new one will be created",
|
||||
"oneOf": [{"type": "string"}, {"type": "integer"}, {"type": "null"}],
|
||||
"description": "TransactionCategory ID or name. If the name doesn't exist, a new one will be created. Can be null if no category is assigned.",
|
||||
}
|
||||
)
|
||||
class TransactionCategoryField(serializers.Field):
|
||||
def to_representation(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
return {"id": value.id, "name": value.name}
|
||||
|
||||
def to_internal_value(self, data):
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, int):
|
||||
try:
|
||||
return TransactionCategory.objects.get(pk=data)
|
||||
|
||||
@@ -2,3 +2,5 @@ from .transactions import *
|
||||
from .accounts import *
|
||||
from .currencies import *
|
||||
from .dca import *
|
||||
from .imports import *
|
||||
|
||||
|
||||
@@ -67,3 +67,12 @@ class AccountSerializer(serializers.ModelSerializer):
|
||||
setattr(instance, attr, value)
|
||||
instance.save()
|
||||
return instance
|
||||
|
||||
|
||||
class AccountBalanceSerializer(serializers.Serializer):
|
||||
"""Serializer for account balance response."""
|
||||
|
||||
current_balance = serializers.DecimalField(max_digits=20, decimal_places=10)
|
||||
projected_balance = serializers.DecimalField(max_digits=20, decimal_places=10)
|
||||
currency = CurrencySerializer()
|
||||
|
||||
|
||||
41
app/apps/api/serializers/imports.py
Normal file
41
app/apps/api/serializers/imports.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
|
||||
|
||||
class ImportProfileSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for listing import profiles."""
|
||||
|
||||
class Meta:
|
||||
model = ImportProfile
|
||||
fields = ["id", "name", "version", "yaml_config"]
|
||||
|
||||
|
||||
class ImportRunSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for listing import runs."""
|
||||
|
||||
class Meta:
|
||||
model = ImportRun
|
||||
fields = [
|
||||
"id",
|
||||
"status",
|
||||
"profile",
|
||||
"file_name",
|
||||
"logs",
|
||||
"processed_rows",
|
||||
"total_rows",
|
||||
"successful_rows",
|
||||
"skipped_rows",
|
||||
"failed_rows",
|
||||
"started_at",
|
||||
"finished_at",
|
||||
]
|
||||
|
||||
|
||||
class ImportFileSerializer(serializers.Serializer):
|
||||
"""Serializer for uploading a file to import using an existing profile."""
|
||||
|
||||
profile_id = serializers.PrimaryKeyRelatedField(
|
||||
queryset=ImportProfile.objects.all(), source="profile"
|
||||
)
|
||||
file = serializers.FileField()
|
||||
@@ -138,6 +138,7 @@ class RecurringTransactionSerializer(serializers.ModelSerializer):
|
||||
def update(self, instance, validated_data):
|
||||
instance = super().update(instance, validated_data)
|
||||
instance.update_unpaid_transactions()
|
||||
instance.generate_upcoming_transactions()
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import User
|
||||
from rest_framework.test import APIClient
|
||||
from django.urls import reverse
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup # Added AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, Transaction
|
||||
from apps.rules.signals import transaction_created # Assuming this is the correct path
|
||||
|
||||
# Default page size for pagination, adjust if your project's default is different
|
||||
DEFAULT_PAGE_SIZE = 10
|
||||
|
||||
class APITestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username='testuser', email='test@example.com', password='testpassword')
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.currency = Currency.objects.create(code="USD", name="US Dollar Test API", decimal_places=2)
|
||||
# Account model requires an AccountGroup
|
||||
self.account_group = AccountGroup.objects.create(name="API Test Group", owner=self.user)
|
||||
self.account = Account.objects.create(
|
||||
name="Test API Account",
|
||||
currency=self.currency,
|
||||
owner=self.user,
|
||||
group=self.account_group
|
||||
)
|
||||
self.category = TransactionCategory.objects.create(
|
||||
name="Test API Category",
|
||||
owner=self.user,
|
||||
type=TransactionCategory.TransactionType.EXPENSE # Default type, can be adjusted
|
||||
)
|
||||
# Remove the example test if it's no longer needed or update it
|
||||
# self.assertEqual(1 + 1, 2) # from test_example
|
||||
|
||||
def test_transactions_endpoint_authenticated_user(self):
|
||||
# User and client are now set up in self.setUp
|
||||
url = reverse('api:transaction-list') # Using 'api:' namespace
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@patch('apps.rules.signals.transaction_created.send')
|
||||
def test_create_transaction_api_success(self, mock_signal_send):
|
||||
url = reverse('api:transaction-list')
|
||||
data = {
|
||||
'account': self.account.pk, # Changed from account_id to account to match typical DRF serializer field names
|
||||
'type': Transaction.Type.EXPENSE.value, # Use enum value
|
||||
'date': date(2023, 1, 15).isoformat(),
|
||||
'amount': '123.45',
|
||||
'description': 'API Test Expense',
|
||||
'category': self.category.pk,
|
||||
'tags': [],
|
||||
'entities': []
|
||||
}
|
||||
|
||||
initial_transaction_count = Transaction.objects.count()
|
||||
response = self.client.post(url, data, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, 201, response.data) # Print response.data on failure
|
||||
self.assertEqual(Transaction.objects.count(), initial_transaction_count + 1)
|
||||
|
||||
created_transaction = Transaction.objects.latest('id') # Get the latest transaction
|
||||
|
||||
self.assertEqual(created_transaction.description, 'API Test Expense')
|
||||
self.assertEqual(created_transaction.amount, Decimal('123.45'))
|
||||
self.assertEqual(created_transaction.owner, self.user)
|
||||
self.assertEqual(created_transaction.account, self.account)
|
||||
self.assertEqual(created_transaction.category, self.category)
|
||||
|
||||
mock_signal_send.assert_called_once()
|
||||
# Check sender argument of the signal call
|
||||
self.assertEqual(mock_signal_send.call_args.kwargs['sender'], Transaction)
|
||||
self.assertEqual(mock_signal_send.call_args.kwargs['instance'], created_transaction)
|
||||
|
||||
|
||||
def test_create_transaction_api_invalid_data(self):
|
||||
url = reverse('api:transaction-list')
|
||||
data = {
|
||||
'account': self.account.pk,
|
||||
'type': 'INVALID_TYPE', # Invalid type
|
||||
'date': date(2023, 1, 15).isoformat(),
|
||||
'amount': 'not_a_number', # Invalid amount
|
||||
'description': 'API Test Invalid Data',
|
||||
'category': self.category.pk
|
||||
}
|
||||
response = self.client.post(url, data, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn('type', response.data)
|
||||
self.assertIn('amount', response.data)
|
||||
|
||||
def test_transaction_list_pagination(self):
|
||||
# Create more transactions than page size (e.g., DEFAULT_PAGE_SIZE + 5)
|
||||
num_to_create = DEFAULT_PAGE_SIZE + 5
|
||||
for i in range(num_to_create):
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=date(2023, 1, 1) + timedelta(days=i),
|
||||
amount=Decimal(f"{10 + i}.00"),
|
||||
description=f"Pag Test Transaction {i+1}",
|
||||
owner=self.user,
|
||||
category=self.category
|
||||
)
|
||||
|
||||
url = reverse('api:transaction-list')
|
||||
response = self.client.get(url)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn('count', response.data)
|
||||
self.assertEqual(response.data['count'], num_to_create)
|
||||
|
||||
self.assertIn('next', response.data)
|
||||
self.assertIsNotNone(response.data['next']) # Assuming count > page size
|
||||
|
||||
self.assertIn('previous', response.data) # Will be None for the first page
|
||||
# self.assertIsNone(response.data['previous']) # For the first page
|
||||
|
||||
self.assertIn('results', response.data)
|
||||
self.assertEqual(len(response.data['results']), DEFAULT_PAGE_SIZE)
|
||||
5
app/apps/api/tests/__init__.py
Normal file
5
app/apps/api/tests/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Import all test classes for Django test discovery
|
||||
from .test_imports import *
|
||||
from .test_accounts import *
|
||||
from .test_data_isolation import *
|
||||
from .test_shared_access import *
|
||||
99
app/apps/api/tests/test_accounts.py
Normal file
99
app/apps/api/tests/test_accounts.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class AccountBalanceAPITests(TestCase):
|
||||
"""Tests for the Account Balance API endpoint"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
|
||||
# Create some transactions
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("500.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Paid income",
|
||||
)
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("200.00"),
|
||||
is_paid=False,
|
||||
date=date(2025, 1, 15),
|
||||
description="Unpaid income",
|
||||
)
|
||||
Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 10),
|
||||
description="Paid expense",
|
||||
)
|
||||
|
||||
def test_get_balance_success(self):
|
||||
"""Test successful balance retrieval"""
|
||||
response = self.client.get(f"/api/accounts/{self.account.id}/balance/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("current_balance", response.data)
|
||||
self.assertIn("projected_balance", response.data)
|
||||
self.assertIn("currency", response.data)
|
||||
|
||||
# Current: 500 - 100 = 400
|
||||
self.assertEqual(Decimal(response.data["current_balance"]), Decimal("400.00"))
|
||||
# Projected: (500 + 200) - 100 = 600
|
||||
self.assertEqual(Decimal(response.data["projected_balance"]), Decimal("600.00"))
|
||||
|
||||
# Check currency data
|
||||
self.assertEqual(response.data["currency"]["code"], "USD")
|
||||
|
||||
def test_get_balance_nonexistent_account(self):
|
||||
"""Test balance for non-existent account returns 404"""
|
||||
response = self.client.get("/api/accounts/99999/balance/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_get_balance_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 403"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get(
|
||||
f"/api/accounts/{self.account.id}/balance/"
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
719
app/apps/api/tests/test_data_isolation.py
Normal file
719
app/apps/api/tests/test_data_isolation.py
Normal file
@@ -0,0 +1,719 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
TransactionEntity,
|
||||
InstallmentPlan,
|
||||
RecurringTransaction,
|
||||
)
|
||||
|
||||
|
||||
ACCESS_DENIED_CODES = [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND]
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class AccountDataIsolationTests(TestCase):
|
||||
"""Tests to ensure users cannot access other users' accounts."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with two distinct users."""
|
||||
User = get_user_model()
|
||||
|
||||
# User 1 - the requester
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
# User 2 - owner of data that user1 should NOT access
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
# Shared currency
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's account
|
||||
self.user1_account_group = AccountGroup.all_objects.create(
|
||||
name="User1 Group", owner=self.user1
|
||||
)
|
||||
self.user1_account = Account.all_objects.create(
|
||||
name="User1 Account",
|
||||
group=self.user1_account_group,
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
# User 2's account (private, should be invisible to user1)
|
||||
self.user2_account_group = AccountGroup.all_objects.create(
|
||||
name="User2 Group", owner=self.user2
|
||||
)
|
||||
self.user2_account = Account.all_objects.create(
|
||||
name="User2 Account",
|
||||
group=self.user2_account_group,
|
||||
currency=self.currency,
|
||||
owner=self.user2,
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_accounts_in_list(self):
|
||||
"""GET /api/accounts/ should only return user's own accounts."""
|
||||
response = self.client1.get("/api/accounts/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
# User1 should only see their own account
|
||||
account_ids = [acc["id"] for acc in response.data["results"]]
|
||||
self.assertIn(self.user1_account.id, account_ids)
|
||||
self.assertNotIn(self.user2_account.id, account_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_account_detail(self):
|
||||
"""GET /api/accounts/{id}/ should deny access to other user's account."""
|
||||
response = self.client1.get(f"/api/accounts/{self.user2_account.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_account(self):
|
||||
"""PATCH on other user's account should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/accounts/{self.user2_account.id}/",
|
||||
{"name": "Hacked Account"},
|
||||
)
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
# Verify account name wasn't changed
|
||||
self.user2_account.refresh_from_db()
|
||||
self.assertEqual(self.user2_account.name, "User2 Account")
|
||||
|
||||
def test_user_cannot_delete_other_users_account(self):
|
||||
"""DELETE on other user's account should deny access."""
|
||||
response = self.client1.delete(f"/api/accounts/{self.user2_account.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
# Verify account still exists
|
||||
self.assertTrue(Account.all_objects.filter(id=self.user2_account.id).exists())
|
||||
|
||||
def test_user_cannot_get_balance_of_other_users_account(self):
|
||||
"""Balance action on other user's account should deny access."""
|
||||
response = self.client1.get(f"/api/accounts/{self.user2_account.id}/balance/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_can_access_own_account(self):
|
||||
"""User can access their own account normally."""
|
||||
response = self.client1.get(f"/api/accounts/{self.user1_account.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "User1 Account")
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class AccountGroupDataIsolationTests(TestCase):
|
||||
"""Tests to ensure users cannot access other users' account groups."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with two distinct users."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
# User 1's account group
|
||||
self.user1_group = AccountGroup.all_objects.create(
|
||||
name="User1 Group", owner=self.user1
|
||||
)
|
||||
|
||||
# User 2's account group
|
||||
self.user2_group = AccountGroup.all_objects.create(
|
||||
name="User2 Group", owner=self.user2
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_account_groups(self):
|
||||
"""GET /api/account-groups/ should only return user's own groups."""
|
||||
response = self.client1.get("/api/account-groups/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
group_ids = [grp["id"] for grp in response.data["results"]]
|
||||
self.assertIn(self.user1_group.id, group_ids)
|
||||
self.assertNotIn(self.user2_group.id, group_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_account_group_detail(self):
|
||||
"""GET /api/account-groups/{id}/ should deny access to other user's group."""
|
||||
response = self.client1.get(f"/api/account-groups/{self.user2_group.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_account_group(self):
|
||||
"""PATCH on other user's account group should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/account-groups/{self.user2_group.id}/",
|
||||
{"name": "Hacked Group"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.user2_group.refresh_from_db()
|
||||
self.assertEqual(self.user2_group.name, "User2 Group")
|
||||
|
||||
def test_user_cannot_delete_other_users_account_group(self):
|
||||
"""DELETE on other user's account group should deny access."""
|
||||
response = self.client1.delete(f"/api/account-groups/{self.user2_group.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(
|
||||
AccountGroup.all_objects.filter(id=self.user2_group.id).exists()
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class TransactionDataIsolationTests(TestCase):
|
||||
"""Tests to ensure users cannot access other users' transactions."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with transactions for two distinct users."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's account and transaction
|
||||
self.user1_account = Account.all_objects.create(
|
||||
name="User1 Account", currency=self.currency, owner=self.user1
|
||||
)
|
||||
self.user1_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.user1_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="User1 Income",
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
# User 2's account and transaction
|
||||
self.user2_account = Account.all_objects.create(
|
||||
name="User2 Account", currency=self.currency, owner=self.user2
|
||||
)
|
||||
self.user2_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.user2_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="User2 Expense",
|
||||
owner=self.user2,
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_transactions_in_list(self):
|
||||
"""GET /api/transactions/ should only return user's own transactions."""
|
||||
response = self.client1.get("/api/transactions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
transaction_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.user1_transaction.id, transaction_ids)
|
||||
self.assertNotIn(self.user2_transaction.id, transaction_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_transaction_detail(self):
|
||||
"""GET /api/transactions/{id}/ should deny access to other user's transaction."""
|
||||
response = self.client1.get(f"/api/transactions/{self.user2_transaction.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_transaction(self):
|
||||
"""PATCH on other user's transaction should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/transactions/{self.user2_transaction.id}/",
|
||||
{"description": "Hacked Transaction"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.user2_transaction.refresh_from_db()
|
||||
self.assertEqual(self.user2_transaction.description, "User2 Expense")
|
||||
|
||||
def test_user_cannot_delete_other_users_transaction(self):
|
||||
"""DELETE on other user's transaction should deny access."""
|
||||
response = self.client1.delete(
|
||||
f"/api/transactions/{self.user2_transaction.id}/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(
|
||||
Transaction.userless_all_objects.filter(
|
||||
id=self.user2_transaction.id
|
||||
).exists()
|
||||
)
|
||||
|
||||
def test_user_cannot_create_transaction_in_other_users_account(self):
|
||||
"""POST /api/transactions/ with other user's account should fail."""
|
||||
response = self.client1.post(
|
||||
"/api/transactions/",
|
||||
{
|
||||
"account": self.user2_account.id,
|
||||
"type": "IN",
|
||||
"amount": "100.00",
|
||||
"date": "2025-01-15",
|
||||
"description": "Sneaky transaction",
|
||||
},
|
||||
format="json",
|
||||
)
|
||||
|
||||
# Should deny access - 400 (validation error), 403, or 404
|
||||
self.assertIn(
|
||||
response.status_code,
|
||||
ACCESS_DENIED_CODES + [status.HTTP_400_BAD_REQUEST],
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class CategoryTagEntityIsolationTests(TestCase):
|
||||
"""Tests for isolation of categories, tags, and entities between users."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
# User 1's categories, tags, entities
|
||||
self.user1_category = TransactionCategory.all_objects.create(
|
||||
name="User1 Category", owner=self.user1
|
||||
)
|
||||
self.user1_tag = TransactionTag.all_objects.create(
|
||||
name="User1 Tag", owner=self.user1
|
||||
)
|
||||
self.user1_entity = TransactionEntity.all_objects.create(
|
||||
name="User1 Entity", owner=self.user1
|
||||
)
|
||||
|
||||
# User 2's categories, tags, entities
|
||||
self.user2_category = TransactionCategory.all_objects.create(
|
||||
name="User2 Category", owner=self.user2
|
||||
)
|
||||
self.user2_tag = TransactionTag.all_objects.create(
|
||||
name="User2 Tag", owner=self.user2
|
||||
)
|
||||
self.user2_entity = TransactionEntity.all_objects.create(
|
||||
name="User2 Entity", owner=self.user2
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_categories(self):
|
||||
"""GET /api/categories/ should only return user's own categories."""
|
||||
response = self.client1.get("/api/categories/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
category_ids = [c["id"] for c in response.data["results"]]
|
||||
self.assertIn(self.user1_category.id, category_ids)
|
||||
self.assertNotIn(self.user2_category.id, category_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_category_detail(self):
|
||||
"""GET /api/categories/{id}/ should deny access to other user's category."""
|
||||
response = self.client1.get(f"/api/categories/{self.user2_category.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_see_other_users_tags(self):
|
||||
"""GET /api/tags/ should only return user's own tags."""
|
||||
response = self.client1.get("/api/tags/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
tag_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.user1_tag.id, tag_ids)
|
||||
self.assertNotIn(self.user2_tag.id, tag_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_tag_detail(self):
|
||||
"""GET /api/tags/{id}/ should deny access to other user's tag."""
|
||||
response = self.client1.get(f"/api/tags/{self.user2_tag.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_see_other_users_entities(self):
|
||||
"""GET /api/entities/ should only return user's own entities."""
|
||||
response = self.client1.get("/api/entities/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
entity_ids = [e["id"] for e in response.data["results"]]
|
||||
self.assertIn(self.user1_entity.id, entity_ids)
|
||||
self.assertNotIn(self.user2_entity.id, entity_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_entity_detail(self):
|
||||
"""GET /api/entities/{id}/ should deny access to other user's entity."""
|
||||
response = self.client1.get(f"/api/entities/{self.user2_entity.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_category(self):
|
||||
"""PATCH on other user's category should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/categories/{self.user2_category.id}/",
|
||||
{"name": "Hacked Category"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_delete_other_users_tag(self):
|
||||
"""DELETE on other user's tag should deny access."""
|
||||
response = self.client1.delete(f"/api/tags/{self.user2_tag.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(
|
||||
TransactionTag.all_objects.filter(id=self.user2_tag.id).exists()
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class DCADataIsolationTests(TestCase):
|
||||
"""Tests to ensure users cannot access other users' DCA strategies and entries."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
self.currency1 = Currency.objects.create(
|
||||
code="BTC", name="Bitcoin", decimal_places=8, prefix=""
|
||||
)
|
||||
self.currency2 = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's DCA strategy and entry
|
||||
self.user1_strategy = DCAStrategy.all_objects.create(
|
||||
name="User1 BTC Strategy",
|
||||
target_currency=self.currency1,
|
||||
payment_currency=self.currency2,
|
||||
owner=self.user1,
|
||||
)
|
||||
self.user1_entry = DCAEntry.objects.create(
|
||||
strategy=self.user1_strategy,
|
||||
date=date(2025, 1, 1),
|
||||
amount_paid=Decimal("100.00"),
|
||||
amount_received=Decimal("0.001"),
|
||||
)
|
||||
|
||||
# User 2's DCA strategy and entry
|
||||
self.user2_strategy = DCAStrategy.all_objects.create(
|
||||
name="User2 BTC Strategy",
|
||||
target_currency=self.currency1,
|
||||
payment_currency=self.currency2,
|
||||
owner=self.user2,
|
||||
)
|
||||
self.user2_entry = DCAEntry.objects.create(
|
||||
strategy=self.user2_strategy,
|
||||
date=date(2025, 1, 1),
|
||||
amount_paid=Decimal("200.00"),
|
||||
amount_received=Decimal("0.002"),
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_dca_strategies(self):
|
||||
"""GET /api/dca/strategies/ should only return user's own strategies."""
|
||||
response = self.client1.get("/api/dca/strategies/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
strategy_ids = [s["id"] for s in response.data["results"]]
|
||||
self.assertIn(self.user1_strategy.id, strategy_ids)
|
||||
self.assertNotIn(self.user2_strategy.id, strategy_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_dca_strategy_detail(self):
|
||||
"""GET /api/dca/strategies/{id}/ should deny access to other user's strategy."""
|
||||
response = self.client1.get(f"/api/dca/strategies/{self.user2_strategy.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_access_other_users_dca_entries(self):
|
||||
"""GET /api/dca/entries/ filtered by other user's strategy should return empty."""
|
||||
response = self.client1.get(
|
||||
f"/api/dca/entries/?strategy={self.user2_strategy.id}"
|
||||
)
|
||||
|
||||
# Either OK with empty results or error
|
||||
if response.status_code == status.HTTP_200_OK:
|
||||
entry_ids = [e["id"] for e in response.data["results"]]
|
||||
self.assertNotIn(self.user2_entry.id, entry_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_dca_entry_detail(self):
|
||||
"""GET /api/dca/entries/{id}/ should deny access to other user's entry."""
|
||||
response = self.client1.get(f"/api/dca/entries/{self.user2_entry.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_access_other_users_strategy_investment_frequency(self):
|
||||
"""investment_frequency action on other user's strategy should deny access."""
|
||||
response = self.client1.get(
|
||||
f"/api/dca/strategies/{self.user2_strategy.id}/investment_frequency/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_access_other_users_strategy_price_comparison(self):
|
||||
"""price_comparison action on other user's strategy should deny access."""
|
||||
response = self.client1.get(
|
||||
f"/api/dca/strategies/{self.user2_strategy.id}/price_comparison/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_access_other_users_strategy_current_price(self):
|
||||
"""current_price action on other user's strategy should deny access."""
|
||||
response = self.client1.get(
|
||||
f"/api/dca/strategies/{self.user2_strategy.id}/current_price/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_dca_strategy(self):
|
||||
"""PATCH on other user's DCA strategy should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/dca/strategies/{self.user2_strategy.id}/",
|
||||
{"name": "Hacked Strategy"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_delete_other_users_dca_entry(self):
|
||||
"""DELETE on other user's DCA entry should deny access."""
|
||||
response = self.client1.delete(f"/api/dca/entries/{self.user2_entry.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(DCAEntry.objects.filter(id=self.user2_entry.id).exists())
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class InstallmentRecurringIsolationTests(TestCase):
|
||||
"""Tests for isolation of installment plans and recurring transactions."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's account
|
||||
self.user1_account = Account.all_objects.create(
|
||||
name="User1 Account", currency=self.currency, owner=self.user1
|
||||
)
|
||||
|
||||
# User 2's account
|
||||
self.user2_account = Account.all_objects.create(
|
||||
name="User2 Account", currency=self.currency, owner=self.user2
|
||||
)
|
||||
|
||||
# User 1's installment plan
|
||||
self.user1_installment = InstallmentPlan.all_objects.create(
|
||||
account=self.user1_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
description="User1 Installment",
|
||||
number_of_installments=12,
|
||||
start_date=date(2025, 1, 1),
|
||||
installment_amount=Decimal("100.00"),
|
||||
)
|
||||
|
||||
# User 2's installment plan
|
||||
self.user2_installment = InstallmentPlan.all_objects.create(
|
||||
account=self.user2_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
description="User2 Installment",
|
||||
number_of_installments=6,
|
||||
start_date=date(2025, 1, 1),
|
||||
installment_amount=Decimal("200.00"),
|
||||
)
|
||||
|
||||
# User 1's recurring transaction
|
||||
self.user1_recurring = RecurringTransaction.all_objects.create(
|
||||
account=self.user1_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("50.00"),
|
||||
description="User1 Recurring",
|
||||
start_date=date(2025, 1, 1),
|
||||
recurrence_type=RecurringTransaction.RecurrenceType.MONTH,
|
||||
recurrence_interval=1,
|
||||
)
|
||||
|
||||
# User 2's recurring transaction
|
||||
self.user2_recurring = RecurringTransaction.all_objects.create(
|
||||
account=self.user2_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("1000.00"),
|
||||
description="User2 Recurring",
|
||||
start_date=date(2025, 1, 1),
|
||||
recurrence_type=RecurringTransaction.RecurrenceType.MONTH,
|
||||
recurrence_interval=1,
|
||||
)
|
||||
|
||||
def test_user_cannot_see_other_users_installment_plans(self):
|
||||
"""GET /api/installment-plans/ should only return user's own plans."""
|
||||
response = self.client1.get("/api/installment-plans/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
plan_ids = [p["id"] for p in response.data["results"]]
|
||||
self.assertIn(self.user1_installment.id, plan_ids)
|
||||
self.assertNotIn(self.user2_installment.id, plan_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_installment_plan_detail(self):
|
||||
"""GET /api/installment-plans/{id}/ should deny access to other user's plan."""
|
||||
response = self.client1.get(
|
||||
f"/api/installment-plans/{self.user2_installment.id}/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_see_other_users_recurring_transactions(self):
|
||||
"""GET /api/recurring-transactions/ should only return user's own recurring."""
|
||||
response = self.client1.get("/api/recurring-transactions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
recurring_ids = [r["id"] for r in response.data["results"]]
|
||||
self.assertIn(self.user1_recurring.id, recurring_ids)
|
||||
self.assertNotIn(self.user2_recurring.id, recurring_ids)
|
||||
|
||||
def test_user_cannot_access_other_users_recurring_transaction_detail(self):
|
||||
"""GET /api/recurring-transactions/{id}/ should deny access to other user's recurring."""
|
||||
response = self.client1.get(
|
||||
f"/api/recurring-transactions/{self.user2_recurring.id}/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_modify_other_users_installment_plan(self):
|
||||
"""PATCH on other user's installment plan should deny access."""
|
||||
response = self.client1.patch(
|
||||
f"/api/installment-plans/{self.user2_installment.id}/",
|
||||
{"description": "Hacked Installment"},
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_cannot_delete_other_users_recurring_transaction(self):
|
||||
"""DELETE on other user's recurring transaction should deny access."""
|
||||
response = self.client1.delete(
|
||||
f"/api/recurring-transactions/{self.user2_recurring.id}/"
|
||||
)
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
self.assertTrue(
|
||||
RecurringTransaction.all_objects.filter(id=self.user2_recurring.id).exists()
|
||||
)
|
||||
404
app/apps/api/tests/test_imports.py
Normal file
404
app/apps/api/tests/test_imports.py
Normal file
@@ -0,0 +1,404 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportAPITests(TestCase):
|
||||
"""Tests for the Import API endpoint"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
# Create a basic import profile with minimal valid YAML config
|
||||
self.profile = ImportProfile.objects.create(
|
||||
name="Test Profile",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
@patch("apps.import_app.tasks.process_import.defer")
|
||||
@patch("django.core.files.storage.FileSystemStorage.save")
|
||||
@patch("django.core.files.storage.FileSystemStorage.path")
|
||||
def test_create_import_success(self, mock_path, mock_save, mock_defer):
|
||||
"""Test successful file upload creates ImportRun and queues task"""
|
||||
mock_save.return_value = "test_file.csv"
|
||||
mock_path.return_value = "/usr/src/app/temp/test_file.csv"
|
||||
|
||||
csv_content = b"date,description,amount,account\n2025-01-01,Test,100,Main"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
self.assertIn("import_run_id", response.data)
|
||||
self.assertEqual(response.data["status"], "queued")
|
||||
|
||||
# Verify ImportRun was created
|
||||
import_run = ImportRun.objects.get(id=response.data["import_run_id"])
|
||||
self.assertEqual(import_run.profile, self.profile)
|
||||
self.assertEqual(import_run.file_name, "test_file.csv")
|
||||
|
||||
# Verify task was deferred
|
||||
mock_defer.assert_called_once_with(
|
||||
import_run_id=import_run.id,
|
||||
file_path="/usr/src/app/temp/test_file.csv",
|
||||
user_id=self.user.id,
|
||||
)
|
||||
|
||||
def test_create_import_missing_profile(self):
|
||||
"""Test request without profile_id returns 400"""
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("profile_id", response.data)
|
||||
|
||||
def test_create_import_missing_file(self):
|
||||
"""Test request without file returns 400"""
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("file", response.data)
|
||||
|
||||
def test_create_import_invalid_profile(self):
|
||||
"""Test request with non-existent profile returns 400"""
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": 99999, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertIn("profile_id", response.data)
|
||||
|
||||
@patch("apps.import_app.tasks.process_import.defer")
|
||||
@patch("django.core.files.storage.FileSystemStorage.save")
|
||||
@patch("django.core.files.storage.FileSystemStorage.path")
|
||||
def test_create_import_xlsx(self, mock_path, mock_save, mock_defer):
|
||||
"""Test successful XLSX file upload"""
|
||||
mock_save.return_value = "test_file.xlsx"
|
||||
mock_path.return_value = "/usr/src/app/temp/test_file.xlsx"
|
||||
|
||||
# Create a simple XLSX-like content (just for the upload test)
|
||||
xlsx_content = BytesIO(b"PK\x03\x04") # XLSX files start with PK header
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.xlsx",
|
||||
xlsx_content.getvalue(),
|
||||
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
|
||||
response = self.client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
||||
self.assertIn("import_run_id", response.data)
|
||||
|
||||
def test_unauthenticated_request(self):
|
||||
"""Test unauthenticated request returns 403"""
|
||||
unauthenticated_client = APIClient()
|
||||
|
||||
csv_content = b"date,description,amount\n2025-01-01,Test,100"
|
||||
file = SimpleUploadedFile(
|
||||
"test_file.csv", csv_content, content_type="text/csv"
|
||||
)
|
||||
|
||||
response = unauthenticated_client.post(
|
||||
"/api/import/import/",
|
||||
{"profile_id": self.profile.id, "file": file},
|
||||
format="multipart",
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportProfileAPITests(TestCase):
|
||||
"""Tests for the Import Profile API endpoints"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.profile1 = ImportProfile.objects.create(
|
||||
name="Profile 1",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
self.profile2 = ImportProfile.objects.create(
|
||||
name="Profile 2",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_income
|
||||
is_paid:
|
||||
detection_method: always_unpaid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
def test_list_profiles(self):
|
||||
"""Test listing all profiles"""
|
||||
response = self.client.get("/api/import/profiles/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 2)
|
||||
self.assertEqual(len(response.data["results"]), 2)
|
||||
|
||||
def test_retrieve_profile(self):
|
||||
"""Test retrieving a specific profile"""
|
||||
response = self.client.get(f"/api/import/profiles/{self.profile1.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["id"], self.profile1.id)
|
||||
self.assertEqual(response.data["name"], "Profile 1")
|
||||
self.assertIn("yaml_config", response.data)
|
||||
|
||||
def test_retrieve_nonexistent_profile(self):
|
||||
"""Test retrieving a non-existent profile returns 404"""
|
||||
response = self.client.get("/api/import/profiles/99999/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_profiles_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 403"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get("/api/import/profiles/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class ImportRunAPITests(TestCase):
|
||||
"""Tests for the Import Run API endpoints"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client = APIClient()
|
||||
self.client.force_authenticate(user=self.user)
|
||||
|
||||
self.profile1 = ImportProfile.objects.create(
|
||||
name="Profile 1",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_expense
|
||||
is_paid:
|
||||
detection_method: always_paid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
self.profile2 = ImportProfile.objects.create(
|
||||
name="Profile 2",
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
yaml_config="""
|
||||
file_type: csv
|
||||
date_format: "%Y-%m-%d"
|
||||
column_mapping:
|
||||
date:
|
||||
source: date
|
||||
description:
|
||||
source: description
|
||||
amount:
|
||||
source: amount
|
||||
transaction_type:
|
||||
detection_method: always_income
|
||||
is_paid:
|
||||
detection_method: always_unpaid
|
||||
account:
|
||||
source: account
|
||||
match_field: name
|
||||
""",
|
||||
)
|
||||
|
||||
# Create import runs
|
||||
self.run1 = ImportRun.objects.create(
|
||||
profile=self.profile1,
|
||||
file_name="file1.csv",
|
||||
status=ImportRun.Status.FINISHED,
|
||||
)
|
||||
self.run2 = ImportRun.objects.create(
|
||||
profile=self.profile1,
|
||||
file_name="file2.csv",
|
||||
status=ImportRun.Status.QUEUED,
|
||||
)
|
||||
self.run3 = ImportRun.objects.create(
|
||||
profile=self.profile2,
|
||||
file_name="file3.csv",
|
||||
status=ImportRun.Status.FINISHED,
|
||||
)
|
||||
|
||||
def test_list_all_runs(self):
|
||||
"""Test listing all runs"""
|
||||
response = self.client.get("/api/import/runs/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 3)
|
||||
self.assertEqual(len(response.data["results"]), 3)
|
||||
|
||||
def test_list_runs_by_profile(self):
|
||||
"""Test filtering runs by profile_id"""
|
||||
response = self.client.get(f"/api/import/runs/?profile_id={self.profile1.id}")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 2)
|
||||
for run in response.data["results"]:
|
||||
self.assertEqual(run["profile"], self.profile1.id)
|
||||
|
||||
def test_list_runs_by_other_profile(self):
|
||||
"""Test filtering runs by another profile_id"""
|
||||
response = self.client.get(f"/api/import/runs/?profile_id={self.profile2.id}")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["count"], 1)
|
||||
self.assertEqual(response.data["results"][0]["profile"], self.profile2.id)
|
||||
|
||||
def test_retrieve_run(self):
|
||||
"""Test retrieving a specific run"""
|
||||
response = self.client.get(f"/api/import/runs/{self.run1.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["id"], self.run1.id)
|
||||
self.assertEqual(response.data["file_name"], "file1.csv")
|
||||
self.assertEqual(response.data["status"], "FINISHED")
|
||||
|
||||
def test_retrieve_nonexistent_run(self):
|
||||
"""Test retrieving a non-existent run returns 404"""
|
||||
response = self.client.get("/api/import/runs/99999/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||
|
||||
def test_runs_unauthenticated(self):
|
||||
"""Test unauthenticated request returns 403"""
|
||||
unauthenticated_client = APIClient()
|
||||
response = unauthenticated_client.get("/api/import/runs/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||
587
app/apps/api/tests/test_shared_access.py
Normal file
587
app/apps/api/tests/test_shared_access.py
Normal file
@@ -0,0 +1,587 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
from rest_framework import status
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
TransactionEntity,
|
||||
)
|
||||
|
||||
|
||||
ACCESS_DENIED_CODES = [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND]
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class SharedAccountAccessTests(TestCase):
|
||||
"""Tests for shared account access via shared_with field."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with shared accounts."""
|
||||
User = get_user_model()
|
||||
|
||||
# User 1 - owner
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
# User 2 - will have shared access
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
# User 3 - no shared access
|
||||
self.user3 = User.objects.create_user(
|
||||
email="user3@test.com", password="testpass123"
|
||||
)
|
||||
self.client3 = APIClient()
|
||||
self.client3.force_authenticate(user=self.user3)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's account shared with user 2
|
||||
self.shared_account = Account.all_objects.create(
|
||||
name="Shared Account",
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
visibility="private",
|
||||
)
|
||||
self.shared_account.shared_with.add(self.user2)
|
||||
|
||||
# User 1's private account (not shared)
|
||||
self.private_account = Account.all_objects.create(
|
||||
name="Private Account",
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
visibility="private",
|
||||
)
|
||||
|
||||
# Transaction in shared account
|
||||
self.shared_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.shared_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Shared Transaction",
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
# Transaction in private account
|
||||
self.private_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.private_account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
amount=Decimal("50.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Private Transaction",
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
def test_user_can_see_accounts_shared_with_them(self):
|
||||
"""User2 should see the account shared with them."""
|
||||
response = self.client2.get("/api/accounts/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
account_ids = [acc["id"] for acc in response.data["results"]]
|
||||
self.assertIn(self.shared_account.id, account_ids)
|
||||
|
||||
def test_user_cannot_see_accounts_not_shared_with_them(self):
|
||||
"""User2 should NOT see user1's private (non-shared) account."""
|
||||
response = self.client2.get("/api/accounts/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
account_ids = [acc["id"] for acc in response.data["results"]]
|
||||
self.assertNotIn(self.private_account.id, account_ids)
|
||||
|
||||
def test_user_can_access_shared_account_detail(self):
|
||||
"""User2 should be able to access shared account details."""
|
||||
response = self.client2.get(f"/api/accounts/{self.shared_account.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Account")
|
||||
|
||||
def test_user_without_share_cannot_access_shared_account(self):
|
||||
"""User3 should NOT be able to access the shared account."""
|
||||
response = self.client3.get(f"/api/accounts/{self.shared_account.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_can_see_transactions_in_shared_account(self):
|
||||
"""User2 should see transactions in the shared account."""
|
||||
response = self.client2.get("/api/transactions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
transaction_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.shared_transaction.id, transaction_ids)
|
||||
self.assertNotIn(self.private_transaction.id, transaction_ids)
|
||||
|
||||
def test_user_can_access_transaction_in_shared_account(self):
|
||||
"""User2 should be able to access transaction details in shared account."""
|
||||
response = self.client2.get(f"/api/transactions/{self.shared_transaction.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["description"], "Shared Transaction")
|
||||
|
||||
def test_user_cannot_access_transaction_in_non_shared_account(self):
|
||||
"""User2 should NOT access transactions in user1's private account."""
|
||||
response = self.client2.get(f"/api/transactions/{self.private_transaction.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
|
||||
def test_user_can_get_balance_of_shared_account(self):
|
||||
"""User2 should be able to get balance of shared account."""
|
||||
response = self.client2.get(f"/api/accounts/{self.shared_account.id}/balance/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn("current_balance", response.data)
|
||||
|
||||
def test_sharing_works_with_multiple_users(self):
|
||||
"""Account shared with multiple users should be accessible by all."""
|
||||
# Add user3 to shared_with
|
||||
self.shared_account.shared_with.add(self.user3)
|
||||
|
||||
# User2 still has access
|
||||
response2 = self.client2.get(f"/api/accounts/{self.shared_account.id}/")
|
||||
self.assertEqual(response2.status_code, status.HTTP_200_OK)
|
||||
|
||||
# User3 now has access
|
||||
response3 = self.client3.get(f"/api/accounts/{self.shared_account.id}/")
|
||||
self.assertEqual(response3.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class PublicVisibilityTests(TestCase):
|
||||
"""Tests for public visibility access."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with public accounts."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's public account
|
||||
self.public_account = Account.all_objects.create(
|
||||
name="Public Account",
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
visibility="public",
|
||||
)
|
||||
|
||||
# User 1's private account
|
||||
self.private_account = Account.all_objects.create(
|
||||
name="Private Account",
|
||||
currency=self.currency,
|
||||
owner=self.user1,
|
||||
visibility="private",
|
||||
)
|
||||
|
||||
# Transaction in public account
|
||||
self.public_transaction = Transaction.userless_all_objects.create(
|
||||
account=self.public_account,
|
||||
type=Transaction.Type.INCOME,
|
||||
amount=Decimal("100.00"),
|
||||
is_paid=True,
|
||||
date=date(2025, 1, 1),
|
||||
description="Public Transaction",
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
def test_user_can_see_public_accounts(self):
|
||||
"""User2 should see user1's public account."""
|
||||
response = self.client2.get("/api/accounts/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
account_ids = [acc["id"] for acc in response.data["results"]]
|
||||
self.assertIn(self.public_account.id, account_ids)
|
||||
self.assertNotIn(self.private_account.id, account_ids)
|
||||
|
||||
def test_user_can_access_public_account_detail(self):
|
||||
"""User2 should be able to access public account details."""
|
||||
response = self.client2.get(f"/api/accounts/{self.public_account.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Public Account")
|
||||
|
||||
def test_user_can_see_transactions_in_public_accounts(self):
|
||||
"""User2 should see transactions in public accounts."""
|
||||
response = self.client2.get("/api/transactions/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
transaction_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.public_transaction.id, transaction_ids)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class SharedCategoryTagEntityTests(TestCase):
|
||||
"""Tests for shared categories, tags, and entities."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with shared categories/tags/entities."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
self.user3 = User.objects.create_user(
|
||||
email="user3@test.com", password="testpass123"
|
||||
)
|
||||
self.client3 = APIClient()
|
||||
self.client3.force_authenticate(user=self.user3)
|
||||
|
||||
# User 1's category shared with user 2
|
||||
self.shared_category = TransactionCategory.all_objects.create(
|
||||
name="Shared Category", owner=self.user1
|
||||
)
|
||||
self.shared_category.shared_with.add(self.user2)
|
||||
|
||||
# User 1's private category
|
||||
self.private_category = TransactionCategory.all_objects.create(
|
||||
name="Private Category", owner=self.user1
|
||||
)
|
||||
|
||||
# User 1's public category
|
||||
self.public_category = TransactionCategory.all_objects.create(
|
||||
name="Public Category", owner=self.user1, visibility="public"
|
||||
)
|
||||
|
||||
# User 1's tag shared with user 2
|
||||
self.shared_tag = TransactionTag.all_objects.create(
|
||||
name="Shared Tag", owner=self.user1
|
||||
)
|
||||
self.shared_tag.shared_with.add(self.user2)
|
||||
|
||||
# User 1's entity shared with user 2
|
||||
self.shared_entity = TransactionEntity.all_objects.create(
|
||||
name="Shared Entity", owner=self.user1
|
||||
)
|
||||
self.shared_entity.shared_with.add(self.user2)
|
||||
|
||||
def test_user_can_see_shared_categories(self):
|
||||
"""User2 should see categories shared with them."""
|
||||
response = self.client2.get("/api/categories/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
category_ids = [c["id"] for c in response.data["results"]]
|
||||
self.assertIn(self.shared_category.id, category_ids)
|
||||
self.assertNotIn(self.private_category.id, category_ids)
|
||||
|
||||
def test_user_can_access_shared_category_detail(self):
|
||||
"""User2 should be able to access shared category details."""
|
||||
response = self.client2.get(f"/api/categories/{self.shared_category.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Category")
|
||||
|
||||
def test_user_can_see_public_categories(self):
|
||||
"""User3 should see public categories."""
|
||||
response = self.client3.get("/api/categories/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
category_ids = [c["id"] for c in response.data["results"]]
|
||||
self.assertIn(self.public_category.id, category_ids)
|
||||
|
||||
def test_user_without_share_cannot_see_shared_category(self):
|
||||
"""User3 should NOT see category shared only with user2."""
|
||||
response = self.client3.get("/api/categories/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
category_ids = [c["id"] for c in response.data["results"]]
|
||||
self.assertNotIn(self.shared_category.id, category_ids)
|
||||
|
||||
def test_user_can_see_shared_tags(self):
|
||||
"""User2 should see tags shared with them."""
|
||||
response = self.client2.get("/api/tags/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
tag_ids = [t["id"] for t in response.data["results"]]
|
||||
self.assertIn(self.shared_tag.id, tag_ids)
|
||||
|
||||
def test_user_can_access_shared_tag_detail(self):
|
||||
"""User2 should be able to access shared tag details."""
|
||||
response = self.client2.get(f"/api/tags/{self.shared_tag.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Tag")
|
||||
|
||||
def test_user_can_see_shared_entities(self):
|
||||
"""User2 should see entities shared with them."""
|
||||
response = self.client2.get("/api/entities/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
entity_ids = [e["id"] for e in response.data["results"]]
|
||||
self.assertIn(self.shared_entity.id, entity_ids)
|
||||
|
||||
def test_user_can_access_shared_entity_detail(self):
|
||||
"""User2 should be able to access shared entity details."""
|
||||
response = self.client2.get(f"/api/entities/{self.shared_entity.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Entity")
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class SharedDCAAccessTests(TestCase):
|
||||
"""Tests for shared DCA strategy access."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with shared DCA strategies."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
self.user3 = User.objects.create_user(
|
||||
email="user3@test.com", password="testpass123"
|
||||
)
|
||||
self.client3 = APIClient()
|
||||
self.client3.force_authenticate(user=self.user3)
|
||||
|
||||
self.currency1 = Currency.objects.create(
|
||||
code="BTC", name="Bitcoin", decimal_places=8, prefix=""
|
||||
)
|
||||
self.currency2 = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
|
||||
# User 1's DCA strategy shared with user 2
|
||||
self.shared_strategy = DCAStrategy.all_objects.create(
|
||||
name="Shared BTC Strategy",
|
||||
target_currency=self.currency1,
|
||||
payment_currency=self.currency2,
|
||||
owner=self.user1,
|
||||
)
|
||||
self.shared_strategy.shared_with.add(self.user2)
|
||||
|
||||
# Entry in shared strategy
|
||||
self.shared_entry = DCAEntry.objects.create(
|
||||
strategy=self.shared_strategy,
|
||||
date=date(2025, 1, 1),
|
||||
amount_paid=Decimal("100.00"),
|
||||
amount_received=Decimal("0.001"),
|
||||
)
|
||||
|
||||
# User 1's private strategy
|
||||
self.private_strategy = DCAStrategy.all_objects.create(
|
||||
name="Private BTC Strategy",
|
||||
target_currency=self.currency1,
|
||||
payment_currency=self.currency2,
|
||||
owner=self.user1,
|
||||
)
|
||||
|
||||
def test_user_can_see_shared_dca_strategies(self):
|
||||
"""User2 should see DCA strategies shared with them."""
|
||||
response = self.client2.get("/api/dca/strategies/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
strategy_ids = [s["id"] for s in response.data["results"]]
|
||||
self.assertIn(self.shared_strategy.id, strategy_ids)
|
||||
self.assertNotIn(self.private_strategy.id, strategy_ids)
|
||||
|
||||
def test_user_can_access_shared_dca_strategy_detail(self):
|
||||
"""User2 should be able to access shared strategy details."""
|
||||
response = self.client2.get(f"/api/dca/strategies/{self.shared_strategy.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared BTC Strategy")
|
||||
|
||||
def test_user_without_share_cannot_see_shared_strategy(self):
|
||||
"""User3 should NOT see strategy shared only with user2."""
|
||||
response = self.client3.get("/api/dca/strategies/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
strategy_ids = [s["id"] for s in response.data["results"]]
|
||||
self.assertNotIn(self.shared_strategy.id, strategy_ids)
|
||||
|
||||
def test_user_can_access_shared_strategy_actions(self):
|
||||
"""User2 should be able to access actions on shared strategy."""
|
||||
# investment_frequency
|
||||
response1 = self.client2.get(
|
||||
f"/api/dca/strategies/{self.shared_strategy.id}/investment_frequency/"
|
||||
)
|
||||
self.assertEqual(response1.status_code, status.HTTP_200_OK)
|
||||
|
||||
# price_comparison
|
||||
response2 = self.client2.get(
|
||||
f"/api/dca/strategies/{self.shared_strategy.id}/price_comparison/"
|
||||
)
|
||||
self.assertEqual(response2.status_code, status.HTTP_200_OK)
|
||||
|
||||
# current_price
|
||||
response3 = self.client2.get(
|
||||
f"/api/dca/strategies/{self.shared_strategy.id}/current_price/"
|
||||
)
|
||||
self.assertEqual(response3.status_code, status.HTTP_200_OK)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class SharedAccountGroupTests(TestCase):
|
||||
"""Tests for shared account group access."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data with shared account groups."""
|
||||
User = get_user_model()
|
||||
|
||||
self.user1 = User.objects.create_user(
|
||||
email="user1@test.com", password="testpass123"
|
||||
)
|
||||
self.client1 = APIClient()
|
||||
self.client1.force_authenticate(user=self.user1)
|
||||
|
||||
self.user2 = User.objects.create_user(
|
||||
email="user2@test.com", password="testpass123"
|
||||
)
|
||||
self.client2 = APIClient()
|
||||
self.client2.force_authenticate(user=self.user2)
|
||||
|
||||
self.user3 = User.objects.create_user(
|
||||
email="user3@test.com", password="testpass123"
|
||||
)
|
||||
self.client3 = APIClient()
|
||||
self.client3.force_authenticate(user=self.user3)
|
||||
|
||||
# User 1's account group shared with user 2
|
||||
self.shared_group = AccountGroup.all_objects.create(
|
||||
name="Shared Group", owner=self.user1
|
||||
)
|
||||
self.shared_group.shared_with.add(self.user2)
|
||||
|
||||
# User 1's private account group
|
||||
self.private_group = AccountGroup.all_objects.create(
|
||||
name="Private Group", owner=self.user1
|
||||
)
|
||||
|
||||
# User 1's public account group
|
||||
self.public_group = AccountGroup.all_objects.create(
|
||||
name="Public Group", owner=self.user1, visibility="public"
|
||||
)
|
||||
|
||||
def test_user_can_see_shared_account_groups(self):
|
||||
"""User2 should see account groups shared with them."""
|
||||
response = self.client2.get("/api/account-groups/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
group_ids = [g["id"] for g in response.data["results"]]
|
||||
self.assertIn(self.shared_group.id, group_ids)
|
||||
self.assertNotIn(self.private_group.id, group_ids)
|
||||
|
||||
def test_user_can_access_shared_account_group_detail(self):
|
||||
"""User2 should be able to access shared account group details."""
|
||||
response = self.client2.get(f"/api/account-groups/{self.shared_group.id}/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.data["name"], "Shared Group")
|
||||
|
||||
def test_user_can_see_public_account_groups(self):
|
||||
"""User3 should see public account groups."""
|
||||
response = self.client3.get("/api/account-groups/")
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
group_ids = [g["id"] for g in response.data["results"]]
|
||||
self.assertIn(self.public_group.id, group_ids)
|
||||
|
||||
def test_user_without_share_cannot_access_shared_group(self):
|
||||
"""User3 should NOT be able to access shared account group."""
|
||||
response = self.client3.get(f"/api/account-groups/{self.shared_group.id}/")
|
||||
|
||||
self.assertIn(response.status_code, ACCESS_DENIED_CODES)
|
||||
@@ -16,7 +16,11 @@ router.register(r"currencies", views.CurrencyViewSet)
|
||||
router.register(r"exchange-rates", views.ExchangeRateViewSet)
|
||||
router.register(r"dca/strategies", views.DCAStrategyViewSet)
|
||||
router.register(r"dca/entries", views.DCAEntryViewSet)
|
||||
router.register(r"import/profiles", views.ImportProfileViewSet, basename="import-profiles")
|
||||
router.register(r"import/runs", views.ImportRunViewSet, basename="import-runs")
|
||||
router.register(r"import/import", views.ImportViewSet, basename="import-import")
|
||||
|
||||
urlpatterns = [
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
|
||||
|
||||
@@ -2,3 +2,5 @@ from .transactions import *
|
||||
from .accounts import *
|
||||
from .currencies import *
|
||||
from .dca import *
|
||||
from .imports import *
|
||||
|
||||
|
||||
@@ -1,27 +1,79 @@
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||
from apps.accounts.models import AccountGroup, Account
|
||||
from apps.api.serializers import AccountGroupSerializer, AccountSerializer
|
||||
from apps.accounts.services import get_account_balance
|
||||
from apps.api.serializers import (
|
||||
AccountGroupSerializer,
|
||||
AccountSerializer,
|
||||
AccountBalanceSerializer,
|
||||
)
|
||||
|
||||
|
||||
class AccountGroupViewSet(viewsets.ModelViewSet):
|
||||
"""ViewSet for managing account groups."""
|
||||
|
||||
queryset = AccountGroup.objects.all()
|
||||
serializer_class = AccountGroupSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return AccountGroup.objects.all().order_by("id")
|
||||
return AccountGroup.objects.all()
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
balance=extend_schema(
|
||||
summary="Get account balance",
|
||||
description="Returns the current and projected balance for the account, along with currency data.",
|
||||
responses={200: AccountBalanceSerializer},
|
||||
),
|
||||
)
|
||||
class AccountViewSet(viewsets.ModelViewSet):
|
||||
"""ViewSet for managing accounts."""
|
||||
|
||||
queryset = Account.objects.all()
|
||||
serializer_class = AccountSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"group": ["exact", "isnull"],
|
||||
"currency": ["exact"],
|
||||
"exchange_currency": ["exact", "isnull"],
|
||||
"is_asset": ["exact"],
|
||||
"is_archived": ["exact"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return (
|
||||
Account.objects.all()
|
||||
.order_by("id")
|
||||
.select_related("group", "currency", "exchange_currency")
|
||||
return Account.objects.all().select_related(
|
||||
"group", "currency", "exchange_currency"
|
||||
)
|
||||
|
||||
@action(detail=True, methods=["get"], permission_classes=[IsAuthenticated])
|
||||
def balance(self, request, pk=None):
|
||||
"""Get current and projected balance for an account."""
|
||||
account = self.get_object()
|
||||
|
||||
current_balance = get_account_balance(account, paid_only=True)
|
||||
projected_balance = get_account_balance(account, paid_only=False)
|
||||
|
||||
serializer = AccountBalanceSerializer(
|
||||
{
|
||||
"current_balance": current_balance,
|
||||
"projected_balance": projected_balance,
|
||||
"currency": account.currency,
|
||||
}
|
||||
)
|
||||
|
||||
return Response(serializer.data)
|
||||
|
||||
@@ -9,8 +9,28 @@ from apps.currencies.models import ExchangeRate
|
||||
class CurrencyViewSet(viewsets.ModelViewSet):
|
||||
queryset = Currency.objects.all()
|
||||
serializer_class = CurrencySerializer
|
||||
filterset_fields = {
|
||||
'name': ['exact', 'icontains'],
|
||||
'code': ['exact', 'icontains'],
|
||||
'decimal_places': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'prefix': ['exact', 'icontains'],
|
||||
'suffix': ['exact', 'icontains'],
|
||||
'exchange_currency': ['exact'],
|
||||
'is_archived': ['exact'],
|
||||
}
|
||||
search_fields = '__all__'
|
||||
ordering_fields = '__all__'
|
||||
|
||||
|
||||
class ExchangeRateViewSet(viewsets.ModelViewSet):
|
||||
queryset = ExchangeRate.objects.all()
|
||||
serializer_class = ExchangeRateSerializer
|
||||
filterset_fields = {
|
||||
'from_currency': ['exact'],
|
||||
'to_currency': ['exact'],
|
||||
'rate': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'date': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'automatic': ['exact'],
|
||||
}
|
||||
search_fields = '__all__'
|
||||
ordering_fields = '__all__'
|
||||
|
||||
@@ -8,6 +8,19 @@ from apps.api.serializers import DCAStrategySerializer, DCAEntrySerializer
|
||||
class DCAStrategyViewSet(viewsets.ModelViewSet):
|
||||
queryset = DCAStrategy.objects.all()
|
||||
serializer_class = DCAStrategySerializer
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"target_currency": ["exact"],
|
||||
"payment_currency": ["exact"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"created_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"updated_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
}
|
||||
search_fields = ["name", "notes"]
|
||||
ordering_fields = "__all__"
|
||||
|
||||
def get_queryset(self):
|
||||
return DCAStrategy.objects.all()
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def investment_frequency(self, request, pk=None):
|
||||
@@ -32,10 +45,22 @@ class DCAStrategyViewSet(viewsets.ModelViewSet):
|
||||
class DCAEntryViewSet(viewsets.ModelViewSet):
|
||||
queryset = DCAEntry.objects.all()
|
||||
serializer_class = DCAEntrySerializer
|
||||
filterset_fields = {
|
||||
"strategy": ["exact"],
|
||||
"date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"amount_paid": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"amount_received": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"expense_transaction": ["exact", "isnull"],
|
||||
"income_transaction": ["exact", "isnull"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"created_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"updated_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
}
|
||||
search_fields = ["notes"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["-date"]
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = DCAEntry.objects.all()
|
||||
strategy_id = self.request.query_params.get("strategy", None)
|
||||
if strategy_id is not None:
|
||||
queryset = queryset.filter(strategy_id=strategy_id)
|
||||
return queryset
|
||||
# Filter entries by strategies the user has access to
|
||||
accessible_strategies = DCAStrategy.objects.all()
|
||||
return DCAEntry.objects.filter(strategy__in=accessible_strategies)
|
||||
|
||||
147
app/apps/api/views/imports.py
Normal file
147
app/apps/api/views/imports.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from django.core.files.storage import FileSystemStorage
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view, inline_serializer
|
||||
from rest_framework import serializers as drf_serializers
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.parsers import MultiPartParser
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.api.serializers import ImportFileSerializer, ImportProfileSerializer, ImportRunSerializer
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
from apps.import_app.tasks import process_import
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
summary="List import profiles",
|
||||
description="Returns a paginated list of all available import profiles.",
|
||||
),
|
||||
retrieve=extend_schema(
|
||||
summary="Get import profile",
|
||||
description="Returns the details of a specific import profile by ID.",
|
||||
),
|
||||
)
|
||||
class ImportProfileViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
"""ViewSet for listing and retrieving import profiles."""
|
||||
|
||||
queryset = ImportProfile.objects.all()
|
||||
serializer_class = ImportProfileSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
filterset_fields = {
|
||||
'name': ['exact', 'icontains'],
|
||||
'yaml_config': ['exact', 'icontains'],
|
||||
'version': ['exact'],
|
||||
}
|
||||
search_fields = ['name', 'yaml_config']
|
||||
ordering_fields = '__all__'
|
||||
ordering = ['name']
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
list=extend_schema(
|
||||
summary="List import runs",
|
||||
description="Returns a paginated list of import runs. Optionally filter by profile_id.",
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="profile_id",
|
||||
type=int,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Filter runs by profile ID",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
),
|
||||
retrieve=extend_schema(
|
||||
summary="Get import run",
|
||||
description="Returns the details of a specific import run by ID, including status and logs.",
|
||||
),
|
||||
)
|
||||
class ImportRunViewSet(viewsets.ReadOnlyModelViewSet):
|
||||
"""ViewSet for listing and retrieving import runs."""
|
||||
|
||||
queryset = ImportRun.objects.all().order_by("-id")
|
||||
serializer_class = ImportRunSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
filterset_fields = {
|
||||
'status': ['exact'],
|
||||
'profile': ['exact'],
|
||||
'file_name': ['exact', 'icontains'],
|
||||
'logs': ['exact', 'icontains'],
|
||||
'processed_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'total_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'successful_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'skipped_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'failed_rows': ['exact', 'gte', 'lte', 'gt', 'lt'],
|
||||
'started_at': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'],
|
||||
'finished_at': ['exact', 'gte', 'lte', 'gt', 'lt', 'isnull'],
|
||||
}
|
||||
search_fields = ['file_name', 'logs']
|
||||
ordering_fields = '__all__'
|
||||
ordering = ['-id']
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
profile_id = self.request.query_params.get("profile_id")
|
||||
if profile_id:
|
||||
queryset = queryset.filter(profile_id=profile_id)
|
||||
return queryset
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
create=extend_schema(
|
||||
summary="Import file",
|
||||
description="Upload a CSV or XLSX file to import using an existing import profile. The import is queued and processed asynchronously.",
|
||||
request={
|
||||
"multipart/form-data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"profile_id": {"type": "integer", "description": "ID of the ImportProfile to use"},
|
||||
"file": {"type": "string", "format": "binary", "description": "CSV or XLSX file to import"},
|
||||
},
|
||||
"required": ["profile_id", "file"],
|
||||
},
|
||||
},
|
||||
responses={
|
||||
202: inline_serializer(
|
||||
name="ImportResponse",
|
||||
fields={
|
||||
"import_run_id": drf_serializers.IntegerField(),
|
||||
"status": drf_serializers.CharField(),
|
||||
},
|
||||
),
|
||||
},
|
||||
),
|
||||
)
|
||||
class ImportViewSet(viewsets.ViewSet):
|
||||
"""ViewSet for importing data via file upload."""
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
parser_classes = [MultiPartParser]
|
||||
|
||||
def create(self, request):
|
||||
serializer = ImportFileSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
profile = serializer.validated_data["profile"]
|
||||
uploaded_file = serializer.validated_data["file"]
|
||||
|
||||
# Save file to temp location
|
||||
fs = FileSystemStorage(location="/usr/src/app/temp")
|
||||
filename = fs.save(uploaded_file.name, uploaded_file)
|
||||
file_path = fs.path(filename)
|
||||
|
||||
# Create ImportRun record
|
||||
import_run = ImportRun.objects.create(profile=profile, file_name=filename)
|
||||
|
||||
# Queue import task
|
||||
process_import.defer(
|
||||
import_run_id=import_run.id,
|
||||
file_path=file_path,
|
||||
user_id=request.user.id,
|
||||
)
|
||||
|
||||
return Response(
|
||||
{"import_run_id": import_run.id, "status": "queued"},
|
||||
status=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from rest_framework import viewsets
|
||||
|
||||
from apps.api.custom.pagination import CustomPageNumberPagination
|
||||
from apps.api.serializers import (
|
||||
TransactionSerializer,
|
||||
TransactionCategorySerializer,
|
||||
@@ -23,64 +24,151 @@ from apps.rules.signals import transaction_updated, transaction_created
|
||||
class TransactionViewSet(viewsets.ModelViewSet):
|
||||
queryset = Transaction.objects.all()
|
||||
serializer_class = TransactionSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"account": ["exact"],
|
||||
"type": ["exact"],
|
||||
"is_paid": ["exact"],
|
||||
"date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"reference_date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"mute": ["exact"],
|
||||
"amount": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"description": ["exact", "icontains"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"category": ["exact", "isnull"],
|
||||
"installment_plan": ["exact", "isnull"],
|
||||
"installment_id": ["exact", "gte", "lte"],
|
||||
"recurring_transaction": ["exact", "isnull"],
|
||||
"internal_note": ["exact", "icontains"],
|
||||
"internal_id": ["exact"],
|
||||
"deleted": ["exact"],
|
||||
"created_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"updated_at": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"deleted_at": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["description", "notes", "internal_note"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["-id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return Transaction.objects.all()
|
||||
|
||||
def perform_create(self, serializer):
|
||||
instance = serializer.save()
|
||||
transaction_created.send(sender=instance)
|
||||
|
||||
def perform_update(self, serializer):
|
||||
old_data = deepcopy(self.get_object())
|
||||
instance = serializer.save()
|
||||
transaction_updated.send(sender=instance)
|
||||
transaction_updated.send(sender=instance, old_data=old_data)
|
||||
|
||||
def partial_update(self, request, *args, **kwargs):
|
||||
kwargs["partial"] = True
|
||||
return self.update(request, *args, **kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
return Transaction.objects.all().order_by("-id")
|
||||
|
||||
|
||||
class TransactionCategoryViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionCategory.objects.all()
|
||||
serializer_class = TransactionCategorySerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"mute": ["exact"],
|
||||
"active": ["exact"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionCategory.objects.all().order_by("id")
|
||||
return TransactionCategory.objects.all()
|
||||
|
||||
|
||||
class TransactionTagViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionTag.objects.all()
|
||||
serializer_class = TransactionTagSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"active": ["exact"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionTag.objects.all().order_by("id")
|
||||
return TransactionTag.objects.all()
|
||||
|
||||
|
||||
class TransactionEntityViewSet(viewsets.ModelViewSet):
|
||||
queryset = TransactionEntity.objects.all()
|
||||
serializer_class = TransactionEntitySerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"name": ["exact", "icontains"],
|
||||
"active": ["exact"],
|
||||
"owner": ["exact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return TransactionEntity.objects.all().order_by("id")
|
||||
return TransactionEntity.objects.all()
|
||||
|
||||
|
||||
class InstallmentPlanViewSet(viewsets.ModelViewSet):
|
||||
queryset = InstallmentPlan.objects.all()
|
||||
serializer_class = InstallmentPlanSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"account": ["exact"],
|
||||
"type": ["exact"],
|
||||
"description": ["exact", "icontains"],
|
||||
"number_of_installments": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"installment_start": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"installment_total_number": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"start_date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"end_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"recurrence": ["exact"],
|
||||
"installment_amount": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"category": ["exact", "isnull"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"add_description_to_transaction": ["exact"],
|
||||
"add_notes_to_transaction": ["exact"],
|
||||
}
|
||||
search_fields = ["description", "notes"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["-id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return InstallmentPlan.objects.all().order_by("-id")
|
||||
return InstallmentPlan.objects.all()
|
||||
|
||||
|
||||
class RecurringTransactionViewSet(viewsets.ModelViewSet):
|
||||
queryset = RecurringTransaction.objects.all()
|
||||
serializer_class = RecurringTransactionSerializer
|
||||
pagination_class = CustomPageNumberPagination
|
||||
filterset_fields = {
|
||||
"is_paused": ["exact"],
|
||||
"account": ["exact"],
|
||||
"type": ["exact"],
|
||||
"amount": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"description": ["exact", "icontains"],
|
||||
"category": ["exact", "isnull"],
|
||||
"notes": ["exact", "icontains"],
|
||||
"reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"start_date": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"end_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"recurrence_type": ["exact"],
|
||||
"recurrence_interval": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"keep_at_most": ["exact", "gte", "lte", "gt", "lt"],
|
||||
"last_generated_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"last_generated_reference_date": ["exact", "gte", "lte", "gt", "lt", "isnull"],
|
||||
"add_description_to_transaction": ["exact"],
|
||||
"add_notes_to_transaction": ["exact"],
|
||||
}
|
||||
search_fields = ["description", "notes"]
|
||||
ordering_fields = "__all__"
|
||||
ordering = ["-id"]
|
||||
|
||||
def get_queryset(self):
|
||||
return RecurringTransaction.objects.all().order_by("-id")
|
||||
return RecurringTransaction.objects.all()
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
from django.test import TestCase, Client
|
||||
from django.contrib.auth.models import User
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone # Though specific dates are used, good for general test setup
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, Transaction
|
||||
# from apps.calendar_view.utils.calendar import get_transactions_by_day # Not directly testing this util here
|
||||
|
||||
class CalendarViewTests(TestCase): # Renamed from CalendarViewTestCase to CalendarViewTests
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username='testcalendaruser', password='password')
|
||||
self.client = Client()
|
||||
self.client.login(username='testcalendaruser', password='password')
|
||||
|
||||
self.currency_usd = Currency.objects.create(name="CV USD", code="CVUSD", decimal_places=2, prefix="$CV ")
|
||||
self.account_group = AccountGroup.objects.create(name="CV Group", owner=self.user)
|
||||
self.account_usd1 = Account.objects.create(
|
||||
name="CV Account USD 1",
|
||||
currency=self.currency_usd,
|
||||
owner=self.user,
|
||||
group=self.account_group
|
||||
)
|
||||
self.category_cv = TransactionCategory.objects.create(
|
||||
name="CV Cat",
|
||||
owner=self.user,
|
||||
type=TransactionCategory.TransactionType.INFO # Using INFO as a generic type
|
||||
)
|
||||
|
||||
# Transactions for specific dates
|
||||
self.t1 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_cv,
|
||||
date=date(2023, 3, 5), amount=Decimal("10.00"),
|
||||
type=Transaction.Type.EXPENSE, is_paid=True, description="March 5th Tx"
|
||||
)
|
||||
self.t2 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_cv,
|
||||
date=date(2023, 3, 10), amount=Decimal("20.00"),
|
||||
type=Transaction.Type.EXPENSE, is_paid=True, description="March 10th Tx"
|
||||
)
|
||||
self.t3 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_cv,
|
||||
date=date(2023, 4, 5), amount=Decimal("30.00"),
|
||||
type=Transaction.Type.EXPENSE, is_paid=True, description="April 5th Tx"
|
||||
)
|
||||
|
||||
def test_calendar_list_view_context_data(self):
|
||||
# Assumes 'calendar_view:calendar_list' is the correct URL name for the main calendar view
|
||||
# The previous test used 'calendar_view:calendar'. I'll assume 'calendar_list' is the new/correct one.
|
||||
# If the view that shows the grid is named 'calendar', this should be adjusted.
|
||||
# Based on subtask, this is for calendar_list view.
|
||||
url = reverse('calendar_view:calendar_list', kwargs={'month': 3, 'year': 2023})
|
||||
response = self.client.get(url)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn('dates', response.context)
|
||||
|
||||
dates_context = response.context['dates']
|
||||
|
||||
entry_mar5 = next((d for d in dates_context if d['date'] == date(2023, 3, 5)), None)
|
||||
self.assertIsNotNone(entry_mar5, "Date March 5th not found in context.")
|
||||
self.assertIn(self.t1, entry_mar5['transactions'], "Transaction t1 not in March 5th transactions.")
|
||||
|
||||
entry_mar10 = next((d for d in dates_context if d['date'] == date(2023, 3, 10)), None)
|
||||
self.assertIsNotNone(entry_mar10, "Date March 10th not found in context.")
|
||||
self.assertIn(self.t2, entry_mar10['transactions'], "Transaction t2 not in March 10th transactions.")
|
||||
|
||||
for day_data in dates_context:
|
||||
self.assertNotIn(self.t3, day_data['transactions'], f"Transaction t3 (April 5th) found in March {day_data['date']} transactions.")
|
||||
|
||||
def test_calendar_transactions_list_view_specific_day(self):
|
||||
url = reverse('calendar_view:calendar_transactions_list', kwargs={'day': 5, 'month': 3, 'year': 2023})
|
||||
response = self.client.get(url)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn('transactions', response.context)
|
||||
|
||||
transactions_context = response.context['transactions']
|
||||
|
||||
self.assertIn(self.t1, transactions_context, "Transaction t1 (March 5th) not found in context for specific day view.")
|
||||
self.assertNotIn(self.t2, transactions_context, "Transaction t2 (March 10th) found in context for March 5th.")
|
||||
self.assertNotIn(self.t3, transactions_context, "Transaction t3 (April 5th) found in context for March 5th.")
|
||||
self.assertEqual(len(transactions_context), 1)
|
||||
|
||||
def test_calendar_view_authenticated_user_generic_month(self):
|
||||
# This is similar to the old test_calendar_view_authenticated_user.
|
||||
# It tests general access to the main calendar view (which might be 'calendar_list' or 'calendar')
|
||||
# Let's use the 'calendar' name as it was in the old test, assuming it's the main monthly view.
|
||||
# If 'calendar_list' is the actual main monthly view, this might be slightly redundant
|
||||
# with the setup of test_calendar_list_view_context_data but still good for general access check.
|
||||
url = reverse('calendar_view:calendar', args=[2023, 1]) # e.g. Jan 2023
|
||||
response = self.client.get(url)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
# Further context checks could be added here if this view has a different structure than 'calendar_list'
|
||||
self.assertIn('dates', response.context) # Assuming it also provides 'dates'
|
||||
self.assertIn('current_month_date', response.context)
|
||||
self.assertEqual(response.context['current_month_date'], date(2023,1,1))
|
||||
@@ -1,7 +1,26 @@
|
||||
from django.contrib import admin
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
@admin.action(description=_("Make public"))
|
||||
def make_public(modeladmin, request, queryset):
|
||||
queryset.update(visibility="public")
|
||||
|
||||
|
||||
@admin.action(description=_("Make private"))
|
||||
def make_private(modeladmin, request, queryset):
|
||||
queryset.update(visibility="private")
|
||||
|
||||
|
||||
class SharedObjectModelAdmin(admin.ModelAdmin):
|
||||
actions = [make_public, make_private]
|
||||
|
||||
list_display = ("__str__", "visibility", "owner", "get_shared_with")
|
||||
|
||||
@admin.display(description=_("Shared with users"))
|
||||
def get_shared_with(self, obj):
|
||||
return ", ".join([p.email for p in obj.shared_with.all()])
|
||||
|
||||
def get_queryset(self, request):
|
||||
# Use the all_objects manager to show all transactions, including deleted ones
|
||||
return self.model.all_objects.all()
|
||||
|
||||
@@ -1,6 +1,28 @@
|
||||
from django.apps import AppConfig
|
||||
from django.core.cache import cache
|
||||
|
||||
|
||||
class CommonConfig(AppConfig):
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "apps.common"
|
||||
|
||||
def ready(self):
|
||||
from django.contrib import admin
|
||||
from django.contrib.sites.models import Site
|
||||
from allauth.socialaccount.models import (
|
||||
SocialAccount,
|
||||
SocialApp,
|
||||
SocialToken,
|
||||
)
|
||||
|
||||
admin.site.unregister(Site)
|
||||
admin.site.unregister(SocialAccount)
|
||||
admin.site.unregister(SocialApp)
|
||||
admin.site.unregister(SocialToken)
|
||||
|
||||
# Delete the cache for update checks to prevent false-positives when the app is restarted
|
||||
# this will be recreated by the check_for_updates task
|
||||
cache.delete("update_check")
|
||||
|
||||
# Register system checks for required environment variables
|
||||
from apps.common import checks # noqa: F401
|
||||
|
||||
103
app/apps/common/checks.py
Normal file
103
app/apps/common/checks.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Django System Checks for required environment variables.
|
||||
|
||||
This module validates that required environment variables (those without defaults)
|
||||
are present before the application starts.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from django.core.checks import Error, register
|
||||
|
||||
|
||||
# List of environment variables that are required (no default values)
|
||||
# Based on the README.md documentation
|
||||
REQUIRED_ENV_VARS = [
|
||||
("SECRET_KEY", "This is used to provide cryptographic signing."),
|
||||
("SQL_DATABASE", "The name of your postgres database."),
|
||||
]
|
||||
|
||||
# List of environment variables that must be valid integers if set
|
||||
INT_ENV_VARS = [
|
||||
("TASK_WORKERS", "How many workers to have for async tasks."),
|
||||
("SESSION_EXPIRY_TIME", "The age of session cookies, in seconds."),
|
||||
("INTERNAL_PORT", "The port on which the app listens on."),
|
||||
("DJANGO_VITE_DEV_SERVER_PORT", "The port where Vite's dev server is running"),
|
||||
]
|
||||
|
||||
|
||||
@register()
|
||||
def check_required_env_vars(app_configs, **kwargs):
|
||||
"""
|
||||
Check that all required environment variables are set.
|
||||
|
||||
Returns a list of Error objects for any missing required variables.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for var_name, description in REQUIRED_ENV_VARS:
|
||||
value = os.getenv(var_name)
|
||||
if not value:
|
||||
errors.append(
|
||||
Error(
|
||||
f"Required environment variable '{var_name}' is not set.",
|
||||
hint=f"{description} Please set this variable in your .env file or environment.",
|
||||
id="wygiwyh.E001",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@register()
|
||||
def check_int_env_vars(app_configs, **kwargs):
|
||||
"""
|
||||
Check that environment variables that should be integers are valid.
|
||||
|
||||
Returns a list of Error objects for any invalid integer variables.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for var_name, description in INT_ENV_VARS:
|
||||
value = os.getenv(var_name)
|
||||
if value is not None:
|
||||
try:
|
||||
int(value)
|
||||
except ValueError:
|
||||
errors.append(
|
||||
Error(
|
||||
f"Environment variable '{var_name}' must be a valid integer, got '{value}'.",
|
||||
hint=f"{description}",
|
||||
id="wygiwyh.E002",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@register()
|
||||
def check_soft_delete_config(app_configs, **kwargs):
|
||||
"""
|
||||
Check that KEEP_DELETED_TRANSACTIONS_FOR is a valid integer when ENABLE_SOFT_DELETE is enabled.
|
||||
|
||||
Returns a list of Error objects if the configuration is invalid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
enable_soft_delete = os.getenv("ENABLE_SOFT_DELETE", "false").lower() == "true"
|
||||
|
||||
if enable_soft_delete:
|
||||
keep_deleted_for = os.getenv("KEEP_DELETED_TRANSACTIONS_FOR")
|
||||
if keep_deleted_for is not None:
|
||||
try:
|
||||
int(keep_deleted_for)
|
||||
except ValueError:
|
||||
errors.append(
|
||||
Error(
|
||||
f"Environment variable 'KEEP_DELETED_TRANSACTIONS_FOR' must be a valid integer when ENABLE_SOFT_DELETE is enabled, got '{keep_deleted_for}'.",
|
||||
hint="Time in days to keep soft deleted transactions for. Set to 0 to keep all transactions indefinitely.",
|
||||
id="wygiwyh.E003",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
@@ -139,7 +139,6 @@ class DynamicModelMultipleChoiceField(forms.ModelMultipleChoiceField):
|
||||
instance.save()
|
||||
return instance
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise ValidationError(_("Error creating new instance"))
|
||||
|
||||
def clean(self, value):
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from django import forms
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.core.exceptions import ValidationError
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Submit, Div, HTML
|
||||
|
||||
from apps.common.widgets.tom_select import TomSelect, TomSelectMultiple
|
||||
from apps.common.models import SharedObject
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.tom_select import TomSelect, TomSelectMultiple
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import HTML, Div, Field, Layout, Submit
|
||||
from django import forms
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
@@ -39,6 +38,7 @@ class SharedObjectForm(forms.Form):
|
||||
choices=SharedObject.Visibility.choices,
|
||||
required=True,
|
||||
label=_("Visibility"),
|
||||
widget=TomSelect(clear_button=False),
|
||||
help_text=_(
|
||||
"Private: Only shown for the owner and shared users. Only editable by the owner."
|
||||
"<br/>"
|
||||
@@ -48,9 +48,6 @@ class SharedObjectForm(forms.Form):
|
||||
|
||||
class Meta:
|
||||
fields = ["visibility", "shared_with_users"]
|
||||
widgets = {
|
||||
"visibility": TomSelect(clear_button=False),
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get the current user to filter available sharing options
|
||||
@@ -73,12 +70,10 @@ class SharedObjectForm(forms.Form):
|
||||
self.helper.layout = Layout(
|
||||
Field("owner"),
|
||||
Field("visibility"),
|
||||
HTML("<hr>"),
|
||||
HTML('<hr class="hr my-3">'),
|
||||
Field("shared_with_users"),
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Save"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Save"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -9,5 +9,8 @@ def truncate_decimal(value, decimal_places):
|
||||
:param decimal_places: The number of decimal places to keep
|
||||
:return: Truncated Decimal value
|
||||
"""
|
||||
if isinstance(value, (int, float)):
|
||||
value = Decimal(str(value))
|
||||
|
||||
multiplier = Decimal(10**decimal_places)
|
||||
return (value * multiplier).to_integral_value(rounding=ROUND_DOWN) / multiplier
|
||||
|
||||
@@ -5,7 +5,12 @@ from django.utils.formats import get_format as original_get_format
|
||||
def get_format(format_type=None, lang=None, use_l10n=None):
|
||||
user = get_current_user()
|
||||
|
||||
if user and user.is_authenticated and hasattr(user, "settings"):
|
||||
if (
|
||||
user
|
||||
and user.is_authenticated
|
||||
and hasattr(user, "settings")
|
||||
and use_l10n is not False
|
||||
):
|
||||
user_settings = user.settings
|
||||
if format_type == "THOUSAND_SEPARATOR":
|
||||
number_format = getattr(user_settings, "number_format", None)
|
||||
@@ -13,11 +18,13 @@ def get_format(format_type=None, lang=None, use_l10n=None):
|
||||
return "."
|
||||
elif number_format == "CD":
|
||||
return ","
|
||||
elif number_format == "SD" or number_format == "SC":
|
||||
return " "
|
||||
elif format_type == "DECIMAL_SEPARATOR":
|
||||
number_format = getattr(user_settings, "number_format", None)
|
||||
if number_format == "DC":
|
||||
if number_format == "DC" or number_format == "SC":
|
||||
return ","
|
||||
elif number_format == "CD":
|
||||
elif number_format == "CD" or number_format == "SD":
|
||||
return "."
|
||||
elif format_type == "SHORT_DATE_FORMAT":
|
||||
date_format = getattr(user_settings, "date_format", None)
|
||||
|
||||
@@ -36,12 +36,19 @@ class SharedObject(models.Model):
|
||||
related_name="%(class)s_owned",
|
||||
null=True,
|
||||
blank=True,
|
||||
verbose_name=_("Owner"),
|
||||
)
|
||||
visibility = models.CharField(
|
||||
max_length=10, choices=Visibility.choices, default=Visibility.private
|
||||
max_length=10,
|
||||
choices=Visibility.choices,
|
||||
default=Visibility.private,
|
||||
verbose_name=_("Visibility"),
|
||||
)
|
||||
shared_with = models.ManyToManyField(
|
||||
settings.AUTH_USER_MODEL, related_name="%(class)s_shared", blank=True
|
||||
settings.AUTH_USER_MODEL,
|
||||
related_name="%(class)s_shared",
|
||||
blank=True,
|
||||
verbose_name=_("Shared with users"),
|
||||
)
|
||||
|
||||
# Use as abstract base class
|
||||
@@ -65,6 +72,18 @@ class SharedObject(models.Model):
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class OwnedObjectManager(models.Manager):
|
||||
def get_queryset(self):
|
||||
"""Return only objects the user can access"""
|
||||
user = get_current_user()
|
||||
base_qs = super().get_queryset()
|
||||
|
||||
if user and user.is_authenticated:
|
||||
return base_qs.filter(Q(owner=user) | Q(owner=None)).distinct()
|
||||
|
||||
return base_qs
|
||||
|
||||
|
||||
class OwnedObject(models.Model):
|
||||
owner = models.ForeignKey(
|
||||
settings.AUTH_USER_MODEL,
|
||||
|
||||
6
app/apps/common/procrastinate.py
Normal file
6
app/apps/common/procrastinate.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import procrastinate
|
||||
|
||||
|
||||
def on_app_ready(app: procrastinate.App):
|
||||
"""This function is ran upon procrastinate initialization."""
|
||||
...
|
||||
@@ -1,25 +1,34 @@
|
||||
import logging
|
||||
from packaging.version import parse as parse_version, InvalidVersion
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
from django.core import management
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
from django.core.cache import cache
|
||||
|
||||
from procrastinate import builtin_tasks
|
||||
from procrastinate.contrib.django import app
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.periodic(cron="0 4 * * *")
|
||||
@app.task(queueing_lock="remove_old_jobs", pass_context=True, name="remove_old_jobs")
|
||||
@app.task(
|
||||
lock="remove_old_jobs",
|
||||
queueing_lock="remove_old_jobs",
|
||||
pass_context=True,
|
||||
name="remove_old_jobs",
|
||||
)
|
||||
async def remove_old_jobs(context, timestamp):
|
||||
try:
|
||||
return await builtin_tasks.remove_old_jobs(
|
||||
context,
|
||||
max_hours=744,
|
||||
remove_error=True,
|
||||
remove_failed=True,
|
||||
remove_cancelled=True,
|
||||
remove_aborted=True,
|
||||
)
|
||||
@@ -32,7 +41,11 @@ async def remove_old_jobs(context, timestamp):
|
||||
|
||||
|
||||
@app.periodic(cron="0 6 1 * *")
|
||||
@app.task(queueing_lock="remove_expired_sessions", name="remove_expired_sessions")
|
||||
@app.task(
|
||||
lock="remove_expired_sessions",
|
||||
queueing_lock="remove_expired_sessions",
|
||||
name="remove_expired_sessions",
|
||||
)
|
||||
async def remove_expired_sessions(timestamp=None):
|
||||
"""Cleanup expired sessions by using Django management command."""
|
||||
try:
|
||||
@@ -45,7 +58,7 @@ async def remove_expired_sessions(timestamp=None):
|
||||
|
||||
|
||||
@app.periodic(cron="0 8 * * *")
|
||||
@app.task(name="reset_demo_data")
|
||||
@app.task(lock="reset_demo_data", name="reset_demo_data")
|
||||
def reset_demo_data(timestamp=None):
|
||||
"""
|
||||
Wipes the database and loads fresh demo data if DEMO mode is active.
|
||||
@@ -79,3 +92,47 @@ def reset_demo_data(timestamp=None):
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during daily demo data reset: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@app.periodic(cron="0 */12 * * *") # Every 12 hours
|
||||
@app.task(lock="check_for_updates", name="check_for_updates")
|
||||
def check_for_updates(timestamp=None):
|
||||
if not settings.CHECK_FOR_UPDATES:
|
||||
return "CHECK_FOR_UPDATES is disabled"
|
||||
|
||||
url = "https://api.github.com/repos/eitchtee/WYGIWYH/releases/latest"
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=60)
|
||||
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
||||
|
||||
data = response.json()
|
||||
latest_version = data.get("tag_name")
|
||||
|
||||
if latest_version:
|
||||
try:
|
||||
current_v = parse_version(settings.APP_VERSION)
|
||||
except InvalidVersion:
|
||||
current_v = parse_version("0.0.0")
|
||||
try:
|
||||
latest_v = parse_version(latest_version)
|
||||
except InvalidVersion:
|
||||
latest_v = parse_version("0.0.0")
|
||||
|
||||
update_info = {
|
||||
"update_available": False,
|
||||
"current_version": str(current_v),
|
||||
"latest_version": str(latest_v),
|
||||
}
|
||||
|
||||
if latest_v > current_v:
|
||||
update_info["update_available"] = True
|
||||
|
||||
# Cache the entire dictionary
|
||||
cache.set("update_check", update_info, 60 * 60 * 25)
|
||||
logger.info(f"Update check complete. Result: {update_info}")
|
||||
else:
|
||||
logger.warning("Could not find 'tag_name' in GitHub API response.")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to fetch updates from GitHub: {e}")
|
||||
|
||||
17
app/apps/common/templatetags/cache_access.py
Normal file
17
app/apps/common/templatetags/cache_access.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# core/templatetags/update_tags.py
|
||||
from django import template
|
||||
from django.core.cache import cache
|
||||
|
||||
register = template.Library()
|
||||
|
||||
|
||||
@register.simple_tag
|
||||
def get_update_check():
|
||||
"""
|
||||
Retrieves the update status dictionary from the cache.
|
||||
Returns a default dictionary if nothing is found.
|
||||
"""
|
||||
return cache.get("update_check") or {
|
||||
"update_available": False,
|
||||
"latest_version": "N/A",
|
||||
}
|
||||
13
app/apps/common/templatetags/crispy_extra.py
Normal file
13
app/apps/common/templatetags/crispy_extra.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from django import forms, template
|
||||
|
||||
register = template.Library()
|
||||
|
||||
|
||||
@register.filter
|
||||
def is_input(field):
|
||||
return isinstance(field.field.widget, forms.TextInput)
|
||||
|
||||
|
||||
@register.filter
|
||||
def is_textarea(field):
|
||||
return isinstance(field.field.widget, forms.Textarea)
|
||||
@@ -11,7 +11,7 @@ def toast_bg(tags):
|
||||
elif "warning" in tags:
|
||||
return "warning"
|
||||
elif "error" in tags:
|
||||
return "danger"
|
||||
return "error"
|
||||
elif "info" in tags:
|
||||
return "info"
|
||||
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
from django.test import TestCase, RequestFactory
|
||||
from django.template import Template, Context
|
||||
from django.urls import reverse, resolve, NoReverseMatch
|
||||
from django.contrib.auth.models import User
|
||||
from decimal import Decimal # Keep existing imports if they are from other tests
|
||||
from app.apps.common.functions.decimals import truncate_decimal # Keep existing imports
|
||||
|
||||
# Helper to create a dummy request with resolver_match
|
||||
def setup_request_for_view(factory, view_name_or_url, user=None, namespace=None, view_name_for_resolver=None):
|
||||
try:
|
||||
url = reverse(view_name_or_url)
|
||||
except NoReverseMatch:
|
||||
url = view_name_or_url # Assume it's already a URL path
|
||||
|
||||
request = factory.get(url)
|
||||
if user:
|
||||
request.user = user
|
||||
|
||||
try:
|
||||
# For resolver_match, we need to simulate how Django does it.
|
||||
# It needs specific view_name and namespace if applicable.
|
||||
# If view_name_for_resolver is provided, use that for resolving,
|
||||
# otherwise, assume view_name_or_url is the view name for resolver_match.
|
||||
resolver_match_source = view_name_for_resolver if view_name_for_resolver else view_name_or_url
|
||||
|
||||
# If it's a namespaced view name like 'app:view', resolve might handle it directly.
|
||||
# If namespace is separately provided, it means the view_name itself is not namespaced.
|
||||
resolved_match = resolve(url) # Resolve the URL to get func, args, kwargs, etc.
|
||||
|
||||
# Ensure resolver_match has the correct attributes, especially 'view_name' and 'namespace'
|
||||
if hasattr(resolved_match, 'view_name'):
|
||||
if ':' in resolved_match.view_name and not namespace: # e.g. 'app_name:view_name'
|
||||
request.resolver_match = resolved_match
|
||||
elif namespace and resolved_match.namespace == namespace and resolved_match.url_name == resolver_match_source.split(':')[-1]:
|
||||
request.resolver_match = resolved_match
|
||||
elif not namespace and resolved_match.url_name == resolver_match_source:
|
||||
request.resolver_match = resolved_match
|
||||
else: # Fallback or if specific view_name/namespace parts are needed for resolver_match
|
||||
# This part is tricky without knowing the exact structure of resolver_match expected by the tag
|
||||
# Forcing the view_name and namespace if they are explicitly passed.
|
||||
if namespace:
|
||||
resolved_match.namespace = namespace
|
||||
if view_name_for_resolver: # This should be the non-namespaced view name part
|
||||
resolved_match.view_name = f"{namespace}:{view_name_for_resolver.split(':')[-1]}" if namespace else view_name_for_resolver.split(':')[-1]
|
||||
resolved_match.url_name = view_name_for_resolver.split(':')[-1]
|
||||
|
||||
request.resolver_match = resolved_match
|
||||
|
||||
else: # Fallback if resolve() doesn't directly give a full resolver_match object as expected
|
||||
request.resolver_match = None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not resolve URL or set resolver_match for '{view_name_or_url}' (or '{view_name_for_resolver}') for test setup: {e}")
|
||||
request.resolver_match = None
|
||||
return request
|
||||
|
||||
class CommonTestCase(TestCase): # Keep existing test class if other tests depend on it
|
||||
def test_example(self): # Example of an old test
|
||||
self.assertEqual(1 + 1, 2)
|
||||
|
||||
def test_truncate_decimal_function(self): # Example of an old test from problem description
|
||||
test_cases = [
|
||||
(Decimal('123.456'), 0, Decimal('123')),
|
||||
(Decimal('123.456'), 1, Decimal('123.4')),
|
||||
(Decimal('123.456'), 2, Decimal('123.45')),
|
||||
]
|
||||
for value, places, expected in test_cases:
|
||||
with self.subTest(value=value, places=places, expected=expected):
|
||||
self.assertEqual(truncate_decimal(value, places), expected)
|
||||
|
||||
|
||||
class CommonTemplateTagsTests(TestCase):
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user('testuser', 'password123')
|
||||
|
||||
# Using view names that should exist in a typical Django project with auth
|
||||
# Ensure these URLs are part of your project's urlpatterns for tests to pass.
|
||||
self.view_name_login = 'login' # Typically 'login' or 'account_login'
|
||||
self.namespace_login = None # Often no namespace for basic auth views, or 'account'
|
||||
|
||||
self.view_name_admin = 'admin:index' # Admin index
|
||||
self.namespace_admin = 'admin'
|
||||
|
||||
# Check if these can be reversed, skip tests if not.
|
||||
try:
|
||||
reverse(self.view_name_login)
|
||||
except NoReverseMatch:
|
||||
self.view_name_login = None # Mark as unusable
|
||||
print(f"Warning: Could not reverse '{self.view_name_login}'. Some active_link tests might be skipped.")
|
||||
try:
|
||||
reverse(self.view_name_admin)
|
||||
except NoReverseMatch:
|
||||
self.view_name_admin = None # Mark as unusable
|
||||
print(f"Warning: Could not reverse '{self.view_name_admin}'. Some active_link tests might be skipped.")
|
||||
|
||||
def test_active_link_view_match(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='" + self.view_name_login + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "active")
|
||||
|
||||
def test_active_link_view_no_match(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='non_existent_view_name' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "")
|
||||
|
||||
def test_active_link_view_match_custom_class(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='" + self.view_name_login + "' css_class='custom-active' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "custom-active")
|
||||
|
||||
def test_active_link_view_no_match_inactive_class(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='non_existent_view_name' inactive_class='custom-inactive' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "custom-inactive")
|
||||
|
||||
def test_active_link_namespace_match(self):
|
||||
if not self.view_name_admin: self.skipTest("Admin URL not reversible.")
|
||||
# The view_name_admin is already namespaced 'admin:index'
|
||||
request = setup_request_for_view(self.factory, self.view_name_admin, self.user,
|
||||
namespace=self.namespace_admin, view_name_for_resolver=self.view_name_admin)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_admin}.")
|
||||
# Ensure the resolver_match has the namespace set correctly by setup_request_for_view
|
||||
self.assertEqual(request.resolver_match.namespace, self.namespace_admin, "Namespace not correctly set in resolver_match for test.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link namespaces='" + self.namespace_admin + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "active")
|
||||
|
||||
def test_active_link_multiple_views_one_match(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible.")
|
||||
request = setup_request_for_view(self.factory, self.view_name_login, self.user,
|
||||
namespace=self.namespace_login, view_name_for_resolver=self.view_name_login)
|
||||
if not request.resolver_match: self.skipTest(f"Could not set resolver_match for {self.view_name_login}.")
|
||||
|
||||
template_str = "{% load active_link %} {% active_link views='other_app:other_view||" + self.view_name_login + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "active")
|
||||
|
||||
def test_active_link_no_request_in_context(self):
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible for placeholder view name.")
|
||||
template_str = "{% load active_link %} {% active_link views='" + self.view_name_login + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({})) # Empty context, no 'request'
|
||||
self.assertEqual(rendered.strip(), "")
|
||||
|
||||
def test_active_link_request_without_resolver_match(self):
|
||||
request = self.factory.get('/some_unresolved_url/') # This URL won't resolve
|
||||
request.user = self.user
|
||||
request.resolver_match = None # Explicitly set to None, as resolve() would fail
|
||||
|
||||
if not self.view_name_login: self.skipTest("Login URL not reversible for placeholder view name.")
|
||||
template_str = "{% load active_link %} {% active_link views='" + self.view_name_login + "' %}"
|
||||
template = Template(template_str)
|
||||
rendered = template.render(Context({'request': request}))
|
||||
self.assertEqual(rendered.strip(), "")
|
||||
@@ -91,6 +91,12 @@ def month_year_picker(request):
|
||||
for date in all_months
|
||||
]
|
||||
|
||||
today_url = (
|
||||
reverse(url, kwargs={"month": current_date.month, "year": current_date.year})
|
||||
if url
|
||||
else ""
|
||||
)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"common/fragments/month_year_picker.html",
|
||||
@@ -98,6 +104,7 @@ def month_year_picker(request):
|
||||
"month_year_data": result,
|
||||
"current_month": current_month,
|
||||
"current_year": current_year,
|
||||
"today_url": today_url,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
5
app/apps/common/widgets/crispy/daisyui.py
Normal file
5
app/apps/common/widgets/crispy/daisyui.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from crispy_forms.layout import Field
|
||||
|
||||
|
||||
class Switch(Field):
|
||||
template = "crispy-daisyui/layout/switch.html"
|
||||
@@ -1,15 +1,14 @@
|
||||
import datetime
|
||||
|
||||
from django.forms import widgets
|
||||
from django.utils import formats, translation, dates
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.functions.format import get_format
|
||||
from apps.common.utils.django import (
|
||||
django_to_python_datetime,
|
||||
django_to_airdatepicker_datetime,
|
||||
django_to_airdatepicker_datetime_separated,
|
||||
django_to_python_datetime,
|
||||
)
|
||||
from apps.common.functions.format import get_format
|
||||
from django.forms import widgets
|
||||
from django.utils import dates, formats, translation
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class AirDatePickerInput(widgets.DateInput):
|
||||
@@ -37,7 +36,9 @@ class AirDatePickerInput(widgets.DateInput):
|
||||
def _get_current_language():
|
||||
"""Get current language code in format compatible with AirDatepicker"""
|
||||
lang_code = translation.get_language()
|
||||
# AirDatepicker uses simple language codes
|
||||
# AirDatepicker uses simple language codes, except for pt-br
|
||||
if lang_code.lower() == "pt-br":
|
||||
return "pt-BR"
|
||||
return lang_code.split("-")[0]
|
||||
|
||||
def _get_format(self):
|
||||
@@ -50,6 +51,8 @@ class AirDatePickerInput(widgets.DateInput):
|
||||
def build_attrs(self, base_attrs, extra_attrs=None):
|
||||
attrs = super().build_attrs(base_attrs, extra_attrs)
|
||||
|
||||
attrs["class"] = attrs.get("class", "") + " input"
|
||||
|
||||
attrs["data-now-button-txt"] = _("Today")
|
||||
attrs["data-auto-close"] = str(self.auto_close).lower()
|
||||
attrs["data-clear-button"] = str(self.clear_button).lower()
|
||||
|
||||
@@ -35,8 +35,8 @@ class ArbitraryDecimalDisplayNumberInput(forms.TextInput):
|
||||
self.attrs.update(
|
||||
{
|
||||
"x-data": "",
|
||||
"x-mask:dynamic": f"$money($input, '{get_format('DECIMAL_SEPARATOR')}', "
|
||||
f"'{get_format('THOUSAND_SEPARATOR')}', '30')",
|
||||
"x-mask:dynamic": f"$money($input, '{get_format('DECIMAL_SEPARATOR')}', '{get_format('THOUSAND_SEPARATOR')}', '30')",
|
||||
"x-on:keyup": "if (!['Control', 'Shift', 'Alt', 'Meta'].includes($event.key) && !(($event.ctrlKey || $event.metaKey) && $event.key.toLowerCase() === 'a')) $el.dispatchEvent(new Event('input'))",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from django.forms import widgets, SelectMultiple
|
||||
from django.forms import SelectMultiple, widgets
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
@@ -17,7 +17,7 @@ class TomSelect(widgets.Select):
|
||||
checkboxes=False,
|
||||
group_by=None,
|
||||
*args,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(attrs, *args, **kwargs)
|
||||
self.remove_button = remove_button
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.currencies.exchange_rates.providers import (
|
||||
SynthFinanceProvider,
|
||||
SynthFinanceStockProvider,
|
||||
CoinGeckoFreeProvider,
|
||||
CoinGeckoProProvider,
|
||||
TransitiveRateProvider,
|
||||
)
|
||||
import apps.currencies.exchange_rates.providers as providers
|
||||
from apps.currencies.models import ExchangeRateService, ExchangeRate, Currency
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,11 +11,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Map service types to provider classes
|
||||
PROVIDER_MAPPING = {
|
||||
"synth_finance": SynthFinanceProvider,
|
||||
"synth_finance_stock": SynthFinanceStockProvider,
|
||||
"coingecko_free": CoinGeckoFreeProvider,
|
||||
"coingecko_pro": CoinGeckoProProvider,
|
||||
"transitive": TransitiveRateProvider,
|
||||
"coingecko_free": providers.CoinGeckoFreeProvider,
|
||||
"coingecko_pro": providers.CoinGeckoProProvider,
|
||||
"transitive": providers.TransitiveRateProvider,
|
||||
"frankfurter": providers.FrankfurterProvider,
|
||||
"twelvedata": providers.TwelveDataProvider,
|
||||
"twelvedatamarkets": providers.TwelveDataMarketsProvider,
|
||||
}
|
||||
|
||||
|
||||
@@ -203,25 +197,70 @@ class ExchangeRateFetcher:
|
||||
|
||||
if provider.rates_inverted:
|
||||
# If rates are inverted, we need to swap currencies
|
||||
ExchangeRate.objects.create(
|
||||
from_currency=to_currency,
|
||||
to_currency=from_currency,
|
||||
rate=rate,
|
||||
date=timezone.now(),
|
||||
)
|
||||
if service.singleton:
|
||||
# Try to get the last automatically created exchange rate
|
||||
exchange_rate = (
|
||||
ExchangeRate.objects.filter(
|
||||
automatic=True,
|
||||
from_currency=to_currency,
|
||||
to_currency=from_currency,
|
||||
)
|
||||
.order_by("-date")
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
exchange_rate = None
|
||||
|
||||
if not exchange_rate:
|
||||
ExchangeRate.objects.create(
|
||||
automatic=True,
|
||||
from_currency=to_currency,
|
||||
to_currency=from_currency,
|
||||
rate=rate,
|
||||
date=timezone.now(),
|
||||
)
|
||||
else:
|
||||
exchange_rate.rate = rate
|
||||
exchange_rate.date = timezone.now()
|
||||
exchange_rate.save()
|
||||
|
||||
processed_pairs.add((to_currency.id, from_currency.id))
|
||||
else:
|
||||
# If rates are not inverted, we can use them as is
|
||||
ExchangeRate.objects.create(
|
||||
from_currency=from_currency,
|
||||
to_currency=to_currency,
|
||||
rate=rate,
|
||||
date=timezone.now(),
|
||||
)
|
||||
if service.singleton:
|
||||
# Try to get the last automatically created exchange rate
|
||||
exchange_rate = (
|
||||
ExchangeRate.objects.filter(
|
||||
automatic=True,
|
||||
from_currency=from_currency,
|
||||
to_currency=to_currency,
|
||||
)
|
||||
.order_by("-date")
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
exchange_rate = None
|
||||
|
||||
if not exchange_rate:
|
||||
ExchangeRate.objects.create(
|
||||
automatic=True,
|
||||
from_currency=from_currency,
|
||||
to_currency=to_currency,
|
||||
rate=rate,
|
||||
date=timezone.now(),
|
||||
)
|
||||
else:
|
||||
exchange_rate.rate = rate
|
||||
exchange_rate.date = timezone.now()
|
||||
exchange_rate.save()
|
||||
|
||||
processed_pairs.add((from_currency.id, to_currency.id))
|
||||
|
||||
service.last_fetch = timezone.now()
|
||||
service.failure_count = 0
|
||||
service.save()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching rates for {service.name}: {e}")
|
||||
service.failure_count += 1
|
||||
service.save()
|
||||
|
||||
@@ -13,70 +13,6 @@ from apps.currencies.exchange_rates.base import ExchangeRateProvider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SynthFinanceProvider(ExchangeRateProvider):
|
||||
"""Implementation for Synth Finance API (synthfinance.com)"""
|
||||
|
||||
BASE_URL = "https://api.synthfinance.com/rates/live"
|
||||
rates_inverted = False # SynthFinance returns non-inverted rates
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
super().__init__(api_key)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
||||
|
||||
def get_rates(
|
||||
self, target_currencies: QuerySet, exchange_currencies: set
|
||||
) -> List[Tuple[Currency, Currency, Decimal]]:
|
||||
results = []
|
||||
currency_groups = {}
|
||||
for currency in target_currencies:
|
||||
if currency.exchange_currency in exchange_currencies:
|
||||
group = currency_groups.setdefault(currency.exchange_currency.code, [])
|
||||
group.append(currency)
|
||||
|
||||
for base_currency, currencies in currency_groups.items():
|
||||
try:
|
||||
to_currencies = ",".join(
|
||||
currency.code
|
||||
for currency in currencies
|
||||
if currency.code != base_currency
|
||||
)
|
||||
response = self.session.get(
|
||||
f"{self.BASE_URL}",
|
||||
params={"from": base_currency, "to": to_currencies},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
rates = data["data"]["rates"]
|
||||
|
||||
for currency in currencies:
|
||||
if currency.code == base_currency:
|
||||
rate = Decimal("1")
|
||||
else:
|
||||
rate = Decimal(str(rates[currency.code]))
|
||||
# Return the rate as is, without inversion
|
||||
results.append((currency.exchange_currency, currency, rate))
|
||||
|
||||
credits_used = data["meta"]["credits_used"]
|
||||
credits_remaining = data["meta"]["credits_remaining"]
|
||||
logger.info(
|
||||
f"Synth Finance API call: {credits_used} credits used, {credits_remaining} remaining"
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Error fetching rates from Synth Finance API for base {base_currency}: {e}"
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"Unexpected response structure from Synth Finance API for base {base_currency}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error processing Synth Finance data for base {base_currency}: {e}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class CoinGeckoFreeProvider(ExchangeRateProvider):
|
||||
"""Implementation for CoinGecko Free API"""
|
||||
|
||||
@@ -152,71 +88,6 @@ class CoinGeckoProProvider(CoinGeckoFreeProvider):
|
||||
self.session.headers.update({"x-cg-pro-api-key": api_key})
|
||||
|
||||
|
||||
class SynthFinanceStockProvider(ExchangeRateProvider):
|
||||
"""Implementation for Synth Finance API Real-Time Prices endpoint (synthfinance.com)"""
|
||||
|
||||
BASE_URL = "https://api.synthfinance.com/tickers"
|
||||
rates_inverted = True
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
super().__init__(api_key)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(
|
||||
{"Authorization": f"Bearer {self.api_key}", "accept": "application/json"}
|
||||
)
|
||||
|
||||
def get_rates(
|
||||
self, target_currencies: QuerySet, exchange_currencies: set
|
||||
) -> List[Tuple[Currency, Currency, Decimal]]:
|
||||
results = []
|
||||
|
||||
for currency in target_currencies:
|
||||
if currency.exchange_currency not in exchange_currencies:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Same currency has rate of 1
|
||||
if currency.code == currency.exchange_currency.code:
|
||||
rate = Decimal("1")
|
||||
results.append((currency.exchange_currency, currency, rate))
|
||||
continue
|
||||
|
||||
# Fetch real-time price for this ticker
|
||||
response = self.session.get(
|
||||
f"{self.BASE_URL}/{currency.code}/real-time"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Use fair market value as the rate
|
||||
rate = Decimal(data["data"]["fair_market_value"])
|
||||
results.append((currency.exchange_currency, currency, rate))
|
||||
|
||||
# Log API usage
|
||||
credits_used = data["meta"]["credits_used"]
|
||||
credits_remaining = data["meta"]["credits_remaining"]
|
||||
logger.info(
|
||||
f"Synth Finance API call for {currency.code}: {credits_used} credits used, {credits_remaining} remaining"
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Error fetching rate from Synth Finance API for ticker {currency.code}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"Unexpected response structure from Synth Finance API for ticker {currency.code}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error processing Synth Finance data for ticker {currency.code}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TransitiveRateProvider(ExchangeRateProvider):
|
||||
"""Calculates exchange rates through paths of existing rates"""
|
||||
|
||||
@@ -306,3 +177,329 @@ class TransitiveRateProvider(ExchangeRateProvider):
|
||||
queue.append((neighbor, path + [neighbor], current_rate * rate))
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
class FrankfurterProvider(ExchangeRateProvider):
|
||||
"""Implementation for the Frankfurter API (frankfurter.dev)"""
|
||||
|
||||
BASE_URL = "https://api.frankfurter.dev/v1/latest"
|
||||
rates_inverted = (
|
||||
False # Frankfurter returns non-inverted rates (e.g., 1 EUR = 1.1 USD)
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""
|
||||
Initializes the provider. The Frankfurter API does not require an API key,
|
||||
so the api_key parameter is ignored.
|
||||
"""
|
||||
super().__init__(api_key)
|
||||
self.session = requests.Session()
|
||||
|
||||
@classmethod
|
||||
def requires_api_key(cls) -> bool:
|
||||
return False
|
||||
|
||||
def get_rates(
|
||||
self, target_currencies: QuerySet, exchange_currencies: set
|
||||
) -> List[Tuple[Currency, Currency, Decimal]]:
|
||||
results = []
|
||||
currency_groups = {}
|
||||
# Group target currencies by their exchange (base) currency to minimize API calls
|
||||
for currency in target_currencies:
|
||||
if currency.exchange_currency in exchange_currencies:
|
||||
group = currency_groups.setdefault(currency.exchange_currency.code, [])
|
||||
group.append(currency)
|
||||
|
||||
# Make one API call for each base currency
|
||||
for base_currency, currencies in currency_groups.items():
|
||||
try:
|
||||
# Create a comma-separated list of target currency codes
|
||||
to_currencies = ",".join(
|
||||
currency.code
|
||||
for currency in currencies
|
||||
if currency.code != base_currency
|
||||
)
|
||||
|
||||
# If there are no target currencies other than the base, skip the API call
|
||||
if not to_currencies:
|
||||
# Handle the case where the only request is for the base rate (e.g., USD to USD)
|
||||
for currency in currencies:
|
||||
if currency.code == base_currency:
|
||||
results.append(
|
||||
(currency.exchange_currency, currency, Decimal("1"))
|
||||
)
|
||||
continue
|
||||
|
||||
response = self.session.get(
|
||||
self.BASE_URL,
|
||||
params={"base": base_currency, "symbols": to_currencies},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
rates = data["rates"]
|
||||
|
||||
# Process the returned rates
|
||||
for currency in currencies:
|
||||
if currency.code == base_currency:
|
||||
# The rate for the base currency to itself is always 1
|
||||
rate = Decimal("1")
|
||||
else:
|
||||
rate = Decimal(str(rates[currency.code]))
|
||||
|
||||
results.append((currency.exchange_currency, currency, rate))
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Error fetching rates from Frankfurter API for base {base_currency}: {e}"
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"Unexpected response structure from Frankfurter API for base {base_currency}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error processing Frankfurter data for base {base_currency}: {e}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class TwelveDataProvider(ExchangeRateProvider):
|
||||
"""Implementation for the Twelve Data API (twelvedata.com)"""
|
||||
|
||||
BASE_URL = "https://api.twelvedata.com/exchange_rate"
|
||||
rates_inverted = (
|
||||
False # The API returns direct rates, e.g., for EUR/USD it's 1 EUR = X USD
|
||||
)
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""
|
||||
Initializes the provider with an API key and a requests session.
|
||||
"""
|
||||
super().__init__(api_key)
|
||||
self.session = requests.Session()
|
||||
|
||||
@classmethod
|
||||
def requires_api_key(cls) -> bool:
|
||||
"""This provider requires an API key."""
|
||||
return True
|
||||
|
||||
def get_rates(
|
||||
self, target_currencies: QuerySet, exchange_currencies: set
|
||||
) -> List[Tuple[Currency, Currency, Decimal]]:
|
||||
"""
|
||||
Fetches exchange rates from the Twelve Data API for the given currency pairs.
|
||||
|
||||
This provider makes one API call for each requested currency pair.
|
||||
"""
|
||||
results = []
|
||||
|
||||
for target_currency in target_currencies:
|
||||
# Ensure the target currency's exchange currency is one we're interested in
|
||||
if target_currency.exchange_currency not in exchange_currencies:
|
||||
continue
|
||||
|
||||
base_currency = target_currency.exchange_currency
|
||||
|
||||
# The exchange rate for the same currency is always 1
|
||||
if base_currency.code == target_currency.code:
|
||||
rate = Decimal("1")
|
||||
results.append((base_currency, target_currency, rate))
|
||||
continue
|
||||
|
||||
# Construct the symbol in the format "BASE/TARGET", e.g., "EUR/USD"
|
||||
symbol = f"{base_currency.code}/{target_currency.code}"
|
||||
|
||||
try:
|
||||
params = {
|
||||
"symbol": symbol,
|
||||
"apikey": self.api_key,
|
||||
}
|
||||
|
||||
response = self.session.get(self.BASE_URL, params=params)
|
||||
response.raise_for_status() # Raise an HTTPError for bad responses (4xx or 5xx)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# The API may return an error message in a JSON object
|
||||
if "rate" not in data:
|
||||
error_message = data.get("message", "Rate not found in response.")
|
||||
logger.error(
|
||||
f"Could not fetch rate for {symbol} from Twelve Data: {error_message}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert the rate to a Decimal for precision
|
||||
rate = Decimal(str(data["rate"]))
|
||||
results.append((base_currency, target_currency, rate))
|
||||
|
||||
logger.info(f"Successfully fetched rate for {symbol} from Twelve Data.")
|
||||
|
||||
time.sleep(
|
||||
60
|
||||
) # We sleep every pair as to not step over TwelveData's minute limit
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"Error fetching rate from Twelve Data API for symbol {symbol}: {e}"
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
f"Unexpected response structure from Twelve Data API for symbol {symbol}: Missing key {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"An unexpected error occurred while processing Twelve Data for {symbol}: {e}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TwelveDataMarketsProvider(ExchangeRateProvider):
|
||||
"""
|
||||
Provides prices for market instruments (stocks, ETFs, etc.) using the Twelve Data API.
|
||||
|
||||
This provider performs a multi-step process:
|
||||
1. Parses instrument codes which can be symbols, FIGI, CUSIP, or ISIN.
|
||||
2. For CUSIPs, it defaults the currency to USD. For all others, it searches
|
||||
for the instrument to determine its native trading currency.
|
||||
3. Fetches the latest price for the instrument in its native currency.
|
||||
4. Converts the price to the requested target exchange currency.
|
||||
"""
|
||||
|
||||
SYMBOL_SEARCH_URL = "https://api.twelvedata.com/symbol_search"
|
||||
PRICE_URL = "https://api.twelvedata.com/price"
|
||||
EXCHANGE_RATE_URL = "https://api.twelvedata.com/exchange_rate"
|
||||
|
||||
rates_inverted = True
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
super().__init__(api_key)
|
||||
self.session = requests.Session()
|
||||
|
||||
@classmethod
|
||||
def requires_api_key(cls) -> bool:
|
||||
return True
|
||||
|
||||
def _parse_code(self, raw_code: str) -> Tuple[str, str]:
|
||||
"""Parses the raw code to determine its type and value."""
|
||||
if raw_code.startswith("figi:"):
|
||||
return "figi", raw_code.removeprefix("figi:")
|
||||
if raw_code.startswith("cusip:"):
|
||||
return "cusip", raw_code.removeprefix("cusip:")
|
||||
if raw_code.startswith("isin:"):
|
||||
return "isin", raw_code.removeprefix("isin:")
|
||||
return "symbol", raw_code
|
||||
|
||||
def get_rates(
|
||||
self, target_currencies: QuerySet, exchange_currencies: set
|
||||
) -> List[Tuple[Currency, Currency, Decimal]]:
|
||||
results = []
|
||||
|
||||
for asset in target_currencies:
|
||||
if asset.exchange_currency not in exchange_currencies:
|
||||
continue
|
||||
|
||||
code_type, code_value = self._parse_code(asset.code)
|
||||
original_currency_code = None
|
||||
|
||||
try:
|
||||
# Determine the instrument's native currency
|
||||
if code_type == "cusip":
|
||||
# CUSIP codes always default to USD
|
||||
original_currency_code = "USD"
|
||||
logger.info(f"Defaulting CUSIP {code_value} to USD currency.")
|
||||
else:
|
||||
# For all other types, find currency via symbol search
|
||||
search_params = {"symbol": code_value, "apikey": "demo"}
|
||||
search_res = self.session.get(
|
||||
self.SYMBOL_SEARCH_URL, params=search_params
|
||||
)
|
||||
search_res.raise_for_status()
|
||||
search_data = search_res.json()
|
||||
|
||||
if not search_data.get("data"):
|
||||
logger.warning(
|
||||
f"TwelveDataMarkets: Symbol search for '{code_value}' returned no results."
|
||||
)
|
||||
continue
|
||||
|
||||
instrument_data = search_data["data"][0]
|
||||
original_currency_code = instrument_data.get("currency")
|
||||
|
||||
if not original_currency_code:
|
||||
logger.error(
|
||||
f"TwelveDataMarkets: Could not determine original currency for '{code_value}'."
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the instrument's price in its native currency
|
||||
price_params = {code_type: code_value, "apikey": self.api_key}
|
||||
price_res = self.session.get(self.PRICE_URL, params=price_params)
|
||||
price_res.raise_for_status()
|
||||
price_data = price_res.json()
|
||||
|
||||
if "price" not in price_data:
|
||||
error_message = price_data.get(
|
||||
"message", "Price key not found in response"
|
||||
)
|
||||
logger.error(
|
||||
f"TwelveDataMarkets: Could not get price for {code_type} '{code_value}': {error_message}"
|
||||
)
|
||||
continue
|
||||
|
||||
price_in_original_currency = Decimal(price_data["price"])
|
||||
|
||||
# Convert price to the target exchange currency
|
||||
target_exchange_currency = asset.exchange_currency
|
||||
|
||||
if (
|
||||
original_currency_code.upper()
|
||||
== target_exchange_currency.code.upper()
|
||||
):
|
||||
final_price = price_in_original_currency
|
||||
else:
|
||||
rate_symbol = (
|
||||
f"{original_currency_code}/{target_exchange_currency.code}"
|
||||
)
|
||||
rate_params = {"symbol": rate_symbol, "apikey": self.api_key}
|
||||
rate_res = self.session.get(
|
||||
self.EXCHANGE_RATE_URL, params=rate_params
|
||||
)
|
||||
rate_res.raise_for_status()
|
||||
rate_data = rate_res.json()
|
||||
|
||||
if "rate" not in rate_data:
|
||||
error_message = rate_data.get(
|
||||
"message", "Rate key not found in response"
|
||||
)
|
||||
logger.error(
|
||||
f"TwelveDataMarkets: Could not get conversion rate for '{rate_symbol}': {error_message}"
|
||||
)
|
||||
continue
|
||||
|
||||
conversion_rate = Decimal(str(rate_data["rate"]))
|
||||
final_price = price_in_original_currency * conversion_rate
|
||||
|
||||
results.append((target_exchange_currency, asset, final_price))
|
||||
logger.info(
|
||||
f"Successfully processed price for {asset.code} as {final_price} {target_exchange_currency.code}"
|
||||
)
|
||||
|
||||
time.sleep(
|
||||
60
|
||||
) # We sleep every pair as to not step over TwelveData's minute limit
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(
|
||||
f"TwelveDataMarkets: API request failed for {code_value}: {e}"
|
||||
)
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(
|
||||
f"TwelveDataMarkets: Error processing API response for {code_value}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"TwelveDataMarkets: An unexpected error occurred for {code_value}: {e}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from crispy_bootstrap5.bootstrap5 import Switch
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Row, Column
|
||||
from django import forms
|
||||
from django.forms import CharField
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.crispy.daisyui import Switch
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.datepicker import AirDateTimePickerInput
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.currencies.models import Currency, ExchangeRate, ExchangeRateService
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Column, Layout, Row
|
||||
from django import forms
|
||||
from django.forms import CharField
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class CurrencyForm(forms.ModelForm):
|
||||
@@ -26,6 +25,7 @@ class CurrencyForm(forms.ModelForm):
|
||||
"suffix",
|
||||
"code",
|
||||
"exchange_currency",
|
||||
"is_archived",
|
||||
]
|
||||
widgets = {
|
||||
"exchange_currency": TomSelect(),
|
||||
@@ -40,6 +40,7 @@ class CurrencyForm(forms.ModelForm):
|
||||
self.helper.layout = Layout(
|
||||
"code",
|
||||
"name",
|
||||
Switch("is_archived"),
|
||||
"decimal_places",
|
||||
"prefix",
|
||||
"suffix",
|
||||
@@ -49,17 +50,13 @@ class CurrencyForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -87,17 +84,13 @@ class ExchangeRateForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -114,6 +107,7 @@ class ExchangeRateServiceForm(forms.ModelForm):
|
||||
"fetch_interval",
|
||||
"target_currencies",
|
||||
"target_accounts",
|
||||
"singleton",
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -126,10 +120,11 @@ class ExchangeRateServiceForm(forms.ModelForm):
|
||||
"name",
|
||||
"service_type",
|
||||
Switch("is_active"),
|
||||
Switch("singleton"),
|
||||
"api_key",
|
||||
Row(
|
||||
Column("interval_type", css_class="form-group col-md-6"),
|
||||
Column("fetch_interval", css_class="form-group col-md-6"),
|
||||
Column("interval_type"),
|
||||
Column("fetch_interval"),
|
||||
),
|
||||
"target_currencies",
|
||||
"target_accounts",
|
||||
@@ -138,16 +133,12 @@ class ExchangeRateServiceForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 5.2.4 on 2025-08-08 02:18
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0014_alter_currency_options'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='exchangerate',
|
||||
name='automatic',
|
||||
field=models.BooleanField(default=False, verbose_name='Automatic'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='exchangerateservice',
|
||||
name='singleton',
|
||||
field=models.BooleanField(default=False, help_text='Create one exchange rate and keep updating it. Avoids database clutter.', verbose_name='Single exchange rate'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.4 on 2025-08-08 02:38
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0015_exchangerate_automatic_exchangerateservice_singleton'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='exchangerate',
|
||||
name='automatic',
|
||||
field=models.BooleanField(default=False, verbose_name='Auto'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.5 on 2025-08-16 22:18
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0016_alter_exchangerate_automatic'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='exchangerateservice',
|
||||
name='service_type',
|
||||
field=models.CharField(choices=[('synth_finance', 'Synth Finance'), ('synth_finance_stock', 'Synth Finance Stock'), ('coingecko_free', 'CoinGecko (Demo/Free)'), ('coingecko_pro', 'CoinGecko (Pro)'), ('transitive', 'Transitive (Calculated from Existing Rates)'), ('frankfurter', 'Frankfurter')], max_length=255, verbose_name='Service Type'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.5 on 2025-08-17 03:54
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0017_alter_exchangerateservice_service_type'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='exchangerateservice',
|
||||
name='service_type',
|
||||
field=models.CharField(choices=[('synth_finance', 'Synth Finance'), ('synth_finance_stock', 'Synth Finance Stock'), ('coingecko_free', 'CoinGecko (Demo/Free)'), ('coingecko_pro', 'CoinGecko (Pro)'), ('transitive', 'Transitive (Calculated from Existing Rates)'), ('frankfurter', 'Frankfurter'), ('twelvedata', 'TwelveData')], max_length=255, verbose_name='Service Type'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.5 on 2025-08-17 06:01
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0018_alter_exchangerateservice_service_type'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='exchangerateservice',
|
||||
name='service_type',
|
||||
field=models.CharField(choices=[('synth_finance', 'Synth Finance'), ('synth_finance_stock', 'Synth Finance Stock'), ('coingecko_free', 'CoinGecko (Demo/Free)'), ('coingecko_pro', 'CoinGecko (Pro)'), ('transitive', 'Transitive (Calculated from Existing Rates)'), ('frankfurter', 'Frankfurter'), ('twelvedata', 'TwelveData'), ('twelvedatamarkets', 'TwelveData Markets')], max_length=255, verbose_name='Service Type'),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,51 @@
|
||||
# Generated by Django 5.2.5 on 2025-08-17 06:25
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
# The new value we are migrating to
|
||||
NEW_SERVICE_TYPE = "frankfurter"
|
||||
# The old values we are deprecating
|
||||
OLD_SERVICE_TYPE_TO_UPDATE = "synth_finance"
|
||||
OLD_SERVICE_TYPE_TO_DELETE = "synth_finance_stock"
|
||||
|
||||
|
||||
def forwards_func(apps, schema_editor):
|
||||
"""
|
||||
Forward migration:
|
||||
- Deletes all ExchangeRateService instances with service_type 'synth_finance_stock'.
|
||||
- Updates all ExchangeRateService instances with service_type 'synth_finance' to 'frankfurter'.
|
||||
"""
|
||||
ExchangeRateService = apps.get_model("currencies", "ExchangeRateService")
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
# 1. Delete the SYNTH_FINANCE_STOCK entries
|
||||
ExchangeRateService.objects.using(db_alias).filter(
|
||||
service_type=OLD_SERVICE_TYPE_TO_DELETE
|
||||
).delete()
|
||||
|
||||
# 2. Update the SYNTH_FINANCE entries to FRANKFURTER
|
||||
ExchangeRateService.objects.using(db_alias).filter(
|
||||
service_type=OLD_SERVICE_TYPE_TO_UPDATE
|
||||
).update(service_type=NEW_SERVICE_TYPE, api_key=None)
|
||||
|
||||
|
||||
def backwards_func(apps, schema_editor):
|
||||
"""
|
||||
Backward migration: This operation is not safely reversible.
|
||||
- We cannot know which 'frankfurter' services were originally 'synth_finance'.
|
||||
- The deleted 'synth_finance_stock' services cannot be recovered.
|
||||
We will leave this function empty to allow migrating backwards without doing anything.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
# Add the previous migration file here
|
||||
("currencies", "0019_alter_exchangerateservice_service_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(forwards_func, reverse_code=backwards_func),
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.5 on 2025-08-17 06:29
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0020_migrate_synth_finance_services'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='exchangerateservice',
|
||||
name='service_type',
|
||||
field=models.CharField(choices=[('coingecko_free', 'CoinGecko (Demo/Free)'), ('coingecko_pro', 'CoinGecko (Pro)'), ('transitive', 'Transitive (Calculated from Existing Rates)'), ('frankfurter', 'Frankfurter'), ('twelvedata', 'TwelveData'), ('twelvedatamarkets', 'TwelveData Markets')], max_length=255, verbose_name='Service Type'),
|
||||
),
|
||||
]
|
||||
18
app/apps/currencies/migrations/0022_currency_is_archived.py
Normal file
18
app/apps/currencies/migrations/0022_currency_is_archived.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.5 on 2025-08-30 00:47
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0021_alter_exchangerateservice_service_type'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='currency',
|
||||
name='is_archived',
|
||||
field=models.BooleanField(default=False, verbose_name='Archived'),
|
||||
),
|
||||
]
|
||||
18
app/apps/currencies/migrations/0023_add_failure_count.py
Normal file
18
app/apps/currencies/migrations/0023_add_failure_count.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.10 on 2026-01-10 06:08
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('currencies', '0022_currency_is_archived'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='exchangerateservice',
|
||||
name='failure_count',
|
||||
field=models.PositiveIntegerField(default=0),
|
||||
),
|
||||
]
|
||||
@@ -32,6 +32,11 @@ class Currency(models.Model):
|
||||
help_text=_("Default currency for exchange calculations"),
|
||||
)
|
||||
|
||||
is_archived = models.BooleanField(
|
||||
default=False,
|
||||
verbose_name=_("Archived"),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -70,6 +75,8 @@ class ExchangeRate(models.Model):
|
||||
)
|
||||
date = models.DateTimeField(verbose_name=_("Date and Time"))
|
||||
|
||||
automatic = models.BooleanField(verbose_name=_("Auto"), default=False)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Exchange Rate")
|
||||
verbose_name_plural = _("Exchange Rates")
|
||||
@@ -92,11 +99,12 @@ class ExchangeRateService(models.Model):
|
||||
"""Configuration for exchange rate services"""
|
||||
|
||||
class ServiceType(models.TextChoices):
|
||||
SYNTH_FINANCE = "synth_finance", "Synth Finance"
|
||||
SYNTH_FINANCE_STOCK = "synth_finance_stock", "Synth Finance Stock"
|
||||
COINGECKO_FREE = "coingecko_free", "CoinGecko (Demo/Free)"
|
||||
COINGECKO_PRO = "coingecko_pro", "CoinGecko (Pro)"
|
||||
TRANSITIVE = "transitive", "Transitive (Calculated from Existing Rates)"
|
||||
FRANKFURTER = "frankfurter", "Frankfurter"
|
||||
TWELVEDATA = "twelvedata", "TwelveData"
|
||||
TWELVEDATA_MARKETS = "twelvedatamarkets", "TwelveData Markets"
|
||||
|
||||
class IntervalType(models.TextChoices):
|
||||
ON = "on", _("On")
|
||||
@@ -128,6 +136,8 @@ class ExchangeRateService(models.Model):
|
||||
null=True, blank=True, verbose_name=_("Last Successful Fetch")
|
||||
)
|
||||
|
||||
failure_count = models.PositiveIntegerField(default=0)
|
||||
|
||||
target_currencies = models.ManyToManyField(
|
||||
Currency,
|
||||
verbose_name=_("Target Currencies"),
|
||||
@@ -148,6 +158,14 @@ class ExchangeRateService(models.Model):
|
||||
blank=True,
|
||||
)
|
||||
|
||||
singleton = models.BooleanField(
|
||||
verbose_name=_("Single exchange rate"),
|
||||
default=False,
|
||||
help_text=_(
|
||||
"Create one exchange rate and keep updating it. Avoids database clutter."
|
||||
),
|
||||
)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Exchange Rate Service")
|
||||
verbose_name_plural = _("Exchange Rate Services")
|
||||
@@ -221,7 +239,7 @@ class ExchangeRateService(models.Model):
|
||||
hours = self._parse_hour_ranges(self.fetch_interval)
|
||||
# Store in normalized format (optional)
|
||||
self.fetch_interval = ",".join(str(h) for h in sorted(hours))
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
raise ValidationError(
|
||||
{
|
||||
"fetch_interval": _(
|
||||
@@ -232,7 +250,7 @@ class ExchangeRateService(models.Model):
|
||||
)
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValidationError(
|
||||
{
|
||||
"fetch_interval": _(
|
||||
|
||||
@@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.periodic(cron="0 * * * *") # Run every hour
|
||||
@app.task(name="automatic_fetch_exchange_rates")
|
||||
@app.task(lock="automatic_fetch_exchange_rates", name="automatic_fetch_exchange_rates")
|
||||
def automatic_fetch_exchange_rates(timestamp=None):
|
||||
"""Fetch exchange rates for all due services"""
|
||||
fetcher = ExchangeRateFetcher()
|
||||
@@ -19,7 +19,7 @@ def automatic_fetch_exchange_rates(timestamp=None):
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
|
||||
@app.task(name="manual_fetch_exchange_rates")
|
||||
@app.task(lock="manual_fetch_exchange_rates", name="manual_fetch_exchange_rates")
|
||||
def manual_fetch_exchange_rates(timestamp=None):
|
||||
"""Fetch exchange rates for all due services"""
|
||||
fetcher = ExchangeRateFetcher()
|
||||
|
||||
@@ -4,12 +4,8 @@ from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
from django.contrib.auth.models import User # Added for ERS owner
|
||||
from datetime import date # Added for CurrencyConversionUtilsTests
|
||||
from apps.currencies.utils.convert import get_exchange_rate, convert # Added convert
|
||||
from unittest.mock import patch # Added patch
|
||||
|
||||
from apps.currencies.models import Currency, ExchangeRate, ExchangeRateService
|
||||
from apps.currencies.models import Currency, ExchangeRate
|
||||
|
||||
|
||||
class CurrencyTests(TestCase):
|
||||
@@ -44,175 +40,12 @@ class CurrencyTests(TestCase):
|
||||
with self.assertRaises(ValidationError):
|
||||
currency.full_clean()
|
||||
|
||||
def test_currency_unique_code(self):
|
||||
"""Test that currency codes must be unique"""
|
||||
Currency.objects.create(code="USD", name="US Dollar", decimal_places=2)
|
||||
with self.assertRaises(IntegrityError):
|
||||
Currency.objects.create(code="USD", name="Another Dollar", decimal_places=2)
|
||||
|
||||
def test_currency_unique_name(self):
|
||||
"""Test that currency names must be unique"""
|
||||
Currency.objects.create(code="USD", name="US Dollar", decimal_places=2)
|
||||
with self.assertRaises(IntegrityError):
|
||||
Currency.objects.create(code="USD2", name="US Dollar", decimal_places=2)
|
||||
|
||||
def test_currency_exchange_currency_cannot_be_self(self):
|
||||
"""Test that a currency's exchange_currency cannot be itself."""
|
||||
currency = Currency.objects.create(
|
||||
code="XYZ", name="Test XYZ", decimal_places=2
|
||||
)
|
||||
currency.exchange_currency = currency # Set exchange_currency to self
|
||||
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
currency.full_clean()
|
||||
|
||||
self.assertIn('exchange_currency', cm.exception.error_dict)
|
||||
# Optionally, check for a specific error message if known:
|
||||
# self.assertTrue(any("cannot be the same as the currency itself" in e.message
|
||||
# for e in cm.exception.error_dict['exchange_currency']))
|
||||
|
||||
|
||||
class ExchangeRateServiceTests(TestCase):
|
||||
def setUp(self):
|
||||
self.owner = User.objects.create_user(username='ers_owner', password='password123')
|
||||
self.base_currency = Currency.objects.create(code="BSC", name="Base Service Coin", decimal_places=2)
|
||||
self.default_ers_params = {
|
||||
'name': "Test ERS",
|
||||
'owner': self.owner,
|
||||
'base_currency': self.base_currency,
|
||||
'provider_class': "dummy.provider.ClassName", # Placeholder
|
||||
}
|
||||
|
||||
def _create_ers_instance(self, interval_type, fetch_interval, **kwargs):
|
||||
params = {**self.default_ers_params, 'interval_type': interval_type, 'fetch_interval': fetch_interval, **kwargs}
|
||||
return ExchangeRateService(**params)
|
||||
|
||||
# Tests for IntervalType.EVERY
|
||||
def test_ers_interval_every_valid_integer(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.EVERY, "12")
|
||||
try:
|
||||
ers.full_clean()
|
||||
except ValidationError:
|
||||
self.fail("ValidationError raised unexpectedly for valid 'EVERY' interval '12'.")
|
||||
|
||||
def test_ers_interval_every_invalid_not_integer(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.EVERY, "abc")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_every_invalid_too_low(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.EVERY, "0")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_every_invalid_too_high(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.EVERY, "25") # Max is 24 for 'EVERY'
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
# Tests for IntervalType.ON (and by extension NOT_ON, as validation logic is shared)
|
||||
def test_ers_interval_on_not_on_valid_single_hour(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "5")
|
||||
try:
|
||||
ers.full_clean() # Should normalize to "5" if not already
|
||||
except ValidationError:
|
||||
self.fail("ValidationError raised unexpectedly for valid 'ON' interval '5'.")
|
||||
self.assertEqual(ers.fetch_interval, "5")
|
||||
|
||||
|
||||
def test_ers_interval_on_not_on_valid_multiple_hours(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "1,8,22")
|
||||
try:
|
||||
ers.full_clean()
|
||||
except ValidationError:
|
||||
self.fail("ValidationError raised unexpectedly for valid 'ON' interval '1,8,22'.")
|
||||
self.assertEqual(ers.fetch_interval, "1,8,22")
|
||||
|
||||
|
||||
def test_ers_interval_on_not_on_valid_range(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "0-4")
|
||||
ers.full_clean() # Should not raise ValidationError
|
||||
self.assertEqual(ers.fetch_interval, "0,1,2,3,4")
|
||||
|
||||
def test_ers_interval_on_not_on_valid_mixed(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "1-3,8,10-12")
|
||||
ers.full_clean() # Should not raise ValidationError
|
||||
self.assertEqual(ers.fetch_interval, "1,2,3,8,10,11,12")
|
||||
|
||||
def test_ers_interval_on_not_on_invalid_char(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "1-3,a")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_on_not_on_invalid_hour_too_high(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "24") # Max is 23 for 'ON' type hours
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_on_not_on_invalid_range_format(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "5-1")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_on_not_on_invalid_range_value_too_high(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "20-24") # 24 is invalid hour
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
def test_ers_interval_on_not_on_empty_interval(self):
|
||||
ers = self._create_ers_instance(ExchangeRateService.IntervalType.ON, "")
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
ers.full_clean()
|
||||
self.assertIn('fetch_interval', cm.exception.error_dict)
|
||||
|
||||
@patch('apps.currencies.exchange_rates.fetcher.PROVIDER_MAPPING')
|
||||
def test_get_provider_valid_service_type(self, mock_provider_mapping):
|
||||
"""Test get_provider returns a configured provider instance for a valid service_type."""
|
||||
|
||||
class MockSynthFinanceProvider:
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
|
||||
# Configure the mock PROVIDER_MAPPING
|
||||
mock_provider_mapping.get.return_value = MockSynthFinanceProvider
|
||||
|
||||
service_instance = self._create_ers_instance(
|
||||
interval_type=ExchangeRateService.IntervalType.EVERY, # Needs some valid interval type
|
||||
fetch_interval="1", # Needs some valid fetch interval
|
||||
service_type=ExchangeRateService.ServiceType.SYNTH_FINANCE,
|
||||
api_key="test_key"
|
||||
)
|
||||
# Ensure the service_type is correctly passed to the mock
|
||||
# The actual get_provider method uses PROVIDER_MAPPING[self.service_type]
|
||||
# So, we should make the mock_provider_mapping behave like a dict for the specific key
|
||||
mock_provider_mapping = {ExchangeRateService.ServiceType.SYNTH_FINANCE: MockSynthFinanceProvider}
|
||||
|
||||
with patch('apps.currencies.exchange_rates.fetcher.PROVIDER_MAPPING', mock_provider_mapping):
|
||||
provider = service_instance.get_provider()
|
||||
|
||||
self.assertIsInstance(provider, MockSynthFinanceProvider)
|
||||
self.assertEqual(provider.key, "test_key")
|
||||
|
||||
@patch('apps.currencies.exchange_rates.fetcher.PROVIDER_MAPPING', {}) # Empty mapping
|
||||
def test_get_provider_invalid_service_type(self, mock_provider_mapping_empty):
|
||||
"""Test get_provider raises KeyError for an invalid or unmapped service_type."""
|
||||
service_instance = self._create_ers_instance(
|
||||
interval_type=ExchangeRateService.IntervalType.EVERY,
|
||||
fetch_interval="1",
|
||||
service_type="UNMAPPED_SERVICE_TYPE", # A type not in the (mocked) mapping
|
||||
api_key="any_key"
|
||||
)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
service_instance.get_provider()
|
||||
|
||||
|
||||
class ExchangeRateTests(TestCase):
|
||||
def setUp(self):
|
||||
@@ -244,169 +77,10 @@ class ExchangeRateTests(TestCase):
|
||||
rate=Decimal("0.85"),
|
||||
date=date,
|
||||
)
|
||||
with self.assertRaises(IntegrityError):
|
||||
with self.assertRaises(Exception): # Could be IntegrityError
|
||||
ExchangeRate.objects.create(
|
||||
from_currency=self.usd,
|
||||
to_currency=self.eur,
|
||||
rate=Decimal("0.86"),
|
||||
date=date,
|
||||
)
|
||||
|
||||
def test_from_and_to_currency_cannot_be_same(self):
|
||||
"""Test that from_currency and to_currency cannot be the same."""
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
rate = ExchangeRate(
|
||||
from_currency=self.usd,
|
||||
to_currency=self.usd, # Same as from_currency
|
||||
rate=Decimal("1.00"),
|
||||
date=timezone.now().date(),
|
||||
)
|
||||
rate.full_clean()
|
||||
|
||||
# Check if the error message is as expected or if the error is associated with a specific field.
|
||||
# The exact key ('to_currency' or '__all__') depends on how the model's clean() method is implemented.
|
||||
# Assuming the validation error is raised with a message like "From and to currency cannot be the same."
|
||||
# and is a non-field error or specifically tied to 'to_currency'.
|
||||
self.assertTrue(
|
||||
'__all__' in cm.exception.error_dict or 'to_currency' in cm.exception.error_dict,
|
||||
"ValidationError should be for '__all__' or 'to_currency'"
|
||||
)
|
||||
# Optionally, check for a specific message if it's consistent:
|
||||
# found_message = False
|
||||
# if '__all__' in cm.exception.error_dict:
|
||||
# found_message = any("cannot be the same" in e.message for e in cm.exception.error_dict['__all__'])
|
||||
# if not found_message and 'to_currency' in cm.exception.error_dict:
|
||||
# found_message = any("cannot be the same" in e.message for e in cm.exception.error_dict['to_currency'])
|
||||
# self.assertTrue(found_message, "Error message about currencies being the same not found.")
|
||||
|
||||
|
||||
class CurrencyConversionUtilsTests(TestCase):
|
||||
def setUp(self):
|
||||
self.usd = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2, prefix="$", suffix="")
|
||||
self.eur = Currency.objects.create(code="EUR", name="Euro", decimal_places=2, prefix="€", suffix="")
|
||||
self.gbp = Currency.objects.create(code="GBP", name="British Pound", decimal_places=2, prefix="£", suffix="")
|
||||
|
||||
# Rates for USD <-> EUR
|
||||
self.usd_eur_rate_10th = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.90"), date=date(2023, 1, 10))
|
||||
self.usd_eur_rate_15th = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.92"), date=date(2023, 1, 15))
|
||||
ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.88"), date=date(2023, 1, 5))
|
||||
|
||||
# Rate for GBP <-> USD (for inverse lookup)
|
||||
self.gbp_usd_rate_10th = ExchangeRate.objects.create(from_currency=self.gbp, to_currency=self.usd, rate=Decimal("1.25"), date=date(2023, 1, 10))
|
||||
|
||||
def test_get_direct_rate_closest_date(self):
|
||||
"""Test fetching a direct rate, ensuring the closest date is chosen."""
|
||||
result = get_exchange_rate(self.usd, self.eur, date(2023, 1, 16))
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.effective_rate, Decimal("0.92"))
|
||||
self.assertEqual(result.original_from_currency, self.usd)
|
||||
self.assertEqual(result.original_to_currency, self.eur)
|
||||
|
||||
def test_get_inverse_rate_closest_date(self):
|
||||
"""Test fetching an inverse rate, ensuring the closest date and correct calculation."""
|
||||
# We are looking for USD to GBP. We have GBP to USD on 2023-01-10.
|
||||
# Target date is 2023-01-12.
|
||||
result = get_exchange_rate(self.usd, self.gbp, date(2023, 1, 12))
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.effective_rate, Decimal("1") / self.gbp_usd_rate_10th.rate)
|
||||
self.assertEqual(result.original_from_currency, self.gbp) # original_from_currency should be GBP
|
||||
self.assertEqual(result.original_to_currency, self.usd) # original_to_currency should be USD
|
||||
|
||||
def test_get_rate_exact_date_preference(self):
|
||||
"""Test that an exact date match is preferred over a closer one."""
|
||||
# Existing rate is on 2023-01-15 (0.92)
|
||||
# Add an exact match for the query date
|
||||
exact_date_rate = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.91"), date=date(2023, 1, 16))
|
||||
|
||||
result = get_exchange_rate(self.usd, self.eur, date(2023, 1, 16))
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.effective_rate, Decimal("0.91"))
|
||||
self.assertEqual(result.original_from_currency, self.usd)
|
||||
self.assertEqual(result.original_to_currency, self.eur)
|
||||
|
||||
def test_get_rate_no_matching_pair(self):
|
||||
"""Test that None is returned if no direct or inverse rate exists between the pair."""
|
||||
# No rates exist for EUR <-> GBP in the setUp
|
||||
result = get_exchange_rate(self.eur, self.gbp, date(2023, 1, 10))
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_rate_prefer_direct_over_inverse_same_diff(self):
|
||||
"""Test that a direct rate is preferred over an inverse if date differences are equal."""
|
||||
# We have GBP-USD on 2023-01-10 (self.gbp_usd_rate_10th)
|
||||
# This means an inverse USD-GBP rate is available for 2023-01-10.
|
||||
# Add a direct USD-GBP rate for the same date.
|
||||
direct_usd_gbp_rate = ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.gbp, rate=Decimal("0.80"), date=date(2023, 1, 10))
|
||||
|
||||
result = get_exchange_rate(self.usd, self.gbp, date(2023, 1, 10))
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.effective_rate, Decimal("0.80"))
|
||||
self.assertEqual(result.original_from_currency, self.usd)
|
||||
self.assertEqual(result.original_to_currency, self.gbp)
|
||||
|
||||
# Now test the EUR to USD case from the problem description
|
||||
# Add EUR to USD, rate 1.1, date 2023-01-10
|
||||
eur_usd_direct_rate = ExchangeRate.objects.create(from_currency=self.eur, to_currency=self.usd, rate=Decimal("1.1"), date=date(2023, 1, 10))
|
||||
# We also have USD to EUR on 2023-01-10 (rate 0.90), which would be an inverse match for EUR to USD.
|
||||
|
||||
result_eur_usd = get_exchange_rate(self.eur, self.usd, date(2023, 1, 10))
|
||||
self.assertIsNotNone(result_eur_usd)
|
||||
self.assertEqual(result_eur_usd.effective_rate, Decimal("1.1"))
|
||||
self.assertEqual(result_eur_usd.original_from_currency, self.eur)
|
||||
self.assertEqual(result_eur_usd.original_to_currency, self.usd)
|
||||
|
||||
def test_convert_successful_direct(self):
|
||||
"""Test successful conversion using a direct rate."""
|
||||
# Uses self.usd_eur_rate_15th (0.92) as it's closest to 2023-01-16
|
||||
converted_amount, prefix, suffix, dp = convert(Decimal('100'), self.usd, self.eur, date(2023, 1, 16))
|
||||
self.assertEqual(converted_amount, Decimal('92.00'))
|
||||
self.assertEqual(prefix, self.eur.prefix)
|
||||
self.assertEqual(suffix, self.eur.suffix)
|
||||
self.assertEqual(dp, self.eur.decimal_places)
|
||||
|
||||
def test_convert_successful_inverse(self):
|
||||
"""Test successful conversion using an inverse rate."""
|
||||
# Uses self.gbp_usd_rate_10th (GBP to USD @ 1.25), so USD to GBP is 1/1.25 = 0.8
|
||||
# Target date 2023-01-12, closest is 2023-01-10
|
||||
converted_amount, prefix, suffix, dp = convert(Decimal('100'), self.usd, self.gbp, date(2023, 1, 12))
|
||||
expected_amount = Decimal('100') * (Decimal('1') / self.gbp_usd_rate_10th.rate)
|
||||
self.assertEqual(converted_amount, expected_amount.quantize(Decimal('0.01')))
|
||||
self.assertEqual(prefix, self.gbp.prefix)
|
||||
self.assertEqual(suffix, self.gbp.suffix)
|
||||
self.assertEqual(dp, self.gbp.decimal_places)
|
||||
|
||||
def test_convert_no_rate_found(self):
|
||||
"""Test conversion when no exchange rate is found."""
|
||||
result_tuple = convert(Decimal('100'), self.eur, self.gbp, date(2023, 1, 10))
|
||||
self.assertEqual(result_tuple, (None, None, None, None))
|
||||
|
||||
def test_convert_same_currency(self):
|
||||
"""Test conversion when from_currency and to_currency are the same."""
|
||||
result_tuple = convert(Decimal('100'), self.usd, self.usd, date(2023, 1, 10))
|
||||
self.assertEqual(result_tuple, (None, None, None, None))
|
||||
|
||||
def test_convert_zero_amount(self):
|
||||
"""Test conversion when the amount is zero."""
|
||||
result_tuple = convert(Decimal('0'), self.usd, self.eur, date(2023, 1, 10))
|
||||
self.assertEqual(result_tuple, (None, None, None, None))
|
||||
|
||||
@patch('apps.currencies.utils.convert.timezone')
|
||||
def test_convert_no_date_uses_today(self, mock_timezone):
|
||||
"""Test conversion uses today's date when no date is provided."""
|
||||
# Mock timezone.now().date() to return a specific date
|
||||
mock_today = date(2023, 1, 16)
|
||||
mock_timezone.now.return_value.date.return_value = mock_today
|
||||
|
||||
# This should use self.usd_eur_rate_15th (0.92) as it's closest to mocked "today" (2023-01-16)
|
||||
converted_amount, prefix, suffix, dp = convert(Decimal('100'), self.usd, self.eur)
|
||||
|
||||
self.assertEqual(converted_amount, Decimal('92.00'))
|
||||
self.assertEqual(prefix, self.eur.prefix)
|
||||
self.assertEqual(suffix, self.eur.suffix)
|
||||
self.assertEqual(dp, self.eur.decimal_places)
|
||||
|
||||
# Verify that timezone.now().date() was called (indirectly, by get_exchange_rate)
|
||||
# This specific assertion for get_exchange_rate being called with a specific date
|
||||
# would require patching get_exchange_rate itself, which is more complex.
|
||||
# For now, we rely on the correct outcome given the mocked date.
|
||||
# A more direct way to test date passing is if convert took get_exchange_rate as a dependency.
|
||||
mock_timezone.now.return_value.date.assert_called_once()
|
||||
|
||||
1
app/apps/currencies/tests/__init__.py
Normal file
1
app/apps/currencies/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package for currencies app
|
||||
109
app/apps/currencies/tests/test_automatic_exchange_rates.py
Normal file
109
app/apps/currencies/tests/test_automatic_exchange_rates.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.currencies.models import Currency, ExchangeRateService
|
||||
from apps.currencies.exchange_rates.fetcher import ExchangeRateFetcher
|
||||
|
||||
|
||||
class ExchangeRateServiceFailureTrackingTests(TestCase):
|
||||
"""Tests for the failure count tracking functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
self.usd = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.eur = Currency.objects.create(
|
||||
code="EUR", name="Euro", decimal_places=2, prefix="€ "
|
||||
)
|
||||
self.eur.exchange_currency = self.usd
|
||||
self.eur.save()
|
||||
|
||||
self.service = ExchangeRateService.objects.create(
|
||||
name="Test Service",
|
||||
service_type=ExchangeRateService.ServiceType.FRANKFURTER,
|
||||
is_active=True,
|
||||
)
|
||||
self.service.target_currencies.add(self.eur)
|
||||
|
||||
def test_failure_count_increments_on_provider_error(self):
|
||||
"""Test that failure_count increments when provider raises an exception."""
|
||||
self.assertEqual(self.service.failure_count, 0)
|
||||
|
||||
with patch.object(
|
||||
self.service, "get_provider", side_effect=Exception("API Error")
|
||||
):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 1)
|
||||
|
||||
def test_failure_count_resets_on_success(self):
|
||||
"""Test that failure_count resets to 0 on successful fetch."""
|
||||
# Set initial failure count
|
||||
self.service.failure_count = 5
|
||||
self.service.save()
|
||||
|
||||
# Mock a successful provider
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.requires_api_key.return_value = False
|
||||
mock_provider.get_rates.return_value = [(self.usd, self.eur, Decimal("0.85"))]
|
||||
mock_provider.rates_inverted = False
|
||||
|
||||
with patch.object(self.service, "get_provider", return_value=mock_provider):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 0)
|
||||
|
||||
def test_failure_count_accumulates_across_fetches(self):
|
||||
"""Test that failure_count accumulates with consecutive failures."""
|
||||
self.assertEqual(self.service.failure_count, 0)
|
||||
|
||||
with patch.object(
|
||||
self.service, "get_provider", side_effect=Exception("API Error")
|
||||
):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 1)
|
||||
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 2)
|
||||
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
self.service.refresh_from_db()
|
||||
self.assertEqual(self.service.failure_count, 3)
|
||||
|
||||
def test_last_fetch_not_updated_on_failure(self):
|
||||
"""Test that last_fetch is NOT updated when a failure occurs."""
|
||||
original_last_fetch = self.service.last_fetch
|
||||
self.assertIsNone(original_last_fetch)
|
||||
|
||||
with patch.object(
|
||||
self.service, "get_provider", side_effect=Exception("API Error")
|
||||
):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
|
||||
self.service.refresh_from_db()
|
||||
self.assertIsNone(self.service.last_fetch)
|
||||
self.assertEqual(self.service.failure_count, 1)
|
||||
|
||||
def test_last_fetch_updated_on_success(self):
|
||||
"""Test that last_fetch IS updated when fetch succeeds."""
|
||||
self.assertIsNone(self.service.last_fetch)
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.requires_api_key.return_value = False
|
||||
mock_provider.get_rates.return_value = [(self.usd, self.eur, Decimal("0.85"))]
|
||||
mock_provider.rates_inverted = False
|
||||
|
||||
with patch.object(self.service, "get_provider", return_value=mock_provider):
|
||||
ExchangeRateFetcher._fetch_service_rates(self.service)
|
||||
|
||||
self.service.refresh_from_db()
|
||||
self.assertIsNotNone(self.service.last_fetch)
|
||||
self.assertEqual(self.service.failure_count, 0)
|
||||
@@ -23,7 +23,7 @@ def currencies_index(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def currencies_list(request):
|
||||
currencies = Currency.objects.all().order_by("id")
|
||||
currencies = Currency.objects.all().order_by("name")
|
||||
return render(
|
||||
request,
|
||||
"currencies/fragments/list.html",
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
from crispy_bootstrap5.bootstrap5 import Switch, BS5Accordion
|
||||
from crispy_forms.bootstrap import FormActions, AccordionGroup
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Row, Column, HTML
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.datepicker import AirDatePickerInput
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.common.widgets.tom_select import TransactionSelect
|
||||
from apps.transactions.models import Transaction, TransactionTag, TransactionCategory
|
||||
from apps.common.fields.forms.dynamic_select import (
|
||||
DynamicModelChoiceField,
|
||||
DynamicModelMultipleChoiceField,
|
||||
)
|
||||
from apps.common.widgets.crispy.daisyui import Switch
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.common.widgets.datepicker import AirDatePickerInput
|
||||
from apps.common.widgets.decimal import ArbitraryDecimalDisplayNumberInput
|
||||
from apps.common.widgets.tom_select import TomSelect, TransactionSelect
|
||||
from apps.dca.models import DCAEntry, DCAStrategy
|
||||
from apps.transactions.models import Transaction, TransactionCategory, TransactionTag
|
||||
from crispy_forms.bootstrap import AccordionGroup, FormActions, Accordion
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import HTML, Column, Layout, Row
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class DCAStrategyForm(forms.ModelForm):
|
||||
@@ -36,8 +34,8 @@ class DCAStrategyForm(forms.ModelForm):
|
||||
self.helper.layout = Layout(
|
||||
"name",
|
||||
Row(
|
||||
Column("payment_currency", css_class="form-group col-md-6"),
|
||||
Column("target_currency", css_class="form-group col-md-6"),
|
||||
Column("payment_currency"),
|
||||
Column("target_currency"),
|
||||
),
|
||||
"notes",
|
||||
)
|
||||
@@ -45,17 +43,13 @@ class DCAStrategyForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -155,11 +149,11 @@ class DCAEntryForm(forms.ModelForm):
|
||||
self.helper.layout = Layout(
|
||||
"date",
|
||||
Row(
|
||||
Column("amount_paid", css_class="form-group col-md-6"),
|
||||
Column("amount_received", css_class="form-group col-md-6"),
|
||||
Column("amount_paid"),
|
||||
Column("amount_received"),
|
||||
),
|
||||
"notes",
|
||||
BS5Accordion(
|
||||
Accordion(
|
||||
AccordionGroup(
|
||||
_("Create transaction"),
|
||||
Switch("create_transaction"),
|
||||
@@ -168,19 +162,11 @@ class DCAEntryForm(forms.ModelForm):
|
||||
Row(
|
||||
Column(
|
||||
"from_account",
|
||||
css_class="form-group",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
"from_category",
|
||||
css_class="form-group col-md-6 mb-0",
|
||||
),
|
||||
Column(
|
||||
"from_tags", css_class="form-group col-md-6 mb-0"
|
||||
),
|
||||
css_class="form-row",
|
||||
Column("from_category"),
|
||||
Column("from_tags"),
|
||||
),
|
||||
),
|
||||
css_class="p-1 mx-1 my-3 border rounded-3",
|
||||
@@ -192,14 +178,10 @@ class DCAEntryForm(forms.ModelForm):
|
||||
"to_account",
|
||||
css_class="form-group",
|
||||
),
|
||||
css_class="form-row",
|
||||
),
|
||||
Row(
|
||||
Column(
|
||||
"to_category", css_class="form-group col-md-6 mb-0"
|
||||
),
|
||||
Column("to_tags", css_class="form-group col-md-6 mb-0"),
|
||||
css_class="form-row",
|
||||
Column("to_category"),
|
||||
Column("to_tags"),
|
||||
),
|
||||
),
|
||||
css_class="p-1 mx-1 my-3 border rounded-3",
|
||||
@@ -220,17 +202,13 @@ class DCAEntryForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
# Generated by Django 5.2.4 on 2025-07-28 02:15
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dca', '0003_dcastrategy_owner_dcastrategy_shared_with_and_more'),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='dcastrategy',
|
||||
name='owner',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='%(class)s_owned', to=settings.AUTH_USER_MODEL, verbose_name='Owner'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='dcastrategy',
|
||||
name='shared_with',
|
||||
field=models.ManyToManyField(blank=True, related_name='%(class)s_shared', to=settings.AUTH_USER_MODEL, verbose_name='Shared with users'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='dcastrategy',
|
||||
name='visibility',
|
||||
field=models.CharField(choices=[('private', 'Private'), ('public', 'Public')], default='private', max_length=10, verbose_name='Visibility'),
|
||||
),
|
||||
]
|
||||
@@ -1,344 +1,3 @@
|
||||
from django.test import TestCase, Client
|
||||
from django.contrib.auth.models import User
|
||||
from django.urls import reverse
|
||||
from django.forms import NON_FIELD_ERRORS
|
||||
from apps.currencies.models import Currency
|
||||
from apps.dca.models import DCAStrategy, DCAEntry
|
||||
from apps.dca.forms import DCAStrategyForm, DCAEntryForm # Added DCAEntryForm
|
||||
from apps.accounts.models import Account, AccountGroup # Added Account models
|
||||
from apps.transactions.models import TransactionCategory, Transaction # Added Transaction models
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
from unittest.mock import patch
|
||||
from django.test import TestCase
|
||||
|
||||
class DCATests(TestCase):
|
||||
def setUp(self):
|
||||
self.owner = User.objects.create_user(username='testowner', password='password123')
|
||||
self.client = Client()
|
||||
self.client.login(username='testowner', password='password123')
|
||||
|
||||
self.payment_curr = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2)
|
||||
self.target_curr = Currency.objects.create(code="BTC", name="Bitcoin", decimal_places=8)
|
||||
|
||||
# AccountGroup for accounts
|
||||
self.account_group = AccountGroup.objects.create(name="DCA Test Group", owner=self.owner)
|
||||
|
||||
# Accounts for transactions
|
||||
self.account1 = Account.objects.create(
|
||||
name="Payment Account USD",
|
||||
owner=self.owner,
|
||||
currency=self.payment_curr,
|
||||
group=self.account_group
|
||||
)
|
||||
self.account2 = Account.objects.create(
|
||||
name="Target Account BTC",
|
||||
owner=self.owner,
|
||||
currency=self.target_curr,
|
||||
group=self.account_group
|
||||
)
|
||||
|
||||
# TransactionCategory for transactions
|
||||
# Using INFO type as it's generic. TRANSFER might imply specific paired transaction logic not relevant here.
|
||||
self.category1 = TransactionCategory.objects.create(
|
||||
name="DCA Category",
|
||||
owner=self.owner,
|
||||
type=TransactionCategory.TransactionType.INFO
|
||||
)
|
||||
|
||||
|
||||
self.strategy1 = DCAStrategy.objects.create(
|
||||
name="Test Strategy 1",
|
||||
owner=self.owner,
|
||||
payment_currency=self.payment_curr,
|
||||
target_currency=self.target_curr
|
||||
)
|
||||
|
||||
self.entries1 = [
|
||||
DCAEntry.objects.create(
|
||||
strategy=self.strategy1,
|
||||
date=date(2023, 1, 1),
|
||||
amount_paid=Decimal('100.00'),
|
||||
amount_received=Decimal('0.010')
|
||||
),
|
||||
DCAEntry.objects.create(
|
||||
strategy=self.strategy1,
|
||||
date=date(2023, 2, 1),
|
||||
amount_paid=Decimal('150.00'),
|
||||
amount_received=Decimal('0.012')
|
||||
),
|
||||
DCAEntry.objects.create(
|
||||
strategy=self.strategy1,
|
||||
date=date(2023, 3, 1),
|
||||
amount_paid=Decimal('120.00'),
|
||||
amount_received=Decimal('0.008')
|
||||
)
|
||||
]
|
||||
|
||||
def test_strategy_index_view_authenticated_user(self):
|
||||
# Uses self.client and self.owner from setUp
|
||||
response = self.client.get(reverse('dca:dca_strategy_index'))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_strategy_totals_and_average_price(self):
|
||||
self.assertEqual(self.strategy1.total_entries(), 3)
|
||||
self.assertEqual(self.strategy1.total_invested(), Decimal('370.00')) # 100 + 150 + 120
|
||||
self.assertEqual(self.strategy1.total_received(), Decimal('0.030')) # 0.01 + 0.012 + 0.008
|
||||
|
||||
expected_avg_price = Decimal('370.00') / Decimal('0.030')
|
||||
# Match precision of the model method if it's specific, e.g. quantize
|
||||
# For now, direct comparison. The model might return a Decimal that needs specific quantizing.
|
||||
self.assertEqual(self.strategy1.average_entry_price(), expected_avg_price)
|
||||
|
||||
def test_strategy_average_price_no_received(self):
|
||||
strategy2 = DCAStrategy.objects.create(
|
||||
name="Test Strategy 2",
|
||||
owner=self.owner,
|
||||
payment_currency=self.payment_curr,
|
||||
target_currency=self.target_curr
|
||||
)
|
||||
DCAEntry.objects.create(
|
||||
strategy=strategy2,
|
||||
date=date(2023, 4, 1),
|
||||
amount_paid=Decimal('100.00'),
|
||||
amount_received=Decimal('0') # Total received is zero
|
||||
)
|
||||
self.assertEqual(strategy2.total_received(), Decimal('0'))
|
||||
self.assertEqual(strategy2.average_entry_price(), Decimal('0'))
|
||||
|
||||
@patch('apps.dca.models.convert')
|
||||
def test_dca_entry_value_and_pl(self, mock_convert):
|
||||
entry = self.entries1[0] # amount_paid=100, amount_received=0.010
|
||||
|
||||
# Simulate current price: 1 target_curr = 20,000 payment_curr
|
||||
# So, 0.010 target_curr should be 0.010 * 20000 = 200 payment_curr
|
||||
simulated_converted_value = entry.amount_received * Decimal('20000')
|
||||
mock_convert.return_value = (
|
||||
simulated_converted_value,
|
||||
self.payment_curr.prefix,
|
||||
self.payment_curr.suffix,
|
||||
self.payment_curr.decimal_places
|
||||
)
|
||||
|
||||
current_val = entry.current_value()
|
||||
self.assertEqual(current_val, Decimal('200.00'))
|
||||
|
||||
# Profit/Loss = current_value - amount_paid = 200 - 100 = 100
|
||||
self.assertEqual(entry.profit_loss(), Decimal('100.00'))
|
||||
|
||||
# P/L % = (profit_loss / amount_paid) * 100 = (100 / 100) * 100 = 100
|
||||
self.assertEqual(entry.profit_loss_percentage(), Decimal('100.00'))
|
||||
|
||||
# Check that convert was called correctly by current_value()
|
||||
# current_value calls convert(self.amount_received, self.strategy.target_currency, self.strategy.payment_currency)
|
||||
# The date argument defaults to None if not passed, which is the case here.
|
||||
mock_convert.assert_called_once_with(
|
||||
entry.amount_received,
|
||||
self.strategy1.target_currency,
|
||||
self.strategy1.payment_currency,
|
||||
None # Date argument is optional and defaults to None
|
||||
)
|
||||
|
||||
@patch('apps.dca.models.convert')
|
||||
def test_dca_strategy_value_and_pl(self, mock_convert):
|
||||
|
||||
def side_effect_func(amount_to_convert, from_currency, to_currency, date=None):
|
||||
if from_currency == self.target_curr and to_currency == self.payment_curr:
|
||||
# Simulate current price: 1 target_curr = 20,000 payment_curr
|
||||
converted_value = amount_to_convert * Decimal('20000')
|
||||
return (converted_value, self.payment_curr.prefix, self.payment_curr.suffix, self.payment_curr.decimal_places)
|
||||
# Fallback for any other unexpected calls, though not expected in this test
|
||||
return (Decimal('0'), '', '', 2)
|
||||
|
||||
mock_convert.side_effect = side_effect_func
|
||||
|
||||
# strategy1 entries:
|
||||
# 1: paid 100, received 0.010. Current value = 0.010 * 20000 = 200
|
||||
# 2: paid 150, received 0.012. Current value = 0.012 * 20000 = 240
|
||||
# 3: paid 120, received 0.008. Current value = 0.008 * 20000 = 160
|
||||
# Total current value = 200 + 240 + 160 = 600
|
||||
self.assertEqual(self.strategy1.current_total_value(), Decimal('600.00'))
|
||||
|
||||
# Total invested = 100 + 150 + 120 = 370
|
||||
# Total profit/loss = current_total_value - total_invested = 600 - 370 = 230
|
||||
self.assertEqual(self.strategy1.total_profit_loss(), Decimal('230.00'))
|
||||
|
||||
# Total P/L % = (total_profit_loss / total_invested) * 100
|
||||
# (230 / 370) * 100 = 62.162162...
|
||||
expected_pl_percentage = (Decimal('230.00') / Decimal('370.00')) * Decimal('100')
|
||||
self.assertAlmostEqual(self.strategy1.total_profit_loss_percentage(), expected_pl_percentage, places=2)
|
||||
|
||||
def test_dca_strategy_form_valid_data(self):
|
||||
form_data = {
|
||||
'name': 'Form Test Strategy',
|
||||
'target_currency': self.target_curr.pk,
|
||||
'payment_currency': self.payment_curr.pk
|
||||
}
|
||||
form = DCAStrategyForm(data=form_data)
|
||||
self.assertTrue(form.is_valid(), form.errors.as_text())
|
||||
|
||||
strategy = form.save(commit=False)
|
||||
strategy.owner = self.owner
|
||||
strategy.save()
|
||||
|
||||
self.assertEqual(strategy.name, 'Form Test Strategy')
|
||||
self.assertEqual(strategy.owner, self.owner)
|
||||
self.assertEqual(strategy.target_currency, self.target_curr)
|
||||
self.assertEqual(strategy.payment_currency, self.payment_curr)
|
||||
|
||||
def test_dca_strategy_form_missing_name(self):
|
||||
form_data = {
|
||||
'target_currency': self.target_curr.pk,
|
||||
'payment_currency': self.payment_curr.pk
|
||||
}
|
||||
form = DCAStrategyForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('name', form.errors)
|
||||
|
||||
def test_dca_strategy_form_missing_target_currency(self):
|
||||
form_data = {
|
||||
'name': 'Form Test Missing Target',
|
||||
'payment_currency': self.payment_curr.pk
|
||||
}
|
||||
form = DCAStrategyForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('target_currency', form.errors)
|
||||
|
||||
# Tests for DCAEntryForm clean method
|
||||
def test_dca_entry_form_clean_create_transaction_missing_accounts(self):
|
||||
data = {
|
||||
'date': date(2023, 1, 1),
|
||||
'amount_paid': Decimal('100.00'),
|
||||
'amount_received': Decimal('0.01'),
|
||||
'create_transaction': True,
|
||||
# from_account and to_account are missing
|
||||
}
|
||||
form = DCAEntryForm(data=data, strategy=self.strategy1, owner=self.owner)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('from_account', form.errors)
|
||||
self.assertIn('to_account', form.errors)
|
||||
|
||||
def test_dca_entry_form_clean_create_transaction_same_accounts(self):
|
||||
data = {
|
||||
'date': date(2023, 1, 1),
|
||||
'amount_paid': Decimal('100.00'),
|
||||
'amount_received': Decimal('0.01'),
|
||||
'create_transaction': True,
|
||||
'from_account': self.account1.pk,
|
||||
'to_account': self.account1.pk, # Same as from_account
|
||||
'from_category': self.category1.pk,
|
||||
'to_category': self.category1.pk,
|
||||
}
|
||||
form = DCAEntryForm(data=data, strategy=self.strategy1, owner=self.owner)
|
||||
self.assertFalse(form.is_valid())
|
||||
# Check for non-field error or specific field error based on form implementation
|
||||
self.assertTrue(NON_FIELD_ERRORS in form.errors or 'to_account' in form.errors)
|
||||
if NON_FIELD_ERRORS in form.errors:
|
||||
self.assertTrue(any("From and To accounts must be different" in error for error in form.errors[NON_FIELD_ERRORS]))
|
||||
|
||||
|
||||
# Tests for DCAEntryForm save method
|
||||
def test_dca_entry_form_save_create_transactions(self):
|
||||
data = {
|
||||
'date': date(2023, 5, 1),
|
||||
'amount_paid': Decimal('200.00'),
|
||||
'amount_received': Decimal('0.025'),
|
||||
'create_transaction': True,
|
||||
'from_account': self.account1.pk,
|
||||
'to_account': self.account2.pk,
|
||||
'from_category': self.category1.pk,
|
||||
'to_category': self.category1.pk,
|
||||
'description': 'Test DCA entry transaction creation'
|
||||
}
|
||||
form = DCAEntryForm(data=data, strategy=self.strategy1, owner=self.owner)
|
||||
|
||||
if not form.is_valid():
|
||||
print(form.errors.as_json()) # Print errors if form is invalid
|
||||
self.assertTrue(form.is_valid())
|
||||
|
||||
entry = form.save()
|
||||
|
||||
self.assertIsNotNone(entry.pk)
|
||||
self.assertEqual(entry.strategy, self.strategy1)
|
||||
self.assertIsNotNone(entry.expense_transaction)
|
||||
self.assertIsNotNone(entry.income_transaction)
|
||||
|
||||
# Check expense transaction
|
||||
expense_tx = entry.expense_transaction
|
||||
self.assertEqual(expense_tx.account, self.account1)
|
||||
self.assertEqual(expense_tx.type, Transaction.Type.EXPENSE)
|
||||
self.assertEqual(expense_tx.amount, data['amount_paid'])
|
||||
self.assertEqual(expense_tx.category, self.category1)
|
||||
self.assertEqual(expense_tx.owner, self.owner)
|
||||
self.assertEqual(expense_tx.date, data['date'])
|
||||
self.assertIn(str(entry.id)[:8], expense_tx.description) # Check if part of entry ID is in description
|
||||
|
||||
# Check income transaction
|
||||
income_tx = entry.income_transaction
|
||||
self.assertEqual(income_tx.account, self.account2)
|
||||
self.assertEqual(income_tx.type, Transaction.Type.INCOME)
|
||||
self.assertEqual(income_tx.amount, data['amount_received'])
|
||||
self.assertEqual(income_tx.category, self.category1)
|
||||
self.assertEqual(income_tx.owner, self.owner)
|
||||
self.assertEqual(income_tx.date, data['date'])
|
||||
self.assertIn(str(entry.id)[:8], income_tx.description)
|
||||
|
||||
|
||||
def test_dca_entry_form_save_update_linked_transactions(self):
|
||||
# 1. Create an initial DCAEntry with linked transactions
|
||||
initial_data = {
|
||||
'date': date(2023, 6, 1),
|
||||
'amount_paid': Decimal('50.00'),
|
||||
'amount_received': Decimal('0.005'),
|
||||
'create_transaction': True,
|
||||
'from_account': self.account1.pk,
|
||||
'to_account': self.account2.pk,
|
||||
'from_category': self.category1.pk,
|
||||
'to_category': self.category1.pk,
|
||||
}
|
||||
initial_form = DCAEntryForm(data=initial_data, strategy=self.strategy1, owner=self.owner)
|
||||
self.assertTrue(initial_form.is_valid(), initial_form.errors.as_json())
|
||||
initial_entry = initial_form.save()
|
||||
|
||||
self.assertIsNotNone(initial_entry.expense_transaction)
|
||||
self.assertIsNotNone(initial_entry.income_transaction)
|
||||
|
||||
# 2. Data for updating the form
|
||||
update_data = {
|
||||
'date': initial_entry.date, # Keep date same or change, as needed
|
||||
'amount_paid': Decimal('55.00'), # New value
|
||||
'amount_received': Decimal('0.006'), # New value
|
||||
# 'create_transaction': False, # Or not present, form should not create new if instance has linked tx
|
||||
'from_account': initial_entry.expense_transaction.account.pk, # Keep same accounts
|
||||
'to_account': initial_entry.income_transaction.account.pk,
|
||||
'from_category': initial_entry.expense_transaction.category.pk,
|
||||
'to_category': initial_entry.income_transaction.category.pk,
|
||||
}
|
||||
|
||||
# When create_transaction is not checked (or False), it means we are not creating *new* transactions,
|
||||
# but if the instance already has linked transactions, they *should* be updated.
|
||||
# The form's save method should handle this.
|
||||
|
||||
update_form = DCAEntryForm(data=update_data, instance=initial_entry, strategy=initial_entry.strategy, owner=self.owner)
|
||||
|
||||
if not update_form.is_valid():
|
||||
print(update_form.errors.as_json()) # Print errors if form is invalid
|
||||
self.assertTrue(update_form.is_valid())
|
||||
|
||||
updated_entry = update_form.save()
|
||||
|
||||
# Refresh from DB to ensure changes are saved and reflected
|
||||
updated_entry.refresh_from_db()
|
||||
if updated_entry.expense_transaction: # Check if it exists before trying to refresh
|
||||
updated_entry.expense_transaction.refresh_from_db()
|
||||
if updated_entry.income_transaction: # Check if it exists before trying to refresh
|
||||
updated_entry.income_transaction.refresh_from_db()
|
||||
|
||||
|
||||
self.assertEqual(updated_entry.amount_paid, Decimal('55.00'))
|
||||
self.assertEqual(updated_entry.amount_received, Decimal('0.006'))
|
||||
|
||||
self.assertIsNotNone(updated_entry.expense_transaction, "Expense transaction should still be linked.")
|
||||
self.assertEqual(updated_entry.expense_transaction.amount, Decimal('55.00'))
|
||||
|
||||
self.assertIsNotNone(updated_entry.income_transaction, "Income transaction should still be linked.")
|
||||
self.assertEqual(updated_entry.income_transaction.amount, Decimal('0.006'))
|
||||
# Create your tests here.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# apps/dca_tracker/views.py
|
||||
from django.contrib import messages
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.db.models import Sum, Avg
|
||||
@@ -23,7 +22,7 @@ def strategy_index(request):
|
||||
@only_htmx
|
||||
@login_required
|
||||
def strategy_list(request):
|
||||
strategies = DCAStrategy.objects.all().order_by("created_at")
|
||||
strategies = DCAStrategy.objects.all().order_by("name")
|
||||
return render(
|
||||
request, "dca/fragments/strategy/list.html", {"strategies": strategies}
|
||||
)
|
||||
@@ -234,7 +233,7 @@ def strategy_entry_add(request, strategy_id):
|
||||
if request.method == "POST":
|
||||
form = DCAEntryForm(request.POST, strategy=strategy)
|
||||
if form.is_valid():
|
||||
entry = form.save()
|
||||
form.save()
|
||||
messages.success(request, _("Entry added successfully"))
|
||||
|
||||
return HttpResponse(
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, HTML
|
||||
from crispy_forms.layout import HTML, Layout
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
|
||||
|
||||
class ExportForm(forms.Form):
|
||||
users = forms.BooleanField(
|
||||
@@ -115,9 +114,7 @@ class ExportForm(forms.Form):
|
||||
"dca",
|
||||
"import_profiles",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Export"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Export"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -162,7 +159,7 @@ class RestoreForm(forms.Form):
|
||||
self.helper.form_method = "post"
|
||||
self.helper.layout = Layout(
|
||||
"zip_file",
|
||||
HTML("<hr />"),
|
||||
HTML('<hr class="hr my-3"/>'),
|
||||
"users",
|
||||
"accounts",
|
||||
"currencies",
|
||||
@@ -181,9 +178,7 @@ class RestoreForm(forms.Form):
|
||||
"dca_entries",
|
||||
"import_profiles",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Restore"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Restore"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from import_export import fields, resources
|
||||
from import_export.widgets import ForeignKeyWidget
|
||||
|
||||
from apps.accounts.models import Account
|
||||
from apps.export_app.widgets.foreign_key import AutoCreateForeignKeyWidget
|
||||
from apps.export_app.widgets.foreign_key import (
|
||||
AllObjectsForeignKeyWidget,
|
||||
AutoCreateForeignKeyWidget,
|
||||
)
|
||||
from apps.export_app.widgets.many_to_many import AutoCreateManyToManyWidget
|
||||
from apps.export_app.widgets.string import EmptyStringToNoneField
|
||||
from apps.transactions.models import (
|
||||
@@ -20,7 +22,7 @@ class TransactionResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
widget=AllObjectsForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
@@ -86,7 +88,7 @@ class RecurringTransactionResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
widget=AllObjectsForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
@@ -119,12 +121,16 @@ class RecurringTransactionResource(resources.ModelResource):
|
||||
def get_queryset(self):
|
||||
return RecurringTransaction.all_objects.all()
|
||||
|
||||
def dehydrate_account_owner(self, obj):
|
||||
"""Export the account's owner ID for proper import matching."""
|
||||
return obj.account.owner_id if obj.account else None
|
||||
|
||||
|
||||
class InstallmentPlanResource(resources.ModelResource):
|
||||
account = fields.Field(
|
||||
attribute="account",
|
||||
column_name="account",
|
||||
widget=ForeignKeyWidget(Account, "name"),
|
||||
widget=AllObjectsForeignKeyWidget(Account, "name"),
|
||||
)
|
||||
|
||||
category = fields.Field(
|
||||
@@ -156,3 +162,7 @@ class InstallmentPlanResource(resources.ModelResource):
|
||||
|
||||
def get_queryset(self):
|
||||
return InstallmentPlan.all_objects.all()
|
||||
|
||||
def dehydrate_account_owner(self, obj):
|
||||
"""Export the account's owner ID for proper import matching."""
|
||||
return obj.account.owner_id if obj.account else None
|
||||
|
||||
@@ -1,164 +1,3 @@
|
||||
from django.test import TestCase, Client
|
||||
from django.contrib.auth.models import User
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from unittest.mock import patch, MagicMock
|
||||
from io import BytesIO
|
||||
import zipfile # Added for zip file creation
|
||||
from django.core.files.uploadedfile import InMemoryUploadedFile # Added for file upload testing
|
||||
from django.test import TestCase
|
||||
|
||||
# Dataset from tablib is not directly imported, its behavior will be mocked.
|
||||
# Resource classes are also mocked by path string.
|
||||
|
||||
from apps.export_app.forms import ExportForm, RestoreForm # Added RestoreForm
|
||||
|
||||
|
||||
class ExportAppTests(TestCase):
|
||||
def setUp(self):
|
||||
self.superuser = User.objects.create_superuser(
|
||||
username='super',
|
||||
email='super@example.com',
|
||||
password='password'
|
||||
)
|
||||
self.client = Client()
|
||||
self.client.login(username='super', password='password')
|
||||
|
||||
@patch('apps.export_app.views.UserResource')
|
||||
def test_export_form_single_selection_csv_response(self, mock_UserResource):
|
||||
# Configure the mock UserResource
|
||||
mock_user_resource_instance = mock_UserResource.return_value
|
||||
|
||||
# Mock the export() method's return value (which is a Dataset object)
|
||||
# Then, mock the 'csv' attribute of this Dataset object
|
||||
mock_dataset = MagicMock() # Using MagicMock for the dataset
|
||||
mock_dataset.csv = "user_id,username\n1,testuser"
|
||||
mock_user_resource_instance.export.return_value = mock_dataset
|
||||
|
||||
post_data = {'users': True} # Other fields default to False or their initial values
|
||||
|
||||
response = self.client.post(reverse('export_app:export_form'), data=post_data)
|
||||
|
||||
mock_user_resource_instance.export.assert_called_once()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response['Content-Type'], 'text/csv')
|
||||
self.assertIn("attachment; filename=", response['Content-Disposition'])
|
||||
self.assertIn(".csv", response['Content-Disposition'])
|
||||
# Check if the filename contains 'users'
|
||||
self.assertIn("users_export_", response['Content-Disposition'].lower())
|
||||
self.assertEqual(response.content.decode(), "user_id,username\n1,testuser")
|
||||
|
||||
@patch('apps.export_app.views.AccountResource') # Mock AccountResource first
|
||||
@patch('apps.export_app.views.UserResource') # Then UserResource
|
||||
def test_export_form_multiple_selections_zip_response(self, mock_UserResource, mock_AccountResource):
|
||||
# Configure UserResource mock
|
||||
mock_user_instance = mock_UserResource.return_value
|
||||
mock_user_dataset = MagicMock()
|
||||
mock_user_dataset.csv = "user_data_here"
|
||||
mock_user_instance.export.return_value = mock_user_dataset
|
||||
|
||||
# Configure AccountResource mock
|
||||
mock_account_instance = mock_AccountResource.return_value
|
||||
mock_account_dataset = MagicMock()
|
||||
mock_account_dataset.csv = "account_data_here"
|
||||
mock_account_instance.export.return_value = mock_account_dataset
|
||||
|
||||
post_data = {
|
||||
'users': True,
|
||||
'accounts': True
|
||||
# other fields default to False or their initial values
|
||||
}
|
||||
|
||||
response = self.client.post(reverse('export_app:export_form'), data=post_data)
|
||||
|
||||
mock_user_instance.export.assert_called_once()
|
||||
mock_account_instance.export.assert_called_once()
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response['Content-Type'], 'application/zip')
|
||||
self.assertIn("attachment; filename=", response['Content-Disposition'])
|
||||
self.assertIn(".zip", response['Content-Disposition'])
|
||||
# Add zip file content check if possible and required later
|
||||
|
||||
def test_export_form_no_selection(self):
|
||||
# Get all field names from ExportForm and set them to False
|
||||
# This ensures that if new export options are added, this test still tries to unselect them.
|
||||
form_fields = ExportForm.base_fields.keys()
|
||||
post_data = {field: False for field in form_fields}
|
||||
|
||||
response = self.client.post(reverse('export_app:export_form'), data=post_data)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
# The expected message is "You have to select at least one export"
|
||||
# This message is translatable, so using _() for comparison if the view returns translated string.
|
||||
# The view returns HttpResponse(_("You have to select at least one export"))
|
||||
self.assertEqual(response.content.decode('utf-8'), _("You have to select at least one export"))
|
||||
|
||||
# Placeholder for zip content check, if to be implemented
|
||||
# import zipfile
|
||||
# def test_zip_contents(self):
|
||||
# # ... (setup response with zip data) ...
|
||||
# with zipfile.ZipFile(BytesIO(response.content), 'r') as zipf:
|
||||
# self.assertIn('users.csv', zipf.namelist())
|
||||
# self.assertIn('accounts.csv', zipf.namelist())
|
||||
# user_csv_content = zipf.read('users.csv').decode()
|
||||
# self.assertEqual(user_csv_content, "user_data_here")
|
||||
# account_csv_content = zipf.read('accounts.csv').decode()
|
||||
# self.assertEqual(account_csv_content, "account_data_here")
|
||||
|
||||
@patch('apps.export_app.views.process_imports')
|
||||
def test_import_form_valid_zip_calls_process_imports(self, mock_process_imports):
|
||||
# Create a mock ZIP file content
|
||||
zip_content_buffer = BytesIO()
|
||||
with zipfile.ZipFile(zip_content_buffer, 'w') as zf:
|
||||
zf.writestr('dummy.csv', 'some,data')
|
||||
zip_content_buffer.seek(0)
|
||||
|
||||
# Create an InMemoryUploadedFile instance
|
||||
mock_zip_file = InMemoryUploadedFile(
|
||||
zip_content_buffer,
|
||||
'zip_file', # field_name
|
||||
'test_export.zip', # file_name
|
||||
'application/zip', # content_type
|
||||
zip_content_buffer.getbuffer().nbytes, # size
|
||||
None # charset
|
||||
)
|
||||
|
||||
post_data = {'zip_file': mock_zip_file}
|
||||
url = reverse('export_app:restore_form')
|
||||
|
||||
response = self.client.post(url, data=post_data, format='multipart')
|
||||
|
||||
mock_process_imports.assert_called_once()
|
||||
# Check the second argument passed to process_imports (the form's cleaned_data['zip_file'])
|
||||
# The first argument (args[0]) is the request object.
|
||||
# The second argument (args[1]) is the form instance.
|
||||
# We need to check the 'zip_file' attribute of the cleaned_data of the form instance.
|
||||
# However, it's simpler to check the UploadedFile object directly if that's what process_imports receives.
|
||||
# Based on the task: "The second argument to process_imports is form.cleaned_data['zip_file']"
|
||||
# This means that process_imports is called as process_imports(request, form.cleaned_data['zip_file'], ...)
|
||||
# Let's assume process_imports signature is process_imports(request, file_obj, ...)
|
||||
# So, call_args[0][1] would be the file_obj.
|
||||
|
||||
# Actually, the view calls process_imports(request, form)
|
||||
# So, we check form.cleaned_data['zip_file'] on the passed form instance
|
||||
called_form_instance = mock_process_imports.call_args[0][1] # The form instance
|
||||
self.assertEqual(called_form_instance.cleaned_data['zip_file'], mock_zip_file)
|
||||
|
||||
self.assertEqual(response.status_code, 204)
|
||||
# The HX-Trigger header might have multiple values, ensure both are present
|
||||
self.assertIn("hide_offcanvas", response.headers['HX-Trigger'])
|
||||
self.assertIn("updated", response.headers['HX-Trigger'])
|
||||
|
||||
|
||||
def test_import_form_no_file_selected(self):
|
||||
post_data = {} # No file selected
|
||||
url = reverse('export_app:restore_form')
|
||||
|
||||
response = self.client.post(url, data=post_data)
|
||||
|
||||
self.assertEqual(response.status_code, 200) # Form re-rendered with errors
|
||||
# Check that the specific error message from RestoreForm.clean() is present
|
||||
expected_error_message = _("Please upload either a ZIP file or at least one CSV file")
|
||||
self.assertContains(response, expected_error_message)
|
||||
# Also check for the HX-Trigger which is always set
|
||||
self.assertIn("updated", response.headers['HX-Trigger'])
|
||||
# Create your tests here.
|
||||
|
||||
@@ -1,6 +1,60 @@
|
||||
from import_export.widgets import ForeignKeyWidget
|
||||
|
||||
|
||||
class AllObjectsForeignKeyWidget(ForeignKeyWidget):
|
||||
"""
|
||||
ForeignKeyWidget that uses 'all_objects' manager for lookups,
|
||||
bypassing user-filtered managers like SharedObjectManager.
|
||||
Also filters by owner if available in the row data.
|
||||
"""
|
||||
|
||||
def get_queryset(self, value, row, *args, **kwargs):
|
||||
# Use all_objects manager if available, otherwise fall back to default
|
||||
if hasattr(self.model, "all_objects"):
|
||||
qs = self.model.all_objects.all()
|
||||
# Filter by owner if the row has an owner field and the model has owner
|
||||
if row:
|
||||
# Check for direct owner field first
|
||||
owner_id = row.get("owner") if "owner" in row else None
|
||||
# Fall back to account_owner for models like InstallmentPlan
|
||||
if not owner_id and "account_owner" in row:
|
||||
owner_id = row.get("account_owner")
|
||||
# If still no owner, try to get it from the existing record's account
|
||||
# This handles backward compatibility with older exports
|
||||
if not owner_id and "id" in row and row.get("id"):
|
||||
try:
|
||||
# Try to find the existing record and get owner from its account
|
||||
from apps.transactions.models import (
|
||||
InstallmentPlan,
|
||||
RecurringTransaction,
|
||||
)
|
||||
|
||||
record_id = row.get("id")
|
||||
# Try to find the existing InstallmentPlan or RecurringTransaction
|
||||
for model_class in [InstallmentPlan, RecurringTransaction]:
|
||||
try:
|
||||
existing = model_class.all_objects.get(id=record_id)
|
||||
if existing.account:
|
||||
owner_id = existing.account.owner_id
|
||||
break
|
||||
except model_class.DoesNotExist:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
# Final fallback: use the current logged-in user
|
||||
# This handles restoring to a fresh database with older exports
|
||||
if not owner_id:
|
||||
from apps.common.middleware.thread_local import get_current_user
|
||||
|
||||
user = get_current_user()
|
||||
if user and user.is_authenticated:
|
||||
owner_id = user.id
|
||||
if owner_id:
|
||||
qs = qs.filter(owner_id=owner_id)
|
||||
return qs
|
||||
return super().get_queryset(value, row, *args, **kwargs)
|
||||
|
||||
|
||||
class AutoCreateForeignKeyWidget(ForeignKeyWidget):
|
||||
def clean(self, value, row=None, *args, **kwargs):
|
||||
if value:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
from apps.import_app.models import ImportProfile
|
||||
from crispy_forms.bootstrap import FormActions
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import (
|
||||
@@ -6,9 +8,6 @@ from crispy_forms.layout import (
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.import_app.models import ImportProfile
|
||||
from apps.common.widgets.crispy.submit import NoClassSubmit
|
||||
|
||||
|
||||
class ImportProfileForm(forms.ModelForm):
|
||||
class Meta:
|
||||
@@ -30,17 +29,13 @@ class ImportProfileForm(forms.ModelForm):
|
||||
if self.instance and self.instance.pk:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Update"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Update"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.helper.layout.append(
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Add"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Add"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -57,8 +52,6 @@ class ImportRunFileUploadForm(forms.Form):
|
||||
self.helper.layout = Layout(
|
||||
"file",
|
||||
FormActions(
|
||||
NoClassSubmit(
|
||||
"submit", _("Import"), css_class="btn btn-outline-primary w-100"
|
||||
),
|
||||
NoClassSubmit("submit", _("Import"), css_class="btn btn-primary"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -459,12 +459,13 @@ class ImportService:
|
||||
# Build query conditions for each field in the rule
|
||||
for field in rule.fields:
|
||||
if field in transaction_data:
|
||||
if rule.match_type == "strict":
|
||||
query = query.filter(**{field: transaction_data[field]})
|
||||
else: # lax matching
|
||||
query = query.filter(
|
||||
**{f"{field}__iexact": transaction_data[field]}
|
||||
)
|
||||
value = transaction_data[field]
|
||||
# Use __iexact only for string fields; non-string types
|
||||
# (date, Decimal, bool, int, etc.) don't support UPPER()
|
||||
if rule.match_type == "strict" or not isinstance(value, str):
|
||||
query = query.filter(**{field: value})
|
||||
else: # lax matching for strings only
|
||||
query = query.filter(**{f"{field}__iexact": value})
|
||||
|
||||
# If we found any matching transaction, it's a duplicate
|
||||
if query.exists():
|
||||
@@ -475,11 +476,27 @@ class ImportService:
|
||||
def _coerce_type(
|
||||
self, value: str, mapping: version_1.ColumnMapping
|
||||
) -> Union[str, int, bool, Decimal, datetime, list, None]:
|
||||
coerce_to = mapping.coerce_to
|
||||
|
||||
# Handle detection methods that don't require a source value
|
||||
if coerce_to == "transaction_type" and isinstance(
|
||||
mapping, version_1.TransactionTypeMapping
|
||||
):
|
||||
if mapping.detection_method == "always_income":
|
||||
return Transaction.Type.INCOME
|
||||
elif mapping.detection_method == "always_expense":
|
||||
return Transaction.Type.EXPENSE
|
||||
elif coerce_to == "is_paid" and isinstance(
|
||||
mapping, version_1.TransactionIsPaidMapping
|
||||
):
|
||||
if mapping.detection_method == "always_paid":
|
||||
return True
|
||||
elif mapping.detection_method == "always_unpaid":
|
||||
return False
|
||||
|
||||
if not value:
|
||||
return None
|
||||
|
||||
coerce_to = mapping.coerce_to
|
||||
|
||||
return self._coerce_single_type(value, coerce_to, mapping)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,424 +0,0 @@
|
||||
from django.test import TestCase
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
import yaml
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from django.test import TestCase
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import IntegrityError
|
||||
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
from apps.import_app.forms import ImportProfileForm
|
||||
from apps.import_app.services.v1 import ImportService
|
||||
from apps.import_app.schemas import version_1
|
||||
from apps.transactions.models import Transaction # For Transaction.Type
|
||||
from unittest.mock import patch
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
class ImportProfileTests(TestCase):
|
||||
|
||||
def test_import_profile_valid_yaml_v1(self):
|
||||
valid_yaml_config = """
|
||||
settings:
|
||||
file_type: csv
|
||||
delimiter: ','
|
||||
encoding: utf-8
|
||||
skip_lines: 0
|
||||
trigger_transaction_rules: true
|
||||
importing: transactions
|
||||
mapping:
|
||||
date:
|
||||
target: date
|
||||
source: Transaction Date
|
||||
format: '%Y-%m-%d'
|
||||
amount:
|
||||
target: amount
|
||||
source: Amount
|
||||
description:
|
||||
target: description
|
||||
source: Narrative
|
||||
account:
|
||||
target: account
|
||||
source: Account Name
|
||||
type: name
|
||||
type:
|
||||
target: type
|
||||
source: Credit Debit
|
||||
detection_method: sign # Assumes positive is income, negative is expense
|
||||
is_paid:
|
||||
target: is_paid
|
||||
detection_method: always_paid
|
||||
deduplication: []
|
||||
"""
|
||||
profile = ImportProfile(
|
||||
name="Test Valid Profile V1",
|
||||
yaml_config=valid_yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
try:
|
||||
profile.full_clean()
|
||||
except ValidationError as e:
|
||||
self.fail(f"Valid YAML config raised ValidationError: {e.error_dict}")
|
||||
|
||||
# Optional: Save and retrieve
|
||||
profile.save()
|
||||
retrieved_profile = ImportProfile.objects.get(pk=profile.pk)
|
||||
self.assertIsNotNone(retrieved_profile)
|
||||
self.assertEqual(retrieved_profile.name, "Test Valid Profile V1")
|
||||
|
||||
def test_import_profile_invalid_yaml_syntax_v1(self):
|
||||
invalid_yaml = "settings: { file_type: csv, delimiter: ','" # Malformed YAML
|
||||
profile = ImportProfile(
|
||||
name="Test Invalid Syntax V1",
|
||||
yaml_config=invalid_yaml,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
profile.full_clean()
|
||||
|
||||
self.assertIn('yaml_config', cm.exception.error_dict)
|
||||
self.assertTrue(any("YAML" in error.message.lower() or "syntax" in error.message.lower() for error in cm.exception.error_dict['yaml_config']))
|
||||
|
||||
def test_import_profile_schema_validation_error_v1(self):
|
||||
schema_error_yaml = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
mapping:
|
||||
date: # Missing 'format' which is required for TransactionDateMapping
|
||||
target: date
|
||||
source: Transaction Date
|
||||
"""
|
||||
profile = ImportProfile(
|
||||
name="Test Schema Error V1",
|
||||
yaml_config=schema_error_yaml,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
profile.full_clean()
|
||||
|
||||
self.assertIn('yaml_config', cm.exception.error_dict)
|
||||
# Pydantic errors usually mention the field and "field required" or similar
|
||||
self.assertTrue(any("format" in error.message.lower() and "field required" in error.message.lower()
|
||||
for error in cm.exception.error_dict['yaml_config']),
|
||||
f"Error messages: {[e.message for e in cm.exception.error_dict['yaml_config']]}")
|
||||
|
||||
|
||||
def test_import_profile_custom_validate_mappings_error_v1(self):
|
||||
custom_validate_yaml = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions # Importing transactions
|
||||
mapping:
|
||||
account_name: # This is an AccountNameMapping, not suitable for 'transactions' importing setting
|
||||
target: account_name
|
||||
source: AccName
|
||||
"""
|
||||
profile = ImportProfile(
|
||||
name="Test Custom Validate Error V1",
|
||||
yaml_config=custom_validate_yaml,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
profile.full_clean()
|
||||
|
||||
self.assertIn('yaml_config', cm.exception.error_dict)
|
||||
# Check for the specific message raised by custom_validate_mappings
|
||||
# The message is "Mapping type AccountNameMapping not allowed for importing 'transactions'."
|
||||
self.assertTrue(any("mapping type accountnamemapping not allowed for importing 'transactions'" in error.message.lower()
|
||||
for error in cm.exception.error_dict['yaml_config']),
|
||||
f"Error messages: {[e.message for e in cm.exception.error_dict['yaml_config']]}")
|
||||
|
||||
|
||||
def test_import_profile_name_unique(self):
|
||||
valid_yaml_config = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
mapping:
|
||||
date:
|
||||
target: date
|
||||
source: Date
|
||||
format: '%Y-%m-%d'
|
||||
""" # Minimal valid YAML for this test
|
||||
|
||||
ImportProfile.objects.create(
|
||||
name="Unique Name Test",
|
||||
yaml_config=valid_yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
|
||||
profile2 = ImportProfile(
|
||||
name="Unique Name Test", # Same name
|
||||
yaml_config=valid_yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
|
||||
# full_clean should catch this because of the unique constraint on the model field.
|
||||
# Django's Model.full_clean() calls Model.validate_unique().
|
||||
with self.assertRaises(ValidationError) as cm:
|
||||
profile2.full_clean()
|
||||
|
||||
self.assertIn('name', cm.exception.error_dict)
|
||||
self.assertTrue(any("already exists" in error.message.lower() for error in cm.exception.error_dict['name']))
|
||||
|
||||
# As a fallback, or for more direct DB constraint testing, also test IntegrityError on save if full_clean didn't catch it.
|
||||
# This will only be reached if the full_clean() above somehow passes.
|
||||
# try:
|
||||
# profile2.save()
|
||||
# except IntegrityError:
|
||||
# pass # Expected if full_clean didn't catch it
|
||||
# else:
|
||||
# if 'name' not in cm.exception.error_dict: # If full_clean passed and save also passed
|
||||
# self.fail("IntegrityError not raised for duplicate name on save(), and full_clean() didn't catch it.")
|
||||
|
||||
def test_import_profile_form_valid_data(self):
|
||||
valid_yaml_config = """
|
||||
settings:
|
||||
file_type: csv
|
||||
delimiter: ','
|
||||
encoding: utf-8
|
||||
skip_lines: 0
|
||||
trigger_transaction_rules: true
|
||||
importing: transactions
|
||||
mapping:
|
||||
date:
|
||||
target: date
|
||||
source: Transaction Date
|
||||
format: '%Y-%m-%d'
|
||||
amount:
|
||||
target: amount
|
||||
source: Amount
|
||||
description:
|
||||
target: description
|
||||
source: Narrative
|
||||
account:
|
||||
target: account
|
||||
source: Account Name
|
||||
type: name
|
||||
type:
|
||||
target: type
|
||||
source: Credit Debit
|
||||
detection_method: sign
|
||||
is_paid:
|
||||
target: is_paid
|
||||
detection_method: always_paid
|
||||
deduplication: []
|
||||
"""
|
||||
form_data = {
|
||||
'name': 'Form Test Valid',
|
||||
'yaml_config': valid_yaml_config,
|
||||
'version': ImportProfile.Versions.VERSION_1
|
||||
}
|
||||
form = ImportProfileForm(data=form_data)
|
||||
self.assertTrue(form.is_valid(), f"Form errors: {form.errors.as_json()}")
|
||||
|
||||
profile = form.save()
|
||||
self.assertIsNotNone(profile.pk)
|
||||
self.assertEqual(profile.name, 'Form Test Valid')
|
||||
# YAMLField might re-serialize the YAML, so direct string comparison might be brittle
|
||||
# if spacing/ordering changes. However, for now, let's assume it's stored as provided or close enough.
|
||||
# A more robust check would be to load both YAMLs and compare the resulting dicts.
|
||||
self.assertEqual(profile.yaml_config.strip(), valid_yaml_config.strip())
|
||||
self.assertEqual(profile.version, ImportProfile.Versions.VERSION_1)
|
||||
|
||||
def test_import_profile_form_invalid_yaml(self):
|
||||
# Using a YAML that causes a schema validation error (missing 'format' for date mapping)
|
||||
invalid_yaml_for_form = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
mapping:
|
||||
date:
|
||||
target: date
|
||||
source: Transaction Date
|
||||
"""
|
||||
form_data = {
|
||||
'name': 'Form Test Invalid',
|
||||
'yaml_config': invalid_yaml_for_form,
|
||||
'version': ImportProfile.Versions.VERSION_1
|
||||
}
|
||||
form = ImportProfileForm(data=form_data)
|
||||
self.assertFalse(form.is_valid())
|
||||
self.assertIn('yaml_config', form.errors)
|
||||
# Check for a message indicating schema validation failure
|
||||
self.assertTrue(any("field required" in error.lower() for error in form.errors['yaml_config']))
|
||||
|
||||
|
||||
class ImportServiceTests(TestCase):
|
||||
# ... (existing setUp and other test methods from previous task) ...
|
||||
def setUp(self):
|
||||
minimal_yaml_config = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
mapping:
|
||||
description:
|
||||
target: description
|
||||
source: Desc
|
||||
"""
|
||||
self.profile = ImportProfile.objects.create(
|
||||
name="Test Service Profile",
|
||||
yaml_config=minimal_yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1
|
||||
)
|
||||
self.import_run = ImportRun.objects.create(
|
||||
profile=self.profile,
|
||||
status=ImportRun.Status.PENDING
|
||||
)
|
||||
# self.service is initialized in each test to allow specific mapping_config
|
||||
# or to re-initialize if service state changes (though it shouldn't for these private methods)
|
||||
|
||||
# Tests for _transform_value
|
||||
def test_transform_value_replace(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.ColumnMapping(target="description", source="Desc") # Basic mapping
|
||||
mapping_config.transformations = [
|
||||
version_1.ReplaceTransformationRule(type="replace", pattern="old", replacement="new")
|
||||
]
|
||||
transformed_value = service._transform_value("this is old text", mapping_config)
|
||||
self.assertEqual(transformed_value, "this is new text")
|
||||
|
||||
def test_transform_value_date_format(self):
|
||||
service = ImportService(self.import_run)
|
||||
# DateFormatTransformationRule is typically part of a DateMapping, but testing transform directly
|
||||
mapping_config = version_1.TransactionDateMapping(target="date", source="Date", format="%d/%m/%Y") # format is for final coercion
|
||||
mapping_config.transformations = [
|
||||
version_1.DateFormatTransformationRule(type="date_format", original_format="%Y-%m-%d", new_format="%d/%m/%Y")
|
||||
]
|
||||
transformed_value = service._transform_value("2023-01-15", mapping_config)
|
||||
self.assertEqual(transformed_value, "15/01/2023")
|
||||
|
||||
def test_transform_value_regex_replace(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.ColumnMapping(target="description", source="Desc")
|
||||
mapping_config.transformations = [
|
||||
version_1.ReplaceTransformationRule(type="regex", pattern=r"\\d+", replacement="NUM")
|
||||
]
|
||||
transformed_value = service._transform_value("abc123xyz456", mapping_config)
|
||||
self.assertEqual(transformed_value, "abcNUMxyzNUM")
|
||||
|
||||
# Tests for _coerce_type
|
||||
def test_coerce_type_string_to_decimal(self):
|
||||
service = ImportService(self.import_run)
|
||||
# TransactionAmountMapping has coerce_to="positive_decimal" by default
|
||||
mapping_config = version_1.TransactionAmountMapping(target="amount", source="Amt")
|
||||
|
||||
coerced = service._coerce_type("123.45", mapping_config)
|
||||
self.assertEqual(coerced, Decimal("123.45"))
|
||||
|
||||
coerced_neg = service._coerce_type("-123.45", mapping_config)
|
||||
self.assertEqual(coerced_neg, Decimal("123.45")) # positive_decimal behavior
|
||||
|
||||
# Test with coerce_to="decimal"
|
||||
mapping_config_decimal = version_1.TransactionAmountMapping(target="amount", source="Amt", coerce_to="decimal")
|
||||
coerced_neg_decimal = service._coerce_type("-123.45", mapping_config_decimal)
|
||||
self.assertEqual(coerced_neg_decimal, Decimal("-123.45"))
|
||||
|
||||
|
||||
def test_coerce_type_string_to_date(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.TransactionDateMapping(target="date", source="Dt", format="%Y-%m-%d")
|
||||
coerced = service._coerce_type("2023-01-15", mapping_config)
|
||||
self.assertEqual(coerced, date(2023, 1, 15))
|
||||
|
||||
def test_coerce_type_string_to_transaction_type_sign(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.TransactionTypeMapping(target="type", source="TType", detection_method="sign")
|
||||
|
||||
self.assertEqual(service._coerce_type("100.00", mapping_config), Transaction.Type.INCOME)
|
||||
self.assertEqual(service._coerce_type("-100.00", mapping_config), Transaction.Type.EXPENSE)
|
||||
self.assertEqual(service._coerce_type("0.00", mapping_config), Transaction.Type.EXPENSE) # Sign detection treats 0 as expense
|
||||
self.assertEqual(service._coerce_type("+200", mapping_config), Transaction.Type.INCOME)
|
||||
|
||||
def test_coerce_type_string_to_transaction_type_keywords(self):
|
||||
service = ImportService(self.import_run)
|
||||
mapping_config = version_1.TransactionTypeMapping(
|
||||
target="type",
|
||||
source="TType",
|
||||
detection_method="keywords",
|
||||
income_keywords=["credit", "dep"],
|
||||
expense_keywords=["debit", "wdrl"]
|
||||
)
|
||||
self.assertEqual(service._coerce_type("Monthly Credit", mapping_config), Transaction.Type.INCOME)
|
||||
self.assertEqual(service._coerce_type("ATM WDRL", mapping_config), Transaction.Type.EXPENSE)
|
||||
self.assertIsNone(service._coerce_type("Unknown Type", mapping_config)) # No keyword match
|
||||
|
||||
@patch('apps.import_app.services.v1.os.remove')
|
||||
def test_process_file_simple_csv_transactions(self, mock_os_remove):
|
||||
simple_transactions_yaml = """
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
delimiter: ','
|
||||
skip_lines: 0
|
||||
mapping:
|
||||
date: {target: date, source: Date, format: '%Y-%m-%d'}
|
||||
amount: {target: amount, source: Amount}
|
||||
description: {target: description, source: Description}
|
||||
type: {target: type, source: Type, detection_method: always_income}
|
||||
account: {target: account, source: AccountName, type: name}
|
||||
"""
|
||||
self.profile.yaml_config = simple_transactions_yaml
|
||||
self.profile.save()
|
||||
self.import_run.refresh_from_db() # Ensure import_run has the latest profile reference if needed
|
||||
|
||||
csv_content = "Date,Amount,Description,Type,AccountName\n2023-01-01,100.00,Test Deposit,INCOME,TestAcc"
|
||||
|
||||
temp_file_path = None
|
||||
try:
|
||||
# Ensure TEMP_DIR exists if ImportService relies on it being pre-existing
|
||||
# For NamedTemporaryFile, dir just needs to be a valid directory path.
|
||||
# If ImportService.TEMP_DIR is a class variable pointing to a specific path,
|
||||
# it should be created or mocked if it doesn't exist by default.
|
||||
# For simplicity, let's assume it exists or tempfile handles it gracefully.
|
||||
# If ImportService.TEMP_DIR is not guaranteed, use default temp dir.
|
||||
temp_dir = getattr(ImportService, 'TEMP_DIR', None)
|
||||
if temp_dir and not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False, dir=temp_dir, suffix='.csv', encoding='utf-8') as tmp_file:
|
||||
tmp_file.write(csv_content)
|
||||
temp_file_path = tmp_file.name
|
||||
|
||||
self.addCleanup(lambda: os.remove(temp_file_path) if temp_file_path and os.path.exists(temp_file_path) else None)
|
||||
|
||||
service = ImportService(self.import_run)
|
||||
|
||||
with patch.object(service, '_create_transaction') as mock_create_transaction:
|
||||
service.process_file(temp_file_path)
|
||||
|
||||
self.import_run.refresh_from_db() # Refresh to get updated status and counts
|
||||
self.assertEqual(self.import_run.status, ImportRun.Status.FINISHED)
|
||||
self.assertEqual(self.import_run.total_rows, 1)
|
||||
self.assertEqual(self.import_run.successful_rows, 1)
|
||||
|
||||
mock_create_transaction.assert_called_once()
|
||||
|
||||
# The first argument to _create_transaction is the row_data dictionary
|
||||
args_dict = mock_create_transaction.call_args[0][0]
|
||||
|
||||
self.assertEqual(args_dict['date'], date(2023, 1, 1))
|
||||
self.assertEqual(args_dict['amount'], Decimal('100.00'))
|
||||
self.assertEqual(args_dict['description'], "Test Deposit")
|
||||
self.assertEqual(args_dict['type'], Transaction.Type.INCOME)
|
||||
|
||||
# Account 'TestAcc' does not exist, so _map_row should resolve 'account' to None.
|
||||
# This assumes the default behavior of AccountMapping(type='name') when an account is not found
|
||||
# and creation of new accounts from mapping is not enabled/implemented in _map_row for this test.
|
||||
self.assertIsNone(args_dict.get('account'),
|
||||
"Account should be None as 'TestAcc' is not created in this test setup.")
|
||||
|
||||
mock_os_remove.assert_called_once_with(temp_file_path)
|
||||
|
||||
finally:
|
||||
# This cleanup is now handled by self.addCleanup, but kept for safety if addCleanup fails early.
|
||||
if temp_file_path and os.path.exists(temp_file_path) and not mock_os_remove.called:
|
||||
# If mock_os_remove was not called (e.g., an error before service.process_file finished),
|
||||
# we might need to manually clean up if addCleanup didn't register or run.
|
||||
# However, addCleanup is generally robust.
|
||||
pass
|
||||
275
app/apps/import_app/tests/test_import_service_v1.py
Normal file
275
app/apps/import_app/tests/test_import_service_v1.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Tests for ImportService v1, specifically for deduplication logic.
|
||||
|
||||
These tests verify that the _check_duplicate_transaction method handles
|
||||
different field types correctly, particularly ensuring that __iexact
|
||||
is only used for string fields (not dates, decimals, etc.).
|
||||
"""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.import_app.models import ImportProfile, ImportRun
|
||||
from apps.import_app.services.v1 import ImportService
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
class DeduplicationTests(TestCase):
|
||||
"""Tests for transaction deduplication during import."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data."""
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Test Account", group=self.account_group, currency=self.currency
|
||||
)
|
||||
|
||||
# Create an existing transaction for deduplication tests
|
||||
self.existing_transaction = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
date=date(2024, 1, 15),
|
||||
amount=Decimal("100.00"),
|
||||
description="Existing Transaction",
|
||||
internal_id="ABC123",
|
||||
)
|
||||
|
||||
def _create_import_service_with_deduplication(
|
||||
self, fields: list[str], match_type: str = "lax"
|
||||
) -> ImportService:
|
||||
"""Helper to create an ImportService with specific deduplication rules."""
|
||||
yaml_config = f"""
|
||||
settings:
|
||||
file_type: csv
|
||||
importing: transactions
|
||||
trigger_transaction_rules: false
|
||||
mapping:
|
||||
date_field:
|
||||
source: date
|
||||
target: date
|
||||
format: "%Y-%m-%d"
|
||||
amount_field:
|
||||
source: amount
|
||||
target: amount
|
||||
description_field:
|
||||
source: description
|
||||
target: description
|
||||
account_field:
|
||||
source: account
|
||||
target: account
|
||||
type: id
|
||||
deduplication:
|
||||
- type: compare
|
||||
fields: {fields}
|
||||
match_type: {match_type}
|
||||
"""
|
||||
profile = ImportProfile.objects.create(
|
||||
name=f"Test Profile {match_type} {'_'.join(fields)}",
|
||||
yaml_config=yaml_config,
|
||||
version=ImportProfile.Versions.VERSION_1,
|
||||
)
|
||||
import_run = ImportRun.objects.create(
|
||||
profile=profile,
|
||||
file_name="test.csv",
|
||||
)
|
||||
return ImportService(import_run)
|
||||
|
||||
def test_deduplication_with_date_field_strict_match(self):
|
||||
"""Test that date fields work with strict matching."""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date"], match_type="strict"
|
||||
)
|
||||
|
||||
# Should find duplicate when date matches
|
||||
is_duplicate = service._check_duplicate_transaction({"date": date(2024, 1, 15)})
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when date differs
|
||||
is_duplicate = service._check_duplicate_transaction({"date": date(2024, 2, 20)})
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_date_field_lax_match(self):
|
||||
"""
|
||||
Test that date fields use strict matching even when match_type is 'lax'.
|
||||
|
||||
This is the fix for the UPPER(date) PostgreSQL error. Date fields
|
||||
cannot use __iexact, so they should fall back to strict matching.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate when date matches (using strict comparison)
|
||||
is_duplicate = service._check_duplicate_transaction({"date": date(2024, 1, 15)})
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when date differs
|
||||
is_duplicate = service._check_duplicate_transaction({"date": date(2024, 2, 20)})
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_amount_field_lax_match(self):
|
||||
"""
|
||||
Test that Decimal fields use strict matching even when match_type is 'lax'.
|
||||
|
||||
Decimal fields cannot use __iexact, so they should fall back to strict matching.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["amount"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate when amount matches
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"amount": Decimal("100.00")}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when amount differs
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"amount": Decimal("200.00")}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_string_field_lax_match(self):
|
||||
"""
|
||||
Test that string fields use case-insensitive matching with match_type 'lax'.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["description"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate with case-insensitive match
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "EXISTING TRANSACTION"}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should find duplicate with exact case match
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "Existing Transaction"}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when description differs
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "Different Transaction"}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_string_field_strict_match(self):
|
||||
"""
|
||||
Test that string fields use case-sensitive matching with match_type 'strict'.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["description"], match_type="strict"
|
||||
)
|
||||
|
||||
# Should NOT find duplicate with different case (strict matching)
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "EXISTING TRANSACTION"}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
# Should find duplicate with exact case match
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"description": "Existing Transaction"}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
def test_deduplication_with_multiple_fields_mixed_types(self):
|
||||
"""
|
||||
Test deduplication with multiple fields of different types.
|
||||
|
||||
Verifies that string fields use __iexact while non-string fields
|
||||
use strict matching, all in the same deduplication rule.
|
||||
"""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date", "amount", "description"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate when all fields match (with case-insensitive description)
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
"amount": Decimal("100.00"),
|
||||
"description": "existing transaction", # lowercase should match
|
||||
}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should NOT find duplicate when date differs
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 2, 20),
|
||||
"amount": Decimal("100.00"),
|
||||
"description": "existing transaction",
|
||||
}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
# Should NOT find duplicate when amount differs
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
"amount": Decimal("999.99"),
|
||||
"description": "existing transaction",
|
||||
}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_internal_id_lax_match(self):
|
||||
"""Test deduplication with internal_id field using lax matching."""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["internal_id"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should find duplicate with case-insensitive match
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{"internal_id": "abc123"} # lowercase should match ABC123
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should find duplicate with exact match
|
||||
is_duplicate = service._check_duplicate_transaction({"internal_id": "ABC123"})
|
||||
self.assertTrue(is_duplicate)
|
||||
|
||||
# Should not find duplicate when internal_id differs
|
||||
is_duplicate = service._check_duplicate_transaction({"internal_id": "XYZ789"})
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_no_duplicate_when_no_transactions_exist(self):
|
||||
"""Test that no duplicate is found when there are no matching transactions."""
|
||||
# Hard delete to bypass signals that require user context
|
||||
self.existing_transaction.hard_delete()
|
||||
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date", "amount"], match_type="lax"
|
||||
)
|
||||
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
"amount": Decimal("100.00"),
|
||||
}
|
||||
)
|
||||
self.assertFalse(is_duplicate)
|
||||
|
||||
def test_deduplication_with_missing_field_in_data(self):
|
||||
"""Test that missing fields in transaction_data are handled gracefully."""
|
||||
service = self._create_import_service_with_deduplication(
|
||||
fields=["date", "nonexistent_field"], match_type="lax"
|
||||
)
|
||||
|
||||
# Should still work, only checking the fields that exist
|
||||
is_duplicate = service._check_duplicate_transaction(
|
||||
{
|
||||
"date": date(2024, 1, 15),
|
||||
}
|
||||
)
|
||||
self.assertTrue(is_duplicate)
|
||||
@@ -1,15 +1,14 @@
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Layout, Field, Row, Column
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from apps.common.widgets.datepicker import (
|
||||
AirDatePickerInput,
|
||||
AirMonthYearPickerInput,
|
||||
AirYearPickerInput,
|
||||
AirDatePickerInput,
|
||||
)
|
||||
from apps.transactions.models import TransactionCategory
|
||||
from apps.common.widgets.tom_select import TomSelect
|
||||
from apps.transactions.models import TransactionCategory
|
||||
from crispy_forms.helper import FormHelper
|
||||
from crispy_forms.layout import Column, Field, Layout, Row
|
||||
from django import forms
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class SingleMonthForm(forms.Form):
|
||||
@@ -59,8 +58,8 @@ class MonthRangeForm(forms.Form):
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("month_from", css_class="form-group col-md-6"),
|
||||
Column("month_to", css_class="form-group col-md-6"),
|
||||
Column("month_from"),
|
||||
Column("month_to"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -82,8 +81,8 @@ class YearRangeForm(forms.Form):
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("year_from", css_class="form-group col-md-6"),
|
||||
Column("year_to", css_class="form-group col-md-6"),
|
||||
Column("year_from"),
|
||||
Column("year_to"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -105,8 +104,8 @@ class DateRangeForm(forms.Form):
|
||||
|
||||
self.helper.layout = Layout(
|
||||
Row(
|
||||
Column("date_from", css_class="form-group col-md-6"),
|
||||
Column("date_to", css_class="form-group col-md-6"),
|
||||
Column("date_from"),
|
||||
Column("date_to"),
|
||||
css_class="mb-0",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,303 +1,3 @@
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import User
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from decimal import Decimal
|
||||
from datetime import date, timedelta
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, Transaction
|
||||
from apps.insights.utils.category_explorer import get_category_sums_by_account, get_category_sums_by_currency
|
||||
from apps.insights.utils.sankey import generate_sankey_data_by_account
|
||||
|
||||
class InsightsUtilsTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username='testinsightsuser', password='password')
|
||||
|
||||
self.currency_usd = Currency.objects.create(code="USD", name="US Dollar", decimal_places=2)
|
||||
self.currency_eur = Currency.objects.create(code="EUR", name="Euro", decimal_places=2)
|
||||
|
||||
# It's good practice to have an AccountGroup for accounts
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group", owner=self.user)
|
||||
|
||||
self.category_food = TransactionCategory.objects.create(name="Food", owner=self.user, type=TransactionCategory.TransactionType.EXPENSE)
|
||||
self.category_salary = TransactionCategory.objects.create(name="Salary", owner=self.user, type=TransactionCategory.TransactionType.INCOME)
|
||||
|
||||
self.account_usd_1 = Account.objects.create(name="USD Account 1", owner=self.user, currency=self.currency_usd, group=self.account_group)
|
||||
self.account_usd_2 = Account.objects.create(name="USD Account 2", owner=self.user, currency=self.currency_usd, group=self.account_group)
|
||||
self.account_eur_1 = Account.objects.create(name="EUR Account 1", owner=self.user, currency=self.currency_eur, group=self.account_group)
|
||||
|
||||
today = date.today()
|
||||
|
||||
# T1: Acc USD1, Food, Expense 50 (paid)
|
||||
Transaction.objects.create(
|
||||
description="Groceries USD1 Food Paid", account=self.account_usd_1, category=self.category_food,
|
||||
type=Transaction.Type.EXPENSE, amount=Decimal('50.00'), date=today, is_paid=True, owner=self.user
|
||||
)
|
||||
# T2: Acc USD1, Food, Expense 20 (unpaid/projected)
|
||||
Transaction.objects.create(
|
||||
description="Restaurant USD1 Food Unpaid", account=self.account_usd_1, category=self.category_food,
|
||||
type=Transaction.Type.EXPENSE, amount=Decimal('20.00'), date=today, is_paid=False, owner=self.user
|
||||
)
|
||||
# T3: Acc USD2, Food, Expense 30 (paid)
|
||||
Transaction.objects.create(
|
||||
description="Snacks USD2 Food Paid", account=self.account_usd_2, category=self.category_food,
|
||||
type=Transaction.Type.EXPENSE, amount=Decimal('30.00'), date=today, is_paid=True, owner=self.user
|
||||
)
|
||||
# T4: Acc USD1, Salary, Income 1000 (paid)
|
||||
Transaction.objects.create(
|
||||
description="Salary USD1 Paid", account=self.account_usd_1, category=self.category_salary,
|
||||
type=Transaction.Type.INCOME, amount=Decimal('1000.00'), date=today, is_paid=True, owner=self.user
|
||||
)
|
||||
# T5: Acc EUR1, Food, Expense 40 (paid, different currency)
|
||||
Transaction.objects.create(
|
||||
description="Groceries EUR1 Food Paid", account=self.account_eur_1, category=self.category_food,
|
||||
type=Transaction.Type.EXPENSE, amount=Decimal('40.00'), date=today, is_paid=True, owner=self.user
|
||||
)
|
||||
# T6: Acc USD2, Salary, Income 200 (unpaid/projected)
|
||||
Transaction.objects.create(
|
||||
description="Bonus USD2 Salary Unpaid", account=self.account_usd_2, category=self.category_salary,
|
||||
type=Transaction.Type.INCOME, amount=Decimal('200.00'), date=today, is_paid=False, owner=self.user
|
||||
)
|
||||
|
||||
def test_get_category_sums_by_account_for_food(self):
|
||||
qs = Transaction.objects.filter(owner=self.user) # Filter by user for safety in shared DB environments
|
||||
result = get_category_sums_by_account(qs, category=self.category_food)
|
||||
|
||||
expected_labels = sorted([self.account_eur_1.name, self.account_usd_1.name, self.account_usd_2.name])
|
||||
self.assertEqual(result['labels'], expected_labels)
|
||||
|
||||
# Expected data structure: {account_name: {'current_income': D('0'), ...}, ...}
|
||||
# Then the util function transforms this.
|
||||
# Let's map labels to their expected index for easier assertion
|
||||
label_to_idx = {name: i for i, name in enumerate(expected_labels)}
|
||||
|
||||
# Initialize expected data arrays based on sorted labels length
|
||||
num_labels = len(expected_labels)
|
||||
expected_current_income = [Decimal('0.00')] * num_labels
|
||||
expected_current_expenses = [Decimal('0.00')] * num_labels
|
||||
expected_projected_income = [Decimal('0.00')] * num_labels
|
||||
expected_projected_expenses = [Decimal('0.00')] * num_labels
|
||||
|
||||
# Populate expected data based on transactions for FOOD category
|
||||
# T1: Acc USD1, Food, Expense 50 (paid) -> account_usd_1, current_expenses = -50
|
||||
expected_current_expenses[label_to_idx[self.account_usd_1.name]] = Decimal('-50.00')
|
||||
# T2: Acc USD1, Food, Expense 20 (unpaid/projected) -> account_usd_1, projected_expenses = -20
|
||||
expected_projected_expenses[label_to_idx[self.account_usd_1.name]] = Decimal('-20.00')
|
||||
# T3: Acc USD2, Food, Expense 30 (paid) -> account_usd_2, current_expenses = -30
|
||||
expected_current_expenses[label_to_idx[self.account_usd_2.name]] = Decimal('-30.00')
|
||||
# T5: Acc EUR1, Food, Expense 40 (paid) -> account_eur_1, current_expenses = -40
|
||||
expected_current_expenses[label_to_idx[self.account_eur_1.name]] = Decimal('-40.00')
|
||||
|
||||
self.assertEqual(result['datasets'][0]['data'], [float(x) for x in expected_current_income]) # Current Income
|
||||
self.assertEqual(result['datasets'][1]['data'], [float(x) for x in expected_current_expenses]) # Current Expenses
|
||||
self.assertEqual(result['datasets'][2]['data'], [float(x) for x in expected_projected_income]) # Projected Income
|
||||
self.assertEqual(result['datasets'][3]['data'], [float(x) for x in expected_projected_expenses]) # Projected Expenses
|
||||
|
||||
self.assertEqual(result['datasets'][0]['label'], "Current Income")
|
||||
self.assertEqual(result['datasets'][1]['label'], "Current Expenses")
|
||||
self.assertEqual(result['datasets'][2]['label'], "Projected Income")
|
||||
self.assertEqual(result['datasets'][3]['label'], "Projected Expenses")
|
||||
|
||||
def test_generate_sankey_data_by_account(self):
|
||||
qs = Transaction.objects.filter(owner=self.user)
|
||||
result = generate_sankey_data_by_account(qs)
|
||||
|
||||
nodes = result['nodes']
|
||||
flows = result['flows']
|
||||
|
||||
# Helper to find a node by a unique part of its ID
|
||||
def find_node_by_id_part(id_part):
|
||||
found_nodes = [n for n in nodes if id_part in n['id']]
|
||||
self.assertEqual(len(found_nodes), 1, f"Node with ID part '{id_part}' not found or not unique. Found: {found_nodes}")
|
||||
return found_nodes[0]
|
||||
|
||||
# Helper to find a flow by unique parts of its source and target node IDs
|
||||
def find_flow_by_node_id_parts(from_id_part, to_id_part):
|
||||
found_flows = [
|
||||
f for f in flows
|
||||
if from_id_part in f['from_node'] and to_id_part in f['to_node']
|
||||
]
|
||||
self.assertEqual(len(found_flows), 1, f"Flow from '{from_id_part}' to '{to_id_part}' not found or not unique. Found: {found_flows}")
|
||||
return found_flows[0]
|
||||
|
||||
# Calculate total volumes by currency (sum of absolute amounts of ALL transactions)
|
||||
total_volume_usd = sum(abs(t.amount) for t in qs if t.account.currency == self.currency_usd) # 50+20+30+1000+200 = 1300
|
||||
total_volume_eur = sum(abs(t.amount) for t in qs if t.account.currency == self.currency_eur) # 40
|
||||
|
||||
self.assertEqual(total_volume_usd, Decimal('1300.00'))
|
||||
self.assertEqual(total_volume_eur, Decimal('40.00'))
|
||||
|
||||
# --- Assertions for Account USD 1 ---
|
||||
acc_usd_1_id_part = f"_{self.account_usd_1.id}"
|
||||
|
||||
node_income_salary_usd1 = find_node_by_id_part(f"income_{self.category_salary.name.lower()}{acc_usd_1_id_part}")
|
||||
self.assertEqual(node_income_salary_usd1['name'], self.category_salary.name)
|
||||
|
||||
node_account_usd1 = find_node_by_id_part(f"account_{self.account_usd_1.name.lower().replace(' ', '_')}{acc_usd_1_id_part}")
|
||||
self.assertEqual(node_account_usd1['name'], self.account_usd_1.name)
|
||||
|
||||
node_expense_food_usd1 = find_node_by_id_part(f"expense_{self.category_food.name.lower()}{acc_usd_1_id_part}")
|
||||
self.assertEqual(node_expense_food_usd1['name'], self.category_food.name)
|
||||
|
||||
node_saved_usd1 = find_node_by_id_part(f"savings_saved{acc_usd_1_id_part}")
|
||||
self.assertEqual(node_saved_usd1['name'], _("Saved"))
|
||||
|
||||
# Flow 1: Salary (T4) to account_usd_1
|
||||
flow_salary_to_usd1 = find_flow_by_node_id_parts(node_income_salary_usd1['id'], node_account_usd1['id'])
|
||||
self.assertEqual(flow_salary_to_usd1['original_amount'], 1000.0)
|
||||
self.assertEqual(flow_salary_to_usd1['currency']['code'], self.currency_usd.code)
|
||||
self.assertAlmostEqual(flow_salary_to_usd1['percentage'], (1000.0 / float(total_volume_usd)) * 100, places=2)
|
||||
self.assertAlmostEqual(flow_salary_to_usd1['flow'], (1000.0 / float(total_volume_usd)), places=4)
|
||||
|
||||
# Flow 2: account_usd_1 to Food (T1)
|
||||
flow_usd1_to_food = find_flow_by_node_id_parts(node_account_usd1['id'], node_expense_food_usd1['id'])
|
||||
self.assertEqual(flow_usd1_to_food['original_amount'], 50.0) # T1 is 50
|
||||
self.assertEqual(flow_usd1_to_food['currency']['code'], self.currency_usd.code)
|
||||
self.assertAlmostEqual(flow_usd1_to_food['percentage'], (50.0 / float(total_volume_usd)) * 100, places=2)
|
||||
|
||||
# Flow 3: account_usd_1 to Saved
|
||||
# Net paid for account_usd_1: 1000 (T4 income) - 50 (T1 expense) = 950
|
||||
flow_usd1_to_saved = find_flow_by_node_id_parts(node_account_usd1['id'], node_saved_usd1['id'])
|
||||
self.assertEqual(flow_usd1_to_saved['original_amount'], 950.0)
|
||||
self.assertEqual(flow_usd1_to_saved['currency']['code'], self.currency_usd.code)
|
||||
self.assertAlmostEqual(flow_usd1_to_saved['percentage'], (950.0 / float(total_volume_usd)) * 100, places=2)
|
||||
|
||||
# --- Assertions for Account USD 2 ---
|
||||
acc_usd_2_id_part = f"_{self.account_usd_2.id}"
|
||||
node_account_usd2 = find_node_by_id_part(f"account_{self.account_usd_2.name.lower().replace(' ', '_')}{acc_usd_2_id_part}")
|
||||
node_expense_food_usd2 = find_node_by_id_part(f"expense_{self.category_food.name.lower()}{acc_usd_2_id_part}")
|
||||
# T6 (Salary for USD2) is unpaid, so no income node/flow for it.
|
||||
# Net paid for account_usd_2 is -30 (T3 expense). So no "Saved" node.
|
||||
|
||||
# Flow: account_usd_2 to Food (T3)
|
||||
flow_usd2_to_food = find_flow_by_node_id_parts(node_account_usd2['id'], node_expense_food_usd2['id'])
|
||||
self.assertEqual(flow_usd2_to_food['original_amount'], 30.0) # T3 is 30
|
||||
self.assertEqual(flow_usd2_to_food['currency']['code'], self.currency_usd.code)
|
||||
self.assertAlmostEqual(flow_usd2_to_food['percentage'], (30.0 / float(total_volume_usd)) * 100, places=2)
|
||||
|
||||
# Check no "Saved" node for account_usd_2
|
||||
saved_nodes_usd2 = [n for n in nodes if f"savings_saved{acc_usd_2_id_part}" in n['id']]
|
||||
self.assertEqual(len(saved_nodes_usd2), 0, "Should be no 'Saved' node for account_usd_2 as net is negative.")
|
||||
|
||||
# --- Assertions for Account EUR 1 ---
|
||||
acc_eur_1_id_part = f"_{self.account_eur_1.id}"
|
||||
node_account_eur1 = find_node_by_id_part(f"account_{self.account_eur_1.name.lower().replace(' ', '_')}{acc_eur_1_id_part}")
|
||||
node_expense_food_eur1 = find_node_by_id_part(f"expense_{self.category_food.name.lower()}{acc_eur_1_id_part}")
|
||||
# Net paid for account_eur_1 is -40 (T5 expense). No "Saved" node.
|
||||
|
||||
# Flow: account_eur_1 to Food (T5)
|
||||
flow_eur1_to_food = find_flow_by_node_id_parts(node_account_eur1['id'], node_expense_food_eur1['id'])
|
||||
self.assertEqual(flow_eur1_to_food['original_amount'], 40.0) # T5 is 40
|
||||
self.assertEqual(flow_eur1_to_food['currency']['code'], self.currency_eur.code)
|
||||
self.assertAlmostEqual(flow_eur1_to_food['percentage'], (40.0 / float(total_volume_eur)) * 100, places=2) # (40/40)*100 = 100%
|
||||
|
||||
# Check no "Saved" node for account_eur_1
|
||||
saved_nodes_eur1 = [n for n in nodes if f"savings_saved{acc_eur_1_id_part}" in n['id']]
|
||||
self.assertEqual(len(saved_nodes_eur1), 0, "Should be no 'Saved' node for account_eur_1 as net is negative.")
|
||||
|
||||
def test_get_category_sums_by_currency_for_food(self):
|
||||
qs = Transaction.objects.filter(owner=self.user)
|
||||
result = get_category_sums_by_currency(qs, category=self.category_food)
|
||||
|
||||
expected_labels = sorted([self.currency_eur.name, self.currency_usd.name])
|
||||
self.assertEqual(result['labels'], expected_labels)
|
||||
|
||||
label_to_idx = {name: i for i, name in enumerate(expected_labels)}
|
||||
num_labels = len(expected_labels)
|
||||
|
||||
expected_current_income = [Decimal('0.00')] * num_labels
|
||||
expected_current_expenses = [Decimal('0.00')] * num_labels
|
||||
expected_projected_income = [Decimal('0.00')] * num_labels
|
||||
expected_projected_expenses = [Decimal('0.00')] * num_labels
|
||||
|
||||
# Food Transactions:
|
||||
# T1: USD Account 1, Food, Expense 50 (paid)
|
||||
# T2: USD Account 1, Food, Expense 20 (unpaid/projected)
|
||||
# T3: USD Account 2, Food, Expense 30 (paid)
|
||||
# T5: EUR Account 1, Food, Expense 40 (paid)
|
||||
|
||||
# Current Expenses:
|
||||
expected_current_expenses[label_to_idx[self.currency_eur.name]] = Decimal('-40.00') # T5
|
||||
expected_current_expenses[label_to_idx[self.currency_usd.name]] = Decimal('-50.00') + Decimal('-30.00') # T1 + T3
|
||||
|
||||
# Projected Expenses:
|
||||
expected_projected_expenses[label_to_idx[self.currency_usd.name]] = Decimal('-20.00') # T2
|
||||
|
||||
self.assertEqual(result['datasets'][0]['data'], [float(x) for x in expected_current_income])
|
||||
self.assertEqual(result['datasets'][1]['data'], [float(x) for x in expected_current_expenses])
|
||||
self.assertEqual(result['datasets'][2]['data'], [float(x) for x in expected_projected_income])
|
||||
self.assertEqual(result['datasets'][3]['data'], [float(x) for x in expected_projected_expenses])
|
||||
|
||||
self.assertEqual(result['datasets'][0]['label'], "Current Income")
|
||||
self.assertEqual(result['datasets'][1]['label'], "Current Expenses")
|
||||
self.assertEqual(result['datasets'][2]['label'], "Projected Income")
|
||||
self.assertEqual(result['datasets'][3]['label'], "Projected Expenses")
|
||||
|
||||
def test_get_category_sums_by_currency_for_salary(self):
|
||||
qs = Transaction.objects.filter(owner=self.user)
|
||||
result = get_category_sums_by_currency(qs, category=self.category_salary)
|
||||
|
||||
# Salary Transactions:
|
||||
# T4: USD Account 1, Salary, Income 1000 (paid)
|
||||
# T6: USD Account 2, Salary, Income 200 (unpaid/projected)
|
||||
# All are USD
|
||||
expected_labels = [self.currency_usd.name] # Only USD has salary transactions
|
||||
self.assertEqual(result['labels'], expected_labels)
|
||||
|
||||
label_to_idx = {name: i for i, name in enumerate(expected_labels)}
|
||||
num_labels = len(expected_labels)
|
||||
|
||||
expected_current_income = [Decimal('0.00')] * num_labels
|
||||
expected_current_expenses = [Decimal('0.00')] * num_labels
|
||||
expected_projected_income = [Decimal('0.00')] * num_labels
|
||||
expected_projected_expenses = [Decimal('0.00')] * num_labels
|
||||
|
||||
# Current Income:
|
||||
expected_current_income[label_to_idx[self.currency_usd.name]] = Decimal('1000.00') # T4
|
||||
|
||||
# Projected Income:
|
||||
expected_projected_income[label_to_idx[self.currency_usd.name]] = Decimal('200.00') # T6
|
||||
|
||||
self.assertEqual(result['datasets'][0]['data'], [float(x) for x in expected_current_income])
|
||||
self.assertEqual(result['datasets'][1]['data'], [float(x) for x in expected_current_expenses])
|
||||
self.assertEqual(result['datasets'][2]['data'], [float(x) for x in expected_projected_income])
|
||||
self.assertEqual(result['datasets'][3]['data'], [float(x) for x in expected_projected_expenses])
|
||||
|
||||
self.assertEqual(result['datasets'][0]['label'], "Current Income")
|
||||
self.assertEqual(result['datasets'][1]['label'], "Current Expenses")
|
||||
self.assertEqual(result['datasets'][2]['label'], "Projected Income")
|
||||
self.assertEqual(result['datasets'][3]['label'], "Projected Expenses")
|
||||
|
||||
|
||||
def test_get_category_sums_by_account_for_salary(self):
|
||||
qs = Transaction.objects.filter(owner=self.user)
|
||||
result = get_category_sums_by_account(qs, category=self.category_salary)
|
||||
|
||||
# Only accounts with salary transactions should appear
|
||||
expected_labels = sorted([self.account_usd_1.name, self.account_usd_2.name])
|
||||
self.assertEqual(result['labels'], expected_labels)
|
||||
|
||||
label_to_idx = {name: i for i, name in enumerate(expected_labels)}
|
||||
num_labels = len(expected_labels)
|
||||
|
||||
expected_current_income = [Decimal('0.00')] * num_labels
|
||||
expected_current_expenses = [Decimal('0.00')] * num_labels
|
||||
expected_projected_income = [Decimal('0.00')] * num_labels
|
||||
expected_projected_expenses = [Decimal('0.00')] * num_labels
|
||||
|
||||
# Populate expected data based on transactions for SALARY category
|
||||
# T4: Acc USD1, Salary, Income 1000 (paid) -> account_usd_1, current_income = 1000
|
||||
expected_current_income[label_to_idx[self.account_usd_1.name]] = Decimal('1000.00')
|
||||
# T6: Acc USD2, Salary, Income 200 (unpaid/projected) -> account_usd_2, projected_income = 200
|
||||
expected_projected_income[label_to_idx[self.account_usd_2.name]] = Decimal('200.00')
|
||||
|
||||
self.assertEqual(result['datasets'][0]['data'], [float(x) for x in expected_current_income])
|
||||
self.assertEqual(result['datasets'][1]['data'], [float(x) for x in expected_current_expenses])
|
||||
self.assertEqual(result['datasets'][2]['data'], [float(x) for x in expected_projected_income])
|
||||
self.assertEqual(result['datasets'][3]['data'], [float(x) for x in expected_projected_expenses])
|
||||
|
||||
self.assertEqual(result['datasets'][0]['label'], "Current Income")
|
||||
self.assertEqual(result['datasets'][1]['label'], "Current Expenses")
|
||||
self.assertEqual(result['datasets'][2]['label'], "Projected Income")
|
||||
self.assertEqual(result['datasets'][3]['label'], "Projected Expenses")
|
||||
# Create your tests here.
|
||||
|
||||
@@ -49,4 +49,14 @@ urlpatterns = [
|
||||
views.emergency_fund,
|
||||
name="insights_emergency_fund",
|
||||
),
|
||||
path(
|
||||
"insights/year-by-year/",
|
||||
views.year_by_year,
|
||||
name="insights_year_by_year",
|
||||
),
|
||||
path(
|
||||
"insights/month-by-month/",
|
||||
views.month_by_month,
|
||||
name="insights_month_by_month",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -9,8 +9,13 @@ from apps.currencies.models import Currency
|
||||
from apps.currencies.utils.convert import convert
|
||||
|
||||
|
||||
def get_categories_totals(transactions_queryset, ignore_empty=False):
|
||||
# First get the category totals as before
|
||||
def get_categories_totals(
|
||||
transactions_queryset, ignore_empty=False, show_entities=False
|
||||
):
|
||||
# Step 1: Aggregate transaction data by category and currency.
|
||||
# This query calculates the total current and projected income/expense for each
|
||||
# category by grouping transactions and summing up their amounts based on their
|
||||
# type (income/expense) and payment status (paid/unpaid).
|
||||
category_currency_metrics = (
|
||||
transactions_queryset.values(
|
||||
"category",
|
||||
@@ -74,7 +79,10 @@ def get_categories_totals(transactions_queryset, ignore_empty=False):
|
||||
.order_by("category__name")
|
||||
)
|
||||
|
||||
# Get tag totals within each category with currency details
|
||||
# Step 2: Aggregate transaction data by tag, category, and currency.
|
||||
# This is similar to the category metrics but adds tags to the grouping,
|
||||
# allowing for a breakdown of totals by tag within each category. It also
|
||||
# handles untagged transactions, where the 'tags' field is None.
|
||||
tag_metrics = transactions_queryset.values(
|
||||
"category",
|
||||
"tags",
|
||||
@@ -129,10 +137,12 @@ def get_categories_totals(transactions_queryset, ignore_empty=False):
|
||||
),
|
||||
)
|
||||
|
||||
# Process the results to structure by category
|
||||
# Step 3: Initialize the main dictionary to structure the final results.
|
||||
# The data will be organized hierarchically: category -> currency -> tags -> entities.
|
||||
result = {}
|
||||
|
||||
# Process category totals first
|
||||
# Step 4: Process the aggregated category metrics to build the initial result structure.
|
||||
# This loop iterates through each category's metrics and populates the `result` dict.
|
||||
for metric in category_currency_metrics:
|
||||
# Skip empty categories if ignore_empty is True
|
||||
if ignore_empty and all(
|
||||
@@ -183,7 +193,7 @@ def get_categories_totals(transactions_queryset, ignore_empty=False):
|
||||
"total_final": total_final,
|
||||
}
|
||||
|
||||
# Add exchanged values if exchange_currency exists
|
||||
# Step 4a: Handle currency conversion for category totals if an exchange currency is defined.
|
||||
if metric["account__currency__exchange_currency"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
@@ -222,7 +232,7 @@ def get_categories_totals(transactions_queryset, ignore_empty=False):
|
||||
|
||||
result[category_id]["currencies"][currency_id] = currency_data
|
||||
|
||||
# Process tag totals and add them to the result, including untagged
|
||||
# Step 5: Process the aggregated tag metrics and integrate them into the result structure.
|
||||
for tag_metric in tag_metrics:
|
||||
category_id = tag_metric["category"]
|
||||
tag_id = tag_metric["tags"] # Will be None for untagged transactions
|
||||
@@ -240,6 +250,7 @@ def get_categories_totals(transactions_queryset, ignore_empty=False):
|
||||
result[category_id]["tags"][tag_key] = {
|
||||
"name": tag_name,
|
||||
"currencies": {},
|
||||
"entities": {},
|
||||
}
|
||||
|
||||
currency_id = tag_metric["account__currency"]
|
||||
@@ -278,7 +289,7 @@ def get_categories_totals(transactions_queryset, ignore_empty=False):
|
||||
"total_final": tag_total_final,
|
||||
}
|
||||
|
||||
# Add exchange currency support for tags
|
||||
# Step 5a: Handle currency conversion for tag totals.
|
||||
if tag_metric["account__currency__exchange_currency"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
@@ -319,4 +330,175 @@ def get_categories_totals(transactions_queryset, ignore_empty=False):
|
||||
currency_id
|
||||
] = tag_currency_data
|
||||
|
||||
# Step 6: If requested, aggregate and process entity-level data.
|
||||
if show_entities:
|
||||
entity_metrics = transactions_queryset.values(
|
||||
"category",
|
||||
"tags",
|
||||
"entities",
|
||||
"entities__name",
|
||||
"account__currency",
|
||||
"account__currency__code",
|
||||
"account__currency__name",
|
||||
"account__currency__decimal_places",
|
||||
"account__currency__prefix",
|
||||
"account__currency__suffix",
|
||||
"account__currency__exchange_currency",
|
||||
).annotate(
|
||||
expense_current=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(
|
||||
type=Transaction.Type.EXPENSE, is_paid=True, then="amount"
|
||||
),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
expense_projected=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(
|
||||
type=Transaction.Type.EXPENSE, is_paid=False, then="amount"
|
||||
),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
income_current=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.INCOME, is_paid=True, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
income_projected=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(
|
||||
type=Transaction.Type.INCOME, is_paid=False, then="amount"
|
||||
),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
)
|
||||
|
||||
for entity_metric in entity_metrics:
|
||||
category_id = entity_metric["category"]
|
||||
tag_id = entity_metric["tags"]
|
||||
entity_id = entity_metric["entities"]
|
||||
|
||||
if category_id in result:
|
||||
tag_key = tag_id if tag_id is not None else "untagged"
|
||||
if tag_key in result[category_id]["tags"]:
|
||||
entity_key = entity_id if entity_id is not None else "no_entity"
|
||||
entity_name = (
|
||||
entity_metric["entities__name"]
|
||||
if entity_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if "entities" not in result[category_id]["tags"][tag_key]:
|
||||
result[category_id]["tags"][tag_key]["entities"] = {}
|
||||
|
||||
if (
|
||||
entity_key
|
||||
not in result[category_id]["tags"][tag_key]["entities"]
|
||||
):
|
||||
result[category_id]["tags"][tag_key]["entities"][entity_key] = {
|
||||
"name": entity_name,
|
||||
"currencies": {},
|
||||
}
|
||||
|
||||
currency_id = entity_metric["account__currency"]
|
||||
|
||||
entity_total_current = (
|
||||
entity_metric["income_current"]
|
||||
- entity_metric["expense_current"]
|
||||
)
|
||||
entity_total_projected = (
|
||||
entity_metric["income_projected"]
|
||||
- entity_metric["expense_projected"]
|
||||
)
|
||||
entity_total_income = (
|
||||
entity_metric["income_current"]
|
||||
+ entity_metric["income_projected"]
|
||||
)
|
||||
entity_total_expense = (
|
||||
entity_metric["expense_current"]
|
||||
+ entity_metric["expense_projected"]
|
||||
)
|
||||
entity_total_final = entity_total_current + entity_total_projected
|
||||
|
||||
entity_currency_data = {
|
||||
"currency": {
|
||||
"code": entity_metric["account__currency__code"],
|
||||
"name": entity_metric["account__currency__name"],
|
||||
"decimal_places": entity_metric[
|
||||
"account__currency__decimal_places"
|
||||
],
|
||||
"prefix": entity_metric["account__currency__prefix"],
|
||||
"suffix": entity_metric["account__currency__suffix"],
|
||||
},
|
||||
"expense_current": entity_metric["expense_current"],
|
||||
"expense_projected": entity_metric["expense_projected"],
|
||||
"total_expense": entity_total_expense,
|
||||
"income_current": entity_metric["income_current"],
|
||||
"income_projected": entity_metric["income_projected"],
|
||||
"total_income": entity_total_income,
|
||||
"total_current": entity_total_current,
|
||||
"total_projected": entity_total_projected,
|
||||
"total_final": entity_total_final,
|
||||
}
|
||||
|
||||
if entity_metric["account__currency__exchange_currency"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=entity_metric["account__currency__exchange_currency"]
|
||||
)
|
||||
|
||||
exchanged = {}
|
||||
for field in [
|
||||
"expense_current",
|
||||
"expense_projected",
|
||||
"income_current",
|
||||
"income_projected",
|
||||
"total_income",
|
||||
"total_expense",
|
||||
"total_current",
|
||||
"total_projected",
|
||||
"total_final",
|
||||
]:
|
||||
amount, prefix, suffix, decimal_places = convert(
|
||||
amount=entity_currency_data[field],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if amount is not None:
|
||||
exchanged[field] = amount
|
||||
if "currency" not in exchanged:
|
||||
exchanged["currency"] = {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
}
|
||||
if exchanged:
|
||||
entity_currency_data["exchanged"] = exchanged
|
||||
|
||||
result[category_id]["tags"][tag_key]["entities"][entity_key][
|
||||
"currencies"
|
||||
][currency_id] = entity_currency_data
|
||||
|
||||
return result
|
||||
|
||||
316
app/apps/insights/utils/month_by_month.py
Normal file
316
app/apps/insights/utils/month_by_month.py
Normal file
@@ -0,0 +1,316 @@
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import Sum, Case, When, Value
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.currencies.models import Currency
|
||||
from apps.currencies.utils.convert import convert
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
def get_month_by_month_data(year=None, group_by="categories"):
|
||||
"""
|
||||
Aggregate transaction totals by month for a specific year, grouped by categories, tags, or entities.
|
||||
|
||||
Args:
|
||||
year: The year to filter transactions (defaults to current year)
|
||||
group_by: One of "categories", "tags", or "entities"
|
||||
|
||||
Returns:
|
||||
{
|
||||
"year": 2025,
|
||||
"available_years": [2025, 2024, ...],
|
||||
"months": [1, 2, 3, ..., 12],
|
||||
"items": {
|
||||
item_id: {
|
||||
"name": "Item Name",
|
||||
"month_totals": {
|
||||
1: {"currencies": {...}},
|
||||
...
|
||||
},
|
||||
"total": {"currencies": {...}}
|
||||
},
|
||||
...
|
||||
},
|
||||
"month_totals": {...},
|
||||
"grand_total": {"currencies": {...}}
|
||||
}
|
||||
"""
|
||||
if year is None:
|
||||
year = timezone.localdate(timezone.now()).year
|
||||
|
||||
# Base queryset - all paid transactions, non-muted
|
||||
transactions = Transaction.objects.filter(
|
||||
is_paid=True,
|
||||
account__is_archived=False,
|
||||
).exclude(account__currency__is_archived=True)
|
||||
|
||||
# Get available years for the selector
|
||||
available_years = list(
|
||||
transactions.values_list("reference_date__year", flat=True)
|
||||
.distinct()
|
||||
.order_by("-reference_date__year")
|
||||
)
|
||||
|
||||
# Filter by the selected year
|
||||
transactions = transactions.filter(reference_date__year=year)
|
||||
|
||||
# Define grouping fields based on group_by parameter
|
||||
if group_by == "tags":
|
||||
group_field = "tags"
|
||||
name_field = "tags__name"
|
||||
elif group_by == "entities":
|
||||
group_field = "entities"
|
||||
name_field = "entities__name"
|
||||
else: # Default to categories
|
||||
group_field = "category"
|
||||
name_field = "category__name"
|
||||
|
||||
# Months 1-12
|
||||
months = list(range(1, 13))
|
||||
|
||||
if not available_years:
|
||||
return {
|
||||
"year": year,
|
||||
"available_years": [],
|
||||
"months": months,
|
||||
"items": {},
|
||||
"month_totals": {},
|
||||
"grand_total": {"currencies": {}},
|
||||
}
|
||||
|
||||
# Aggregate by group, month, and currency
|
||||
metrics = (
|
||||
transactions.values(
|
||||
group_field,
|
||||
name_field,
|
||||
"reference_date__month",
|
||||
"account__currency",
|
||||
"account__currency__code",
|
||||
"account__currency__name",
|
||||
"account__currency__decimal_places",
|
||||
"account__currency__prefix",
|
||||
"account__currency__suffix",
|
||||
"account__currency__exchange_currency",
|
||||
)
|
||||
.annotate(
|
||||
expense_total=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.EXPENSE, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
income_total=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.INCOME, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
)
|
||||
.order_by(name_field, "reference_date__month")
|
||||
)
|
||||
|
||||
# Build result structure
|
||||
result = {
|
||||
"year": year,
|
||||
"available_years": available_years,
|
||||
"months": months,
|
||||
"items": OrderedDict(),
|
||||
"month_totals": {},
|
||||
"grand_total": {"currencies": {}},
|
||||
}
|
||||
|
||||
# Store currency info for later use in totals
|
||||
currency_info = {}
|
||||
|
||||
for metric in metrics:
|
||||
item_id = metric[group_field]
|
||||
item_name = metric[name_field]
|
||||
month = metric["reference_date__month"]
|
||||
currency_id = metric["account__currency"]
|
||||
|
||||
# Use a consistent key for None (uncategorized/untagged/no entity)
|
||||
item_key = item_id if item_id is not None else "__none__"
|
||||
|
||||
if item_key not in result["items"]:
|
||||
result["items"][item_key] = {
|
||||
"name": item_name,
|
||||
"month_totals": {},
|
||||
"total": {"currencies": {}},
|
||||
}
|
||||
|
||||
if month not in result["items"][item_key]["month_totals"]:
|
||||
result["items"][item_key]["month_totals"][month] = {"currencies": {}}
|
||||
|
||||
# Calculate final total (income - expense)
|
||||
final_total = metric["income_total"] - metric["expense_total"]
|
||||
|
||||
# Store currency info for totals calculation
|
||||
if currency_id not in currency_info:
|
||||
currency_info[currency_id] = {
|
||||
"code": metric["account__currency__code"],
|
||||
"name": metric["account__currency__name"],
|
||||
"decimal_places": metric["account__currency__decimal_places"],
|
||||
"prefix": metric["account__currency__prefix"],
|
||||
"suffix": metric["account__currency__suffix"],
|
||||
"exchange_currency_id": metric["account__currency__exchange_currency"],
|
||||
}
|
||||
|
||||
currency_data = {
|
||||
"currency": {
|
||||
"code": metric["account__currency__code"],
|
||||
"name": metric["account__currency__name"],
|
||||
"decimal_places": metric["account__currency__decimal_places"],
|
||||
"prefix": metric["account__currency__prefix"],
|
||||
"suffix": metric["account__currency__suffix"],
|
||||
},
|
||||
"final_total": final_total,
|
||||
"income_total": metric["income_total"],
|
||||
"expense_total": metric["expense_total"],
|
||||
}
|
||||
|
||||
# Handle currency conversion if exchange currency is set
|
||||
if metric["account__currency__exchange_currency"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=metric["account__currency__exchange_currency"]
|
||||
)
|
||||
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=final_total,
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
|
||||
if converted_amount is not None:
|
||||
currency_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
result["items"][item_key]["month_totals"][month]["currencies"][currency_id] = (
|
||||
currency_data
|
||||
)
|
||||
|
||||
# Accumulate item total (across all months for this item)
|
||||
if currency_id not in result["items"][item_key]["total"]["currencies"]:
|
||||
result["items"][item_key]["total"]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["items"][item_key]["total"]["currencies"][currency_id][
|
||||
"final_total"
|
||||
] += final_total
|
||||
|
||||
# Accumulate month total (across all items for this month)
|
||||
if month not in result["month_totals"]:
|
||||
result["month_totals"][month] = {"currencies": {}}
|
||||
if currency_id not in result["month_totals"][month]["currencies"]:
|
||||
result["month_totals"][month]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["month_totals"][month]["currencies"][currency_id]["final_total"] += (
|
||||
final_total
|
||||
)
|
||||
|
||||
# Accumulate grand total
|
||||
if currency_id not in result["grand_total"]["currencies"]:
|
||||
result["grand_total"]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["grand_total"]["currencies"][currency_id]["final_total"] += final_total
|
||||
|
||||
# Add currency conversion for item totals
|
||||
for item_key, item_data in result["items"].items():
|
||||
for currency_id, total_data in item_data["total"]["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Add currency conversion for month totals
|
||||
for month, month_data in result["month_totals"].items():
|
||||
for currency_id, total_data in month_data["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Add currency conversion for grand total
|
||||
for currency_id, total_data in result["grand_total"]["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -13,7 +13,9 @@ from apps.insights.forms import (
|
||||
)
|
||||
|
||||
|
||||
def get_transactions(request, include_unpaid=True, include_silent=False):
|
||||
def get_transactions(
|
||||
request, include_unpaid=True, include_silent=False, include_untracked_accounts=False
|
||||
):
|
||||
transactions = Transaction.objects.all()
|
||||
|
||||
filter_type = request.GET.get("type", None)
|
||||
@@ -91,6 +93,15 @@ def get_transactions(request, include_unpaid=True, include_silent=False):
|
||||
transactions = transactions.filter(is_paid=True)
|
||||
|
||||
if not include_silent:
|
||||
transactions = transactions.exclude(Q(category__mute=True) & ~Q(category=None))
|
||||
transactions = transactions.exclude(
|
||||
Q(Q(category__mute=True) & ~Q(category=None)) | Q(mute=True)
|
||||
)
|
||||
|
||||
if not include_untracked_accounts:
|
||||
transactions = transactions.exclude(
|
||||
account__in=request.user.untracked_accounts.all()
|
||||
)
|
||||
|
||||
transactions = transactions.exclude(account__currency__is_archived=True)
|
||||
|
||||
return transactions
|
||||
|
||||
303
app/apps/insights/utils/year_by_year.py
Normal file
303
app/apps/insights/utils/year_by_year.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import Sum, Case, When, Value
|
||||
from django.db.models.functions import Coalesce
|
||||
|
||||
from apps.currencies.models import Currency
|
||||
from apps.currencies.utils.convert import convert
|
||||
from apps.transactions.models import Transaction
|
||||
|
||||
|
||||
def get_year_by_year_data(group_by="categories"):
|
||||
"""
|
||||
Aggregate transaction totals by year for categories, tags, or entities.
|
||||
|
||||
Args:
|
||||
group_by: One of "categories", "tags", or "entities"
|
||||
|
||||
Returns:
|
||||
{
|
||||
"years": [2025, 2024, ...], # Sorted descending
|
||||
"items": {
|
||||
item_id: {
|
||||
"name": "Item Name",
|
||||
"year_totals": {
|
||||
2025: {"currencies": {...}},
|
||||
...
|
||||
},
|
||||
"total": {"currencies": {...}} # Sum across all years
|
||||
},
|
||||
...
|
||||
},
|
||||
"year_totals": { # Sum across all items for each year
|
||||
2025: {"currencies": {...}},
|
||||
...
|
||||
},
|
||||
"grand_total": {"currencies": {...}} # Sum of everything
|
||||
}
|
||||
"""
|
||||
# Base queryset - all paid transactions, non-muted
|
||||
transactions = Transaction.objects.filter(
|
||||
is_paid=True,
|
||||
account__is_archived=False,
|
||||
).exclude(account__currency__is_archived=True)
|
||||
|
||||
# Define grouping fields based on group_by parameter
|
||||
if group_by == "tags":
|
||||
group_field = "tags"
|
||||
name_field = "tags__name"
|
||||
elif group_by == "entities":
|
||||
group_field = "entities"
|
||||
name_field = "entities__name"
|
||||
else: # Default to categories
|
||||
group_field = "category"
|
||||
name_field = "category__name"
|
||||
|
||||
# Get all unique years with transactions
|
||||
years = (
|
||||
transactions.values_list("reference_date__year", flat=True)
|
||||
.distinct()
|
||||
.order_by("-reference_date__year")
|
||||
)
|
||||
years = list(years)
|
||||
|
||||
if not years:
|
||||
return {
|
||||
"years": [],
|
||||
"items": {},
|
||||
"year_totals": {},
|
||||
"grand_total": {"currencies": {}},
|
||||
}
|
||||
|
||||
# Aggregate by group, year, and currency
|
||||
metrics = (
|
||||
transactions.values(
|
||||
group_field,
|
||||
name_field,
|
||||
"reference_date__year",
|
||||
"account__currency",
|
||||
"account__currency__code",
|
||||
"account__currency__name",
|
||||
"account__currency__decimal_places",
|
||||
"account__currency__prefix",
|
||||
"account__currency__suffix",
|
||||
"account__currency__exchange_currency",
|
||||
)
|
||||
.annotate(
|
||||
expense_total=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.EXPENSE, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
income_total=Coalesce(
|
||||
Sum(
|
||||
Case(
|
||||
When(type=Transaction.Type.INCOME, then="amount"),
|
||||
default=Value(0),
|
||||
output_field=models.DecimalField(),
|
||||
)
|
||||
),
|
||||
Decimal("0"),
|
||||
),
|
||||
)
|
||||
.order_by(name_field, "-reference_date__year")
|
||||
)
|
||||
|
||||
# Build result structure
|
||||
result = {
|
||||
"years": years,
|
||||
"items": OrderedDict(),
|
||||
"year_totals": {}, # Totals per year across all items
|
||||
"grand_total": {"currencies": {}}, # Grand total across everything
|
||||
}
|
||||
|
||||
# Store currency info for later use in totals
|
||||
currency_info = {}
|
||||
|
||||
for metric in metrics:
|
||||
item_id = metric[group_field]
|
||||
item_name = metric[name_field]
|
||||
year = metric["reference_date__year"]
|
||||
currency_id = metric["account__currency"]
|
||||
|
||||
# Use a consistent key for None (uncategorized/untagged/no entity)
|
||||
item_key = item_id if item_id is not None else "__none__"
|
||||
|
||||
if item_key not in result["items"]:
|
||||
result["items"][item_key] = {
|
||||
"name": item_name,
|
||||
"year_totals": {},
|
||||
"total": {"currencies": {}}, # Total for this item across all years
|
||||
}
|
||||
|
||||
if year not in result["items"][item_key]["year_totals"]:
|
||||
result["items"][item_key]["year_totals"][year] = {"currencies": {}}
|
||||
|
||||
# Calculate final total (income - expense)
|
||||
final_total = metric["income_total"] - metric["expense_total"]
|
||||
|
||||
# Store currency info for totals calculation
|
||||
if currency_id not in currency_info:
|
||||
currency_info[currency_id] = {
|
||||
"code": metric["account__currency__code"],
|
||||
"name": metric["account__currency__name"],
|
||||
"decimal_places": metric["account__currency__decimal_places"],
|
||||
"prefix": metric["account__currency__prefix"],
|
||||
"suffix": metric["account__currency__suffix"],
|
||||
"exchange_currency_id": metric["account__currency__exchange_currency"],
|
||||
}
|
||||
|
||||
currency_data = {
|
||||
"currency": {
|
||||
"code": metric["account__currency__code"],
|
||||
"name": metric["account__currency__name"],
|
||||
"decimal_places": metric["account__currency__decimal_places"],
|
||||
"prefix": metric["account__currency__prefix"],
|
||||
"suffix": metric["account__currency__suffix"],
|
||||
},
|
||||
"final_total": final_total,
|
||||
"income_total": metric["income_total"],
|
||||
"expense_total": metric["expense_total"],
|
||||
}
|
||||
|
||||
# Handle currency conversion if exchange currency is set
|
||||
if metric["account__currency__exchange_currency"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=metric["account__currency__exchange_currency"]
|
||||
)
|
||||
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=final_total,
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
|
||||
if converted_amount is not None:
|
||||
currency_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
result["items"][item_key]["year_totals"][year]["currencies"][currency_id] = (
|
||||
currency_data
|
||||
)
|
||||
|
||||
# Accumulate item total (across all years for this item)
|
||||
if currency_id not in result["items"][item_key]["total"]["currencies"]:
|
||||
result["items"][item_key]["total"]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["items"][item_key]["total"]["currencies"][currency_id][
|
||||
"final_total"
|
||||
] += final_total
|
||||
|
||||
# Accumulate year total (across all items for this year)
|
||||
if year not in result["year_totals"]:
|
||||
result["year_totals"][year] = {"currencies": {}}
|
||||
if currency_id not in result["year_totals"][year]["currencies"]:
|
||||
result["year_totals"][year]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["year_totals"][year]["currencies"][currency_id]["final_total"] += (
|
||||
final_total
|
||||
)
|
||||
|
||||
# Accumulate grand total
|
||||
if currency_id not in result["grand_total"]["currencies"]:
|
||||
result["grand_total"]["currencies"][currency_id] = {
|
||||
"currency": currency_data["currency"].copy(),
|
||||
"final_total": Decimal("0"),
|
||||
}
|
||||
result["grand_total"]["currencies"][currency_id]["final_total"] += final_total
|
||||
|
||||
# Add currency conversion for item totals
|
||||
for item_key, item_data in result["items"].items():
|
||||
for currency_id, total_data in item_data["total"]["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Add currency conversion for year totals
|
||||
for year, year_data in result["year_totals"].items():
|
||||
for currency_id, total_data in year_data["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
# Add currency conversion for grand total
|
||||
for currency_id, total_data in result["grand_total"]["currencies"].items():
|
||||
if currency_info[currency_id]["exchange_currency_id"]:
|
||||
from_currency = Currency.objects.get(id=currency_id)
|
||||
exchange_currency = Currency.objects.get(
|
||||
id=currency_info[currency_id]["exchange_currency_id"]
|
||||
)
|
||||
converted_amount, prefix, suffix, decimal_places = convert(
|
||||
amount=total_data["final_total"],
|
||||
from_currency=from_currency,
|
||||
to_currency=exchange_currency,
|
||||
)
|
||||
if converted_amount is not None:
|
||||
total_data["exchanged"] = {
|
||||
"final_total": converted_amount,
|
||||
"currency": {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"decimal_places": decimal_places,
|
||||
"code": exchange_currency.code,
|
||||
"name": exchange_currency.name,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -26,6 +26,8 @@ from apps.insights.utils.sankey import (
|
||||
generate_sankey_data_by_currency,
|
||||
)
|
||||
from apps.insights.utils.transactions import get_transactions
|
||||
from apps.insights.utils.year_by_year import get_year_by_year_data
|
||||
from apps.insights.utils.month_by_month import get_month_by_month_data
|
||||
from apps.transactions.models import TransactionCategory, Transaction
|
||||
from apps.transactions.utils.calculations import calculate_currency_totals
|
||||
|
||||
@@ -74,7 +76,9 @@ def index(request):
|
||||
def sankey_by_account(request):
|
||||
# Get filtered transactions
|
||||
|
||||
transactions = get_transactions(request)
|
||||
transactions = get_transactions(
|
||||
request, include_untracked_accounts=True, include_silent=True
|
||||
)
|
||||
|
||||
# Generate Sankey data
|
||||
sankey_data = generate_sankey_data_by_account(transactions)
|
||||
@@ -91,7 +95,9 @@ def sankey_by_account(request):
|
||||
@require_http_methods(["GET"])
|
||||
def sankey_by_currency(request):
|
||||
# Get filtered transactions
|
||||
transactions = get_transactions(request)
|
||||
transactions = get_transactions(
|
||||
request, include_silent=True, include_untracked_accounts=True
|
||||
)
|
||||
|
||||
# Generate Sankey data
|
||||
sankey_data = generate_sankey_data_by_currency(transactions)
|
||||
@@ -180,6 +186,14 @@ def category_overview(request):
|
||||
else:
|
||||
show_tags = request.session.get("insights_category_explorer_show_tags", True)
|
||||
|
||||
if "show_entities" in request.GET:
|
||||
show_entities = request.GET["show_entities"] == "on"
|
||||
request.session["insights_category_explorer_show_entities"] = show_entities
|
||||
else:
|
||||
show_entities = request.session.get(
|
||||
"insights_category_explorer_show_entities", False
|
||||
)
|
||||
|
||||
if "showing" in request.GET:
|
||||
showing = request.GET["showing"]
|
||||
request.session["insights_category_explorer_showing"] = showing
|
||||
@@ -190,7 +204,9 @@ def category_overview(request):
|
||||
transactions = get_transactions(request, include_silent=True)
|
||||
|
||||
total_table = get_categories_totals(
|
||||
transactions_queryset=transactions, ignore_empty=False
|
||||
transactions_queryset=transactions,
|
||||
ignore_empty=False,
|
||||
show_entities=show_entities,
|
||||
)
|
||||
|
||||
return render(
|
||||
@@ -200,6 +216,7 @@ def category_overview(request):
|
||||
"total_table": total_table,
|
||||
"view_type": view_type,
|
||||
"show_tags": show_tags,
|
||||
"show_entities": show_entities,
|
||||
"showing": showing,
|
||||
},
|
||||
)
|
||||
@@ -239,10 +256,14 @@ def late_transactions(request):
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def emergency_fund(request):
|
||||
transactions_currency_queryset = Transaction.objects.filter(
|
||||
is_paid=True, account__is_archived=False, account__is_asset=False
|
||||
).order_by(
|
||||
"account__currency__name",
|
||||
transactions_currency_queryset = (
|
||||
Transaction.objects.filter(
|
||||
is_paid=True, account__is_archived=False, account__is_asset=False
|
||||
)
|
||||
.exclude(account__in=request.user.untracked_accounts.all())
|
||||
.order_by(
|
||||
"account__currency__name",
|
||||
)
|
||||
)
|
||||
currency_net_worth = calculate_currency_totals(
|
||||
transactions_queryset=transactions_currency_queryset, ignore_empty=False
|
||||
@@ -260,7 +281,9 @@ def emergency_fund(request):
|
||||
reference_date__gte=start_date,
|
||||
reference_date__lte=end_date,
|
||||
category__mute=False,
|
||||
mute=False,
|
||||
)
|
||||
.exclude(account__in=request.user.untracked_accounts.all())
|
||||
.values("reference_date", "account__currency")
|
||||
.annotate(monthly_total=Sum("amount"))
|
||||
)
|
||||
@@ -285,3 +308,71 @@ def emergency_fund(request):
|
||||
"insights/fragments/emergency_fund.html",
|
||||
{"data": currency_net_worth},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def year_by_year(request):
|
||||
if "group_by" in request.GET:
|
||||
group_by = request.GET["group_by"]
|
||||
request.session["insights_year_by_year_group_by"] = group_by
|
||||
else:
|
||||
group_by = request.session.get("insights_year_by_year_group_by", "categories")
|
||||
|
||||
# Validate group_by value
|
||||
if group_by not in ("categories", "tags", "entities"):
|
||||
group_by = "categories"
|
||||
|
||||
data = get_year_by_year_data(group_by=group_by)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"insights/fragments/year_by_year.html",
|
||||
{
|
||||
"data": data,
|
||||
"group_by": group_by,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@only_htmx
|
||||
@login_required
|
||||
@require_http_methods(["GET"])
|
||||
def month_by_month(request):
|
||||
# Handle year selection
|
||||
if "year" in request.GET:
|
||||
try:
|
||||
year = int(request.GET["year"])
|
||||
request.session["insights_month_by_month_year"] = year
|
||||
except (ValueError, TypeError):
|
||||
year = request.session.get(
|
||||
"insights_month_by_month_year", timezone.localdate(timezone.now()).year
|
||||
)
|
||||
else:
|
||||
year = request.session.get(
|
||||
"insights_month_by_month_year", timezone.localdate(timezone.now()).year
|
||||
)
|
||||
|
||||
# Handle group_by selection
|
||||
if "group_by" in request.GET:
|
||||
group_by = request.GET["group_by"]
|
||||
request.session["insights_month_by_month_group_by"] = group_by
|
||||
else:
|
||||
group_by = request.session.get("insights_month_by_month_group_by", "categories")
|
||||
|
||||
# Validate group_by value
|
||||
if group_by not in ("categories", "tags", "entities"):
|
||||
group_by = "categories"
|
||||
|
||||
data = get_month_by_month_data(year=year, group_by=group_by)
|
||||
|
||||
return render(
|
||||
request,
|
||||
"insights/fragments/month_by_month.html",
|
||||
{
|
||||
"data": data,
|
||||
"group_by": group_by,
|
||||
"selected_year": year,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,165 +1,3 @@
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import User
|
||||
from django.utils import timezone
|
||||
from unittest.mock import patch
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
from django.test import Client # Added
|
||||
from django.urls import reverse # Added
|
||||
|
||||
from apps.currencies.models import Currency, ExchangeRate
|
||||
from apps.mini_tools.utils.exchange_rate_map import get_currency_exchange_map
|
||||
|
||||
class MiniToolsUtilsTests(TestCase):
|
||||
def setUp(self):
|
||||
# User is not strictly necessary for this utility but good practice for test setup
|
||||
self.user = User.objects.create_user(username='testuser', password='password')
|
||||
|
||||
self.usd = Currency.objects.create(name="US Dollar", code="USD", decimal_places=2, prefix="$")
|
||||
self.eur = Currency.objects.create(name="Euro", code="EUR", decimal_places=2, prefix="€")
|
||||
self.gbp = Currency.objects.create(name="British Pound", code="GBP", decimal_places=2, prefix="£")
|
||||
|
||||
# USD -> EUR rates
|
||||
# Rate for 2023-01-10 (will be processed last for USD->EUR due to ordering)
|
||||
ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.90"), date=date(2023, 1, 10))
|
||||
# Rate for 2023-01-15 (closer to target_date 2023-01-16, processed first for USD->EUR)
|
||||
ExchangeRate.objects.create(from_currency=self.usd, to_currency=self.eur, rate=Decimal("0.92"), date=date(2023, 1, 15))
|
||||
|
||||
# GBP -> USD rate
|
||||
self.gbp_usd_rate = ExchangeRate.objects.create(from_currency=self.gbp, to_currency=self.usd, rate=Decimal("1.25"), date=date(2023, 1, 12))
|
||||
|
||||
def test_get_currency_exchange_map_structure_and_rates(self):
|
||||
target_date = date(2023, 1, 16)
|
||||
rate_map = get_currency_exchange_map(date=target_date)
|
||||
|
||||
# Assert USD in map
|
||||
self.assertIn("US Dollar", rate_map)
|
||||
usd_data = rate_map["US Dollar"]
|
||||
self.assertEqual(usd_data["decimal_places"], 2)
|
||||
self.assertEqual(usd_data["prefix"], "$")
|
||||
self.assertIn("rates", usd_data)
|
||||
|
||||
# USD -> EUR: Expecting rate from 2023-01-10 (0.90)
|
||||
# Query order: (USD,EUR,2023-01-15), (USD,EUR,2023-01-10)
|
||||
# Loop overwrite means the last one processed (0.90) sticks.
|
||||
self.assertIn("Euro", usd_data["rates"])
|
||||
self.assertEqual(usd_data["rates"]["Euro"]["rate"], Decimal("0.90"))
|
||||
|
||||
# USD -> GBP: Inverse of GBP->USD rate from 2023-01-12 (1.25)
|
||||
# Query for GBP->USD, date 2023-01-12, diff 4 days.
|
||||
self.assertIn("British Pound", usd_data["rates"])
|
||||
self.assertEqual(usd_data["rates"]["British Pound"]["rate"], Decimal("1") / self.gbp_usd_rate.rate)
|
||||
|
||||
# Assert EUR in map
|
||||
self.assertIn("Euro", rate_map)
|
||||
eur_data = rate_map["Euro"]
|
||||
self.assertEqual(eur_data["decimal_places"], 2)
|
||||
self.assertEqual(eur_data["prefix"], "€")
|
||||
self.assertIn("rates", eur_data)
|
||||
|
||||
# EUR -> USD: Inverse of USD->EUR rate from 2023-01-10 (0.90)
|
||||
self.assertIn("US Dollar", eur_data["rates"])
|
||||
self.assertEqual(eur_data["rates"]["US Dollar"]["rate"], Decimal("1") / Decimal("0.90"))
|
||||
|
||||
# Assert GBP in map
|
||||
self.assertIn("British Pound", rate_map)
|
||||
gbp_data = rate_map["British Pound"]
|
||||
self.assertEqual(gbp_data["decimal_places"], 2)
|
||||
self.assertEqual(gbp_data["prefix"], "£")
|
||||
self.assertIn("rates", gbp_data)
|
||||
|
||||
# GBP -> USD: Direct rate from 2023-01-12 (1.25)
|
||||
self.assertIn("US Dollar", gbp_data["rates"])
|
||||
self.assertEqual(gbp_data["rates"]["US Dollar"]["rate"], self.gbp_usd_rate.rate)
|
||||
|
||||
@patch('apps.mini_tools.utils.exchange_rate_map.timezone')
|
||||
def test_get_currency_exchange_map_uses_today_if_no_date(self, mock_django_timezone):
|
||||
# Mock timezone.localtime().date() to return a specific date
|
||||
mock_today = date(2023, 1, 16)
|
||||
mock_django_timezone.localtime.return_value.date.return_value = mock_today
|
||||
|
||||
rate_map = get_currency_exchange_map() # No date argument, should use mocked "today"
|
||||
|
||||
# Re-assert one key rate to confirm the mocked date was used.
|
||||
# Based on test_get_currency_exchange_map_structure_and_rates, with target_date 2023-01-16,
|
||||
# USD -> EUR should be 0.90.
|
||||
self.assertIn("US Dollar", rate_map)
|
||||
self.assertIn("Euro", rate_map["US Dollar"]["rates"])
|
||||
self.assertEqual(rate_map["US Dollar"]["rates"]["Euro"]["rate"], Decimal("0.90"))
|
||||
|
||||
# Verify that timezone.localtime().date() was called
|
||||
mock_django_timezone.localtime.return_value.date.assert_called_once()
|
||||
|
||||
|
||||
class MiniToolsViewTests(TestCase):
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username='viewtestuser', password='password')
|
||||
self.client = Client()
|
||||
self.client.login(username='viewtestuser', password='password')
|
||||
|
||||
self.usd = Currency.objects.create(name="US Dollar Test", code="USDTEST", decimal_places=2, prefix="$T ")
|
||||
self.eur = Currency.objects.create(name="Euro Test", code="EURTEST", decimal_places=2, prefix="€T ")
|
||||
|
||||
@patch('apps.mini_tools.views.convert')
|
||||
def test_currency_converter_convert_view_successful(self, mock_convert):
|
||||
mock_convert.return_value = (Decimal("85.00"), "€T ", "", 2) # prefix, suffix, dp
|
||||
|
||||
get_params = {
|
||||
'from_value': "100",
|
||||
'from_currency': self.usd.id,
|
||||
'to_currency': self.eur.id
|
||||
}
|
||||
response = self.client.get(reverse('mini_tools:currency_converter_convert'), data=get_params)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
mock_convert.assert_called_once()
|
||||
args, kwargs = mock_convert.call_args
|
||||
|
||||
# The view calls: convert(amount=amount_decimal, from_currency=from_currency_obj, to_currency=to_currency_obj)
|
||||
# So, these are keyword arguments.
|
||||
self.assertEqual(kwargs['amount'], Decimal('100'))
|
||||
self.assertEqual(kwargs['from_currency'], self.usd)
|
||||
self.assertEqual(kwargs['to_currency'], self.eur)
|
||||
|
||||
self.assertEqual(response.context['converted_amount'], Decimal("85.00"))
|
||||
self.assertEqual(response.context['prefix'], "€T ")
|
||||
self.assertEqual(response.context['suffix'], "")
|
||||
self.assertEqual(response.context['decimal_places'], 2)
|
||||
self.assertEqual(response.context['from_value'], "100") # Check original value passed through
|
||||
self.assertEqual(response.context['from_currency_selected'], str(self.usd.id))
|
||||
self.assertEqual(response.context['to_currency_selected'], str(self.eur.id))
|
||||
|
||||
|
||||
@patch('apps.mini_tools.views.convert')
|
||||
def test_currency_converter_convert_view_missing_params(self, mock_convert):
|
||||
get_params = {
|
||||
'from_value': "100",
|
||||
'from_currency': self.usd.id
|
||||
# 'to_currency' is missing
|
||||
}
|
||||
response = self.client.get(reverse('mini_tools:currency_converter_convert'), data=get_params)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
mock_convert.assert_not_called()
|
||||
self.assertIsNone(response.context.get('converted_amount')) # Use .get() for safety if key might be absent
|
||||
self.assertEqual(response.context['from_value'], "100")
|
||||
self.assertEqual(response.context['from_currency_selected'], str(self.usd.id))
|
||||
self.assertIsNone(response.context.get('to_currency_selected'))
|
||||
|
||||
|
||||
@patch('apps.mini_tools.views.convert')
|
||||
def test_currency_converter_convert_view_invalid_currency_id(self, mock_convert):
|
||||
get_params = {
|
||||
'from_value': "100",
|
||||
'from_currency': self.usd.id,
|
||||
'to_currency': 999 # Non-existent currency ID
|
||||
}
|
||||
response = self.client.get(reverse('mini_tools:currency_converter_convert'), data=get_params)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
mock_convert.assert_not_called()
|
||||
self.assertIsNone(response.context.get('converted_amount'))
|
||||
self.assertEqual(response.context['from_value'], "100")
|
||||
self.assertEqual(response.context['from_currency_selected'], str(self.usd.id))
|
||||
self.assertEqual(response.context['to_currency_selected'], '999') # View passes invalid ID to context
|
||||
# Create your tests here.
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
from django.test import TestCase, Client
|
||||
from django.contrib.auth.models import User
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone # Though specific dates are used, good for general test setup
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import TransactionCategory, TransactionTag, Transaction
|
||||
|
||||
class MonthlyOverviewViewTests(TestCase): # Renamed from MonthlyOverviewTestCase
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username='testmonthlyuser', password='password')
|
||||
self.client = Client()
|
||||
self.client.login(username='testmonthlyuser', password='password')
|
||||
|
||||
self.currency_usd = Currency.objects.create(name="MO USD", code="MOUSD", decimal_places=2, prefix="$MO ")
|
||||
self.account_group = AccountGroup.objects.create(name="MO Group", owner=self.user)
|
||||
self.account_usd1 = Account.objects.create(
|
||||
name="MO Account USD 1",
|
||||
currency=self.currency_usd,
|
||||
owner=self.user,
|
||||
group=self.account_group
|
||||
)
|
||||
self.category_food = TransactionCategory.objects.create(
|
||||
name="MO Food",
|
||||
owner=self.user,
|
||||
type=TransactionCategory.TransactionType.EXPENSE
|
||||
)
|
||||
self.category_salary = TransactionCategory.objects.create(
|
||||
name="MO Salary",
|
||||
owner=self.user,
|
||||
type=TransactionCategory.TransactionType.INCOME
|
||||
)
|
||||
self.tag_urgent = TransactionTag.objects.create(name="Urgent", owner=self.user)
|
||||
|
||||
# Transactions for March 2023
|
||||
self.t_food1 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_food,
|
||||
date=date(2023, 3, 5), amount=Decimal("50.00"),
|
||||
type=Transaction.Type.EXPENSE, description="Groceries March", is_paid=True
|
||||
)
|
||||
self.t_food1.tags.add(self.tag_urgent)
|
||||
|
||||
self.t_food2 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_food,
|
||||
date=date(2023, 3, 10), amount=Decimal("25.00"),
|
||||
type=Transaction.Type.EXPENSE, description="Lunch March", is_paid=True
|
||||
)
|
||||
self.t_salary1 = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_salary,
|
||||
date=date(2023, 3, 1), amount=Decimal("1000.00"),
|
||||
type=Transaction.Type.INCOME, description="March Salary", is_paid=True
|
||||
)
|
||||
# Transaction for April 2023
|
||||
self.t_april_food = Transaction.objects.create(
|
||||
owner=self.user, account=self.account_usd1, category=self.category_food,
|
||||
date=date(2023, 4, 5), amount=Decimal("30.00"),
|
||||
type=Transaction.Type.EXPENSE, description="April Groceries", is_paid=True
|
||||
)
|
||||
# URL for the main overview page for March 2023, used in the adapted test
|
||||
self.url_main_overview_march = reverse('monthly_overview:monthly_overview', kwargs={'month': 3, 'year': 2023})
|
||||
|
||||
|
||||
def test_transactions_list_no_filters(self):
|
||||
url = reverse('monthly_overview:monthly_transactions_list', kwargs={'month': 3, 'year': 2023})
|
||||
response = self.client.get(url, HTTP_HX_REQUEST='true')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
context_txns = response.context['transactions']
|
||||
self.assertIn(self.t_food1, context_txns)
|
||||
self.assertIn(self.t_food2, context_txns)
|
||||
self.assertIn(self.t_salary1, context_txns)
|
||||
self.assertNotIn(self.t_april_food, context_txns)
|
||||
self.assertEqual(len(context_txns), 3)
|
||||
|
||||
def test_transactions_list_filter_by_description(self):
|
||||
url = reverse('monthly_overview:monthly_transactions_list', kwargs={'month': 3, 'year': 2023})
|
||||
response = self.client.get(url + "?description=Groceries", HTTP_HX_REQUEST='true') # Filter for "Groceries March"
|
||||
self.assertEqual(response.status_code, 200)
|
||||
context_txns = response.context['transactions']
|
||||
self.assertIn(self.t_food1, context_txns)
|
||||
self.assertNotIn(self.t_food2, context_txns)
|
||||
self.assertNotIn(self.t_salary1, context_txns)
|
||||
self.assertEqual(len(context_txns), 1)
|
||||
|
||||
def test_transactions_list_filter_by_type_income(self):
|
||||
url = reverse('monthly_overview:monthly_transactions_list', kwargs={'month': 3, 'year': 2023})
|
||||
response = self.client.get(url + "?type=IN", HTTP_HX_REQUEST='true')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
context_txns = response.context['transactions']
|
||||
self.assertIn(self.t_salary1, context_txns)
|
||||
self.assertEqual(len(context_txns), 1)
|
||||
|
||||
def test_transactions_list_filter_by_tag(self):
|
||||
url = reverse('monthly_overview:monthly_transactions_list', kwargs={'month': 3, 'year': 2023})
|
||||
response = self.client.get(url + f"?tags={self.tag_urgent.name}", HTTP_HX_REQUEST='true')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
context_txns = response.context['transactions']
|
||||
self.assertIn(self.t_food1, context_txns)
|
||||
self.assertEqual(len(context_txns), 1)
|
||||
|
||||
def test_transactions_list_filter_by_category(self):
|
||||
url = reverse('monthly_overview:monthly_transactions_list', kwargs={'month': 3, 'year': 2023})
|
||||
response = self.client.get(url + f"?category={self.category_food.name}", HTTP_HX_REQUEST='true')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
context_txns = response.context['transactions']
|
||||
self.assertIn(self.t_food1, context_txns)
|
||||
self.assertIn(self.t_food2, context_txns)
|
||||
self.assertEqual(len(context_txns), 2)
|
||||
|
||||
def test_transactions_list_ordering_amount_desc(self):
|
||||
url = reverse('monthly_overview:monthly_transactions_list', kwargs={'month': 3, 'year': 2023})
|
||||
response = self.client.get(url + "?order=-amount", HTTP_HX_REQUEST='true')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
context_txns = list(response.context['transactions'])
|
||||
self.assertEqual(context_txns[0], self.t_salary1) # Amount 1000 (INCOME)
|
||||
self.assertEqual(context_txns[1], self.t_food1) # Amount 50 (EXPENSE)
|
||||
self.assertEqual(context_txns[2], self.t_food2) # Amount 25 (EXPENSE)
|
||||
|
||||
def test_monthly_overview_main_view_authenticated_user(self):
|
||||
# This test checks general access and basic context for the main monthly overview page.
|
||||
response = self.client.get(self.url_main_overview_march)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn('current_month_date', response.context)
|
||||
self.assertEqual(response.context['current_month_date'], date(2023,3,1))
|
||||
# Check for other expected context variables if necessary for this main view.
|
||||
# For example, if it also lists transactions or summaries directly in its initial context.
|
||||
self.assertIn('transactions_by_day', response.context) # Assuming this is part of the main view context as well
|
||||
self.assertIn('total_income_current_month', response.context)
|
||||
self.assertIn('total_expenses_current_month', response.context)
|
||||
331
app/apps/monthly_overview/tests/test_summary.py
Normal file
331
app/apps/monthly_overview/tests/test_summary.py
Normal file
@@ -0,0 +1,331 @@
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from apps.accounts.models import Account, AccountGroup
|
||||
from apps.currencies.models import Currency
|
||||
from apps.transactions.models import (
|
||||
Transaction,
|
||||
TransactionCategory,
|
||||
TransactionTag,
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
STORAGES={
|
||||
"default": {"BACKEND": "django.core.files.storage.FileSystemStorage"},
|
||||
"staticfiles": {
|
||||
"BACKEND": "django.contrib.staticfiles.storage.StaticFilesStorage"
|
||||
},
|
||||
},
|
||||
WHITENOISE_AUTOREFRESH=True,
|
||||
)
|
||||
class MonthlySummaryFilterBehaviorTests(TestCase):
|
||||
"""Tests for monthly summary views filter behavior.
|
||||
|
||||
These tests verify that:
|
||||
1. Views work correctly without any filters
|
||||
2. Views work correctly with filters applied
|
||||
3. The filter detection logic properly uses different querysets
|
||||
4. Calculated values reflect the applied filters
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data"""
|
||||
User = get_user_model()
|
||||
self.user = User.objects.create_user(
|
||||
email="testuser@test.com", password="testpass123"
|
||||
)
|
||||
self.client.login(username="testuser@test.com", password="testpass123")
|
||||
|
||||
self.currency = Currency.objects.create(
|
||||
code="USD", name="US Dollar", decimal_places=2, prefix="$ "
|
||||
)
|
||||
self.account_group = AccountGroup.objects.create(name="Test Group")
|
||||
self.account = Account.objects.create(
|
||||
name="Test Account",
|
||||
group=self.account_group,
|
||||
currency=self.currency,
|
||||
is_asset=False,
|
||||
)
|
||||
self.category = TransactionCategory.objects.create(
|
||||
name="Test Category", owner=self.user
|
||||
)
|
||||
self.tag = TransactionTag.objects.create(name="TestTag", owner=self.user)
|
||||
|
||||
# Create test transactions for December 2025
|
||||
# Income: 1000 (paid)
|
||||
self.income_transaction = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.INCOME,
|
||||
is_paid=True,
|
||||
date=date(2025, 12, 10),
|
||||
reference_date=date(2025, 12, 1),
|
||||
amount=Decimal("1000.00"),
|
||||
description="December Income",
|
||||
owner=self.user,
|
||||
)
|
||||
|
||||
# Expense: 200 (paid)
|
||||
self.expense_transaction = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
is_paid=True,
|
||||
date=date(2025, 12, 15),
|
||||
reference_date=date(2025, 12, 1),
|
||||
amount=Decimal("200.00"),
|
||||
description="December Expense",
|
||||
category=self.category,
|
||||
owner=self.user,
|
||||
)
|
||||
self.expense_transaction.tags.add(self.tag)
|
||||
|
||||
# Expense: 150 (projected/unpaid)
|
||||
self.projected_expense = Transaction.objects.create(
|
||||
account=self.account,
|
||||
type=Transaction.Type.EXPENSE,
|
||||
is_paid=False,
|
||||
date=date(2025, 12, 20),
|
||||
reference_date=date(2025, 12, 1),
|
||||
amount=Decimal("150.00"),
|
||||
description="Projected Expense",
|
||||
owner=self.user,
|
||||
)
|
||||
|
||||
def _get_currency_data(self, context_dict):
|
||||
"""Helper to extract data for our test currency from context dict.
|
||||
|
||||
The context dict is keyed by currency ID, so we need to find
|
||||
the entry for our currency.
|
||||
"""
|
||||
if not context_dict:
|
||||
return None
|
||||
for currency_id, data in context_dict.items():
|
||||
if data.get("currency", {}).get("code") == "USD":
|
||||
return data
|
||||
return None
|
||||
|
||||
# --- monthly_summary view tests ---
|
||||
|
||||
def test_monthly_summary_no_filter_returns_200(self):
|
||||
"""Test that monthly_summary returns 200 without filters"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_monthly_summary_no_filter_includes_all_transactions(self):
|
||||
"""Without filters, summary should include all transactions"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should have the income: 1000
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["income_current"], Decimal("1000.00"))
|
||||
|
||||
# expense_current should have paid expense: 200
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_current"], Decimal("200.00"))
|
||||
|
||||
# expense_projected should have unpaid expense: 150
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_projected"], Decimal("150.00"))
|
||||
|
||||
def test_monthly_summary_type_filter_only_income(self):
|
||||
"""With type=IN filter, summary should only include income"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?type=IN",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should still have 1000
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["income_current"], Decimal("1000.00"))
|
||||
|
||||
# expense_current should be empty/zero (filtered out)
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_current", 0), Decimal("0"))
|
||||
|
||||
# expense_projected should be empty/zero (filtered out)
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_projected", 0), Decimal("0"))
|
||||
|
||||
def test_monthly_summary_type_filter_only_expenses(self):
|
||||
"""With type=EX filter, summary should only include expenses"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?type=EX",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should be empty/zero (filtered out)
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("income_current", 0), Decimal("0"))
|
||||
|
||||
# expense_current should have 200
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_current"], Decimal("200.00"))
|
||||
|
||||
# expense_projected should have 150
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_projected"], Decimal("150.00"))
|
||||
|
||||
def test_monthly_summary_is_paid_filter_only_paid(self):
|
||||
"""With is_paid=1 filter, summary should only include paid transactions"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?is_paid=1",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should have 1000 (paid)
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["income_current"], Decimal("1000.00"))
|
||||
|
||||
# expense_current should have 200 (paid)
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_current"], Decimal("200.00"))
|
||||
|
||||
# expense_projected should be empty/zero (filtered out - unpaid)
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_projected", 0), Decimal("0"))
|
||||
|
||||
def test_monthly_summary_is_paid_filter_only_unpaid(self):
|
||||
"""With is_paid=0 filter, summary should only include unpaid transactions"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?is_paid=0",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# income_current should be empty/zero (filtered out - paid)
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("income_current", 0), Decimal("0"))
|
||||
|
||||
# expense_current should be empty/zero (filtered out - paid)
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_current", 0), Decimal("0"))
|
||||
|
||||
# expense_projected should have 150 (unpaid)
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_projected"], Decimal("150.00"))
|
||||
|
||||
def test_monthly_summary_description_filter(self):
|
||||
"""With description filter, summary should only include matching transactions"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?description=Income",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# Only income matches "Income" description
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["income_current"], Decimal("1000.00"))
|
||||
|
||||
# Expenses should be filtered out
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("expense_current", 0), Decimal("0"))
|
||||
|
||||
def test_monthly_summary_amount_filter(self):
|
||||
"""With amount filter, summary should only include transactions in range"""
|
||||
# Filter to only get transactions between 100 and 250 (should get 200 and 150)
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/?from_amount=100&to_amount=250",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
context = response.context
|
||||
|
||||
# Income (1000) should be filtered out
|
||||
income_current = context.get("income_current", {})
|
||||
usd_data = self._get_currency_data(income_current)
|
||||
if usd_data:
|
||||
self.assertEqual(usd_data.get("income_current", 0), Decimal("0"))
|
||||
|
||||
# expense_current should have 200
|
||||
expense_current = context.get("expense_current", {})
|
||||
usd_data = self._get_currency_data(expense_current)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_current"], Decimal("200.00"))
|
||||
|
||||
# expense_projected should have 150
|
||||
expense_projected = context.get("expense_projected", {})
|
||||
usd_data = self._get_currency_data(expense_projected)
|
||||
self.assertIsNotNone(usd_data)
|
||||
self.assertEqual(usd_data["expense_projected"], Decimal("150.00"))
|
||||
|
||||
# --- monthly_account_summary view tests ---
|
||||
|
||||
def test_monthly_account_summary_no_filter_returns_200(self):
|
||||
"""Test that monthly_account_summary returns 200 without filters"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/accounts/",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_monthly_account_summary_with_filter_returns_200(self):
|
||||
"""Test that monthly_account_summary returns 200 with filter"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/accounts/?type=IN",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# --- monthly_currency_summary view tests ---
|
||||
|
||||
def test_monthly_currency_summary_no_filter_returns_200(self):
|
||||
"""Test that monthly_currency_summary returns 200 without filters"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/currencies/",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_monthly_currency_summary_with_filter_returns_200(self):
|
||||
"""Test that monthly_currency_summary returns 200 with filter"""
|
||||
response = self.client.get(
|
||||
"/monthly/12/2025/summary/currencies/?type=EX",
|
||||
HTTP_HX_REQUEST="true",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user