|
| 1 | +"""Hebbian connection matrix with spreading activation. |
| 2 | +
|
| 3 | +Edges are undirected: edge (a, b) is stored canonically with the lower id |
| 4 | +first to avoid duplicate rows. Weight accumulates over repeated strengthen() |
| 5 | +calls; decay_all() reduces all weights (floored at 0); garbage_collect() |
| 6 | +removes weak edges to keep the graph compact. |
| 7 | +
|
| 8 | +Spreading activation is a bounded BFS that propagates seed activation |
| 9 | +through the graph, attenuating by (weight * decay_per_hop) at each hop. |
| 10 | +Multi-path arrivals take the max (not sum) — prevents an activation |
| 11 | +runaway on densely connected graphs. |
| 12 | +
|
| 13 | +Design per spec Section 4.1 (brain/memory/hebbian.py) and OG's F32/F33 |
| 14 | +Hebbian work. |
| 15 | +""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import sqlite3 |
| 20 | +from collections.abc import Iterable |
| 21 | +from pathlib import Path |
| 22 | + |
| 23 | + |
| 24 | +class HebbianMatrix: |
| 25 | + """SQLite-backed sparse weighted graph between memory ids.""" |
| 26 | + |
| 27 | + _SCHEMA = """ |
| 28 | + CREATE TABLE IF NOT EXISTS hebbian_edges ( |
| 29 | + memory_a TEXT NOT NULL, |
| 30 | + memory_b TEXT NOT NULL, |
| 31 | + weight REAL NOT NULL DEFAULT 0.0, |
| 32 | + last_strengthened_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, |
| 33 | + PRIMARY KEY (memory_a, memory_b) |
| 34 | + ); |
| 35 | +
|
| 36 | + CREATE INDEX IF NOT EXISTS idx_hebbian_a ON hebbian_edges(memory_a); |
| 37 | + CREATE INDEX IF NOT EXISTS idx_hebbian_b ON hebbian_edges(memory_b); |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__(self, db_path: str | Path) -> None: |
| 41 | + self._conn = sqlite3.connect(str(db_path)) |
| 42 | + self._conn.executescript(self._SCHEMA) |
| 43 | + self._conn.commit() |
| 44 | + |
| 45 | + def close(self) -> None: |
| 46 | + self._conn.close() |
| 47 | + |
| 48 | + def strengthen(self, a: str, b: str, delta: float = 0.1) -> None: |
| 49 | + """Add `delta` to the weight of edge (a, b). Creates the edge if new. |
| 50 | +
|
| 51 | + `delta` must be positive — the module contract is that weights are |
| 52 | + non-negative. Callers that want to weaken an edge use `decay_all` |
| 53 | + or `garbage_collect`. Negative delta raises ValueError. |
| 54 | + """ |
| 55 | + if a == b: |
| 56 | + return # self-edges not tracked |
| 57 | + if delta <= 0.0: |
| 58 | + raise ValueError(f"delta must be positive, got {delta!r}") |
| 59 | + lo, hi = _canonical(a, b) |
| 60 | + self._conn.execute( |
| 61 | + """ |
| 62 | + INSERT INTO hebbian_edges (memory_a, memory_b, weight) |
| 63 | + VALUES (?, ?, ?) |
| 64 | + ON CONFLICT(memory_a, memory_b) |
| 65 | + DO UPDATE SET weight = weight + excluded.weight, |
| 66 | + last_strengthened_at = CURRENT_TIMESTAMP |
| 67 | + """, |
| 68 | + (lo, hi, delta), |
| 69 | + ) |
| 70 | + self._conn.commit() |
| 71 | + |
| 72 | + def weight(self, a: str, b: str) -> float: |
| 73 | + """Return the weight of edge (a, b). Zero if no edge.""" |
| 74 | + if a == b: |
| 75 | + return 0.0 |
| 76 | + lo, hi = _canonical(a, b) |
| 77 | + row = self._conn.execute( |
| 78 | + "SELECT weight FROM hebbian_edges WHERE memory_a = ? AND memory_b = ?", |
| 79 | + (lo, hi), |
| 80 | + ).fetchone() |
| 81 | + return float(row[0]) if row else 0.0 |
| 82 | + |
| 83 | + def neighbors(self, memory_id: str) -> list[tuple[str, float]]: |
| 84 | + """Return [(other_id, weight), ...] for every edge touching memory_id.""" |
| 85 | + rows = self._conn.execute( |
| 86 | + """ |
| 87 | + SELECT memory_b, weight FROM hebbian_edges WHERE memory_a = ? |
| 88 | + UNION ALL |
| 89 | + SELECT memory_a, weight FROM hebbian_edges WHERE memory_b = ? |
| 90 | + """, |
| 91 | + (memory_id, memory_id), |
| 92 | + ).fetchall() |
| 93 | + return [(other, float(weight)) for other, weight in rows] |
| 94 | + |
| 95 | + def decay_all(self, rate: float) -> None: |
| 96 | + """Subtract `rate` from every weight, floored at 0. |
| 97 | +
|
| 98 | + `rate` must be non-negative. A negative rate would inflate every |
| 99 | + weight in a single scheduled batch — silent corruption for |
| 100 | + dream/heartbeat cycles. ValueError guards the sign. |
| 101 | + """ |
| 102 | + if rate < 0.0: |
| 103 | + raise ValueError(f"decay rate must be non-negative, got {rate!r}") |
| 104 | + self._conn.execute("UPDATE hebbian_edges SET weight = MAX(weight - ?, 0.0)", (rate,)) |
| 105 | + self._conn.commit() |
| 106 | + |
| 107 | + def garbage_collect(self, threshold: float = 0.01) -> int: |
| 108 | + """Remove edges with weight < threshold. Returns the count removed.""" |
| 109 | + cursor = self._conn.execute("DELETE FROM hebbian_edges WHERE weight < ?", (threshold,)) |
| 110 | + self._conn.commit() |
| 111 | + return cursor.rowcount |
| 112 | + |
| 113 | + def spreading_activation( |
| 114 | + self, |
| 115 | + seed_ids: Iterable[str], |
| 116 | + depth: int = 2, |
| 117 | + decay_per_hop: float = 0.5, |
| 118 | + ) -> dict[str, float]: |
| 119 | + """BFS spreading activation from seed_ids, returning activation by id. |
| 120 | +
|
| 121 | + Seed nodes have activation 1.0 and are protected: propagation |
| 122 | + cannot lower them. Each hop multiplies the source activation by |
| 123 | + (edge_weight * decay_per_hop) to produce the neighbour's |
| 124 | + activation. Multi-path arrivals take the max (not sum) — prevents |
| 125 | + activation runaway on densely connected graphs. |
| 126 | +
|
| 127 | + Returns a dict {memory_id: activation}. |
| 128 | + """ |
| 129 | + activation: dict[str, float] = {} |
| 130 | + for sid in seed_ids: |
| 131 | + activation[sid] = 1.0 |
| 132 | + |
| 133 | + frontier = set(activation) |
| 134 | + for _ in range(depth): |
| 135 | + next_frontier: set[str] = set() |
| 136 | + for node in frontier: |
| 137 | + for neighbour, weight in self.neighbors(node): |
| 138 | + propagated = activation[node] * weight * decay_per_hop |
| 139 | + if propagated > activation.get(neighbour, 0.0): |
| 140 | + activation[neighbour] = propagated |
| 141 | + next_frontier.add(neighbour) |
| 142 | + frontier = next_frontier |
| 143 | + if not frontier: |
| 144 | + break |
| 145 | + return activation |
| 146 | + |
| 147 | + |
| 148 | +def _canonical(a: str, b: str) -> tuple[str, str]: |
| 149 | + """Sort the pair so edge (a, b) and (b, a) hash to the same row.""" |
| 150 | + return (a, b) if a <= b else (b, a) |
0 commit comments