Initial commit: import from sinmb79/Gov-chat-bot
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,91 @@
|
||||
from typing import Optional
|
||||
|
||||
from app.providers.base import SearchResult
|
||||
from app.providers.vectordb import VectorDBProvider
|
||||
|
||||
|
||||
class ChromaVectorDBProvider(VectorDBProvider):
|
||||
"""
|
||||
ChromaDB 기반 벡터 검색.
|
||||
컬렉션명 = tenant_{tenant_id} (테넌트 격리)
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "chromadb", port: int = 8000):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
if self._client is None:
|
||||
import chromadb
|
||||
self._client = chromadb.HttpClient(host=self.host, port=self.port)
|
||||
return self._client
|
||||
|
||||
def _collection_name(self, tenant_id: str) -> str:
|
||||
return f"tenant_{tenant_id}"
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
tenant_id: str,
|
||||
doc_id: str,
|
||||
chunks: list[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict],
|
||||
) -> int:
|
||||
client = self._get_client()
|
||||
collection = client.get_or_create_collection(self._collection_name(tenant_id))
|
||||
ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
|
||||
collection.upsert(ids=ids, documents=chunks, embeddings=embeddings, metadatas=metadatas)
|
||||
return len(chunks)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query_vec: list[float],
|
||||
top_k: int = 3,
|
||||
threshold: float = 0.70,
|
||||
) -> list[SearchResult]:
|
||||
client = self._get_client()
|
||||
collection_name = self._collection_name(tenant_id)
|
||||
try:
|
||||
collection = client.get_collection(collection_name)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
results = collection.query(
|
||||
query_embeddings=[query_vec],
|
||||
n_results=min(top_k, collection.count()),
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
search_results = []
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return []
|
||||
|
||||
for i, doc_id in enumerate(results["ids"][0]):
|
||||
# Chroma distances: 1 - cosine_similarity (낮을수록 유사)
|
||||
distance = results["distances"][0][i]
|
||||
score = 1.0 - distance # cosine similarity로 변환
|
||||
if score >= threshold:
|
||||
search_results.append(
|
||||
SearchResult(
|
||||
text=results["documents"][0][i],
|
||||
doc_id=doc_id,
|
||||
score=score,
|
||||
metadata=results["metadatas"][0][i] or {},
|
||||
)
|
||||
)
|
||||
return search_results
|
||||
|
||||
async def delete(self, tenant_id: str, doc_id: str) -> None:
|
||||
client = self._get_client()
|
||||
collection_name = self._collection_name(tenant_id)
|
||||
try:
|
||||
collection = client.get_collection(collection_name)
|
||||
# doc_id로 시작하는 모든 청크 삭제
|
||||
all_ids = collection.get()["ids"]
|
||||
ids_to_delete = [id_ for id_ in all_ids if id_.startswith(f"{doc_id}_")]
|
||||
if ids_to_delete:
|
||||
collection.delete(ids=ids_to_delete)
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user