diff --git a/src/backend/app/api/auth.py b/src/backend/app/api/auth.py index afcc7eb..8e72dbe 100644 --- a/src/backend/app/api/auth.py +++ b/src/backend/app/api/auth.py @@ -1,7 +1,6 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlalchemy.orm import Session -from datetime import timedelta from typing import Annotated from ..core.database import get_db @@ -12,41 +11,133 @@ from ..core.security import ( verify_token, ) from ..core.config import get_settings -from ..db.schemas import UserCreate, UserResponse, Token +from ..core.limiter import limiter +from ..db.schemas import ( + UserCreate, + UserResponse, + Token, + UserSettings, + UserSettingsUpdate, +) +from ..db.models import User router = APIRouter() settings = get_settings() -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") + +TOKEN_BLACKLIST = set() -@router.post("/register", response_model=UserResponse) +def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db) +) -> User: + if token in TOKEN_BLACKLIST: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token has been revoked", + ) + payload = verify_token(token) + if payload is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + ) + user_id = payload.get("sub") + if user_id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token payload", + ) + user = db.query(User).filter(User.id == user_id).first() + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found", + ) + return user + + +@router.post( + "/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED +) def register(user: UserCreate, db: Session = Depends(get_db)): - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Not implemented" + existing_user = db.query(User).filter(User.email == user.email).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered", + ) + hashed_password = get_password_hash(user.password) + db_user = User( + email=user.email, + password_hash=hashed_password, ) + db.add(db_user) + db.commit() + db.refresh(db_user) + return db_user -@router.post("/login") +@router.post("/login", response_model=Token) +@limiter.limit("5/minute") def login( + request: Request, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Session = Depends(get_db), ): - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Not implemented" - ) + user = db.query(User).filter(User.email == form_data.username).first() + if not user or not verify_password(form_data.password, user.password_hash): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect email or password", + ) + access_token = create_access_token(data={"sub": user.id}) + return Token(access_token=access_token, token_type="bearer") @router.post("/logout") -def logout(token: Annotated[str, Depends(oauth2_scheme)]): - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Not implemented" - ) +def logout( + current_user: Annotated[User, Depends(get_current_user)], + token: Annotated[str, Depends(oauth2_scheme)], +): + TOKEN_BLACKLIST.add(token) + return {"message": "Successfully logged out"} @router.get("/me", response_model=UserResponse) def get_me( - token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db) + current_user: Annotated[User, Depends(get_current_user)], ): - raise HTTPException( - status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Not implemented" - ) + return current_user + + +@router.get("/settings", response_model=UserSettings) +def get_settings_endpoint( + current_user: Annotated[User, Depends(get_current_user)], +): + return UserSettings(email=current_user.email) + + +@router.patch("/settings", response_model=UserSettings) +def update_settings( + current_user: Annotated[User, Depends(get_current_user)], + settings_update: UserSettingsUpdate, + db: Session = Depends(get_db), +): + if settings_update.email: + existing = ( + db.query(User) + .filter(User.email == settings_update.email, User.id != current_user.id) + .first() + ) + if existing: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already in use", + ) + current_user.email = settings_update.email + if settings_update.password: + current_user.password_hash = get_password_hash(settings_update.password) + db.commit() + db.refresh(current_user) + return UserSettings(email=current_user.email) diff --git a/src/backend/app/core/limiter.py b/src/backend/app/core/limiter.py new file mode 100644 index 0000000..38404a8 --- /dev/null +++ b/src/backend/app/core/limiter.py @@ -0,0 +1,4 @@ +from slowapi import Limiter +from slowapi.util import get_remote_address + +limiter = Limiter(key_func=get_remote_address) diff --git a/src/backend/app/db/schemas.py b/src/backend/app/db/schemas.py index b564d59..7ef7ad0 100644 --- a/src/backend/app/db/schemas.py +++ b/src/backend/app/db/schemas.py @@ -23,6 +23,15 @@ class Token(BaseModel): token_type: str +class UserSettings(BaseModel): + email: EmailStr + + +class UserSettingsUpdate(BaseModel): + email: Optional[EmailStr] = None + password: Optional[str] = None + + class BotCreate(BaseModel): name: str description: Optional[str] = None diff --git a/src/backend/app/main.py b/src/backend/app/main.py index d0c56fa..f0549b3 100644 --- a/src/backend/app/main.py +++ b/src/backend/app/main.py @@ -1,6 +1,9 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from slowapi import Limiter +from slowapi.util import get_remote_address from .api import auth, bots, backtest, simulate, config +from .core.limiter import limiter app = FastAPI( title="Randebu Trading Bot API", @@ -8,6 +11,8 @@ app = FastAPI( version="0.1.0", ) +app.state.limiter = limiter + app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index 409bb67..dc9f84b 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -10,3 +10,4 @@ crewai>=0.1.0 anthropic>=0.18.0 httpx>=0.26.0 python-multipart>=0.0.6 +slowapi>=0.1.9