-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_embeddings.py
More file actions
74 lines (54 loc) · 2.07 KB
/
Copy pathcreate_embeddings.py
File metadata and controls
74 lines (54 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import List
import chromadb # type: ignore
from chromadb.utils import embedding_functions # type: ignore
from chromadb.api.models import Collection # type: ignore
from dotenv import load_dotenv
import argparse
from get_data import get_mentor_sentences
load_dotenv()
PERSIST_DIRECTORY = '.db'
EMBD_FNC = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-mpnet-base-v2")
def get_chroma_client() -> chromadb.Client:
# chroma_client = chromadb.Client(chromadb.config.Settings(
# chroma_db_impl="duckdb+parquet",
# persist_directory=PERSIST_DIRECTORY,
# ))
chroma_client = chromadb.Client()
return chroma_client
def create_mentor_embeddings(
mentor_sentences: List[dict],
collection_name: str = 'mentors') -> Collection:
chroma_client = get_chroma_client()
collection = chroma_client.create_collection(
name=collection_name,
embedding_function=EMBD_FNC,
metadata={
"hnsw:space": "cosine",
},
)
collection.add(
ids=[mentor['id'] for mentor in mentor_sentences],
documents=[mentor['sentence'] for mentor in mentor_sentences],
metadatas=[mentor['metadata'] for mentor in mentor_sentences],
)
return collection
def get_mentor_embeddings(collection_name: str = 'mentors') -> Collection:
chroma_client = get_chroma_client()
collection = chroma_client.get_collection(
name=collection_name,
embedding_function=EMBD_FNC)
return collection
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Create mentor embeddings')
parser.add_argument('--collection-name',
help='Name of the collection',
default='mentors')
args = parser.parse_args()
mentor_sentences = get_mentor_sentences()
print('Creating collection {} for {} sentences'.format(
args.collection_name,
len(mentor_sentences)))
collection = create_mentor_embeddings(
mentor_sentences,
collection_name=args.collection_name)