Files
claude-mem/src/services/worker/GeminiAgent.ts
T
bigphoot f837a9eb77 feat: multi-turn conversations and Claude fallback for Gemini provider
Major improvements to Gemini provider:

**Shared Conversation History**
- Add ConversationMessage interface for provider-agnostic history
- Both Claude and Gemini agents read/write shared conversationHistory
- Context persists across provider switches via claudeSessionId linkage

**Multi-Turn Gemini API**
- Replace stateless single-query with full conversation context
- queryGeminiMultiTurn() sends entire history for coherent responses
- Maps 'assistant' role to 'model' for Gemini API compatibility

**Automatic Fallback to Claude**
- Detect rate limits (429), server errors (5xx), network failures
- Fall back to Claude SDK when Gemini API fails
- Reset 'processing' messages to 'pending' before fallback

**Mid-Session Provider Switching**
- Track currentProvider on ActiveSession
- Provider changes take effect after current generator finishes
- Avoids race conditions from aborting active generators

Files changed:
- worker-types.ts: Add ConversationMessage, currentProvider tracking
- GeminiAgent.ts: Multi-turn queries, fallback logic
- SDKAgent.ts: Capture messages to shared history
- SessionManager.ts: Initialize new session fields
- SessionRoutes.ts: Provider selection and switching logic
- worker-service.ts: Wire up fallback agent dependency

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-24 11:03:14 -08:00

492 lines
17 KiB
TypeScript

/**
* GeminiAgent: Gemini-based observation extraction
*
* Alternative to SDKAgent that uses Google's Gemini API directly
* for extracting observations from tool usage.
*
* Responsibility:
* - Call Gemini REST API for observation extraction
* - Parse XML responses (same format as Claude)
* - Sync to database and Chroma
*/
import path from 'path';
import { homedir } from 'os';
import { DatabaseManager } from './DatabaseManager.js';
import { SessionManager } from './SessionManager.js';
import { logger } from '../../utils/logger.js';
import { parseObservations, parseSummary } from '../../sdk/parser.js';
import { buildInitPrompt, buildObservationPrompt, buildSummaryPrompt, buildContinuationPrompt } from '../../sdk/prompts.js';
import { SettingsDefaultsManager } from '../../shared/SettingsDefaultsManager.js';
import { USER_SETTINGS_PATH } from '../../shared/paths.js';
import type { ActiveSession, PendingMessage, ConversationMessage } from '../worker-types.js';
import { ModeManager } from '../domain/ModeManager.js';
// Gemini API endpoint
const GEMINI_API_URL = 'https://generativelanguage.googleapis.com/v1beta/models';
// Gemini model types
export type GeminiModel = 'gemini-2.0-flash-exp' | 'gemini-1.5-flash' | 'gemini-1.5-pro';
interface GeminiResponse {
candidates?: Array<{
content?: {
parts?: Array<{
text?: string;
}>;
};
}>;
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
totalTokenCount?: number;
};
}
/**
* Gemini content message format
* role: "user" or "model" (Gemini uses "model" not "assistant")
*/
interface GeminiContent {
role: 'user' | 'model';
parts: Array<{ text: string }>;
}
// Forward declaration for fallback agent type
type FallbackAgent = {
startSession(session: ActiveSession, worker?: any): Promise<void>;
};
export class GeminiAgent {
private dbManager: DatabaseManager;
private sessionManager: SessionManager;
private fallbackAgent: FallbackAgent | null = null;
constructor(dbManager: DatabaseManager, sessionManager: SessionManager) {
this.dbManager = dbManager;
this.sessionManager = sessionManager;
}
/**
* Set the fallback agent (Claude SDK) for when Gemini API fails
* Must be set after construction to avoid circular dependency
*/
setFallbackAgent(agent: FallbackAgent): void {
this.fallbackAgent = agent;
}
/**
* Check if an error should trigger fallback to Claude
*/
private shouldFallbackToClaude(error: any): boolean {
const message = error?.message || '';
// Fall back on rate limit (429), server errors (5xx), or network issues
return (
message.includes('429') ||
message.includes('500') ||
message.includes('502') ||
message.includes('503') ||
message.includes('ECONNREFUSED') ||
message.includes('ETIMEDOUT') ||
message.includes('fetch failed')
);
}
/**
* Start Gemini agent for a session
* Uses multi-turn conversation to maintain context across messages
*/
async startSession(session: ActiveSession, worker?: any): Promise<void> {
try {
// Get Gemini configuration
const { apiKey, model } = this.getGeminiConfig();
if (!apiKey) {
throw new Error('Gemini API key not configured. Set CLAUDE_MEM_GEMINI_API_KEY in settings or GEMINI_API_KEY environment variable.');
}
// Load active mode
const mode = ModeManager.getInstance().getActiveMode();
// Build initial prompt
const initPrompt = session.lastPromptNumber === 1
? buildInitPrompt(session.project, session.claudeSessionId, session.userPrompt, mode)
: buildContinuationPrompt(session.userPrompt, session.lastPromptNumber, session.claudeSessionId, mode);
// Add to conversation history and query Gemini with full context
session.conversationHistory.push({ role: 'user', content: initPrompt });
const initResponse = await this.queryGeminiMultiTurn(session.conversationHistory, apiKey, model);
if (initResponse.content) {
// Add response to conversation history
session.conversationHistory.push({ role: 'assistant', content: initResponse.content });
// Track token usage
const tokensUsed = initResponse.tokensUsed || 0;
session.cumulativeInputTokens += Math.floor(tokensUsed * 0.7); // Rough estimate
session.cumulativeOutputTokens += Math.floor(tokensUsed * 0.3);
// Process response
await this.processGeminiResponse(session, initResponse.content, worker, tokensUsed);
}
// Process pending messages
for await (const message of this.sessionManager.getMessageIterator(session.sessionDbId)) {
if (message.type === 'observation') {
// Update last prompt number
if (message.prompt_number !== undefined) {
session.lastPromptNumber = message.prompt_number;
}
// Build observation prompt
const obsPrompt = buildObservationPrompt({
id: 0,
tool_name: message.tool_name!,
tool_input: JSON.stringify(message.tool_input),
tool_output: JSON.stringify(message.tool_response),
created_at_epoch: Date.now(),
cwd: message.cwd
});
// Add to conversation history and query Gemini with full context
session.conversationHistory.push({ role: 'user', content: obsPrompt });
const obsResponse = await this.queryGeminiMultiTurn(session.conversationHistory, apiKey, model);
if (obsResponse.content) {
// Add response to conversation history
session.conversationHistory.push({ role: 'assistant', content: obsResponse.content });
const tokensUsed = obsResponse.tokensUsed || 0;
session.cumulativeInputTokens += Math.floor(tokensUsed * 0.7);
session.cumulativeOutputTokens += Math.floor(tokensUsed * 0.3);
await this.processGeminiResponse(session, obsResponse.content, worker, tokensUsed);
}
} else if (message.type === 'summarize') {
// Build summary prompt
const summaryPrompt = buildSummaryPrompt({
id: session.sessionDbId,
sdk_session_id: session.sdkSessionId,
project: session.project,
user_prompt: session.userPrompt,
last_user_message: message.last_user_message || '',
last_assistant_message: message.last_assistant_message || ''
}, mode);
// Add to conversation history and query Gemini with full context
session.conversationHistory.push({ role: 'user', content: summaryPrompt });
const summaryResponse = await this.queryGeminiMultiTurn(session.conversationHistory, apiKey, model);
if (summaryResponse.content) {
// Add response to conversation history
session.conversationHistory.push({ role: 'assistant', content: summaryResponse.content });
const tokensUsed = summaryResponse.tokensUsed || 0;
session.cumulativeInputTokens += Math.floor(tokensUsed * 0.7);
session.cumulativeOutputTokens += Math.floor(tokensUsed * 0.3);
await this.processGeminiResponse(session, summaryResponse.content, worker, tokensUsed);
}
}
}
// Mark session complete
const sessionDuration = Date.now() - session.startTime;
logger.success('SDK', 'Gemini agent completed', {
sessionId: session.sessionDbId,
duration: `${(sessionDuration / 1000).toFixed(1)}s`,
historyLength: session.conversationHistory.length
});
this.dbManager.getSessionStore().markSessionCompleted(session.sessionDbId);
} catch (error: any) {
if (error.name === 'AbortError') {
logger.warn('SDK', 'Gemini agent aborted', { sessionId: session.sessionDbId });
throw error;
}
// Check if we should fall back to Claude
if (this.shouldFallbackToClaude(error) && this.fallbackAgent) {
logger.warn('SDK', 'Gemini API failed, falling back to Claude SDK', {
sessionDbId: session.sessionDbId,
error: error.message,
historyLength: session.conversationHistory.length
});
// Reset any 'processing' messages back to 'pending' so Claude can retry them
// This handles the case where Gemini failed mid-processing a message
const pendingStore = this.sessionManager.getPendingMessageStore();
const resetCount = pendingStore.resetStuckMessages(0); // 0 = reset ALL processing messages
if (resetCount > 0) {
logger.info('SDK', 'Reset processing messages for fallback', {
sessionDbId: session.sessionDbId,
resetCount
});
}
// Fall back to Claude - it will use the same session with shared conversationHistory
// Note: Claude SDK will continue processing from current state
return this.fallbackAgent.startSession(session, worker);
}
logger.failure('SDK', 'Gemini agent error', { sessionDbId: session.sessionDbId }, error);
throw error;
}
}
/**
* Convert shared ConversationMessage array to Gemini's contents format
* Maps 'assistant' role to 'model' for Gemini API compatibility
*/
private conversationToGeminiContents(history: ConversationMessage[]): GeminiContent[] {
return history.map(msg => ({
role: msg.role === 'assistant' ? 'model' : 'user',
parts: [{ text: msg.content }]
}));
}
/**
* Query Gemini via REST API with full conversation history (multi-turn)
* Sends the entire conversation context for coherent responses
*/
private async queryGeminiMultiTurn(
history: ConversationMessage[],
apiKey: string,
model: GeminiModel
): Promise<{ content: string; tokensUsed?: number }> {
const contents = this.conversationToGeminiContents(history);
const totalChars = history.reduce((sum, m) => sum + m.content.length, 0);
logger.debug('SDK', `Querying Gemini multi-turn (${model})`, {
turns: history.length,
totalChars
});
const url = `${GEMINI_API_URL}/${model}:generateContent?key=${apiKey}`;
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
contents,
generationConfig: {
temperature: 0.3, // Lower temperature for structured extraction
maxOutputTokens: 4096,
},
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Gemini API error: ${response.status} - ${error}`);
}
const data = await response.json() as GeminiResponse;
if (!data.candidates?.[0]?.content?.parts?.[0]?.text) {
logger.warn('SDK', 'Empty response from Gemini');
return { content: '' };
}
const content = data.candidates[0].content.parts[0].text;
const tokensUsed = data.usageMetadata?.totalTokenCount;
return { content, tokensUsed };
}
/**
* Process Gemini response (same format as Claude)
*/
private async processGeminiResponse(
session: ActiveSession,
text: string,
worker: any | undefined,
discoveryTokens: number
): Promise<void> {
// Parse observations (same XML format)
const observations = parseObservations(text, session.claudeSessionId);
// Store observations
for (const obs of observations) {
const { id: obsId, createdAtEpoch } = this.dbManager.getSessionStore().storeObservation(
session.claudeSessionId,
session.project,
obs,
session.lastPromptNumber,
discoveryTokens
);
logger.info('SDK', 'Gemini observation saved', {
sessionId: session.sessionDbId,
obsId,
type: obs.type,
title: obs.title || '(untitled)'
});
// Sync to Chroma
this.dbManager.getChromaSync().syncObservation(
obsId,
session.claudeSessionId,
session.project,
obs,
session.lastPromptNumber,
createdAtEpoch,
discoveryTokens
).catch(err => {
logger.warn('SDK', 'Gemini chroma sync failed', { obsId }, err);
});
// Broadcast to SSE clients
if (worker && worker.sseBroadcaster) {
worker.sseBroadcaster.broadcast({
type: 'new_observation',
observation: {
id: obsId,
sdk_session_id: session.sdkSessionId,
session_id: session.claudeSessionId,
type: obs.type,
title: obs.title,
subtitle: obs.subtitle,
text: null,
narrative: obs.narrative || null,
facts: JSON.stringify(obs.facts || []),
concepts: JSON.stringify(obs.concepts || []),
files_read: JSON.stringify(obs.files_read || []),
files_modified: JSON.stringify(obs.files_modified || []),
project: session.project,
prompt_number: session.lastPromptNumber,
created_at_epoch: createdAtEpoch
}
});
}
}
// Parse summary
const summary = parseSummary(text, session.sessionDbId);
if (summary) {
// Convert nullable fields to empty strings for storeSummary
const summaryForStore = {
request: summary.request || '',
investigated: summary.investigated || '',
learned: summary.learned || '',
completed: summary.completed || '',
next_steps: summary.next_steps || '',
notes: summary.notes
};
const { id: summaryId, createdAtEpoch } = this.dbManager.getSessionStore().storeSummary(
session.claudeSessionId,
session.project,
summaryForStore,
session.lastPromptNumber,
discoveryTokens
);
logger.info('SDK', 'Gemini summary saved', {
sessionId: session.sessionDbId,
summaryId,
request: summary.request || '(no request)'
});
// Sync to Chroma
this.dbManager.getChromaSync().syncSummary(
summaryId,
session.claudeSessionId,
session.project,
summaryForStore,
session.lastPromptNumber,
createdAtEpoch,
discoveryTokens
).catch(err => {
logger.warn('SDK', 'Gemini chroma sync failed', { summaryId }, err);
});
// Broadcast to SSE clients
if (worker && worker.sseBroadcaster) {
worker.sseBroadcaster.broadcast({
type: 'new_summary',
summary: {
id: summaryId,
session_id: session.claudeSessionId,
request: summary.request,
investigated: summary.investigated,
learned: summary.learned,
completed: summary.completed,
next_steps: summary.next_steps,
notes: summary.notes,
project: session.project,
prompt_number: session.lastPromptNumber,
created_at_epoch: createdAtEpoch
}
});
}
}
// Mark messages as processed
await this.markMessagesProcessed(session, worker);
}
/**
* Mark pending messages as processed
*/
private async markMessagesProcessed(session: ActiveSession, worker: any | undefined): Promise<void> {
const pendingMessageStore = this.sessionManager.getPendingMessageStore();
if (session.pendingProcessingIds.size > 0) {
for (const messageId of session.pendingProcessingIds) {
pendingMessageStore.markProcessed(messageId);
}
logger.debug('SDK', 'Gemini messages marked as processed', {
sessionId: session.sessionDbId,
count: session.pendingProcessingIds.size
});
session.pendingProcessingIds.clear();
const deletedCount = pendingMessageStore.cleanupProcessed(100);
if (deletedCount > 0) {
logger.debug('SDK', 'Gemini cleaned up old processed messages', { deletedCount });
}
}
if (worker && typeof worker.broadcastProcessingStatus === 'function') {
worker.broadcastProcessingStatus();
}
}
/**
* Get Gemini configuration from settings or environment
*/
private getGeminiConfig(): { apiKey: string; model: GeminiModel } {
const settingsPath = path.join(homedir(), '.claude-mem', 'settings.json');
const settings = SettingsDefaultsManager.loadFromFile(settingsPath);
// API key: check settings first, then environment variable
const apiKey = settings.CLAUDE_MEM_GEMINI_API_KEY || process.env.GEMINI_API_KEY || '';
// Model: from settings or default
const model = (settings.CLAUDE_MEM_GEMINI_MODEL || 'gemini-2.0-flash-exp') as GeminiModel;
return { apiKey, model };
}
}
/**
* Check if Gemini is available (has API key configured)
*/
export function isGeminiAvailable(): boolean {
const settingsPath = path.join(homedir(), '.claude-mem', 'settings.json');
const settings = SettingsDefaultsManager.loadFromFile(settingsPath);
return !!(settings.CLAUDE_MEM_GEMINI_API_KEY || process.env.GEMINI_API_KEY);
}
/**
* Check if Gemini is the selected provider
*/
export function isGeminiSelected(): boolean {
const settingsPath = path.join(homedir(), '.claude-mem', 'settings.json');
const settings = SettingsDefaultsManager.loadFromFile(settingsPath);
return settings.CLAUDE_MEM_PROVIDER === 'gemini';
}