356 lines
10 KiB
TypeScript
356 lines
10 KiB
TypeScript
/**
|
|
* Level 3b: Context Management
|
|
*
|
|
* Features:
|
|
* 1. Context pruning - Remove old messages when too long
|
|
* 2. Context compression - Summarize old messages
|
|
* 3. Token estimation
|
|
* 4. Configurable limits
|
|
*/
|
|
|
|
import { Agent, type AgentTool, type AgentMessage, type AgentEvent } from "@mariozechner/pi-agent-core";
|
|
import { registerBuiltInApiProviders } from "@mariozechner/pi-ai";
|
|
import type { Model } from "@mariozechner/pi-ai";
|
|
import * as fs from "fs";
|
|
import * as path from "path";
|
|
import { exec } from "child_process";
|
|
|
|
// ============== CONFIG ==============
|
|
const OPENROUTER_API_KEY = process.env.OPENROUTER_API_KEY || "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";
|
|
process.env.OPENROUTER_API_KEY = OPENROUTER_API_KEY;
|
|
|
|
const model: Model<"openai-responses"> = {
|
|
id: "stepfun/step-3.5-flash:free",
|
|
name: "Step-3.5 Flash (Free)",
|
|
api: "openai-responses",
|
|
provider: "openrouter",
|
|
baseUrl: "https://openrouter.ai/api/v1",
|
|
reasoning: false,
|
|
input: ["text"],
|
|
output: ["text"],
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
|
contextWindow: 128000,
|
|
maxTokens: 8192,
|
|
};
|
|
|
|
// ============== CONTEXT MANAGER ==============
|
|
interface ContextConfig {
|
|
maxTokens?: number;
|
|
pruneThreshold?: number; // When to start pruning
|
|
keepRecent?: number; // How many recent messages to always keep
|
|
compressionEnabled?: boolean;
|
|
}
|
|
|
|
interface MessageWithTokens extends AgentMessage {
|
|
_tokens?: number;
|
|
}
|
|
|
|
class ContextManager {
|
|
private maxTokens: number;
|
|
private pruneThreshold: number;
|
|
private keepRecent: number;
|
|
private compressionEnabled: boolean;
|
|
|
|
// Stats
|
|
private pruneCount = 0;
|
|
private compressCount = 0;
|
|
|
|
constructor(config: ContextConfig = {}) {
|
|
this.maxTokens = config.maxTokens || 100000; // Default 100k
|
|
this.pruneThreshold = config.pruneThreshold || 80000; // Start pruning at 80k
|
|
this.keepRecent = config.keepRecent || 10; // Keep last 10 messages
|
|
this.compressionEnabled = config.compressionEnabled || false;
|
|
}
|
|
|
|
// Estimate tokens (rough approximation: 1 token ≈ 4 characters)
|
|
estimateTokens(message: AgentMessage): number {
|
|
const msg = message as any;
|
|
let text = "";
|
|
|
|
if (typeof msg.content === "string") {
|
|
text = msg.content;
|
|
} else if (Array.isArray(msg.content)) {
|
|
for (const block of msg.content) {
|
|
if (block.type === "text") {
|
|
text += block.text || "";
|
|
}
|
|
}
|
|
}
|
|
|
|
// Rough estimate: 1 token ≈ 4 characters
|
|
return Math.ceil(text.length / 4);
|
|
}
|
|
|
|
// Calculate total tokens in messages
|
|
calculateTotalTokens(messages: AgentMessage[]): number {
|
|
return messages.reduce((sum, msg) => sum + this.estimateTokens(msg), 0);
|
|
}
|
|
|
|
// Prune old messages
|
|
prune(messages: AgentMessage[]): AgentMessage[] {
|
|
const total = this.calculateTotalTokens(messages);
|
|
|
|
if (total < this.pruneThreshold) {
|
|
return messages; // No pruning needed
|
|
}
|
|
|
|
console.log(`✂️ Pruning context: ${total} tokens > ${this.pruneThreshold} threshold`);
|
|
|
|
// Keep system prompt (first message) if it's a system message
|
|
let result: AgentMessage[] = [];
|
|
if (messages.length > 0 && (messages[0] as any).role === "system") {
|
|
result.push(messages[0]);
|
|
}
|
|
|
|
// Keep recent messages
|
|
const recent = messages.slice(-this.keepRecent);
|
|
result = result.concat(recent);
|
|
|
|
// Add summary placeholder if we removed middle messages
|
|
const removed = messages.length - result.length;
|
|
if (removed > 1) {
|
|
const summaryMsg: AgentMessage = {
|
|
role: "user",
|
|
content: [{ type: "text", text: `[Context: ${removed} older messages removed for brevity]` }],
|
|
timestamp: Date.now(),
|
|
};
|
|
result.splice(1, 0, summaryMsg); // Insert after system prompt
|
|
}
|
|
|
|
const newTotal = this.calculateTotalTokens(result);
|
|
this.pruneCount++;
|
|
|
|
console.log(`✂️ Pruned: ${messages.length} → ${result.length} messages`);
|
|
console.log(`✂️ Tokens: ${total} → ${newTotal}`);
|
|
console.log(`✂️ (Total prunes: ${this.pruneCount})`);
|
|
|
|
return result;
|
|
}
|
|
|
|
// Compress messages (placeholder - would need LLM for real compression)
|
|
compress(messages: AgentMessage[]): AgentMessage[] {
|
|
// This is a simplified version - real compression would use an LLM
|
|
console.log(`📦 Compression requested (${messages.length} messages)`);
|
|
|
|
// For now, just prune
|
|
this.compressCount++;
|
|
return this.prune(messages);
|
|
}
|
|
|
|
// Transform context - call this before sending to LLM
|
|
transform(messages: AgentMessage[]): AgentMessage[] {
|
|
const total = this.calculateTotalTokens(messages);
|
|
|
|
if (total > this.maxTokens) {
|
|
console.log(`⚠️ Context overflow: ${total} > ${this.maxTokens}, forcing prune`);
|
|
return this.prune(messages);
|
|
}
|
|
|
|
if (total > this.pruneThreshold && this.compressionEnabled) {
|
|
return this.compress(messages);
|
|
}
|
|
|
|
if (total > this.pruneThreshold) {
|
|
return this.prune(messages);
|
|
}
|
|
|
|
return messages;
|
|
}
|
|
|
|
getStats() {
|
|
return {
|
|
maxTokens: this.maxTokens,
|
|
pruneThreshold: this.pruneThreshold,
|
|
keepRecent: this.keepRecent,
|
|
compressionEnabled: this.compressionEnabled,
|
|
pruneCount: this.pruneCount,
|
|
compressCount: this.compressCount,
|
|
};
|
|
}
|
|
}
|
|
|
|
// ============== TOOLS ==============
|
|
function createTools(cwd: string = process.cwd()): AgentTool[] {
|
|
return [
|
|
{
|
|
name: "bash",
|
|
label: "Run Command",
|
|
description: "Run a shell command",
|
|
parameters: {
|
|
type: "object",
|
|
properties: {
|
|
command: { type: "string", description: "Command to run" },
|
|
},
|
|
required: ["command"],
|
|
} as const,
|
|
execute: async (toolCallId: string, params: { command: string }) => {
|
|
return new Promise((resolve) => {
|
|
exec(params.command, { cwd }, (error, stdout, stderr) => {
|
|
if (error) {
|
|
resolve({
|
|
content: [{ type: "text", text: stderr || error.message }],
|
|
details: { command: params.command, exitCode: error.code },
|
|
isError: true,
|
|
});
|
|
} else {
|
|
resolve({
|
|
content: [{ type: "text", text: stdout }],
|
|
details: { command: params.command, exitCode: 0 },
|
|
});
|
|
}
|
|
});
|
|
});
|
|
},
|
|
},
|
|
];
|
|
}
|
|
|
|
// ============== SHADOW WITH CONTEXT ==============
|
|
class ShadowWithContext {
|
|
private agent: Agent;
|
|
private contextManager: ContextManager;
|
|
public id: string;
|
|
public messageCount = 0;
|
|
|
|
constructor(id: string, worktreePath: string, contextConfig?: ContextConfig) {
|
|
this.id = id;
|
|
this.contextManager = new ContextManager(contextConfig);
|
|
|
|
this.agent = new Agent({
|
|
initialState: {
|
|
systemPrompt: "You are a helpful coding assistant. Be concise.",
|
|
model: model,
|
|
tools: createTools(worktreePath) as any,
|
|
messages: [],
|
|
},
|
|
convertToLlm: (messages: AgentMessage[]) => {
|
|
// Transform context before sending to LLM
|
|
const transformed = this.contextManager.transform(messages);
|
|
|
|
return transformed
|
|
.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult")
|
|
.map((m) => ({
|
|
role: m.role,
|
|
content: m.content,
|
|
}));
|
|
},
|
|
});
|
|
|
|
this.agent.subscribe((event) => {
|
|
if (event.type === "message_end") {
|
|
this.messageCount++;
|
|
}
|
|
});
|
|
}
|
|
|
|
async run(message: string): Promise<void> {
|
|
const msg: AgentMessage = {
|
|
role: "user",
|
|
content: [{ type: "text", text: message }],
|
|
timestamp: Date.now(),
|
|
};
|
|
|
|
await this.agent.prompt(msg);
|
|
}
|
|
|
|
getContextStats() {
|
|
return {
|
|
messageCount: this.messageCount,
|
|
contextManager: this.contextManager.getStats(),
|
|
};
|
|
}
|
|
}
|
|
|
|
// ============== TEST ==============
|
|
async function testContextPruning() {
|
|
console.log("\n" + "=".repeat(50));
|
|
console.log("TEST: Context Pruning");
|
|
console.log("=".repeat(50));
|
|
|
|
// Create shadow with aggressive pruning (for testing)
|
|
const shadow = new ShadowWithContext("test-1", "/tmp", {
|
|
maxTokens: 5000,
|
|
pruneThreshold: 2000,
|
|
keepRecent: 3,
|
|
compressionEnabled: false,
|
|
});
|
|
|
|
console.log("Context config:", shadow.getContextStats().contextManager);
|
|
|
|
// Simulate many messages to trigger pruning
|
|
const longText = "This is a test message with some content. ".repeat(50);
|
|
|
|
console.log("\n📝 Adding messages to trigger pruning...\n");
|
|
|
|
for (let i = 0; i < 15; i++) {
|
|
const msg: AgentMessage = {
|
|
role: "user",
|
|
content: [{ type: "text", text: `Message ${i}: ${longText}` }],
|
|
timestamp: Date.now(),
|
|
};
|
|
|
|
// Manually trigger context transform
|
|
const messages = Array(15).fill(null).map((_, j) => ({
|
|
role: j % 2 === 0 ? "user" as const : "assistant" as const,
|
|
content: [{ type: "text" as const, text: `Message ${j}: ${longText}` }],
|
|
timestamp: Date.now(),
|
|
}));
|
|
|
|
const transformed = (shadow as any).contextManager.transform(messages);
|
|
|
|
if (transformed.length < messages.length) {
|
|
console.log(`📊 After message ${i}: ${messages.length} → ${transformed.length} messages`);
|
|
}
|
|
}
|
|
|
|
console.log("\n📊 Final stats:", shadow.getContextStats());
|
|
}
|
|
|
|
async function testActualAgent() {
|
|
console.log("\n" + "=".repeat(50));
|
|
console.log("TEST: Actual Agent with Context Management");
|
|
console.log("=".repeat(50));
|
|
|
|
// Create with normal settings
|
|
const shadow = new ShadowWithContext("test-2", "/tmp", {
|
|
maxTokens: 50000,
|
|
pruneThreshold: 30000,
|
|
keepRecent: 10,
|
|
});
|
|
|
|
console.log("\n🚀 Running agent with context management...\n");
|
|
|
|
// Run multiple turns to build up context
|
|
await shadow.run("Say hello and run 'echo Hello 1'");
|
|
console.log("📊 After turn 1:", shadow.getContextStats());
|
|
|
|
await shadow.run("Say hi and run 'echo Hello 2'");
|
|
console.log("📊 After turn 2:", shadow.getContextStats());
|
|
|
|
await shadow.run("Run 'echo Hello 3'");
|
|
console.log("📊 After turn 3:", shadow.getContextStats());
|
|
|
|
await shadow.run("Run 'echo Hello 4'");
|
|
console.log("📊 After turn 4:", shadow.getContextStats());
|
|
|
|
await shadow.run("Run 'echo Hello 5'");
|
|
console.log("📊 After turn 5:", shadow.getContextStats());
|
|
|
|
console.log("\n✅ Agent test complete!");
|
|
console.log("📊 Final stats:", shadow.getContextStats());
|
|
}
|
|
|
|
// ============== MAIN ==============
|
|
async function main() {
|
|
console.log("🧪 Level 3b: Context Management\n");
|
|
|
|
registerBuiltInApiProviders();
|
|
|
|
await testContextPruning();
|
|
await testActualAgent();
|
|
|
|
console.log("\n✅ All tests complete!");
|
|
}
|
|
|
|
main().catch(console.error);
|