Skip to content

Commit aea3273

Browse files
committed
fix: refactor validate_dates function (avoid writing duplicate code)
1 parent d7281c2 commit aea3273

File tree

4 files changed

+24
-63
lines changed

4 files changed

+24
-63
lines changed

app/dependencies.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from app.db import Session, crud, get_db
55
from config import SUDOERS
66
from fastapi import Depends, HTTPException
7-
from datetime import datetime, timezone
7+
from datetime import datetime, timezone, timedelta
88
from app.utils.jwt import get_subscription_payload
99

1010

@@ -36,20 +36,23 @@ def get_dbnode(node_id: int, db: Session = Depends(get_db)):
3636
return dbnode
3737

3838

39-
def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[str, datetime]]) -> bool:
39+
def validate_dates(start: Optional[Union[str, datetime]], end: Optional[Union[str, datetime]]) -> (datetime, datetime):
4040
"""Validate if start and end dates are correct and if end is after start."""
4141
try:
4242
if start:
43-
start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start)
43+
start_date = start if isinstance(start, datetime) else datetime.fromisoformat(start).astimezone(timezone.utc)
4444
else:
45-
start_date = None
45+
start_date = datetime.now(timezone.utc) - timedelta(days=30)
4646
if end:
47-
end_date = end if isinstance(end, datetime) else datetime.fromisoformat(end)
47+
end_date = end if isinstance(end, datetime) else datetime.fromisoformat(end).astimezone(timezone.utc)
4848
if start_date and end_date < start_date:
49-
return False
50-
return True
49+
raise HTTPException(status_code=400, detail="Start date must be before end date")
50+
else:
51+
end_date = datetime.now(timezone.utc)
52+
53+
return start_date, end_date
5154
except ValueError:
52-
return False
55+
raise HTTPException(status_code=400, detail="Invalid date range or format")
5356

5457

5558
def get_user_template(template_id: int, db: Session = Depends(get_db)):
@@ -61,8 +64,8 @@ def get_user_template(template_id: int, db: Session = Depends(get_db)):
6164

6265

6366
def get_validated_sub(
64-
token: str,
65-
db: Session = Depends(get_db)
67+
token: str,
68+
db: Session = Depends(get_db)
6669
) -> UserResponse:
6770
sub = get_subscription_payload(token)
6871
if not sub:
@@ -79,9 +82,9 @@ def get_validated_sub(
7982

8083

8184
def get_validated_user(
82-
username: str,
83-
admin: Admin = Depends(Admin.get_current),
84-
db: Session = Depends(get_db)
85+
username: str,
86+
admin: Admin = Depends(Admin.get_current),
87+
db: Session = Depends(get_db)
8588
) -> UserResponse:
8689
dbuser = crud.get_user(db, username)
8790
if not dbuser:
@@ -93,8 +96,8 @@ def get_validated_user(
9396
return dbuser
9497

9598

96-
def get_expired_users_list(db: Session, admin: Admin, expired_after: Optional[datetime] = None, expired_before: Optional[datetime] = None):
97-
99+
def get_expired_users_list(db: Session, admin: Admin, expired_after: Optional[datetime] = None,
100+
expired_before: Optional[datetime] = None):
98101
expired_before = expired_before or datetime.now(timezone.utc)
99102
expired_after = expired_after or datetime.min.replace(tzinfo=timezone.utc)
100103

app/routers/node.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -198,17 +198,7 @@ def get_usage(
198198
_: Admin = Depends(Admin.check_sudo_admin)
199199
):
200200
"""Retrieve usage statistics for nodes within a specified date range."""
201-
if not validate_dates(start, end):
202-
raise HTTPException(status_code=400, detail="Invalid date range or format")
203-
204-
if not start:
205-
start = datetime.now(timezone.utc) - timedelta(days=30)
206-
else:
207-
start = datetime.fromisoformat(start).astimezone(timezone.utc)
208-
if not end:
209-
end = datetime.now(timezone.utc)
210-
else:
211-
end = datetime.fromisoformat(end).astimezone(timezone.utc)
201+
start, end = validate_dates(start, end)
212202

213203
usages = crud.get_nodes_usage(db, start, end)
214204

app/routers/subscription.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,7 @@ def user_get_usage(
144144
db: Session = Depends(get_db)
145145
):
146146
"""Fetches the usage statistics for the user within a specified date range."""
147-
if not validate_dates(start, end):
148-
raise HTTPException(status_code=400, detail="Invalid date range or format")
149-
150-
if not start:
151-
start = datetime.now(timezone.utc) - timedelta(days=30)
152-
else:
153-
start = datetime.fromisoformat(start).astimezone(timezone.utc)
154-
if not end:
155-
end = datetime.now(timezone.utc)
156-
else:
157-
end = datetime.fromisoformat(end).astimezone(timezone.utc)
147+
start, end = validate_dates(start, end)
158148

159149
usages = crud.get_user_usages(db, dbuser, start, end)
160150

app/routers/user.py

+4-26
Original file line numberDiff line numberDiff line change
@@ -265,17 +265,7 @@ def get_user_usage(
265265
db: Session = Depends(get_db)
266266
):
267267
"""Get users usage"""
268-
if not validate_dates(start, end):
269-
raise HTTPException(status_code=400, detail="Invalid date range or format")
270-
271-
if not start:
272-
start = datetime.now(timezone.utc) - timedelta(days=30)
273-
else:
274-
start = datetime.fromisoformat(start).astimezone(timezone.utc)
275-
if not end:
276-
end = datetime.now(timezone.utc)
277-
else:
278-
end = datetime.fromisoformat(end).astimezone(timezone.utc)
268+
start, end = validate_dates(start, end)
279269

280270
usages = crud.get_user_usages(db, dbuser, start, end)
281271

@@ -291,17 +281,7 @@ def get_users_usage(
291281
admin: Admin = Depends(Admin.get_current)
292282
):
293283
"""Get all users usage"""
294-
if not validate_dates(start, end):
295-
raise HTTPException(status_code=400, detail="Invalid date range or format")
296-
297-
if not start:
298-
start = datetime.now(timezone.utc) - timedelta(days=30)
299-
else:
300-
start = datetime.fromisoformat(start).astimezone(timezone.utc)
301-
if not end:
302-
end = datetime.now(timezone.utc)
303-
else:
304-
end = datetime.fromisoformat(end).astimezone(timezone.utc)
284+
start, end = validate_dates(start, end)
305285

306286
usages = crud.get_all_users_usages(
307287
db=db,
@@ -350,8 +330,7 @@ def get_expired_users(
350330
- If both are omitted, returns all expired users
351331
"""
352332

353-
if not validate_dates(expired_after, expired_before):
354-
raise HTTPException(status_code=400, detail="Invalid date range or format")
333+
expired_after, expired_before = validate_dates(expired_after, expired_before)
355334

356335
expired_users = get_expired_users_list(db, admin, expired_after, expired_before)
357336
return [u.username for u in expired_users]
@@ -372,8 +351,7 @@ def delete_expired_users(
372351
- **expired_before** UTC datetime (optional)
373352
- At least one of expired_after or expired_before must be provided
374353
"""
375-
if not validate_dates(expired_after, expired_before, allow_both_none=False):
376-
raise HTTPException(status_code=400, detail="Invalid date range or format")
354+
expired_after, expired_before = validate_dates(expired_after, expired_before)
377355

378356
expired_users = get_expired_users_list(db, admin, expired_after, expired_before)
379357
removed_users = [u.username for u in expired_users]

0 commit comments

Comments
 (0)