-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_chain.py
More file actions
165 lines (126 loc) · 4.99 KB
/
Copy pathllm_chain.py
File metadata and controls
165 lines (126 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
LLM chain and RAG pipeline.
Handles integration with local language models and RAG query processing.
"""
import logging
from typing import Optional
from langchain_community.llms import Ollama as OllamaLLM
from langchain_community.llms import LlamaCpp
from langchain_core.prompts import PromptTemplate
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from config import LLM_MODEL_NAME, LLM_MODEL_PATH, TEMPERATURE, MAX_TOKENS
logger = logging.getLogger(__name__)
class LLMManager:
"""Manages language model loading and configuration."""
def __init__(self, model_name: str = LLM_MODEL_NAME, use_local_path: bool = False):
"""
Initialize the LLM manager.
Args:
model_name: Model name (for Ollama) or path (for local model)
use_local_path: If True, load model from local file path
"""
self.model_name = model_name
self.llm = None
self.use_local_path = use_local_path
def load_ollama_model(self) -> None:
"""Load model using Ollama (requires Ollama server running locally)."""
logger.info(f"Loading Ollama model: {self.model_name}")
self.llm = OllamaLLM(
model=self.model_name,
temperature=TEMPERATURE,
num_ctx=2048, # Context window
)
logger.info(f"Ollama model loaded: {self.model_name}")
def load_local_model(self, model_path: str = None) -> None:
"""
Load a local GGML model file.
Args:
model_path: Path to the GGML model file (.gguf)
"""
if model_path is None:
model_path = str(LLM_MODEL_PATH)
logger.info(f"Loading local model from: {model_path}")
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
self.llm = LlamaCpp(
model_path=model_path,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
n_ctx=2048, # Context window
callback_manager=callback_manager,
verbose=False,
)
logger.info(f"Local model loaded from: {model_path}")
def load_model(self) -> None:
"""Load the configured model."""
if self.use_local_path:
self.load_local_model()
else:
self.load_ollama_model()
def get_llm(self):
"""Get the loaded LLM instance."""
if self.llm is None:
self.load_model()
return self.llm
class RAGChain:
"""Manages the RAG (Retrieval-Augmented Generation) chain."""
def __init__(self, llm_manager: LLMManager, retriever):
"""
Initialize the RAG chain.
Args:
llm_manager: Initialized LLMManager instance
retriever: Vector store retriever
"""
self.llm_manager = llm_manager
self.retriever = retriever
self.qa_chain = None
self._initialize_chain()
def _initialize_chain(self) -> None:
"""Initialize the RAG chain."""
logger.info("Initializing RAG chain...")
# Chain is built on-the-fly in query_with_context
logger.info("RAG chain initialized successfully")
def query(self, question: str) -> dict:
"""
Process a query through the RAG pipeline.
Args:
question: User's question
Returns:
Dictionary with answer and source documents
"""
return self.query_with_context(question)
def query_with_context(self, question: str, num_docs: int = 3) -> dict:
"""
Process a query with explicit document retrieval.
Args:
question: User's question
num_docs: Number of documents to retrieve
Returns:
Dictionary with answer, sources, and retrieved documents
"""
logger.info(f"Processing query with {num_docs} context documents: {question}")
# Retrieve relevant documents
relevant_docs = self.retriever.invoke(question)[:num_docs]
# Format context
context = "\n\n".join([f"Document {i+1}:\n{doc.page_content}"
for i, doc in enumerate(relevant_docs)])
# Create prompt with context
rag_template = """You are a helpful assistant. Use the following context to answer the question.
If you don't know the answer based on the context, say "I don't know based on the provided context."
Context:
{context}
Question: {question}
Answer:"""
prompt = rag_template.format(context=context, question=question)
# Get answer from LLM
try:
llm = self.llm_manager.get_llm()
answer = llm.invoke(prompt)
except Exception as e:
logger.warning(f"LLM error (Ollama not running?): {str(e)}")
answer = f"[Retrieved {len(relevant_docs)} relevant documents - LLM unavailable]\n\n" + context
return {
"answer": answer,
"context_documents": relevant_docs,
"context": context,
}