Compare commits
3 Commits
fix/issue-
...
fix/issue-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a461005015 | ||
| b0311bc96f | |||
|
|
a280217254 |
@@ -1,36 +1,219 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from typing import List, Dict, Any, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from .auth import get_current_user
|
||||
from ..core.database import get_db
|
||||
from ..core.config import get_settings
|
||||
from ..db.schemas import BacktestCreate, BacktestResponse
|
||||
from ..db.models import Bot, Backtest, Signal, User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
running_backtests: Dict[str, Any] = {}
|
||||
executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
@router.post("/bots/{bot_id}/backtest", response_model=BacktestResponse)
|
||||
def start_backtest(bot_id: str, config: BacktestCreate, db: Session = Depends(get_db)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Not implemented"
|
||||
|
||||
def run_backtest_sync(
|
||||
backtest_id: str, db_url: str, bot_id: str, config: Dict[str, Any]
|
||||
):
|
||||
import asyncio
|
||||
from ..services.backtest.engine import BacktestEngine
|
||||
from ..core.database import SessionLocal
|
||||
|
||||
async def _run():
|
||||
engine = BacktestEngine(config)
|
||||
engine.run_id = backtest_id
|
||||
running_backtests[backtest_id] = engine
|
||||
try:
|
||||
results = await engine.run()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
backtest = db.query(Backtest).filter(Backtest.id == backtest_id).first()
|
||||
if backtest:
|
||||
backtest.status = engine.status
|
||||
backtest.ended_at = datetime.utcnow()
|
||||
backtest.result = results
|
||||
db.commit()
|
||||
|
||||
for signal in engine.signals:
|
||||
db_signal = Signal(
|
||||
id=signal["id"],
|
||||
bot_id=signal["bot_id"],
|
||||
run_id=signal["run_id"],
|
||||
signal_type=signal["signal_type"],
|
||||
token=signal["token"],
|
||||
price=signal["price"],
|
||||
confidence=signal.get("confidence"),
|
||||
reasoning=signal.get("reasoning"),
|
||||
executed=signal.get("executed", False),
|
||||
created_at=signal["created_at"],
|
||||
)
|
||||
db.add(db_signal)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
finally:
|
||||
if backtest_id in running_backtests:
|
||||
del running_backtests[backtest_id]
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/bots/{bot_id}/backtest",
|
||||
response_model=BacktestResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def start_backtest(
|
||||
bot_id: str,
|
||||
config: BacktestCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
bot = db.query(Bot).filter(Bot.id == bot_id).first()
|
||||
if not bot:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Bot not found"
|
||||
)
|
||||
if bot.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized"
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
backtest_id = str(uuid.uuid4())
|
||||
|
||||
backtest_config = {
|
||||
"bot_id": bot_id,
|
||||
"token": config.token,
|
||||
"chain": config.chain,
|
||||
"timeframe": config.timeframe,
|
||||
"start_date": config.start_date,
|
||||
"end_date": config.end_date,
|
||||
"strategy_config": bot.strategy_config,
|
||||
"ave_api_key": settings.AVE_API_KEY,
|
||||
"ave_api_plan": settings.AVE_API_PLAN,
|
||||
"initial_balance": 10000.0,
|
||||
}
|
||||
|
||||
backtest = Backtest(
|
||||
id=backtest_id,
|
||||
bot_id=bot_id,
|
||||
started_at=datetime.utcnow(),
|
||||
status="running",
|
||||
config={
|
||||
"token": config.token,
|
||||
"chain": config.chain,
|
||||
"timeframe": config.timeframe,
|
||||
"start_date": config.start_date,
|
||||
"end_date": config.end_date,
|
||||
},
|
||||
)
|
||||
db.add(backtest)
|
||||
db.commit()
|
||||
db.refresh(backtest)
|
||||
|
||||
db_url = str(settings.DATABASE_URL)
|
||||
background_tasks.add_task(
|
||||
run_backtest_sync, backtest_id, db_url, bot_id, backtest_config
|
||||
)
|
||||
|
||||
return backtest
|
||||
|
||||
|
||||
@router.get("/bots/{bot_id}/backtest/{run_id}", response_model=BacktestResponse)
|
||||
def get_backtest(bot_id: str, run_id: str, db: Session = Depends(get_db)):
|
||||
def get_backtest(
|
||||
bot_id: str,
|
||||
run_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
bot = db.query(Bot).filter(Bot.id == bot_id).first()
|
||||
if not bot:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Not implemented"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Bot not found"
|
||||
)
|
||||
if bot.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized"
|
||||
)
|
||||
|
||||
backtest = (
|
||||
db.query(Backtest)
|
||||
.filter(Backtest.id == run_id, Backtest.bot_id == bot_id)
|
||||
.first()
|
||||
)
|
||||
if not backtest:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Backtest not found"
|
||||
)
|
||||
|
||||
return backtest
|
||||
|
||||
|
||||
@router.get("/bots/{bot_id}/backtests", response_model=List[BacktestResponse])
|
||||
def list_backtests(bot_id: str, db: Session = Depends(get_db)):
|
||||
def list_backtests(
|
||||
bot_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
bot = db.query(Bot).filter(Bot.id == bot_id).first()
|
||||
if not bot:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Not implemented"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Bot not found"
|
||||
)
|
||||
if bot.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized"
|
||||
)
|
||||
|
||||
backtests = (
|
||||
db.query(Backtest)
|
||||
.filter(Backtest.bot_id == bot_id)
|
||||
.order_by(Backtest.started_at.desc())
|
||||
.all()
|
||||
)
|
||||
return backtests
|
||||
|
||||
|
||||
@router.post("/bots/{bot_id}/backtest/{run_id}/stop")
|
||||
def stop_backtest(bot_id: str, run_id: str, db: Session = Depends(get_db)):
|
||||
def stop_backtest(
|
||||
bot_id: str,
|
||||
run_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
bot = db.query(Bot).filter(Bot.id == bot_id).first()
|
||||
if not bot:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Not implemented"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Bot not found"
|
||||
)
|
||||
if bot.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized"
|
||||
)
|
||||
|
||||
backtest = (
|
||||
db.query(Backtest)
|
||||
.filter(Backtest.id == run_id, Backtest.bot_id == bot_id)
|
||||
.first()
|
||||
)
|
||||
if not backtest:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Backtest not found"
|
||||
)
|
||||
|
||||
if run_id in running_backtests:
|
||||
engine = running_backtests[run_id]
|
||||
asyncio.create_task(engine.stop())
|
||||
backtest.status = "stopped"
|
||||
backtest.ended_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
return {"status": "stopping", "run_id": run_id}
|
||||
|
||||
@@ -4,14 +4,18 @@ from typing import List, Annotated
|
||||
|
||||
from .auth import get_current_user
|
||||
from ..core.database import get_db
|
||||
from ..core.config import get_settings
|
||||
from ..db.schemas import (
|
||||
BotCreate,
|
||||
BotUpdate,
|
||||
BotResponse,
|
||||
BotConversationCreate,
|
||||
BotConversationResponse,
|
||||
BotChatRequest,
|
||||
BotChatResponse,
|
||||
)
|
||||
from ..db.models import Bot, BotConversation, User
|
||||
from ..services.ai_agent.crew import get_trading_crew
|
||||
|
||||
router = APIRouter()
|
||||
MAX_BOTS_PER_USER = 3
|
||||
@@ -154,10 +158,10 @@ def delete_bot(
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.post("/{bot_id}/chat", response_model=BotConversationResponse)
|
||||
@router.post("/{bot_id}/chat", response_model=BotChatResponse)
|
||||
def chat(
|
||||
bot_id: str,
|
||||
message: BotConversationCreate,
|
||||
request: BotChatRequest,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
@@ -173,15 +177,75 @@ def chat(
|
||||
detail="Not authorized to chat with this bot",
|
||||
)
|
||||
|
||||
conversation_history = (
|
||||
db.query(BotConversation)
|
||||
.filter(BotConversation.bot_id == bot_id)
|
||||
.order_by(BotConversation.created_at)
|
||||
.all()
|
||||
)
|
||||
history_for_crew = [
|
||||
{"role": conv.role, "content": conv.content}
|
||||
for conv in conversation_history[-10:]
|
||||
]
|
||||
|
||||
user_message = request.message
|
||||
if request.strategy_config:
|
||||
crew = get_trading_crew()
|
||||
result = crew.chat(user_message, history_for_crew)
|
||||
|
||||
assistant_content = result.get("response", "I couldn't process your request.")
|
||||
if result.get("success") and result.get("strategy_config"):
|
||||
bot.strategy_config = result["strategy_config"]
|
||||
db.commit()
|
||||
|
||||
db_conversation = BotConversation(
|
||||
bot_id=bot_id,
|
||||
role=message.role,
|
||||
content=message.content,
|
||||
role="user",
|
||||
content=user_message,
|
||||
)
|
||||
db.add(db_conversation)
|
||||
|
||||
db_assistant = BotConversation(
|
||||
bot_id=bot_id,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
db.add(db_assistant)
|
||||
db.commit()
|
||||
db.refresh(db_conversation)
|
||||
return db_conversation
|
||||
db.refresh(db_assistant)
|
||||
|
||||
return BotChatResponse(
|
||||
response=assistant_content,
|
||||
strategy_config=result.get("strategy_config"),
|
||||
success=result.get("success", False),
|
||||
)
|
||||
else:
|
||||
crew = get_trading_crew()
|
||||
result = crew.chat(user_message, history_for_crew)
|
||||
|
||||
assistant_content = result.get("response", "I couldn't process your request.")
|
||||
|
||||
db_conversation = BotConversation(
|
||||
bot_id=bot_id,
|
||||
role="user",
|
||||
content=user_message,
|
||||
)
|
||||
db.add(db_conversation)
|
||||
|
||||
db_assistant = BotConversation(
|
||||
bot_id=bot_id,
|
||||
role="assistant",
|
||||
content=assistant_content,
|
||||
)
|
||||
db.add(db_assistant)
|
||||
db.commit()
|
||||
db.refresh(db_assistant)
|
||||
|
||||
return BotChatResponse(
|
||||
response=assistant_content,
|
||||
strategy_config=result.get("strategy_config"),
|
||||
success=result.get("success", False),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{bot_id}/history", response_model=List[BotConversationResponse])
|
||||
|
||||
@@ -118,6 +118,17 @@ class BotConversationResponse(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class BotChatRequest(BaseModel):
|
||||
message: str
|
||||
strategy_config: Optional[bool] = False
|
||||
|
||||
|
||||
class BotChatResponse(BaseModel):
|
||||
response: str
|
||||
strategy_config: Optional[dict] = None
|
||||
success: bool = False
|
||||
|
||||
|
||||
class SignalResponse(BaseModel):
|
||||
id: str
|
||||
bot_id: str
|
||||
|
||||
@@ -1,15 +1,247 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
from crewai import Agent, Task, Crew
|
||||
from .llm_connector import MiniMaxConnector, MiniMaxLLM
|
||||
from ..core.config import get_settings
|
||||
|
||||
|
||||
class CrewAgent:
|
||||
def __init__(self, role: str, goal: str, backstory: str):
|
||||
self.role = role
|
||||
self.goal = goal
|
||||
self.backstory = backstory
|
||||
class StrategyValidator:
|
||||
SUPPORTED_CONDITIONS = ["price_drop", "price_rise", "volume_spike", "price_level"]
|
||||
SUPPORTED_ACTIONS = ["buy", "sell", "notify"]
|
||||
|
||||
def execute_task(self, task: str) -> str:
|
||||
raise NotImplementedError("CrewAI agent not yet implemented")
|
||||
def validate(self, strategy_config: dict) -> tuple[bool, list[str]]:
|
||||
errors = []
|
||||
|
||||
if "conditions" not in strategy_config:
|
||||
errors.append("Missing 'conditions' in strategy config")
|
||||
return False, errors
|
||||
|
||||
if not isinstance(strategy_config["conditions"], list):
|
||||
errors.append("'conditions' must be a list")
|
||||
return False, errors
|
||||
|
||||
if len(strategy_config["conditions"]) == 0:
|
||||
errors.append("At least one condition is required")
|
||||
return False, errors
|
||||
|
||||
for i, condition in enumerate(strategy_config["conditions"]):
|
||||
if "type" not in condition:
|
||||
errors.append(f"Condition {i}: missing 'type'")
|
||||
continue
|
||||
|
||||
cond_type = condition.get("type")
|
||||
if cond_type not in self.SUPPORTED_CONDITIONS:
|
||||
errors.append(f"Condition {i}: unsupported type '{cond_type}'")
|
||||
continue
|
||||
|
||||
params = condition.get("params", {})
|
||||
if cond_type in ["price_drop", "price_rise", "volume_spike"]:
|
||||
if "token" not in params:
|
||||
errors.append(f"Condition {i}: missing 'token'")
|
||||
if "threshold_percent" not in params:
|
||||
errors.append(f"Condition {i}: missing 'threshold_percent'")
|
||||
elif not isinstance(params["threshold_percent"], (int, float)):
|
||||
errors.append(
|
||||
f"Condition {i}: 'threshold_percent' must be a number"
|
||||
)
|
||||
elif params["threshold_percent"] <= 0:
|
||||
errors.append(
|
||||
f"Condition {i}: 'threshold_percent' must be positive"
|
||||
)
|
||||
|
||||
elif cond_type == "price_level":
|
||||
if "token" not in params:
|
||||
errors.append(f"Condition {i}: missing 'token'")
|
||||
if "price" not in params:
|
||||
errors.append(f"Condition {i}: missing 'price'")
|
||||
if "direction" not in params:
|
||||
errors.append(f"Condition {i}: missing 'direction'")
|
||||
elif params["direction"] not in ["above", "below"]:
|
||||
errors.append(
|
||||
f"Condition {i}: direction must be 'above' or 'below'"
|
||||
)
|
||||
|
||||
if "actions" in strategy_config:
|
||||
if not isinstance(strategy_config["actions"], list):
|
||||
errors.append("'actions' must be a list")
|
||||
else:
|
||||
for i, action in enumerate(strategy_config["actions"]):
|
||||
if "type" not in action:
|
||||
errors.append(f"Action {i}: missing 'type'")
|
||||
elif action["type"] not in self.SUPPORTED_ACTIONS:
|
||||
errors.append(
|
||||
f"Action {i}: unsupported type '{action['type']}'"
|
||||
)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
def get_trading_crew():
|
||||
raise NotImplementedError("Trading crew not yet implemented")
|
||||
class StrategyExplainer:
|
||||
def explain(self, strategy_config: dict) -> str:
|
||||
explanations = []
|
||||
|
||||
if "conditions" in strategy_config:
|
||||
cond_list = strategy_config["conditions"]
|
||||
if cond_list:
|
||||
explanations.append("This strategy will trigger when:")
|
||||
for cond in cond_list:
|
||||
cond_type = cond.get("type")
|
||||
params = cond.get("params", {})
|
||||
token = params.get("token", "the token")
|
||||
|
||||
if cond_type == "price_drop":
|
||||
pct = params.get("threshold_percent", 0)
|
||||
explanations.append(f" - {token} price drops by {pct}%")
|
||||
elif cond_type == "price_rise":
|
||||
pct = params.get("threshold_percent", 0)
|
||||
explanations.append(f" - {token} price rises by {pct}%")
|
||||
elif cond_type == "volume_spike":
|
||||
pct = params.get("threshold_percent", 0)
|
||||
explanations.append(
|
||||
f" - {token} trading volume increases by {pct}%"
|
||||
)
|
||||
elif cond_type == "price_level":
|
||||
price = params.get("price", 0)
|
||||
direction = params.get("direction", "unknown")
|
||||
explanations.append(
|
||||
f" - {token} price crosses {direction} ${price}"
|
||||
)
|
||||
|
||||
if "actions" in strategy_config:
|
||||
actions = strategy_config.get("actions", [])
|
||||
if actions:
|
||||
explanations.append("\nWhen triggered, the strategy will:")
|
||||
for action in actions:
|
||||
action_type = action.get("type")
|
||||
if action_type == "buy":
|
||||
explanations.append(" - Buy the token")
|
||||
elif action_type == "sell":
|
||||
explanations.append(" - Sell the token")
|
||||
elif action_type == "notify":
|
||||
explanations.append(" - Send a notification")
|
||||
|
||||
if not explanations:
|
||||
explanations.append("Strategy configuration is empty or invalid.")
|
||||
|
||||
return "\n".join(explanations)
|
||||
|
||||
|
||||
def create_trading_designer_agent(
|
||||
api_key: str, model: str = "MiniMax-Text-01"
|
||||
) -> Agent:
|
||||
connector = MiniMaxConnector(api_key=api_key, model=model)
|
||||
|
||||
system_prompt = """You are a Trading Strategy Designer AI. Your role is to parse user requests
|
||||
for trading strategies into structured JSON configuration.
|
||||
|
||||
Supported conditions (MVP):
|
||||
- price_drop: Triggers when a token's price drops by a specified percentage
|
||||
- price_rise: Triggers when a token's price rises by a specified percentage
|
||||
- volume_spike: Triggers when trading volume increases by a specified percentage
|
||||
- price_level: Triggers when price crosses above or below a specified level
|
||||
|
||||
Always ask clarifying questions if the user's request is ambiguous.
|
||||
Output strategy_config in valid JSON format only when you have all required information.
|
||||
"""
|
||||
|
||||
return Agent(
|
||||
role="Trading Strategy Designer",
|
||||
goal="Convert natural language trading requests into precise strategy configurations",
|
||||
backstory=system_prompt,
|
||||
llm=MiniMaxLLM(api_key=api_key, model=model),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
def create_strategy_validator_agent(
|
||||
api_key: str, model: str = "MiniMax-Text-01"
|
||||
) -> Agent:
|
||||
return Agent(
|
||||
role="Strategy Validator",
|
||||
goal="Validate trading strategy configurations for feasibility and identify potential issues",
|
||||
backstory="""You are a meticulous strategy validator with expertise in trading systems.
|
||||
You check that all required parameters are present, values are reasonable, and the
|
||||
strategy makes logical sense. You never approve strategies with missing or invalid data.""",
|
||||
llm=MiniMaxLLM(api_key=api_key, model=model),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
def create_strategy_explainer_agent(
|
||||
api_key: str, model: str = "MiniMax-Text-01"
|
||||
) -> Agent:
|
||||
return Agent(
|
||||
role="Strategy Explainer",
|
||||
goal="Generate clear, user-friendly explanations of trading strategies",
|
||||
backstory="""You are a patient trading strategy explainer. You translate complex
|
||||
strategy configurations into easy-to-understand language. You help users understand
|
||||
exactly what their strategies will do when triggered.""",
|
||||
llm=MiniMaxLLM(api_key=api_key, model=model),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
class TradingCrew:
|
||||
def __init__(self, api_key: str, model: str = "MiniMax-Text-01"):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.validator = StrategyValidator()
|
||||
self.explainer = StrategyExplainer()
|
||||
self.connector = MiniMaxConnector(api_key=api_key, model=model)
|
||||
|
||||
def parse_strategy(
|
||||
self, user_message: str, conversation_history: list[dict] = None
|
||||
) -> dict:
|
||||
strategy_config = self.connector.parse_strategy(
|
||||
user_message, conversation_history
|
||||
)
|
||||
|
||||
if "error" in strategy_config:
|
||||
return strategy_config
|
||||
|
||||
is_valid, errors = self.validator.validate(strategy_config)
|
||||
if not is_valid:
|
||||
return {
|
||||
"error": "Strategy validation failed",
|
||||
"validation_errors": errors,
|
||||
"partial_config": strategy_config,
|
||||
}
|
||||
|
||||
return strategy_config
|
||||
|
||||
def explain_strategy(self, strategy_config: dict) -> str:
|
||||
return self.explainer.explain(strategy_config)
|
||||
|
||||
def chat(self, user_message: str, conversation_history: list[dict] = None) -> dict:
|
||||
strategy_config = self.parse_strategy(user_message, conversation_history)
|
||||
|
||||
if "error" in strategy_config:
|
||||
explanation = f"I had trouble understanding your strategy: {strategy_config.get('error', 'Unknown error')}"
|
||||
if "validation_errors" in strategy_config:
|
||||
explanation += "\n\nValidation issues:"
|
||||
for err in strategy_config["validation_errors"]:
|
||||
explanation += f"\n - {err}"
|
||||
return {
|
||||
"response": explanation,
|
||||
"strategy_config": strategy_config.get("partial_config"),
|
||||
"success": False,
|
||||
}
|
||||
|
||||
explanation = self.explain_strategy(strategy_config)
|
||||
return {
|
||||
"response": f"I've configured your strategy:\n\n{explanation}",
|
||||
"strategy_config": strategy_config,
|
||||
"success": True,
|
||||
}
|
||||
|
||||
|
||||
def get_trading_crew(
|
||||
api_key: Optional[str] = None, model: Optional[str] = None
|
||||
) -> TradingCrew:
|
||||
if api_key is None:
|
||||
settings = get_settings()
|
||||
api_key = settings.MINIMAX_API_KEY
|
||||
if model is None:
|
||||
settings = get_settings()
|
||||
model = settings.MINIMAX_MODEL
|
||||
|
||||
return TradingCrew(api_key=api_key, model=model)
|
||||
|
||||
@@ -1,13 +1,108 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict, Any
|
||||
import httpx
|
||||
from crewai import LLM
|
||||
|
||||
|
||||
class LLMConnector:
|
||||
class MiniMaxLLM(LLM):
|
||||
def __init__(self, api_key: str, model: str = "MiniMax-Text-01", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.base_url = "https://api.minimax.chat/v1"
|
||||
|
||||
def _call(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 2048),
|
||||
}
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
return self._call(messages, **kwargs)
|
||||
|
||||
|
||||
class MiniMaxConnector:
|
||||
def __init__(self, api_key: str, model: str = "MiniMax-Text-01"):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
def chat(self, messages: list[dict], **kwargs):
|
||||
raise NotImplementedError("LLM integration not yet implemented")
|
||||
def chat(self, messages: list[dict], **kwargs) -> str:
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
formatted_messages.append(
|
||||
{
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg.get("content", str(msg)),
|
||||
}
|
||||
)
|
||||
else:
|
||||
formatted_messages.append({"role": "user", "content": str(msg)})
|
||||
|
||||
def parse_strategy(self, user_message: str) -> dict:
|
||||
raise NotImplementedError("Strategy parsing not yet implemented")
|
||||
llm = MiniMaxLLM(api_key=self.api_key, model=self.model)
|
||||
return llm.call(formatted_messages, **kwargs)
|
||||
|
||||
def parse_strategy(
|
||||
self, user_message: str, conversation_history: list[dict] = None
|
||||
) -> dict:
|
||||
system_prompt = """You are a trading strategy designer. Parse the user's natural language request into a JSON strategy_config object.
|
||||
|
||||
Supported conditions (MVP):
|
||||
- price_drop: Token price drops by X% (requires: token, threshold_percent)
|
||||
- price_rise: Token price rises by X% (requires: token, threshold_percent)
|
||||
- volume_spike: Trading volume increases X% (requires: token, threshold_percent)
|
||||
- price_level: Price crosses above/below X (requires: token, price, direction)
|
||||
|
||||
Output ONLY valid JSON with this schema:
|
||||
{
|
||||
"conditions": [
|
||||
{
|
||||
"type": "price_drop|price_rise|volume_spike|price_level",
|
||||
"params": {
|
||||
"token": "TOKEN_SYMBOL",
|
||||
"threshold_percent": number, // for price_drop, price_rise, volume_spike
|
||||
"price": number, // for price_level
|
||||
"direction": "above|below" // for price_level
|
||||
}
|
||||
}
|
||||
],
|
||||
"actions": [
|
||||
{
|
||||
"type": "buy|sell|notify",
|
||||
"params": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
If the user wants a condition not in the supported list, ask for clarification.
|
||||
"""
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if conversation_history:
|
||||
for msg in conversation_history:
|
||||
messages.append(
|
||||
{"role": msg.get("role", "user"), "content": msg.get("content", "")}
|
||||
)
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
response = self.chat(messages)
|
||||
try:
|
||||
import json
|
||||
|
||||
result = json.loads(response)
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "Failed to parse strategy", "raw_response": response}
|
||||
|
||||
70
src/backend/app/services/backtest/ave_client.py
Normal file
70
src/backend/app/services/backtest/ave_client.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import httpx
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class AveCloudClient:
|
||||
BASE_URL = "https://prod.ave-api.com"
|
||||
|
||||
def __init__(self, api_key: str, plan: str = "free"):
|
||||
self.api_key = api_key
|
||||
self.plan = plan
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {"X-API-KEY": self.api_key}
|
||||
|
||||
async def get_klines(
|
||||
self,
|
||||
token_id: str,
|
||||
interval: str = "1h",
|
||||
limit: int = 100,
|
||||
start_time: Optional[int] = None,
|
||||
end_time: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
url = f"{self.BASE_URL}/v2/klines/token/{token_id}"
|
||||
params = {"interval": interval, "limit": limit}
|
||||
if start_time:
|
||||
params["start_time"] = start_time
|
||||
if end_time:
|
||||
params["end_time"] = end_time
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
url, headers=self._headers(), params=params, timeout=30.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if data.get("status") == 200:
|
||||
return data.get("data", [])
|
||||
raise Exception(f"Failed to fetch klines: {data}")
|
||||
|
||||
async def get_token_price(self, token_id: str) -> Optional[Dict[str, Any]]:
|
||||
url = f"{self.BASE_URL}/v2/tokens/price"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self._headers(),
|
||||
json={"token_ids": [token_id]},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if data.get("status") == 200:
|
||||
prices = data.get("data", {})
|
||||
return prices.get(token_id)
|
||||
return None
|
||||
|
||||
async def get_batch_prices(self, token_ids: List[str]) -> Dict[str, Dict[str, Any]]:
|
||||
url = f"{self.BASE_URL}/v2/tokens/price"
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=self._headers(),
|
||||
json={"token_ids": token_ids},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if data.get("status") == 200:
|
||||
return data.get("data", {})
|
||||
return {}
|
||||
@@ -1,15 +1,324 @@
|
||||
from typing import Optional, Dict, Any
|
||||
import uuid
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from .ave_client import AveCloudClient
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.status = "pending"
|
||||
self.results: Optional[Dict[str, Any]] = None
|
||||
self.signals: List[Dict[str, Any]] = []
|
||||
self.ave_client = AveCloudClient(
|
||||
api_key=config.get("ave_api_key", ""),
|
||||
plan=config.get("ave_api_plan", "free"),
|
||||
)
|
||||
self.bot_id = config.get("bot_id")
|
||||
self.strategy_config = config.get("strategy_config", {})
|
||||
self.conditions = self.strategy_config.get("conditions", [])
|
||||
self.actions = self.strategy_config.get("actions", [])
|
||||
self.initial_balance = config.get("initial_balance", 10000.0)
|
||||
self.current_balance = self.initial_balance
|
||||
self.position = 0.0
|
||||
self.position_token = ""
|
||||
self.trades: List[Dict[str, Any]] = []
|
||||
self.running = False
|
||||
|
||||
async def run(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError("Backtest engine not yet implemented")
|
||||
self.running = True
|
||||
self.status = "running"
|
||||
started_at = datetime.utcnow()
|
||||
|
||||
try:
|
||||
token = self.config.get("token", "")
|
||||
chain = self.config.get("chain", "bsc")
|
||||
timeframe = self.config.get("timeframe", "1h")
|
||||
start_date = self.config.get("start_date", "")
|
||||
end_date = self.config.get("end_date", "")
|
||||
|
||||
token_id = (
|
||||
f"{token}-{chain}"
|
||||
if token and not token.endswith(f"-{chain}")
|
||||
else token
|
||||
)
|
||||
|
||||
if not token_id or token_id == f"-{chain}":
|
||||
raise ValueError("Token ID is required")
|
||||
|
||||
start_ts = None
|
||||
end_ts = None
|
||||
if start_date:
|
||||
start_ts = int(
|
||||
datetime.fromisoformat(
|
||||
start_date.replace("Z", "+00:00")
|
||||
).timestamp()
|
||||
* 1000
|
||||
)
|
||||
if end_date:
|
||||
end_ts = int(
|
||||
datetime.fromisoformat(end_date.replace("Z", "+00:00")).timestamp()
|
||||
* 1000
|
||||
)
|
||||
|
||||
klines = await self.ave_client.get_klines(
|
||||
token_id=token_id,
|
||||
interval=timeframe,
|
||||
limit=1000,
|
||||
start_time=start_ts,
|
||||
end_time=end_ts,
|
||||
)
|
||||
|
||||
if not klines:
|
||||
self.status = "failed"
|
||||
self.results = {"error": "No kline data available"}
|
||||
return self.results
|
||||
|
||||
await self._process_klines(klines)
|
||||
self._calculate_metrics()
|
||||
self.status = "completed"
|
||||
|
||||
except Exception as e:
|
||||
self.status = "failed"
|
||||
self.results = {"error": str(e)}
|
||||
|
||||
ended_at = datetime.utcnow()
|
||||
self.results = self.results or {}
|
||||
self.results["started_at"] = started_at
|
||||
self.results["ended_at"] = ended_at
|
||||
self.results["duration_seconds"] = (ended_at - started_at).total_seconds()
|
||||
|
||||
return self.results
|
||||
|
||||
async def _process_klines(self, klines: List[Dict[str, Any]]):
|
||||
for i, kline in enumerate(klines):
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
price = float(kline.get("close", 0))
|
||||
if price <= 0:
|
||||
continue
|
||||
|
||||
timestamp = kline.get("timestamp", 0)
|
||||
|
||||
for condition in self.conditions:
|
||||
if self._check_condition(condition, klines, i, price):
|
||||
await self._execute_actions(price, timestamp, condition)
|
||||
break
|
||||
|
||||
def _check_condition(
|
||||
self,
|
||||
condition: Dict[str, Any],
|
||||
klines: List[Dict[str, Any]],
|
||||
current_idx: int,
|
||||
current_price: float,
|
||||
) -> bool:
|
||||
cond_type = condition.get("type", "")
|
||||
threshold = condition.get("threshold", 0)
|
||||
timeframe = condition.get("timeframe", "1h")
|
||||
price_level = condition.get("price")
|
||||
direction = condition.get("direction", "above")
|
||||
|
||||
if cond_type == "price_drop":
|
||||
if current_idx == 0:
|
||||
return False
|
||||
prev_price = float(klines[current_idx - 1].get("close", 0))
|
||||
if prev_price <= 0:
|
||||
return False
|
||||
drop_pct = ((prev_price - current_price) / prev_price) * 100
|
||||
return drop_pct >= threshold
|
||||
|
||||
elif cond_type == "price_rise":
|
||||
if current_idx == 0:
|
||||
return False
|
||||
prev_price = float(klines[current_idx - 1].get("close", 0))
|
||||
if prev_price <= 0:
|
||||
return False
|
||||
rise_pct = ((current_price - prev_price) / prev_price) * 100
|
||||
return rise_pct >= threshold
|
||||
|
||||
elif cond_type == "volume_spike":
|
||||
if current_idx == 0:
|
||||
return False
|
||||
prev_volume = float(klines[current_idx - 1].get("volume", 0))
|
||||
current_volume = float(kline.get("volume", 0))
|
||||
if prev_volume <= 0:
|
||||
return False
|
||||
volume_increase = ((current_volume - prev_volume) / prev_volume) * 100
|
||||
return volume_increase >= threshold
|
||||
|
||||
elif cond_type == "price_level":
|
||||
if price_level is None:
|
||||
return False
|
||||
if direction == "above":
|
||||
return current_price > price_level
|
||||
else:
|
||||
return current_price < price_level
|
||||
|
||||
return False
|
||||
|
||||
async def _execute_actions(
|
||||
self, price: float, timestamp: int, matched_condition: Dict[str, Any]
|
||||
):
|
||||
token = matched_condition.get("token", self.config.get("token", ""))
|
||||
|
||||
for action in self.actions:
|
||||
action_type = action.get("type", "")
|
||||
amount_percent = action.get("amount_percent", 10)
|
||||
amount = self.current_balance * (amount_percent / 100)
|
||||
|
||||
if action_type == "buy" and self.current_balance >= amount:
|
||||
self.position += amount / price
|
||||
self.current_balance -= amount
|
||||
self.position_token = token
|
||||
self.trades.append(
|
||||
{
|
||||
"type": "buy",
|
||||
"token": token,
|
||||
"price": price,
|
||||
"amount": amount,
|
||||
"quantity": amount / price,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
self.signals.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"bot_id": self.bot_id,
|
||||
"run_id": self.run_id,
|
||||
"signal_type": "buy",
|
||||
"token": token,
|
||||
"price": price,
|
||||
"confidence": 0.8,
|
||||
"reasoning": f"Condition {matched_condition.get('type')} triggered buy",
|
||||
"executed": False,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
)
|
||||
|
||||
elif action_type == "sell" and self.position > 0:
|
||||
sell_amount = self.position * price
|
||||
self.current_balance += sell_amount
|
||||
self.trades.append(
|
||||
{
|
||||
"type": "sell",
|
||||
"token": self.position_token,
|
||||
"price": price,
|
||||
"amount": sell_amount,
|
||||
"quantity": self.position,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
)
|
||||
self.position = 0
|
||||
self.signals.append(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"bot_id": self.bot_id,
|
||||
"run_id": self.run_id,
|
||||
"signal_type": "sell",
|
||||
"token": self.position_token,
|
||||
"price": price,
|
||||
"confidence": 0.8,
|
||||
"reasoning": f"Condition {matched_condition.get('type')} triggered sell",
|
||||
"executed": False,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
)
|
||||
|
||||
def _calculate_metrics(self):
|
||||
final_balance = self.current_balance + (
|
||||
self.position * self.trades[-1]["price"]
|
||||
if self.trades and self.position > 0
|
||||
else 0
|
||||
)
|
||||
total_return = (
|
||||
(final_balance - self.initial_balance) / self.initial_balance
|
||||
) * 100
|
||||
|
||||
buy_trades = [t for t in self.trades if t["type"] == "buy"]
|
||||
sell_trades = [t for t in self.trades if t["type"] == "sell"]
|
||||
total_trades = len(buy_trades) + len(sell_trades)
|
||||
|
||||
winning_trades = 0
|
||||
for i, trade in enumerate(sell_trades):
|
||||
if i < len(buy_trades):
|
||||
buy_price = buy_trades[i]["price"]
|
||||
sell_price = trade["price"]
|
||||
if sell_price > buy_price:
|
||||
winning_trades += 1
|
||||
|
||||
win_rate = (winning_trades / len(sell_trades) * 100) if sell_trades else 0
|
||||
|
||||
portfolio_values = []
|
||||
running_balance = self.initial_balance
|
||||
running_position = 0.0
|
||||
current_token = ""
|
||||
last_price = 0.0
|
||||
|
||||
for trade in self.trades:
|
||||
if trade["type"] == "buy":
|
||||
running_position = trade["quantity"]
|
||||
running_balance = trade["amount"]
|
||||
current_token = trade["token"]
|
||||
last_price = trade["price"]
|
||||
else:
|
||||
running_balance = trade["amount"]
|
||||
running_position = 0
|
||||
last_price = trade["price"]
|
||||
|
||||
portfolio_value = running_balance + (running_position * last_price)
|
||||
portfolio_values.append(portfolio_value)
|
||||
|
||||
max_value = self.initial_balance
|
||||
max_drawdown = 0.0
|
||||
for value in portfolio_values:
|
||||
if value > max_value:
|
||||
max_value = value
|
||||
drawdown = ((max_value - value) / max_value) * 100
|
||||
if drawdown > max_drawdown:
|
||||
max_drawdown = drawdown
|
||||
|
||||
sharpe_ratio = 0.0
|
||||
if len(portfolio_values) > 1:
|
||||
returns = []
|
||||
for i in range(1, len(portfolio_values)):
|
||||
ret = (
|
||||
portfolio_values[i] - portfolio_values[i - 1]
|
||||
) / portfolio_values[i - 1]
|
||||
returns.append(ret)
|
||||
if returns:
|
||||
avg_return = sum(returns) / len(returns)
|
||||
variance = sum((r - avg_return) ** 2 for r in returns) / len(returns)
|
||||
std_dev = variance**0.5
|
||||
if std_dev > 0:
|
||||
sharpe_ratio = avg_return / std_dev
|
||||
|
||||
buy_signals = len(buy_trades)
|
||||
sell_signals = len(sell_trades)
|
||||
|
||||
self.results = {
|
||||
"total_return": round(total_return, 2),
|
||||
"win_rate": round(win_rate, 2),
|
||||
"total_trades": total_trades,
|
||||
"buy_signals": buy_signals,
|
||||
"sell_signals": sell_signals,
|
||||
"max_drawdown": round(max_drawdown, 2),
|
||||
"sharpe_ratio": round(sharpe_ratio, 2),
|
||||
"final_balance": round(final_balance, 2),
|
||||
"signals": self.signals,
|
||||
}
|
||||
|
||||
async def stop(self):
|
||||
raise NotImplementedError("Backtest stop not yet implemented")
|
||||
self.running = False
|
||||
self.status = "stopped"
|
||||
self._calculate_metrics()
|
||||
|
||||
def get_results(self) -> Dict[str, Any]:
|
||||
raise NotImplementedError("Backtest results not yet implemented")
|
||||
return {
|
||||
"id": self.run_id,
|
||||
"status": self.status,
|
||||
"results": self.results,
|
||||
"signals": self.signals,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user