#!/bin/bash
# claude-mem-sync — Synchronize claude-mem observations between machines
#
# Usage:
#   claude-mem-sync push <remote-host>    # local → remote
#   claude-mem-sync pull <remote-host>    # remote → local
#   claude-mem-sync sync <remote-host>    # bidirectional (push + pull)
#   claude-mem-sync status <remote-host>  # compare counts
#
# Prerequisites:
#   - SSH access to remote host (key-based auth recommended)
#   - Python 3 on both machines
#   - claude-mem installed on both machines (~/.claude-mem/claude-mem.db)
#
# Environment variables:
#   CLAUDE_MEM_DB         Local database path (default: ~/.claude-mem/claude-mem.db)
#   CLAUDE_MEM_REMOTE_DB  Remote database path (default: ~/.claude-mem/claude-mem.db)

set -euo pipefail

LOCAL_DB="${CLAUDE_MEM_DB:-$HOME/.claude-mem/claude-mem.db}"
COMMAND="${1:?Usage: claude-mem-sync <push|pull|sync|status> <remote-host>}"
REMOTE_HOST="${2:?Missing remote host. Usage: claude-mem-sync $COMMAND <remote-host>}"
REMOTE_DB="${CLAUDE_MEM_REMOTE_DB:-\$HOME/.claude-mem/claude-mem.db}"
TMPDIR="/tmp/claude-mem-sync-$$"

mkdir -p "$TMPDIR"
trap "rm -rf $TMPDIR" EXIT

# Column lists for observations and session_summaries
OBS_COLS="memory_session_id,project,text,type,title,subtitle,facts,narrative,concepts,files_read,files_modified,prompt_number,discovery_tokens,created_at,created_at_epoch"
SUM_COLS="memory_session_id,project,request,investigated,learned,completed,next_steps,files_read,files_edited,notes,prompt_number,discovery_tokens,created_at,created_at_epoch"

export_obs() {
    local db="$1" output="$2"
    python3 -c "
import sqlite3, json, sys
conn = sqlite3.connect('$db')
cur = conn.cursor()
cur.execute('''SELECT $OBS_COLS FROM observations ORDER BY created_at''')
cols = '$OBS_COLS'.split(',')
rows = [dict(zip(cols, r)) for r in cur.fetchall()]
cur.execute('''SELECT $SUM_COLS FROM session_summaries ORDER BY created_at''')
cols2 = '$SUM_COLS'.split(',')
sums = [dict(zip(cols2, r)) for r in cur.fetchall()]
json.dump({'observations': rows, 'summaries': sums}, open('$output', 'w'))
print(f'{len(rows)} obs, {len(sums)} sums exported', file=sys.stderr)
conn.close()
"
}

import_obs() {
    local db="$1" input="$2"
    python3 -c "
import sqlite3, json, sys
conn = sqlite3.connect('$db')
cur = conn.cursor()
cur.execute('SELECT created_at, title FROM observations')
existing = set((r[0],r[1]) for r in cur.fetchall())
cur.execute('SELECT created_at, request FROM session_summaries')
existing_s = set((r[0],r[1]) for r in cur.fetchall())
data = json.load(open('$input'))
oi, si = 0, 0
obs_cols = '$OBS_COLS'.split(',')
sum_cols = '$SUM_COLS'.split(',')
obs_placeholders = ','.join(['?'] * len(obs_cols))
sum_placeholders = ','.join(['?'] * len(sum_cols))
for o in data['observations']:
    if (o['created_at'], o['title']) not in existing:
        cur.execute(f'INSERT INTO observations ($OBS_COLS) VALUES ({obs_placeholders})',
            tuple(o[k] for k in obs_cols))
        oi += 1
for s in data['summaries']:
    if (s['created_at'], s['request']) not in existing_s:
        cur.execute(f'INSERT INTO session_summaries ($SUM_COLS) VALUES ({sum_placeholders})',
            tuple(s[k] for k in sum_cols))
        si += 1
conn.commit()
print(f'{oi} new obs, {si} new sums imported', file=sys.stderr)
conn.close()
"
}

count_db() {
    local db="$1"
    python3 -c "
import sqlite3
conn = sqlite3.connect('$db')
cur = conn.cursor()
cur.execute('SELECT COUNT(*) FROM observations')
obs = cur.fetchone()[0]
cur.execute('SELECT COUNT(*) FROM session_summaries')
sums = cur.fetchone()[0]
cur.execute('SELECT MAX(created_at) FROM observations')
last = cur.fetchone()[0] or 'empty'
print(f'{obs} obs, {sums} sums (last: {last[:19]})')
conn.close()
"
}

case "$COMMAND" in
    push)
        echo "=== Push: local → $REMOTE_HOST ==="
        export_obs "$LOCAL_DB" "$TMPDIR/export.json"
        scp -q "$TMPDIR/export.json" "$REMOTE_HOST:/tmp/mem-import.json"
        # Run import on remote
        ssh "$REMOTE_HOST" "python3 -c \"
import sqlite3, json, sys
conn = sqlite3.connect('$REMOTE_DB')
cur = conn.cursor()
cur.execute('SELECT created_at, title FROM observations')
existing = set((r[0],r[1]) for r in cur.fetchall())
cur.execute('SELECT created_at, request FROM session_summaries')
existing_s = set((r[0],r[1]) for r in cur.fetchall())
data = json.load(open('/tmp/mem-import.json'))
obs_cols = '$OBS_COLS'.split(',')
sum_cols = '$SUM_COLS'.split(',')
obs_ph = ','.join(['?'] * len(obs_cols))
sum_ph = ','.join(['?'] * len(sum_cols))
oi, si = 0, 0
for o in data['observations']:
    if (o['created_at'], o['title']) not in existing:
        cur.execute(f'INSERT INTO observations ($OBS_COLS) VALUES ({obs_ph})', tuple(o[k] for k in obs_cols))
        oi += 1
for s in data['summaries']:
    if (s['created_at'], s['request']) not in existing_s:
        cur.execute(f'INSERT INTO session_summaries ($SUM_COLS) VALUES ({sum_ph})', tuple(s[k] for k in sum_cols))
        si += 1
conn.commit()
print(f'Remote: {oi} new obs, {si} new sums imported', file=sys.stderr)
conn.close()
\""
        ;;
    pull)
        echo "=== Pull: $REMOTE_HOST → local ==="
        ssh "$REMOTE_HOST" "python3 -c \"
import sqlite3, json
conn = sqlite3.connect('$REMOTE_DB')
cur = conn.cursor()
cur.execute('SELECT $OBS_COLS FROM observations ORDER BY created_at')
cols = '$OBS_COLS'.split(',')
obs = [dict(zip(cols, r)) for r in cur.fetchall()]
cur.execute('SELECT $SUM_COLS FROM session_summaries ORDER BY created_at')
cols2 = '$SUM_COLS'.split(',')
sums = [dict(zip(cols2, r)) for r in cur.fetchall()]
json.dump({'observations': obs, 'summaries': sums}, open('/tmp/mem-export.json', 'w'))
print(f'{len(obs)} obs, {len(sums)} sums exported')
conn.close()
\""
        scp -q "$REMOTE_HOST:/tmp/mem-export.json" "$TMPDIR/import.json"
        import_obs "$LOCAL_DB" "$TMPDIR/import.json"
        ;;
    sync)
        echo "=== Bidirectional sync with $REMOTE_HOST ==="
        "$0" push "$REMOTE_HOST"
        "$0" pull "$REMOTE_HOST"
        "$0" status "$REMOTE_HOST"
        ;;
    status)
        echo "=== Local ==="
        count_db "$LOCAL_DB"
        echo "=== Remote ($REMOTE_HOST) ==="
        ssh "$REMOTE_HOST" "python3 -c \"
import sqlite3
conn = sqlite3.connect('$REMOTE_DB')
cur = conn.cursor()
cur.execute('SELECT COUNT(*) FROM observations')
obs = cur.fetchone()[0]
cur.execute('SELECT COUNT(*) FROM session_summaries')
sums = cur.fetchone()[0]
cur.execute('SELECT MAX(created_at) FROM observations')
last = cur.fetchone()[0] or 'empty'
print(f'{obs} obs, {sums} sums (last: {last[:19]})')
conn.close()
\""
        ;;
    *)
        echo "Usage: claude-mem-sync <push|pull|sync|status> <remote-host>"
        exit 1
        ;;
esac
