Skip to content

Commit 61750ac

Browse files
zizhaofclaude
andcommitted
perf: reuse sentence embeddings for chunk vectors
On an Oracle Free Tier 4-vCPU ARM box with no GPU, bge-m3 encode is the single biggest cost of attachment upload. Measured in-container on a synthetic 7KB, 150-sentence text: semantic-chunking + chunk-storage needed two separate embed_texts passes (~8.5s each) = ~17s wall time total, which dominated the user-visible extract latency. Fold the per-chunk embed into the sentence-level pass: - New chunk_and_embed() does the same sentence-boundary detection as chunk_text_semantic, but returns the resulting (chunks, embeddings) pair. Per-chunk embeddings are the L2-renormalized sum of their constituent sentence vectors. Since bge-m3 sentence embeddings are already unit-norm, this is equivalent to mean-pool + renormalize. - process_attachment uses chunk_and_embed instead of chunk_text_semantic + a second embed_texts(chunks) call. - chunk_text_semantic stays as a thin text-only wrapper (kept for other call sites and test compat). - Fallback path: when the single sentence-embed pass fails, fall back to _chunk_fixed and re-embed the fixed chunks once (same cost shape as the old double-pass, acceptable for the error path only). Also set TOKENIZERS_PARALLELISM=true in the backend compose env so the HF tokenizer can fork during encode's tokenize step (tiny speedup, zero risk — we don't fork post-import). Expected staging wall-time: extract ~17s → ~9s, dominated by the now-single sentence embed pass. Will verify on staging. Tests - Updated process_attachment tests to mock chunk_and_embed. - Updated chunk_text_semantic fallback test (fallback path now re-embeds). - Added TestChunkAndEmbed covering: aligned (chunks, embeddings) lengths, unit-norm of pooled chunk vectors, re-embed fallback on failure, empty-input returns ([], []). 250 pytest passed (was 246 + 4 new). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 4912ebc commit 61750ac

3 files changed

Lines changed: 200 additions & 60 deletions

File tree

backend/services/attachment_processor.py

Lines changed: 114 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import base64
1414
import io
1515
import logging
16+
import math
1617
import re
1718
from typing import Optional
1819

@@ -217,67 +218,131 @@ def _split_sentences(text: str) -> list[str]:
217218
return sentences
218219

219220

221+
def _group_sentences_into_chunks(
222+
sentences: list[str],
223+
sent_embs: list[list[float]],
224+
) -> tuple[list[str], list[list[int]]]:
225+
"""
226+
Walk sentences with precomputed embeddings, cutting on semantic jumps or size overflow.
227+
Returns (chunk_texts, chunk_sentence_indices).
228+
"""
229+
chunks: list[str] = []
230+
chunk_idx: list[list[int]] = []
231+
current_sents: list[str] = [sentences[0]]
232+
current_idx: list[int] = [0]
233+
current_len: int = len(sentences[0])
234+
235+
for i in range(1, len(sentences)):
236+
sent = sentences[i]
237+
sent_len = len(sent)
238+
239+
# Dot product equals cosine similarity because vectors are L2-normalized
240+
sim: float = sum(a * b for a, b in zip(sent_embs[i - 1], sent_embs[i]))
241+
242+
# Break on semantic jump or size overflow (only if current chunk meets MIN_CHUNK_CHARS)
243+
should_break = (
244+
sim < SEMANTIC_THRESHOLD or current_len + sent_len > MAX_CHUNK_CHARS
245+
) and current_len >= MIN_CHUNK_CHARS
246+
247+
if should_break:
248+
chunks.append("".join(current_sents))
249+
chunk_idx.append(current_idx)
250+
current_sents = [sent]
251+
current_idx = [i]
252+
current_len = sent_len
253+
else:
254+
current_sents.append(sent)
255+
current_idx.append(i)
256+
current_len += sent_len
257+
258+
if current_sents:
259+
tail = "".join(current_sents)
260+
# Merge a too-short tail chunk into the previous one to avoid tiny orphan chunks
261+
if chunks and len(tail) < MIN_CHUNK_CHARS:
262+
chunks[-1] += tail
263+
chunk_idx[-1].extend(current_idx)
264+
else:
265+
chunks.append(tail)
266+
chunk_idx.append(current_idx)
267+
268+
return chunks, chunk_idx
269+
270+
271+
def _pool_chunk_embedding(sent_embs: list[list[float]], idxs: list[int]) -> list[float]:
272+
"""
273+
Derive a chunk embedding by L2-normalizing the sum of its sentence embeddings.
274+
Equivalent to mean-pool + renormalize since bge-m3 sentence vectors are unit-length.
275+
"""
276+
if not idxs:
277+
raise ValueError("empty sentence index list")
278+
dim = len(sent_embs[idxs[0]])
279+
acc = [0.0] * dim
280+
for i in idxs:
281+
v = sent_embs[i]
282+
for j in range(dim):
283+
acc[j] += v[j]
284+
norm = math.sqrt(sum(x * x for x in acc)) or 1.0
285+
return [x / norm for x in acc]
286+
287+
220288
async def chunk_text_semantic(text: str) -> list[str]:
221289
"""
222-
Semantic chunking based on embedding cosine similarity.
290+
Semantic chunking based on embedding cosine similarity; returns text chunks only.
291+
Thin wrapper around chunk_and_embed for call sites that don't need the embeddings.
292+
"""
293+
chunks, _ = await chunk_and_embed(text)
294+
return chunks
223295

224-
Flow:
225-
1. Split by sentence boundaries -> sentences
226-
2. Batch-embed all sentences once (bge-m3 is normalized; dot product = cosine similarity)
227-
3. When adjacent-sentence similarity < SEMANTIC_THRESHOLD, treat as a semantic break and cut
228-
4. Merge sentences into chunks while keeping length under MAX_CHUNK_CHARS
229296

230-
Falls back to fixed-size chunking on any error.
297+
async def chunk_and_embed(text: str) -> tuple[list[str], list[list[float]]]:
298+
"""
299+
Semantic chunking + per-chunk embeddings in a single embed pass.
300+
301+
Flow:
302+
1. Split into sentences.
303+
2. Batch-embed every sentence once (bge-m3 is L2-normalized; dot product = cosine sim).
304+
3. Group adjacent sentences into chunks, breaking on semantic jumps or size overflow.
305+
4. Derive each chunk's embedding by summing + L2-renormalizing the sentence vectors
306+
that compose it. Saves a full second embed pass that process_attachment used to
307+
do on the joined chunk text — embedding is ~half the extract wall time on CPU.
308+
309+
Returns (chunks, chunk_embeddings) with aligned lengths.
310+
Falls back to fixed-size chunking + a re-embed on any error.
231311
"""
232312
from services.embedding_service import embed_texts
233313

234314
sentences = _split_sentences(text)
235315
if not sentences:
236-
return []
316+
return [], []
237317
if len(sentences) == 1:
238-
return [sentences[0]] if len(sentences[0]) >= MIN_CHUNK_CHARS else []
318+
sent = sentences[0]
319+
if len(sent) < MIN_CHUNK_CHARS:
320+
return [], []
321+
embs = await embed_texts([sent])
322+
return [sent], list(embs)
239323

240324
try:
241-
# Single batch embed call — avoids repeated executor dispatch overhead
242-
embeddings = await embed_texts(sentences)
243-
244-
chunks: list[str] = []
245-
current: list[str] = [sentences[0]]
246-
current_len: int = len(sentences[0])
247-
248-
for i in range(1, len(sentences)):
249-
sent = sentences[i]
250-
sent_len = len(sent)
251-
252-
# Dot product equals cosine similarity because vectors are L2-normalized
253-
sim: float = sum(a * b for a, b in zip(embeddings[i - 1], embeddings[i]))
254-
255-
# Break on semantic jump or size overflow (only if current chunk meets MIN_CHUNK_CHARS)
256-
should_break = (
257-
sim < SEMANTIC_THRESHOLD or current_len + sent_len > MAX_CHUNK_CHARS
258-
) and current_len >= MIN_CHUNK_CHARS
259-
260-
if should_break:
261-
chunks.append("".join(current))
262-
current = [sent]
263-
current_len = sent_len
264-
else:
265-
current.append(sent)
266-
current_len += sent_len
267-
268-
if current:
269-
tail = "".join(current)
270-
# Merge a too-short tail chunk into the previous one to avoid tiny orphan chunks
271-
if chunks and len(tail) < MIN_CHUNK_CHARS:
272-
chunks[-1] += tail
273-
else:
274-
chunks.append(tail)
275-
276-
return [c for c in chunks if c.strip()]
325+
sent_embs = await embed_texts(sentences)
326+
327+
chunks, chunk_idx = _group_sentences_into_chunks(sentences, sent_embs)
328+
329+
# Drop empty-after-strip chunks and their embeddings in lockstep
330+
kept: list[tuple[str, list[float]]] = []
331+
for text_chunk, idxs in zip(chunks, chunk_idx):
332+
if text_chunk.strip() and idxs:
333+
kept.append((text_chunk, _pool_chunk_embedding(sent_embs, idxs)))
334+
if not kept:
335+
return [], []
336+
kept_chunks, kept_embs = zip(*kept)
337+
return list(kept_chunks), list(kept_embs)
277338

278339
except Exception as e:
279-
logger.warning("语义切分失败,fallback 到固定切分 / Semantic chunking failed, falling back: %s", e)
280-
return _chunk_fixed(text)
340+
logger.warning("Semantic chunk+embed failed, falling back to fixed chunking + re-embed: %s", e)
341+
fallback_chunks = _chunk_fixed(text)
342+
if not fallback_chunks:
343+
return [], []
344+
fallback_embs = await embed_texts(fallback_chunks)
345+
return fallback_chunks, list(fallback_embs)
281346

282347

283348
def _chunk_fixed(text: str) -> list[str]:
@@ -372,14 +437,11 @@ async def process_attachment(
372437
logger.info("附件内联模式 / Attachment inline: %s (%d chars)", filename, len(text))
373438
return {"chunk_count": 0, "inline_text": text.strip()}
374439

375-
# Long text: semantic chunk embed → store
376-
chunks = await chunk_text_semantic(text)
440+
# Long text: semantic chunk + embed in a single pass → store
441+
chunks, embeddings = await chunk_and_embed(text)
377442
if not chunks:
378443
return {"chunk_count": 0, "inline_text": None}
379444

380-
from services.embedding_service import embed_texts
381-
embeddings = await embed_texts(chunks)
382-
383445
await _store_chunks(session_id, filename, chunks, embeddings)
384446
logger.info("附件 RAG 模式 / Attachment RAG: %s (%d chunks, session=%s)",
385447
filename, len(chunks), session_id)

backend/tests/test_attachment_processor.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,92 @@ async def test_max_chunk_size_triggers_break(self):
118118

119119
@pytest.mark.asyncio
120120
async def test_embed_failure_falls_back_to_fixed(self):
121-
"""Falls back to fixed chunking on embed failure."""
122-
from services.attachment_processor import chunk_text_semantic, CHUNK_SIZE
121+
"""Falls back to fixed chunking on semantic-pass embed failure."""
122+
from services.attachment_processor import chunk_text_semantic
123123

124-
# Text needs sentence punctuation so _split_sentences yields >1 sentence and embed_texts is called
125124
single_sent = "这是一个用于测试的句子,内容并不重要只是用来触发分块逻辑。"
126125
text = single_sent * 20 # Repeated 20 times; each ends with a period, totaling 20 sentences
127-
with patch("services.embedding_service.embed_texts", new=AsyncMock(side_effect=Exception("model error"))):
126+
127+
# First call (sentence embed for semantic-boundary detection) blows up;
128+
# fallback path does a second embed on the fixed chunks, which succeeds.
129+
calls: list[int] = []
130+
async def fake_embed(texts):
131+
calls.append(len(texts))
132+
if len(calls) == 1:
133+
raise Exception("model error")
134+
return [[0.1] * 1024 for _ in texts]
135+
136+
with patch("services.embedding_service.embed_texts", new=fake_embed):
128137
result = await chunk_text_semantic(text)
129138

130139
assert len(result) > 1 # Fallback fixed-split still produces multiple chunks
131140

132141

142+
# ── chunk_and_embed ───────────────────────────────────────────────────
143+
144+
class TestChunkAndEmbed:
145+
@pytest.mark.asyncio
146+
async def test_returns_aligned_lists(self):
147+
"""chunks and embeddings are same length and per-chunk."""
148+
from services.attachment_processor import chunk_and_embed
149+
text = (
150+
"First sentence about dogs. Second about dogs too. "
151+
"Totally unrelated topic: quantum computing basics. "
152+
"More on quantum. Yet more on quantum theory."
153+
) * 10
154+
# Two clusters of sentences — embeddings differ per cluster to trigger a break
155+
dog_vec = [1.0, 0.0] + [0.0] * 1022
156+
qc_vec = [0.0, 1.0] + [0.0] * 1022
157+
sent_embs = [dog_vec, dog_vec, qc_vec, qc_vec, qc_vec] * 10
158+
with patch("services.embedding_service.embed_texts",
159+
new=AsyncMock(return_value=sent_embs)):
160+
chunks, embs = await chunk_and_embed(text)
161+
assert len(chunks) == len(embs) > 0
162+
163+
@pytest.mark.asyncio
164+
async def test_chunk_embeddings_are_unit_norm(self):
165+
"""Pooled chunk embeddings are L2-normalized (unit length)."""
166+
from services.attachment_processor import chunk_and_embed
167+
# Two similar sentences → one chunk covering both. Vectors unit-norm but not identical.
168+
text = "Alpha statement. Beta statement."
169+
v1 = [1.0, 0.0] + [0.0] * 1022
170+
v2 = [0.8, 0.6] + [0.0] * 1022 # unit length
171+
with patch("services.embedding_service.embed_texts",
172+
new=AsyncMock(return_value=[v1, v2])):
173+
chunks, embs = await chunk_and_embed(text)
174+
assert len(embs) == 1
175+
norm_sq = sum(x * x for x in embs[0])
176+
assert abs(norm_sq - 1.0) < 1e-6
177+
178+
@pytest.mark.asyncio
179+
async def test_fallback_reembeds_fixed_chunks(self):
180+
"""On semantic-embed failure, falls back to fixed chunks + a re-embed call."""
181+
from services.attachment_processor import chunk_and_embed
182+
text = "Sentence one. Sentence two. Sentence three." * 200
183+
184+
call_log: list[str] = []
185+
186+
async def fake_embed(texts):
187+
# First call (sentence embed during semantic path) fails;
188+
# second call (fallback chunk embed) succeeds and returns one vec per chunk.
189+
call_log.append("call")
190+
if len(call_log) == 1:
191+
raise RuntimeError("embed model exploded")
192+
return [[0.1] * 1024 for _ in texts]
193+
194+
with patch("services.embedding_service.embed_texts", new=fake_embed):
195+
chunks, embs = await chunk_and_embed(text)
196+
assert len(chunks) == len(embs) > 0
197+
assert len(call_log) == 2 # Sentence pass failed, fallback pass succeeded
198+
199+
@pytest.mark.asyncio
200+
async def test_empty_text_returns_empty(self):
201+
"""Empty / whitespace-only text returns ([], [])."""
202+
from services.attachment_processor import chunk_and_embed
203+
assert await chunk_and_embed("") == ([], [])
204+
assert await chunk_and_embed(" \n ") == ([], [])
205+
206+
133207
# ── _chunk_fixed ──────────────────────────────────────────────────────
134208

135209
class TestChunkFixed:
@@ -355,9 +429,9 @@ async def test_long_text_goes_to_rag(self):
355429
fake_embeddings = [[0.1] * 1024] * 3
356430

357431
with patch("services.attachment_processor.extract_text", new=AsyncMock(return_value=long_text)), \
358-
patch("services.attachment_processor.chunk_text_semantic", new=AsyncMock(return_value=fake_chunks)), \
359-
patch("services.attachment_processor._store_chunks", new=AsyncMock()), \
360-
patch("services.embedding_service.embed_texts", new=AsyncMock(return_value=fake_embeddings)):
432+
patch("services.attachment_processor.chunk_and_embed",
433+
new=AsyncMock(return_value=(fake_chunks, fake_embeddings))), \
434+
patch("services.attachment_processor._store_chunks", new=AsyncMock()):
361435
result = await process_attachment("session-1", "doc.pdf", b"pdf bytes")
362436

363437
assert result["chunk_count"] == len(fake_chunks)
@@ -371,7 +445,8 @@ async def test_no_chunks_returns_failure(self):
371445
long_text = "X" * (INLINE_THRESHOLD + 100)
372446

373447
with patch("services.attachment_processor.extract_text", new=AsyncMock(return_value=long_text)), \
374-
patch("services.attachment_processor.chunk_text_semantic", new=AsyncMock(return_value=[])):
448+
patch("services.attachment_processor.chunk_and_embed",
449+
new=AsyncMock(return_value=([], []))):
375450
result = await process_attachment("session-1", "file.txt", b"content")
376451

377452
assert result == {"chunk_count": 0, "inline_text": None}

docker-compose.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ services:
77
env_file: ./backend/.env
88
environment:
99
- SEARXNG_URL=${SEARXNG_INTERNAL}
10+
# Let HuggingFace tokenizers fork for parallel encode (small speedup during
11+
# bge-m3 encode's tokenize step; safe since we don't fork post-import).
12+
- TOKENIZERS_PARALLELISM=true
1013
volumes:
1114
- ./logs:/app/logs # 日志持久化到宿主机 / persist logs to host
1215
- hf-cache:/root/.cache/huggingface # 共享 HuggingFace 模型缓存(prod/staging 复用同一份 bge-m3)

0 commit comments

Comments
 (0)