-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathforce_index.py
More file actions
101 lines (82 loc) · 3.58 KB
/
Copy pathforce_index.py
File metadata and controls
101 lines (82 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import hashlib
import json
import requests
import lancedb
import pandas as pd
import time
from datetime import datetime
import google.generativeai as genai
from dotenv import load_dotenv
from concurrent.futures import ThreadPoolExecutor
load_dotenv()
# Config
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
TANA_TOKEN = os.getenv("TANA_TOKEN")
TANA_URL = "http://127.0.0.1:8262/mcp"
WORKSPACE_ID = "--D3QJHnLgSk"
LANCE_DB_PATH = "/Users/krshirkoohi/Documents/AI Workspace/projects/MCP Servers/tana-embeddings/vector_store"
TABLE_NAME = "tana_nodes"
genai.configure(api_key=GOOGLE_API_KEY)
EMBED_MODEL = "models/gemini-embedding-001"
def get_node_hash(n):
return hashlib.sha256(f"{n.get('name','')}|{n.get('description','')}".encode()).hexdigest()
def call_tana(method, params):
headers = {"Authorization": f"Bearer {TANA_TOKEN}", "Content-Type": "application/json", "Accept": "application/json"}
return requests.post(TANA_URL, headers=headers, json={"jsonrpc": "2.0", "method": "tools/call", "params": params, "id": 1}).json()
def get_all_nodes():
print("Fetching all tags...")
res = call_tana("tools/call", {"name": "list_tags", "arguments": {"workspaceId": WORKSPACE_ID}})
tags = json.loads(res['result']['content'][0]['text'])
all_nodes = {}
print(f"Scanning {len(tags)} tags for nodes...")
def fetch_tag(tag):
try:
r = call_tana("tools/call", {"name": "search_nodes", "arguments": {"query": {"hasType": tag['id']}, "limit": 100}})
nodes = json.loads(r['result']['content'][0]['text'])
return nodes
except: return []
with ThreadPoolExecutor(max_workers=5) as executor:
results = executor.map(fetch_tag, tags)
for nodes in results:
for n in nodes: all_nodes[n['id']] = n
return all_nodes
def main():
db = lancedb.connect(LANCE_DB_PATH)
table = db.open_table(TABLE_NAME)
print("Pre-loading existing index...")
existing_ids = set(table.to_pandas()['id'].tolist())
nodes_dict = get_all_nodes()
to_index = [n for nid, n in nodes_dict.items() if nid not in existing_ids]
print(f"Found {len(to_index)} nodes that need indexing.")
if not to_index:
print("Everything already indexed. Done.")
return
print(f"Starting High-Speed Embedding in batches of 100...")
for i in range(0, len(to_index), 100):
batch = to_index[i:i+100]
texts = [f"{n.get('name','')}\n{n.get('description','')}" for n in batch]
while True:
try:
res = genai.embed_content(model=EMBED_MODEL, content=texts, task_type="retrieval_document")
vectors = res['embeddings']
updates = []
for node, vector in zip(batch, vectors):
updates.append({
"vector": vector, "id": node['id'], "name": node.get('name',''),
"description": node.get('description',''), "hash": get_node_hash(node),
"last_updated": datetime.now().isoformat()
})
table.add(updates)
print(f"Indexed {i + len(batch)}/{len(to_index)}...")
break
except Exception as e:
if "429" in str(e):
print("Hit burst limit. Waiting 30s...")
time.sleep(30)
else:
print(f"Error: {e}")
break
print(f"SUCCESS. Entire workspace indexed. Total nodes: {len(table.to_pandas())}")
if __name__ == "__main__":
main()