Portfolio (cash balance, position, etc.) is now saved to DB during simulation so it persists and shows in frontend.
323 lines
10 KiB
Python
323 lines
10 KiB
Python
import uuid
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime
|
|
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
|
from sqlalchemy.orm import Session
|
|
from typing import List, Dict, Any, Optional
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from .auth import get_current_user
|
|
from ..core.database import get_db
|
|
from ..core.config import get_settings
|
|
from ..db.schemas import SimulationCreate, SimulationResponse
|
|
from ..db.models import Bot, Simulation, Signal, User
|
|
from ..services.ave.client import AveCloudClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
running_simulations: Dict[str, Any] = {}
|
|
executor = ThreadPoolExecutor(max_workers=4)
|
|
|
|
|
|
def run_simulation_sync(
|
|
simulation_id: str, db_url: str, bot_id: str, config: Dict[str, Any]
|
|
):
|
|
import asyncio
|
|
import time
|
|
from ..services.simulate.engine import SimulateEngine
|
|
from ..core.database import SessionLocal
|
|
|
|
async def _run():
|
|
engine = SimulateEngine(config)
|
|
engine.run_id = simulation_id
|
|
running_simulations[simulation_id] = engine
|
|
|
|
# Serialize signals for JSON storage (convert datetime to string)
|
|
def serialize_signal(s):
|
|
created = s.get("created_at")
|
|
if hasattr(created, "isoformat"):
|
|
created = created.isoformat()
|
|
return {
|
|
**s,
|
|
"created_at": created
|
|
}
|
|
|
|
def save_progress():
|
|
"""Save current progress to database."""
|
|
db = SessionLocal()
|
|
try:
|
|
simulation = (
|
|
db.query(Simulation).filter(Simulation.id == simulation_id).first()
|
|
)
|
|
if simulation:
|
|
simulation.status = engine.status
|
|
simulation.signals = [serialize_signal(s) for s in engine.signals]
|
|
simulation.klines = [
|
|
{"time": k.get("time"), "close": k.get("close")}
|
|
for k in engine.klines
|
|
]
|
|
simulation.trade_log = engine.trade_log
|
|
# Save portfolio data
|
|
simulation.portfolio = {
|
|
"initial_balance": engine.config.get("initial_balance", 10000),
|
|
"current_balance": engine.current_balance,
|
|
"position": engine.position,
|
|
"position_token": engine.position_token,
|
|
"entry_price": engine.entry_price,
|
|
"current_price": engine.last_close,
|
|
}
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
async def run_with_progress_save():
|
|
"""Run simulation and save progress periodically."""
|
|
last_save_time = time.time()
|
|
save_interval = 5 # Save every 5 seconds
|
|
|
|
while engine.running and engine.status == "running":
|
|
await asyncio.sleep(1) # Check every second
|
|
|
|
current_time = time.time()
|
|
if current_time - last_save_time >= save_interval:
|
|
save_progress()
|
|
last_save_time = current_time
|
|
|
|
# Final save when done
|
|
save_progress()
|
|
|
|
try:
|
|
# Run both simulation and progress saving concurrently
|
|
await asyncio.gather(
|
|
engine.run(),
|
|
run_with_progress_save()
|
|
)
|
|
finally:
|
|
# Save final state
|
|
save_progress()
|
|
if simulation_id in running_simulations:
|
|
del running_simulations[simulation_id]
|
|
|
|
asyncio.run(_run())
|
|
|
|
|
|
@router.post(
|
|
"/bots/{bot_id}/simulate",
|
|
response_model=SimulationResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
async def start_simulation(
|
|
bot_id: str,
|
|
config: SimulationCreate,
|
|
background_tasks: BackgroundTasks,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
bot = db.query(Bot).filter(Bot.id == bot_id).first()
|
|
if not bot:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="Bot not found"
|
|
)
|
|
if bot.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized"
|
|
)
|
|
|
|
# Check if there's already a running simulation for this bot
|
|
existing_simulation = (
|
|
db.query(Simulation)
|
|
.filter(Simulation.bot_id == bot_id, Simulation.status == "running")
|
|
.first()
|
|
)
|
|
if existing_simulation:
|
|
# Stop the existing simulation first
|
|
if existing_simulation.id in running_simulations:
|
|
running_simulations[existing_simulation.id].stop()
|
|
del running_simulations[existing_simulation.id]
|
|
existing_simulation.status = "stopped"
|
|
db.commit()
|
|
|
|
settings = get_settings()
|
|
simulation_id = str(uuid.uuid4())
|
|
|
|
# Create AVE client for klines fetching
|
|
ave_client = AveCloudClient(
|
|
api_key=settings.AVE_API_KEY,
|
|
plan=settings.AVE_API_PLAN,
|
|
)
|
|
|
|
simulation_config = {
|
|
"bot_id": bot_id,
|
|
"token": config.token,
|
|
"chain": config.chain,
|
|
"kline_interval": config.kline_interval,
|
|
"auto_execute": False, # Always paper trade
|
|
"strategy_config": bot.strategy_config,
|
|
"ave_api_key": settings.AVE_API_KEY,
|
|
"ave_api_plan": settings.AVE_API_PLAN,
|
|
}
|
|
|
|
simulation = Simulation(
|
|
id=simulation_id,
|
|
bot_id=bot_id,
|
|
started_at=datetime.utcnow(),
|
|
status="running",
|
|
config={
|
|
"token": config.token,
|
|
"chain": config.chain,
|
|
"kline_interval": config.kline_interval,
|
|
},
|
|
signals=[],
|
|
klines=[],
|
|
)
|
|
db.add(simulation)
|
|
db.commit()
|
|
db.refresh(simulation)
|
|
|
|
# Fetch klines SYNCHRONOUSLY so user can see chart immediately
|
|
try:
|
|
token_id = f"{config.token}-{config.chain}"
|
|
|
|
# Calculate time range (last 1 hour)
|
|
import time
|
|
end_time = int(time.time() * 1000)
|
|
start_time = end_time - (60 * 60 * 1000) # 1 hour ago
|
|
|
|
klines_data = await ave_client.get_klines(
|
|
token_id,
|
|
interval=config.kline_interval,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
limit=500
|
|
)
|
|
klines_for_chart = [
|
|
{"time": k.get("time"), "close": k.get("close")}
|
|
for k in sorted(klines_data, key=lambda x: x.get("time", 0))
|
|
]
|
|
# Update simulation with klines
|
|
simulation.klines = klines_for_chart
|
|
db.commit()
|
|
db.refresh(simulation)
|
|
logger.info(f"Fetched {len(klines_for_chart)} klines for simulation {simulation_id}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to fetch klines: {e}")
|
|
|
|
# Run simulation in background for signal processing
|
|
background_tasks.add_task(
|
|
run_simulation_sync, simulation_id, str(settings.DATABASE_URL), bot_id, simulation_config
|
|
)
|
|
|
|
return simulation
|
|
|
|
|
|
@router.get("/bots/{bot_id}/simulate/{run_id}", response_model=SimulationResponse)
|
|
def get_simulation(
|
|
bot_id: str,
|
|
run_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
bot = db.query(Bot).filter(Bot.id == bot_id).first()
|
|
if not bot:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="Bot not found"
|
|
)
|
|
if bot.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized"
|
|
)
|
|
|
|
simulation = (
|
|
db.query(Simulation)
|
|
.filter(Simulation.id == run_id, Simulation.bot_id == bot_id)
|
|
.first()
|
|
)
|
|
if not simulation:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="Simulation not found"
|
|
)
|
|
|
|
if run_id in running_simulations:
|
|
engine = running_simulations[run_id]
|
|
simulation.signals = engine.get_signals()
|
|
|
|
return simulation
|
|
|
|
|
|
@router.get("/bots/{bot_id}/simulations", response_model=List[SimulationResponse])
|
|
def list_simulations(
|
|
bot_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
bot = db.query(Bot).filter(Bot.id == bot_id).first()
|
|
if not bot:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="Bot not found"
|
|
)
|
|
if bot.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized"
|
|
)
|
|
|
|
simulations = (
|
|
db.query(Simulation)
|
|
.filter(Simulation.bot_id == bot_id)
|
|
.order_by(Simulation.started_at.desc())
|
|
.all()
|
|
)
|
|
|
|
for sim in simulations:
|
|
if sim.id in running_simulations:
|
|
engine = running_simulations[sim.id]
|
|
sim.signals = engine.get_signals()
|
|
# Include klines from running engine for chart display
|
|
if hasattr(engine, 'klines'):
|
|
sim.klines = [{"time": k.get("time"), "close": k.get("close")} for k in engine.klines]
|
|
|
|
return simulations
|
|
|
|
|
|
@router.post("/bots/{bot_id}/simulate/{run_id}/stop")
|
|
def stop_simulation(
|
|
bot_id: str,
|
|
run_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
bot = db.query(Bot).filter(Bot.id == bot_id).first()
|
|
if not bot:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="Bot not found"
|
|
)
|
|
if bot.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized"
|
|
)
|
|
|
|
simulation = (
|
|
db.query(Simulation)
|
|
.filter(Simulation.id == run_id, Simulation.bot_id == bot_id)
|
|
.first()
|
|
)
|
|
if not simulation:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="Simulation not found"
|
|
)
|
|
|
|
# Always update status to stopped, even if engine is not in memory
|
|
simulation.status = "stopped"
|
|
|
|
# Try to stop the engine if it's still in memory
|
|
if run_id in running_simulations:
|
|
engine = running_simulations[run_id]
|
|
engine.stop()
|
|
del running_simulations[run_id]
|
|
|
|
db.commit()
|
|
|
|
return {"status": "stopped", "run_id": run_id}
|