Initial commit: import from sinmb79/Gov-chat-bot
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
from app.providers.llm import LLMProvider, NullLLMProvider
|
||||
from app.providers.embedding import EmbeddingProvider, NotImplementedEmbeddingProvider
|
||||
from app.providers.vectordb import VectorDBProvider
|
||||
|
||||
# 워밍업 상태 전역 플래그
|
||||
_embedding_warmed_up = False
|
||||
|
||||
|
||||
def get_llm_provider(config: dict) -> LLMProvider:
|
||||
provider = config.get("LLM_PROVIDER", "none")
|
||||
if provider == "none":
|
||||
return NullLLMProvider()
|
||||
if provider == "anthropic":
|
||||
from app.providers.llm_anthropic import AnthropicLLMProvider
|
||||
return AnthropicLLMProvider(
|
||||
api_key=config.get("ANTHROPIC_API_KEY", ""),
|
||||
model=config.get("LLM_MODEL", "claude-haiku-4-5-20251001"),
|
||||
)
|
||||
if provider == "openai":
|
||||
from app.providers.llm_anthropic import OpenAILLMProvider
|
||||
return OpenAILLMProvider(
|
||||
api_key=config.get("OPENAI_API_KEY", ""),
|
||||
model=config.get("LLM_MODEL", "gpt-4o-mini"),
|
||||
)
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
|
||||
|
||||
def get_embedding_provider(config: dict) -> EmbeddingProvider:
|
||||
provider = config.get("EMBEDDING_PROVIDER", "none")
|
||||
if provider == "local":
|
||||
from app.providers.local_embedding import LocalEmbeddingProvider
|
||||
model = config.get("EMBEDDING_MODEL", "jhgan/ko-sroberta-multitask")
|
||||
return LocalEmbeddingProvider(model_name=model)
|
||||
return NotImplementedEmbeddingProvider()
|
||||
|
||||
|
||||
def get_vectordb_provider(config: dict) -> VectorDBProvider:
|
||||
from app.providers.chroma import ChromaVectorDBProvider
|
||||
return ChromaVectorDBProvider(
|
||||
host=config.get("CHROMA_HOST", "chromadb"),
|
||||
port=int(config.get("CHROMA_PORT", 8000)),
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
text: str
|
||||
doc_id: str
|
||||
score: float
|
||||
metadata: dict = field(default_factory=dict)
|
||||
@@ -0,0 +1,91 @@
|
||||
from typing import Optional
|
||||
|
||||
from app.providers.base import SearchResult
|
||||
from app.providers.vectordb import VectorDBProvider
|
||||
|
||||
|
||||
class ChromaVectorDBProvider(VectorDBProvider):
|
||||
"""
|
||||
ChromaDB 기반 벡터 검색.
|
||||
컬렉션명 = tenant_{tenant_id} (테넌트 격리)
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "chromadb", port: int = 8000):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._client = None
|
||||
|
||||
def _get_client(self):
|
||||
if self._client is None:
|
||||
import chromadb
|
||||
self._client = chromadb.HttpClient(host=self.host, port=self.port)
|
||||
return self._client
|
||||
|
||||
def _collection_name(self, tenant_id: str) -> str:
|
||||
return f"tenant_{tenant_id}"
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
tenant_id: str,
|
||||
doc_id: str,
|
||||
chunks: list[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict],
|
||||
) -> int:
|
||||
client = self._get_client()
|
||||
collection = client.get_or_create_collection(self._collection_name(tenant_id))
|
||||
ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
|
||||
collection.upsert(ids=ids, documents=chunks, embeddings=embeddings, metadatas=metadatas)
|
||||
return len(chunks)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query_vec: list[float],
|
||||
top_k: int = 3,
|
||||
threshold: float = 0.70,
|
||||
) -> list[SearchResult]:
|
||||
client = self._get_client()
|
||||
collection_name = self._collection_name(tenant_id)
|
||||
try:
|
||||
collection = client.get_collection(collection_name)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
results = collection.query(
|
||||
query_embeddings=[query_vec],
|
||||
n_results=min(top_k, collection.count()),
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
|
||||
search_results = []
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return []
|
||||
|
||||
for i, doc_id in enumerate(results["ids"][0]):
|
||||
# Chroma distances: 1 - cosine_similarity (낮을수록 유사)
|
||||
distance = results["distances"][0][i]
|
||||
score = 1.0 - distance # cosine similarity로 변환
|
||||
if score >= threshold:
|
||||
search_results.append(
|
||||
SearchResult(
|
||||
text=results["documents"][0][i],
|
||||
doc_id=doc_id,
|
||||
score=score,
|
||||
metadata=results["metadatas"][0][i] or {},
|
||||
)
|
||||
)
|
||||
return search_results
|
||||
|
||||
async def delete(self, tenant_id: str, doc_id: str) -> None:
|
||||
client = self._get_client()
|
||||
collection_name = self._collection_name(tenant_id)
|
||||
try:
|
||||
collection = client.get_collection(collection_name)
|
||||
# doc_id로 시작하는 모든 청크 삭제
|
||||
all_ids = collection.get()["ids"]
|
||||
ids_to_delete = [id_ for id_ in all_ids if id_.startswith(f"{doc_id}_")]
|
||||
if ids_to_delete:
|
||||
collection.delete(ids=ids_to_delete)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,30 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
@abstractmethod
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def warmup(self) -> None:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimension(self) -> int:
|
||||
...
|
||||
|
||||
|
||||
class NotImplementedEmbeddingProvider(EmbeddingProvider):
|
||||
"""Phase 1에서 LocalEmbeddingProvider로 교체 예정"""
|
||||
|
||||
async def embed(self, texts: list[str]) -> list:
|
||||
raise NotImplementedError("Embedding provider not configured. Set EMBEDDING_PROVIDER.")
|
||||
|
||||
async def warmup(self) -> None:
|
||||
pass # 예외 없이 통과
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
return 768
|
||||
@@ -0,0 +1,28 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_message: str,
|
||||
context_chunks: list,
|
||||
max_tokens: int = 512,
|
||||
) -> Optional[str]:
|
||||
"""실패 시 None 반환. 예외 raise 금지."""
|
||||
...
|
||||
|
||||
|
||||
class NullLLMProvider(LLMProvider):
|
||||
"""LLM_PROVIDER=none 기본값"""
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
system_prompt: str = "",
|
||||
user_message: str = "",
|
||||
context_chunks: list = None,
|
||||
max_tokens: int = 512,
|
||||
) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Anthropic Claude LLM Provider.
|
||||
근거(context_chunks)가 있을 때만 호출.
|
||||
할루시네이션 방지: 근거 없으면 None 반환.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from app.providers.llm import LLMProvider
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = """당신은 {tenant_name}AI 안내 도우미입니다.
|
||||
반드시 아래 근거 문서에 있는 내용만을 바탕으로 답변하세요.
|
||||
근거 없는 내용은 절대 추측하거나 생성하지 마세요.
|
||||
|
||||
근거 문서:
|
||||
{context}
|
||||
|
||||
규칙:
|
||||
1. 근거 문서에 없는 내용은 "담당자에게 문의해 주세요"로 안내
|
||||
2. 답변은 간결하고 명확하게 (3문장 이내)
|
||||
3. 전문 용어는 쉬운 말로 바꿔 설명
|
||||
"""
|
||||
|
||||
|
||||
class AnthropicLLMProvider(LLMProvider):
|
||||
def __init__(self, api_key: str, model: str = "claude-haiku-4-5-20251001"):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_message: str,
|
||||
context_chunks: list,
|
||||
max_tokens: int = 512,
|
||||
) -> Optional[str]:
|
||||
"""근거 없으면 None 반환. 예외 발생 시 None 반환."""
|
||||
if not context_chunks:
|
||||
return None # 할루시네이션 방지 — 근거 없으면 LLM 미호출
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
client = anthropic.AsyncAnthropic(api_key=self.api_key)
|
||||
message = await client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
system=system_prompt,
|
||||
messages=[{"role": "user", "content": user_message}],
|
||||
)
|
||||
return message.content[0].text if message.content else None
|
||||
except Exception:
|
||||
return None # 실패 시 None — 호출자가 Tier D로 폴백
|
||||
|
||||
|
||||
class OpenAILLMProvider(LLMProvider):
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_message: str,
|
||||
context_chunks: list,
|
||||
max_tokens: int = 512,
|
||||
) -> Optional[str]:
|
||||
if not context_chunks:
|
||||
return None
|
||||
|
||||
try:
|
||||
import openai
|
||||
client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
response = await client.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message},
|
||||
],
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception:
|
||||
return None
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.providers.embedding import EmbeddingProvider
|
||||
|
||||
import app.providers as providers_module
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
jhgan/ko-sroberta-multitask 기반 로컬 임베딩.
|
||||
sentence-transformers 패키지 필요.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "jhgan/ko-sroberta-multitask"):
|
||||
self.model_name = model_name
|
||||
self._model = None
|
||||
|
||||
async def warmup(self) -> None:
|
||||
"""모델 로드. 최초 1회 실행."""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
providers_module._embedding_warmed_up = True
|
||||
|
||||
async def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
if self._model is None:
|
||||
await self.warmup()
|
||||
embeddings = self._model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings.tolist()
|
||||
|
||||
@property
|
||||
def dimension(self) -> int:
|
||||
return 768
|
||||
@@ -0,0 +1,30 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.providers.base import SearchResult
|
||||
|
||||
|
||||
class VectorDBProvider(ABC):
|
||||
@abstractmethod
|
||||
async def upsert(
|
||||
self,
|
||||
tenant_id: str,
|
||||
doc_id: str,
|
||||
chunks: list[str],
|
||||
embeddings: list[list[float]],
|
||||
metadatas: list[dict],
|
||||
) -> int:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query_vec: list[float],
|
||||
top_k: int = 3,
|
||||
threshold: float = 0.70,
|
||||
) -> list[SearchResult]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, tenant_id: str, doc_id: str) -> None:
|
||||
...
|
||||
Reference in New Issue
Block a user