diff --git a/backend/app/api/daily_reports.py b/backend/app/api/daily_reports.py index 0e1ff80..ed43d4f 100644 --- a/backend/app/api/daily_reports.py +++ b/backend/app/api/daily_reports.py @@ -1,6 +1,7 @@ import uuid from datetime import date from fastapi import APIRouter, HTTPException, status +from fastapi.responses import Response from sqlalchemy import select from app.deps import CurrentUser, DB from app.models.daily_report import DailyReport, InputSource @@ -9,6 +10,7 @@ from app.schemas.daily_report import ( DailyReportCreate, DailyReportUpdate, DailyReportGenerateRequest, DailyReportResponse ) from app.services.daily_report_gen import generate_work_content +from app.services.pdf_service import generate_daily_report_pdf router = APIRouter(prefix="/projects/{project_id}/daily-reports", tags=["작업일보"]) @@ -97,6 +99,20 @@ async def update_report(project_id: uuid.UUID, report_id: uuid.UUID, data: Daily return report +@router.get("/{report_id}/pdf") +async def download_report_pdf(project_id: uuid.UUID, report_id: uuid.UUID, db: DB, current_user: CurrentUser): + """작업일보 PDF 다운로드""" + r = await db.execute(select(DailyReport).where(DailyReport.id == report_id, DailyReport.project_id == project_id)) + report = r.scalar_one_or_none() + if not report: + raise HTTPException(status_code=404, detail="일보를 찾을 수 없습니다") + project = await _get_project_or_404(project_id, db) + pdf_bytes = generate_daily_report_pdf(report, project) + filename = f"daily_report_{report.report_date}.pdf" + return Response(content=pdf_bytes, media_type="application/pdf", + headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename}"}) + + @router.delete("/{report_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_report(project_id: uuid.UUID, report_id: uuid.UUID, db: DB, current_user: CurrentUser): result = await db.execute(select(DailyReport).where(DailyReport.id == report_id, DailyReport.project_id == project_id)) diff --git a/backend/app/api/inspections.py b/backend/app/api/inspections.py index 6373336..43c541a 100644 --- a/backend/app/api/inspections.py +++ b/backend/app/api/inspections.py @@ -1,11 +1,13 @@ import uuid from fastapi import APIRouter, HTTPException, status +from fastapi.responses import Response from sqlalchemy import select from app.deps import CurrentUser, DB from app.models.inspection import InspectionRequest from app.models.project import Project, WBSItem from app.schemas.inspection import InspectionCreate, InspectionUpdate, InspectionGenerateRequest, InspectionResponse from app.services.inspection_gen import generate_checklist +from app.services.pdf_service import generate_inspection_pdf router = APIRouter(prefix="/projects/{project_id}/inspections", tags=["검측요청서"]) @@ -96,6 +98,20 @@ async def update_inspection(project_id: uuid.UUID, inspection_id: uuid.UUID, dat return insp +@router.get("/{inspection_id}/pdf") +async def download_inspection_pdf(project_id: uuid.UUID, inspection_id: uuid.UUID, db: DB, current_user: CurrentUser): + """검측요청서 PDF 다운로드""" + r = await db.execute(select(InspectionRequest).where(InspectionRequest.id == inspection_id, InspectionRequest.project_id == project_id)) + insp = r.scalar_one_or_none() + if not insp: + raise HTTPException(status_code=404, detail="검측요청서를 찾을 수 없습니다") + project = await _get_project_or_404(project_id, db) + pdf_bytes = generate_inspection_pdf(insp, project) + filename = f"inspection_{insp.requested_date}_{insp.inspection_type}.pdf" + return Response(content=pdf_bytes, media_type="application/pdf", + headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename}"}) + + @router.delete("/{inspection_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_inspection(project_id: uuid.UUID, inspection_id: uuid.UUID, db: DB, current_user: CurrentUser): result = await db.execute(select(InspectionRequest).where(InspectionRequest.id == inspection_id, InspectionRequest.project_id == project_id)) diff --git a/backend/app/api/quality.py b/backend/app/api/quality.py new file mode 100644 index 0000000..bc49c60 --- /dev/null +++ b/backend/app/api/quality.py @@ -0,0 +1,137 @@ +import uuid +from fastapi import APIRouter, HTTPException, status +from sqlalchemy import select, func +from app.deps import CurrentUser, DB +from app.models.quality import QualityTest, QualityResult +from app.models.project import Project +from app.schemas.quality import QualityTestCreate, QualityTestUpdate, QualityTestResponse + +router = APIRouter(prefix="/projects/{project_id}/quality", tags=["품질시험"]) + + +async def _get_project_or_404(project_id: uuid.UUID, db: DB) -> Project: + result = await db.execute(select(Project).where(Project.id == project_id)) + p = result.scalar_one_or_none() + if not p: + raise HTTPException(status_code=404, detail="프로젝트를 찾을 수 없습니다") + return p + + +@router.get("", response_model=list[QualityTestResponse]) +async def list_quality_tests( + project_id: uuid.UUID, + db: DB, + current_user: CurrentUser, + test_type: str | None = None, + result: QualityResult | None = None, +): + query = select(QualityTest).where(QualityTest.project_id == project_id) + if test_type: + query = query.where(QualityTest.test_type == test_type) + if result: + query = query.where(QualityTest.result == result) + query = query.order_by(QualityTest.test_date.desc()) + rows = await db.execute(query) + return rows.scalars().all() + + +@router.post("", response_model=QualityTestResponse, status_code=status.HTTP_201_CREATED) +async def create_quality_test( + project_id: uuid.UUID, + data: QualityTestCreate, + db: DB, + current_user: CurrentUser, +): + await _get_project_or_404(project_id, db) + test = QualityTest(**data.model_dump(), project_id=project_id) + db.add(test) + await db.commit() + await db.refresh(test) + return test + + +@router.get("/summary") +async def quality_summary(project_id: uuid.UUID, db: DB, current_user: CurrentUser): + """프로젝트 품질시험 합격률 요약""" + total_q = await db.execute( + select(func.count()).where(QualityTest.project_id == project_id) + ) + total = total_q.scalar() or 0 + + pass_q = await db.execute( + select(func.count()).where( + QualityTest.project_id == project_id, + QualityTest.result == QualityResult.PASS, + ) + ) + passed = pass_q.scalar() or 0 + + return { + "total": total, + "passed": passed, + "failed": total - passed, + "pass_rate": round(passed / total * 100, 1) if total > 0 else None, + } + + +@router.get("/{test_id}", response_model=QualityTestResponse) +async def get_quality_test( + project_id: uuid.UUID, test_id: uuid.UUID, db: DB, current_user: CurrentUser +): + result = await db.execute( + select(QualityTest).where( + QualityTest.id == test_id, QualityTest.project_id == project_id + ) + ) + test = result.scalar_one_or_none() + if not test: + raise HTTPException(status_code=404, detail="품질시험 기록을 찾을 수 없습니다") + return test + + +@router.put("/{test_id}", response_model=QualityTestResponse) +async def update_quality_test( + project_id: uuid.UUID, + test_id: uuid.UUID, + data: QualityTestUpdate, + db: DB, + current_user: CurrentUser, +): + result = await db.execute( + select(QualityTest).where( + QualityTest.id == test_id, QualityTest.project_id == project_id + ) + ) + test = result.scalar_one_or_none() + if not test: + raise HTTPException(status_code=404, detail="품질시험 기록을 찾을 수 없습니다") + + update_data = data.model_dump(exclude_none=True) + + # 측정값/기준값 변경 시 합격 여부 재계산 + new_measured = update_data.get("measured_value", test.measured_value) + new_design = update_data.get("design_value", test.design_value) + if "result" not in update_data and new_design is not None: + update_data["result"] = QualityResult.PASS if new_measured >= new_design else QualityResult.FAIL + + for field, value in update_data.items(): + setattr(test, field, value) + await db.commit() + await db.refresh(test) + return test + + +@router.delete("/{test_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_quality_test( + project_id: uuid.UUID, test_id: uuid.UUID, db: DB, current_user: CurrentUser +): + result = await db.execute( + select(QualityTest).where( + QualityTest.id == test_id, QualityTest.project_id == project_id + ) + ) + test = result.scalar_one_or_none() + if not test: + raise HTTPException(status_code=404, detail="품질시험 기록을 찾을 수 없습니다") + await db.delete(test) + await db.commit() diff --git a/backend/app/api/reports.py b/backend/app/api/reports.py index 37784fe..548829f 100644 --- a/backend/app/api/reports.py +++ b/backend/app/api/reports.py @@ -1,6 +1,7 @@ import uuid from datetime import date from fastapi import APIRouter, HTTPException, status +from fastapi.responses import Response from sqlalchemy import select, func from app.deps import CurrentUser, DB from app.models.report import Report, ReportType @@ -9,6 +10,7 @@ from app.models.weather import WeatherAlert from app.models.project import Project from app.schemas.report import ReportGenerateRequest, ReportResponse from app.services.report_gen import generate_weekly_report, generate_monthly_report +from app.services.pdf_service import generate_report_pdf router = APIRouter(prefix="/projects/{project_id}/reports", tags=["공정보고서"]) @@ -134,6 +136,20 @@ async def get_report(project_id: uuid.UUID, report_id: uuid.UUID, db: DB, curren return report +@router.get("/{report_id}/pdf") +async def download_report_pdf(project_id: uuid.UUID, report_id: uuid.UUID, db: DB, current_user: CurrentUser): + """공정보고서 PDF 다운로드""" + r = await db.execute(select(Report).where(Report.id == report_id, Report.project_id == project_id)) + report = r.scalar_one_or_none() + if not report: + raise HTTPException(status_code=404, detail="보고서를 찾을 수 없습니다") + project = await _get_project_or_404(project_id, db) + pdf_bytes = generate_report_pdf(report, project) + filename = f"report_{report.report_type.value}_{report.period_start}.pdf" + return Response(content=pdf_bytes, media_type="application/pdf", + headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename}"}) + + @router.delete("/{report_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_report(project_id: uuid.UUID, report_id: uuid.UUID, db: DB, current_user: CurrentUser): result = await db.execute(select(Report).where(Report.id == report_id, Report.project_id == project_id)) diff --git a/backend/app/main.py b/backend/app/main.py index f43ff16..f145c55 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -2,14 +2,17 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from app.config import settings -from app.api import auth, projects, tasks, daily_reports, reports, inspections, weather, rag, kakao, permits, settings as settings_router +from app.api import auth, projects, tasks, daily_reports, reports, inspections, weather, rag, kakao, permits, quality, settings as settings_router +from app.services.scheduler import start_scheduler, stop_scheduler @asynccontextmanager async def lifespan(app: FastAPI): - # Startup: seed default work types, check DB connection + # Startup + start_scheduler() yield - # Shutdown: cleanup resources + # Shutdown + stop_scheduler() def create_app() -> FastAPI: @@ -41,6 +44,7 @@ def create_app() -> FastAPI: app.include_router(rag.router, prefix=api_prefix) app.include_router(kakao.router, prefix=api_prefix) app.include_router(permits.router, prefix=api_prefix) + app.include_router(quality.router, prefix=api_prefix) app.include_router(settings_router.router, prefix=api_prefix) @app.get("/health") diff --git a/backend/app/schemas/quality.py b/backend/app/schemas/quality.py new file mode 100644 index 0000000..0ca171a --- /dev/null +++ b/backend/app/schemas/quality.py @@ -0,0 +1,57 @@ +import uuid +from datetime import date +from pydantic import BaseModel, model_validator +from app.models.quality import QualityResult + + +class QualityTestCreate(BaseModel): + wbs_item_id: uuid.UUID | None = None + test_type: str # compression_strength, slump, compaction, etc. + test_date: date + location_detail: str | None = None + design_value: float | None = None + measured_value: float + unit: str + result: QualityResult | None = None # auto-calculated if design_value provided + lab_name: str | None = None + report_number: str | None = None + notes: str | None = None + + @model_validator(mode="after") + def auto_result(self) -> "QualityTestCreate": + if self.result is None: + if self.design_value is not None: + self.result = QualityResult.PASS if self.measured_value >= self.design_value else QualityResult.FAIL + else: + self.result = QualityResult.PASS + return self + + +class QualityTestUpdate(BaseModel): + test_date: date | None = None + location_detail: str | None = None + design_value: float | None = None + measured_value: float | None = None + unit: str | None = None + result: QualityResult | None = None + lab_name: str | None = None + report_number: str | None = None + notes: str | None = None + + +class QualityTestResponse(BaseModel): + id: uuid.UUID + project_id: uuid.UUID + wbs_item_id: uuid.UUID | None + test_type: str + test_date: date + location_detail: str | None + design_value: float | None + measured_value: float + unit: str + result: QualityResult + lab_name: str | None + report_number: str | None + notes: str | None + + model_config = {"from_attributes": True} diff --git a/backend/app/services/pdf_service.py b/backend/app/services/pdf_service.py new file mode 100644 index 0000000..df0246b --- /dev/null +++ b/backend/app/services/pdf_service.py @@ -0,0 +1,78 @@ +""" +PDF 생성 서비스 — WeasyPrint + Jinja2 +""" +from datetime import datetime +from pathlib import Path +from jinja2 import Environment, FileSystemLoader + +TEMPLATES_DIR = Path(__file__).parent.parent / "templates" +_jinja_env = Environment(loader=FileSystemLoader(str(TEMPLATES_DIR)), autoescape=True) + +INSPECTION_TYPE_LABELS = { + "rebar": "철근 배근 검측", + "formwork": "거푸집 검측", + "concrete": "콘크리트 타설 검측", + "pipe_burial": "관 매설 검측", + "compaction": "다짐 검측", + "waterproofing": "방수 검측", + "finishing": "마감 검측", +} + +REPORT_TYPE_LABELS = { + "weekly": "주간", + "monthly": "월간", +} + +REPORT_STATUS_LABELS = { + "draft": "초안", + "reviewed": "검토완료", + "submitted": "제출완료", +} + + +def _render_html(template_name: str, **context) -> str: + template = _jinja_env.get_template(template_name) + return template.render(now=datetime.now().strftime("%Y-%m-%d %H:%M"), **context) + + +def _html_to_pdf(html: str) -> bytes: + try: + from weasyprint import HTML + return HTML(string=html).write_pdf() + except ImportError: + raise RuntimeError( + "WeasyPrint가 설치되지 않았습니다. `pip install weasyprint` 실행 후 재시도하세요." + ) + + +def generate_daily_report_pdf(report, project) -> bytes: + html = _render_html("daily_report.html", report=report, project=project) + return _html_to_pdf(html) + + +def generate_inspection_pdf(inspection, project) -> bytes: + type_label = INSPECTION_TYPE_LABELS.get(inspection.inspection_type, inspection.inspection_type) + html = _render_html( + "inspection.html", + inspection=inspection, + project=project, + inspection_type_label=type_label, + ) + return _html_to_pdf(html) + + +def generate_report_pdf(report, project) -> bytes: + type_label = REPORT_TYPE_LABELS.get(report.report_type.value, report.report_type.value) + status_label = REPORT_STATUS_LABELS.get(report.status.value, report.status.value) + period_label = f"{report.period_start} ~ {report.period_end} ({type_label})" + html = _render_html( + "report.html", + report=report, + project=project, + report_type_label=type_label, + status_label=status_label, + period_label=period_label, + content_json=report.content_json, + ai_draft_text=report.ai_draft_text, + ) + return _html_to_pdf(html) diff --git a/backend/app/services/scheduler.py b/backend/app/services/scheduler.py new file mode 100644 index 0000000..dd03dd4 --- /dev/null +++ b/backend/app/services/scheduler.py @@ -0,0 +1,143 @@ +""" +APScheduler 날씨 자동 수집 배치 +- 3시간마다 활성 프로젝트의 날씨 데이터를 수집 +- 수집 후 날씨 경보 평가 +""" +import logging +from datetime import date, datetime + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.interval import IntervalTrigger +from sqlalchemy import select + +from app.core.database import AsyncSessionLocal +from app.models.project import Project +from app.models.weather import WeatherData, WeatherAlert +from app.services.weather_service import fetch_short_term_forecast, evaluate_alerts + +logger = logging.getLogger(__name__) + +_scheduler: AsyncIOScheduler | None = None + + +async def _collect_weather_for_all_projects(): + """활성 프로젝트 전체의 날씨를 수집하고 경보를 평가합니다.""" + async with AsyncSessionLocal() as db: + result = await db.execute( + select(Project).where(Project.status == "active") + ) + projects = result.scalars().all() + + if not projects: + return + + logger.info(f"날씨 수집 시작: {len(projects)}개 프로젝트") + + for project in projects: + # 프로젝트에 KMA 격자 좌표가 없으면 스킵 + if not project.kma_nx or not project.kma_ny: + continue + + try: + forecasts = await fetch_short_term_forecast(project.kma_nx, project.kma_ny) + + today = date.today().isoformat() + today_forecasts = [f for f in forecasts if f.get("date") == today] + + if today_forecasts: + # 오늘 날씨 데이터 upsert (중복 저장 방지) + existing = await db.execute( + select(WeatherData).where( + WeatherData.project_id == project.id, + WeatherData.forecast_date == date.today(), + ) + ) + weather_row = existing.scalar_one_or_none() + + if not weather_row: + # 최고/최저 기온 계산 + temps = [f.get("temperature") for f in today_forecasts if f.get("temperature") is not None] + precips = [f.get("precipitation_mm", 0) or 0 for f in today_forecasts] + wind_speeds = [f.get("wind_speed", 0) or 0 for f in today_forecasts] + + weather_row = WeatherData( + project_id=project.id, + forecast_date=date.today(), + temperature_max=max(temps) if temps else None, + temperature_min=min(temps) if temps else None, + precipitation_mm=sum(precips), + wind_speed_max=max(wind_speeds) if wind_speeds else None, + sky_condition=today_forecasts[0].get("sky_condition"), + raw_forecast=today_forecasts, + ) + db.add(weather_row) + await db.flush() + + # 날씨 경보 평가 (활성 태스크 기반) + from app.models.task import Task + tasks_result = await db.execute( + select(Task).where( + Task.project_id == project.id, + Task.status.in_(["not_started", "in_progress"]), + ) + ) + tasks = tasks_result.scalars().all() + + # 오늘 이미 생성된 경보 확인 + existing_alerts = await db.execute( + select(WeatherAlert).where( + WeatherAlert.project_id == project.id, + WeatherAlert.alert_date == date.today(), + ) + ) + already_alerted = existing_alerts.scalars().all() + alerted_types = {a.alert_type for a in already_alerted} + + new_alerts = evaluate_alerts( + forecasts=today_forecasts, + tasks=tasks, + existing_alert_types=alerted_types, + ) + + for alert_data in new_alerts: + alert = WeatherAlert( + project_id=project.id, + alert_date=date.today(), + **alert_data, + ) + db.add(alert) + + except Exception as e: + logger.error(f"프로젝트 {project.id} 날씨 수집 실패: {e}") + continue + + await db.commit() + logger.info("날씨 수집 완료") + + +def start_scheduler(): + """FastAPI 시작 시 스케줄러를 초기화하고 시작합니다.""" + global _scheduler + _scheduler = AsyncIOScheduler(timezone="Asia/Seoul") + + # 3시간마다 날씨 수집 + _scheduler.add_job( + _collect_weather_for_all_projects, + trigger=IntervalTrigger(hours=3), + id="weather_collect", + name="날씨 데이터 자동 수집", + replace_existing=True, + misfire_grace_time=300, # 5분 내 누락 허용 + ) + + _scheduler.start() + logger.info("APScheduler 시작: 날씨 수집 3시간 주기") + return _scheduler + + +def stop_scheduler(): + """FastAPI 종료 시 스케줄러를 중지합니다.""" + global _scheduler + if _scheduler and _scheduler.running: + _scheduler.shutdown(wait=False) + logger.info("APScheduler 종료") diff --git a/backend/app/templates/daily_report.html b/backend/app/templates/daily_report.html new file mode 100644 index 0000000..a7c0eaf --- /dev/null +++ b/backend/app/templates/daily_report.html @@ -0,0 +1,98 @@ + + + + + + + +

작 업 일 보

+
{{ project.name }}
+ + + + + + + + + + + + + +
공사명{{ project.name }}
일자{{ report.report_date }}날씨{{ report.weather_summary or '-' }}
기온 (최고/최저){{ report.temperature_high }}°C / {{ report.temperature_low }}°C상태 + {% if report.status.value == 'confirmed' %} + 확인완료 + {% else %} + 초안 + {% endif %} +
+ +
▶ 투입 인원
+ {% if report.workers_count %} + + + {% for key in report.workers_count %}{% endfor %} + + + + {% set total = namespace(n=0) %} + {% for key, val in report.workers_count.items() %} + + {% set total.n = total.n + val %} + {% endfor %} + + +
{{ key }}합계
{{ val }}명{{ total.n }}명
+ {% else %} +
투입 인원 정보 없음
+ {% endif %} + + {% if report.equipment_list %} +
▶ 투입 장비
+ + + {% for eq in report.equipment_list %} + + + + + + + {% endfor %} +
장비명규격수량비고
{{ eq.get('type', '-') }}{{ eq.get('spec', '-') }}{{ eq.get('count', 1) }}대{{ eq.get('notes', '') }}
+ {% endif %} + +
▶ 작업 내용
+ + +
{{ report.work_content or '-' }}
+ + {% if report.issues %} +
▶ 특이사항 / 문제점
+ + +
{{ report.issues }}
+ {% endif %} + + + + diff --git a/backend/app/templates/inspection.html b/backend/app/templates/inspection.html new file mode 100644 index 0000000..c39ebe5 --- /dev/null +++ b/backend/app/templates/inspection.html @@ -0,0 +1,91 @@ + + + + + + + +

검 측 요 청 서

+ + + + + + + + + + + + + + + + {% if inspection.inspector_name %} + + {% endif %} +
공사명{{ project.name }}
검측 항목{{ inspection_type_label }}요청일{{ inspection.requested_date }}
위치 / 부위{{ inspection.location_detail or '-' }}결과 + {% if inspection.result %} + {% if inspection.result.value == 'pass' %}합격 + {% elif inspection.result.value == 'fail' %}불합격 + {% else %}조건부합격{% endif %} + {% else %}-{% endif %} +
검측자{{ inspection.inspector_name }}
+ + {% if inspection.checklist_items %} +
▶ 검측 체크리스트
+ + + {% for item in inspection.checklist_items %} + + + + + + + {% endfor %} +
No.검측 항목기준값확인
{{ loop.index }}{{ item.get('item', item) if item is mapping else item }}{{ item.get('standard', '') if item is mapping else '' }}
+ {% endif %} + + {% if inspection.notes %} +
▶ 특이사항
+
{{ inspection.notes }}
+ {% endif %} + +
+
+
현장대리인
+
+
(인)
+
+
+
감독관
+
+
(인)
+
+
+ + + + diff --git a/backend/app/templates/report.html b/backend/app/templates/report.html new file mode 100644 index 0000000..86d5879 --- /dev/null +++ b/backend/app/templates/report.html @@ -0,0 +1,58 @@ + + + + + + + +

{{ report_type_label }} 공정보고서

+
{{ project.name }}  |  {{ period_label }}
+ + + + + + + +
공사명{{ project.name }}
보고 기간{{ period_label }}상태{{ status_label }}
+ + {% if content_json %} + {% if content_json.get('work_summary') %} +
▶ 주요 작업 내용
+
{{ content_json.work_summary }}
+ {% endif %} + + {% if content_json.get('overall_progress') is not none %} +
▶ 공정률
+
종합 공정률{{ content_json.overall_progress }}%
+ {% endif %} + + {% if content_json.get('issues') %} +
▶ 문제점 및 조치사항
+
{{ content_json.issues }}
+ {% endif %} + + {% if content_json.get('next_plan') %} +
▶ 다음 기간 예정 작업
+
{{ content_json.next_plan }}
+ {% endif %} + {% elif ai_draft_text %} +
▶ AI 작성 보고서 초안
+
{{ ai_draft_text }}
+ {% endif %} + + + + diff --git a/backend/scripts/__init__.py b/backend/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/scripts/seed_rag.py b/backend/scripts/seed_rag.py new file mode 100644 index 0000000..e4dd24d --- /dev/null +++ b/backend/scripts/seed_rag.py @@ -0,0 +1,273 @@ +""" +RAG 시드 스크립트 +법규/시방서 PDF 또는 텍스트 파일을 pgvector에 색인합니다. + +사용법: + python scripts/seed_rag.py --file "경로/파일명.pdf" --title "KCS 14 20 10" --type kcs + python scripts/seed_rag.py --file "경로/파일명.txt" --title "건설안전관리법" --type law + python scripts/seed_rag.py --list # 색인된 소스 목록 출력 + python scripts/seed_rag.py --delete # 소스 및 청크 삭제 + +지원 파일 형식: PDF, TXT, MD +""" +import argparse +import asyncio +import os +import sys +import uuid +from pathlib import Path + +# 프로젝트 루트를 sys.path에 추가 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import httpx +from sqlalchemy import select, text, delete +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker + +from app.config import settings +from app.models.rag import RagSource, RagChunk, RagSourceType + +# ─── 텍스트 추출 ──────────────────────────────────────────────────────────────── + +def extract_text_from_pdf(filepath: str) -> str: + try: + import pdfplumber + text_parts = [] + with pdfplumber.open(filepath) as pdf: + for page in pdf.pages: + t = page.extract_text() + if t: + text_parts.append(t) + return "\n".join(text_parts) + except ImportError: + # fallback: pypdf + try: + from pypdf import PdfReader + reader = PdfReader(filepath) + return "\n".join(page.extract_text() or "" for page in reader.pages) + except ImportError: + raise RuntimeError( + "PDF 읽기 라이브러리가 없습니다.\n" + "설치: pip install pdfplumber 또는 pip install pypdf" + ) + + +def extract_text(filepath: str) -> str: + ext = Path(filepath).suffix.lower() + if ext == ".pdf": + return extract_text_from_pdf(filepath) + elif ext in (".txt", ".md"): + with open(filepath, encoding="utf-8") as f: + return f.read() + else: + raise ValueError(f"지원하지 않는 파일 형식: {ext} (pdf, txt, md만 가능)") + + +# ─── 텍스트 청킹 ──────────────────────────────────────────────────────────────── + +def split_chunks(text: str, chunk_size: int = 800, overlap: int = 100) -> list[str]: + """ + 단락 단위로 먼저 분리하고, chunk_size 초과 시 슬라이딩 윈도우로 분할. + overlap: 앞 청크 마지막 n 글자를 다음 청크 앞에 붙임 (문맥 유지). + """ + paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] + chunks: list[str] = [] + current = "" + + for para in paragraphs: + if len(current) + len(para) < chunk_size: + current = (current + "\n\n" + para).strip() + else: + if current: + chunks.append(current) + # 긴 단락은 슬라이딩 윈도우 + if len(para) > chunk_size: + for start in range(0, len(para), chunk_size - overlap): + chunks.append(para[start : start + chunk_size]) + else: + current = para + + if current: + chunks.append(current) + + return [c for c in chunks if len(c) > 50] # 너무 짧은 청크 제거 + + +# ─── 임베딩 ───────────────────────────────────────────────────────────────────── + +async def embed_batch(texts: list[str]) -> list[list[float]]: + """배치 임베딩 (Voyage AI 또는 OpenAI)""" + if settings.VOYAGE_API_KEY: + return await _embed_voyage_batch(texts) + elif settings.OPENAI_API_KEY: + return await _embed_openai_batch(texts) + else: + raise ValueError("VOYAGE_API_KEY 또는 OPENAI_API_KEY를 .env에 설정하세요.") + + +async def _embed_voyage_batch(texts: list[str]) -> list[list[float]]: + # Voyage AI는 배치 최대 128개 + BATCH = 128 + results = [] + async with httpx.AsyncClient(timeout=60.0) as client: + for i in range(0, len(texts), BATCH): + batch = texts[i : i + BATCH] + resp = await client.post( + "https://api.voyageai.com/v1/embeddings", + headers={"Authorization": f"Bearer {settings.VOYAGE_API_KEY}"}, + json={"model": settings.EMBEDDING_MODEL, "input": batch}, + ) + resp.raise_for_status() + data = resp.json()["data"] + results.extend(item["embedding"] for item in sorted(data, key=lambda x: x["index"])) + return results + + +async def _embed_openai_batch(texts: list[str]) -> list[list[float]]: + BATCH = 100 + results = [] + async with httpx.AsyncClient(timeout=60.0) as client: + for i in range(0, len(texts), BATCH): + batch = texts[i : i + BATCH] + resp = await client.post( + "https://api.openai.com/v1/embeddings", + headers={"Authorization": f"Bearer {settings.OPENAI_API_KEY}"}, + json={"model": "text-embedding-3-small", "input": batch}, + ) + resp.raise_for_status() + data = resp.json()["data"] + results.extend(item["embedding"] for item in sorted(data, key=lambda x: x["index"])) + return results + + +# ─── DB 작업 ───────────────────────────────────────────────────────────────────── + +async def get_session() -> AsyncSession: + engine = create_async_engine(settings.DATABASE_URL, echo=False) + factory = async_sessionmaker(engine, expire_on_commit=False) + return factory() + + +async def seed(filepath: str, title: str, source_type: str, chunk_size: int, overlap: int): + print(f"\n[1/4] 파일 읽기: {filepath}") + raw_text = extract_text(filepath) + print(f" 추출된 텍스트: {len(raw_text):,}자") + + print(f"[2/4] 청크 분할 (크기={chunk_size}, 겹침={overlap})") + chunks = split_chunks(raw_text, chunk_size, overlap) + print(f" 청크 수: {len(chunks)}개") + + print(f"[3/4] 임베딩 생성 중...") + embeddings = await embed_batch([c for c in chunks]) + print(f" 임베딩 완료: {len(embeddings)}개") + + print(f"[4/4] DB 저장 중...") + async with await get_session() as session: + # RagSource 생성 + source = RagSource( + title=title, + source_type=RagSourceType(source_type), + ) + session.add(source) + await session.flush() # source.id 확보 + + # RagChunk 배치 저장 + dim = settings.EMBEDDING_DIMENSIONS + for idx, (content, emb) in enumerate(zip(chunks, embeddings)): + chunk = RagChunk( + source_id=source.id, + chunk_index=idx, + content=content, + metadata_={"chunk_index": idx, "source_title": title}, + ) + session.add(chunk) + await session.flush() # chunk.id 확보 + + # pgvector 직접 업데이트 (SQLAlchemy ORM이 VECTOR 타입을 직접 지원 안 함) + emb_str = "[" + ",".join(str(x) for x in emb) + "]" + await session.execute( + text("UPDATE rag_chunks SET embedding = :emb WHERE id = :id"), + {"emb": emb_str, "id": chunk.id}, + ) + + await session.commit() + print(f"\n완료! source_id={source.id}") + print(f" 제목: {title}") + print(f" 타입: {source_type}") + print(f" 청크: {len(chunks)}개 저장됨") + + +async def list_sources(): + async with await get_session() as session: + result = await session.execute( + select(RagSource).order_by(RagSource.created_at.desc()) + ) + sources = result.scalars().all() + if not sources: + print("색인된 소스가 없습니다.") + return + print(f"\n{'ID':<38} {'타입':<12} {'제목'}") + print("-" * 80) + for s in sources: + chunks_q = await session.execute( + text("SELECT COUNT(*) FROM rag_chunks WHERE source_id = :id"), + {"id": s.id}, + ) + count = chunks_q.scalar() + print(f"{str(s.id):<38} {s.source_type.value:<12} {s.title} ({count}청크)") + + +async def delete_source(source_id: str): + async with await get_session() as session: + sid = uuid.UUID(source_id) + result = await session.execute(select(RagSource).where(RagSource.id == sid)) + source = result.scalar_one_or_none() + if not source: + print(f"소스를 찾을 수 없습니다: {source_id}") + return + await session.execute(delete(RagChunk).where(RagChunk.source_id == sid)) + await session.delete(source) + await session.commit() + print(f"삭제 완료: {source.title} ({source_id})") + + +# ─── CLI ───────────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser( + description="CONAI RAG 시드 스크립트 — 법규/시방서를 pgvector에 색인", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument("--file", help="임베딩할 파일 경로 (pdf/txt/md)") + parser.add_argument("--title", help="문서 제목 (예: KCS 14 20 10 콘크리트 시방서)") + parser.add_argument( + "--type", + choices=["kcs", "law", "regulation", "guideline"], + default="kcs", + help="소스 타입: kcs(시방서), law(법령), regulation(규정), guideline(지침)", + ) + parser.add_argument("--chunk-size", type=int, default=800, help="청크 최대 글자 수 (기본: 800)") + parser.add_argument("--overlap", type=int, default=100, help="청크 겹침 글자 수 (기본: 100)") + parser.add_argument("--list", action="store_true", help="색인된 소스 목록 출력") + parser.add_argument("--delete", metavar="SOURCE_ID", help="소스 ID로 삭제") + + args = parser.parse_args() + + if args.list: + asyncio.run(list_sources()) + elif args.delete: + asyncio.run(delete_source(args.delete)) + elif args.file: + if not args.title: + parser.error("--title 이 필요합니다") + if not os.path.exists(args.file): + print(f"파일을 찾을 수 없습니다: {args.file}") + sys.exit(1) + asyncio.run(seed(args.file, args.title, args.type, args.chunk_size, args.overlap)) + else: + parser.print_help() + + +if __name__ == "__main__": + main()