From eed8723f5191d8db9101815360e05f8bf5486524 Mon Sep 17 00:00:00 2001 From: chan Date: Wed, 3 Sep 2025 10:29:02 +0900 Subject: [PATCH] =?UTF-8?q?=EC=84=B9=EC=85=98=20=EC=A1=B0=ED=9A=8C=20?= =?UTF-8?q?=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docker-compose.yml | 18 +- dockerfile | 17 +- requirements.txt | 8 +- workspace/app.py | 78 +++++++-- workspace/ocr_eval_app.py | 307 +++++++++++++++++++++++++++++++++++ workspace/ocr_eval_engine.py | 92 +++++++++++ 6 files changed, 493 insertions(+), 27 deletions(-) create mode 100644 workspace/ocr_eval_app.py create mode 100644 workspace/ocr_eval_engine.py diff --git a/docker-compose.yml b/docker-compose.yml index 44d41e1..1c8bb2d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,11 +4,23 @@ version: '3.8' services: ui: image: ocr-comparison-ui - volumes: - - ./workspace:/workspace build: context: . - dockerfile: dockerfile # Dockerfile의 상대 경로를 직접 지정 + dockerfile: dockerfile + volumes: + - ./workspace:/workspace ports: - "8501:8501" + command: streamlit run app.py --server.port=8501 + + eval_ui: + image: ocr-comparison-ui + build: + context: . + dockerfile: dockerfile + volumes: + - ./workspace:/workspace + ports: + - "8601:8601" + command: streamlit run ocr_eval_app.py --server.port=8601 diff --git a/dockerfile b/dockerfile index 69f0ac4..b4255ab 100644 --- a/dockerfile +++ b/dockerfile @@ -1,17 +1,18 @@ -# Dockerfile - FROM python:3.10-slim +RUN pip install uv + WORKDIR /workspace COPY requirements.txt . +RUN uv pip install --system --no-cache -r requirements.txt + COPY workspace/ . -RUN pip install --no-cache-dir -r requirements.txt -COPY workspace/app.py . - -COPY . . - +# 서비스가 사용할 포트를 명시합니다 (문서화 목적). EXPOSE 8501 +EXPOSE 8502 -CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"] \ No newline at end of file +# 컨테이너의 기본 실행 명령을 정의합니다. +# 이 명령은 docker-compose.yml 파일의 command에 의해 각 서비스별로 재정의됩니다. +CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"] diff --git a/requirements.txt b/requirements.txt index a29bc44..55bd14f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,10 @@ streamlit requests python-dotenv -pandas \ No newline at end of file +pandas +fuzzywuzzy==0.18.0 +jiwer==4.0.0 +# levenshtein and python-levenshtein are replaced by rapidfuzz for better compatibility +rapidfuzz==3.13.0 +markupsafe==3.0.2 +werkzeug==3.1.3 diff --git a/workspace/app.py b/workspace/app.py index 45edfd6..9f8b34e 100644 --- a/workspace/app.py +++ b/workspace/app.py @@ -44,6 +44,13 @@ def save_completed_file(seed, basename): json.dump(list(completed_set), f, indent=2) +def get_existing_sessions(): + """ "shared_sessions" 디렉토리에서 기존 세션 목록을 가져옵니다. """ + if not SESSION_BASE_PATH.exists(): + return [] + return sorted([d.name for d in SESSION_BASE_PATH.iterdir() if d.is_dir()]) + + # --- 헬퍼 함수 --- @@ -196,26 +203,65 @@ def main(): st.session_state.current_index = 0 SESSION_BASE_PATH.mkdir(parents=True, exist_ok=True) - matched_files = None + st.sidebar.info("화면을 넓게 보려면 오른쪽 위 화살표를 누르세요 <<") + st.sidebar.markdown("---") + + st.sidebar.header("세션 선택") + existing_sessions = get_existing_sessions() + session_options = ["새 세션 생성"] + existing_sessions + + current_seed_from_url = st.query_params.get("seed") + + # URL에 시드가 없으면 "새 세션 생성"을 기본값으로 + if not current_seed_from_url: + current_selection = "새 세션 생성" + # URL의 시드가 존재하지 않는 세션이면 경고 후 "새 세션 생성"으로 + elif current_seed_from_url not in existing_sessions: + st.sidebar.warning(f"URL의 시드 '{current_seed_from_url}'에 해당하는 세션을 찾을 수 없습니다.") + current_selection = "새 세션 생성" + # 잘못된 시드는 URL에서 제거 + if "seed" in st.query_params: + del st.query_params["seed"] + else: + current_selection = current_seed_from_url + + selected_session = st.sidebar.selectbox( + "작업할 세션을 선택하세요.", + session_options, + index=session_options.index(current_selection), + key="session_selector", + ) + + # 사용자가 선택을 변경하면 URL을 업데이트하고 앱을 다시 실행 + if selected_session != current_selection: + if selected_session == "새 세션 생성": + if "seed" in st.query_params: + del st.query_params["seed"] + else: + st.query_params["seed"] = selected_session + st.rerun() + + # --- 이후 로직은 URL의 'seed' 쿼리 파라미터를 기반으로 동작 --- url_seed = st.query_params.get("seed") + matched_files = None + completed_files = set() if url_seed: completed_files = load_completed_files(url_seed) files = load_files_from_session(url_seed) if files[0] is not None: - st.success(f"'{url_seed}' 시드에서 파일을 불러왔습니다.") + st.success(f"'{url_seed}' 세션에서 파일을 불러왔습니다.") matched_files = match_files_3_way(*files) else: + # 이 경우는 위에서 처리되었지만, 안전장치로 남겨둠 st.error(f"'{url_seed}'에 해당하는 세션을 찾을 수 없습니다.") - else: - completed_files = set() + if "seed" in st.query_params: + del st.query_params["seed"] + st.rerun() - # --- 사이드바 --- - st.sidebar.info("화면을 넓게 보려면 오른쪽 위 화살표를 누르세요 <<") - st.sidebar.markdown("---") - - st.sidebar.header("파일 업로드") - if not matched_files: + # 파일 업로드 UI는 새 세션 생성 시에만 표시 + if not url_seed: + st.sidebar.header("새 세션 생성 (파일 업로드)") docs = st.sidebar.file_uploader( "1. 원본 문서", accept_multiple_files=True, type=["png", "jpg", "pdf"] ) @@ -238,12 +284,14 @@ def main(): st.sidebar.info("URL을 복사하여 다른 사람과 세션을 공유하세요.") st.sidebar.text_input("공유 시드", url_seed, disabled=True) + if not url_seed: + st.info("새로운 세션을 생성하려면 사이드바에서 모든 종류의 파일을 업로드하세요.") + return + if not matched_files: - st.info("모든 종류의 파일을 업로드하고 세션을 생성하세요.") - if matched_files is not None and not matched_files: - st.warning( - "파일 이름(확장자 제외)이 동일한 '문서-paddle_ocr-upstage' 세트를 찾을 수 없습니다." - ) + st.warning( + "파일 이름(확장자 제외)이 동일한 '문서-paddle_ocr-upstage' 세트를 찾을 수 없습니다." + ) return st.sidebar.header("파일 탐색") diff --git a/workspace/ocr_eval_app.py b/workspace/ocr_eval_app.py new file mode 100644 index 0000000..4b69a5d --- /dev/null +++ b/workspace/ocr_eval_app.py @@ -0,0 +1,307 @@ +# ocr_eval_app.py +import base64 +import json +import difflib +from pathlib import Path +import pandas as pd +import io + +import streamlit as st +from ocr_eval_engine import OCREvaluator + +# --- 상수 --- +SESSION_BASE_PATH = Path(__file__).parent / "shared_sessions" +EDIT_KEY = "parsed" + +# --- 헬퍼 함수 --- + +def get_evaluable_sessions(): + """ "shared_sessions"에서 'groundtruth' 폴더를 포함하는 세션 목록을 가져옵니다. """ + if not SESSION_BASE_PATH.exists(): + return [] + evaluable = [] + for d in SESSION_BASE_PATH.iterdir(): + if d.is_dir() and (d / "groundtruth").is_dir(): + evaluable.append(d.name) + return sorted(evaluable) + +def get_session_path(seed): + return SESSION_BASE_PATH / seed + +def display_pdf(file_path): + """PDF 파일을 iframe으로 표시합니다.""" + bytes_data = file_path.read_bytes() + base64_pdf = base64.b64encode(bytes_data).decode("utf-8") + st.markdown( + f'', + unsafe_allow_html=True, + ) + +def generate_hyp_html(ref: str, hyp: str) -> str: + """ + difflib.SequenceMatcher를 사용하여 가설(hyp) 텍스트의 오류를 시각화하는 HTML을 생성합니다. + """ + matcher = difflib.SequenceMatcher(None, ref, hyp) + html_out = "" + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + hyp_chunk = hyp[j1:j2] + hyp_chunk_display = hyp_chunk.replace("&", "&").replace("<", "<").replace(">", ">") + hyp_chunk_display = hyp_chunk_display.replace("\n", "
").replace(" ", " ") + + if tag == 'equal': + html_out += f'{hyp_chunk_display}' + elif tag == 'replace': + html_out += f'{hyp_chunk_display}' + elif tag == 'insert': + html_out += f'{hyp_chunk_display}' + + return f'
{html_out}
' + +def match_evaluation_files(seed): + """ + 세션 폴더 내에서 평가에 필요한 파일들의 목록을 찾아서 매칭합니다. + """ + session_path = get_session_path(seed) + doc_path = session_path / "docs" + gt_path = session_path / "groundtruth" + paddle_path = session_path / "jsons_paddle_ocr" + upstage_path = session_path / "jsons_upstage" + + if not all([p.is_dir() for p in [doc_path, gt_path, paddle_path, upstage_path]]): + return None + + gt_files = {f.stem for f in gt_path.glob("*.json")} + doc_map = {f.stem: f for f in doc_path.iterdir()} + paddle_map = {f.stem: f for f in paddle_path.glob("*.json")} + upstage_map = {f.stem: f for f in upstage_path.glob("*.json")} + + matched = {} + for stem in sorted(list(gt_files)): + if stem in doc_map and stem in paddle_map and stem in upstage_map: + matched[stem] = { + "doc_file": doc_map[stem], + "gt_file": gt_path / f"{stem}.json", + "paddle_file": paddle_map[stem], + "upstage_file": upstage_map[stem], + } + return matched + +def display_evaluation_for_file(files): + """선택된 파일에 대한 평가 결과를 표시합니다.""" + st.header("📊 성능 평가 결과") + + try: + with open(files["gt_file"], "r", encoding="utf-8") as f: + gt_data = json.load(f) + with open(files["paddle_file"], "r", encoding="utf-8") as f: + paddle_data = json.load(f) + with open(files["upstage_file"], "r", encoding="utf-8") as f: + upstage_data = json.load(f) + + gt_text = (gt_data[0] if isinstance(gt_data, list) else gt_data).get(EDIT_KEY, "") + paddle_text = (paddle_data[0] if isinstance(paddle_data, list) else paddle_data).get(EDIT_KEY, "") + upstage_text = (upstage_data[0] if isinstance(upstage_data, list) else upstage_data).get(EDIT_KEY, "") + + if not gt_text: + st.warning("정답 텍스트가 비어있어 평가할 수 없습니다.") + return + + evaluator = OCREvaluator(gt_text) + paddle_results = evaluator.evaluate(paddle_text) + upstage_results = evaluator.evaluate(upstage_text) + + col1, col2 = st.columns(2) + with col1: + st.markdown("#### Model 1: Paddle OCR") + m_col1, m_col2 = st.columns(2) + m_col1.metric("엄격한 WER", f"{paddle_results['strict_wer']:.2%}") + m_col2.metric("엄격한 CER", f"{paddle_results['strict_cer']:.2%}") + m_col1.metric("유연한 WER", f"{paddle_results['flexible_wer']:.2%}") + m_col2.metric("유연한 CER", f"{paddle_results['flexible_cer']:.2%}") + + with col2: + st.markdown("#### Model 2: Upstage OCR") + m_col1, m_col2 = st.columns(2) + m_col1.metric("엄격한 WER", f"{upstage_results['strict_wer']:.2%}") + m_col2.metric("엄격한 CER", f"{upstage_results['strict_cer']:.2%}") + m_col1.metric("유연한 WER", f"{upstage_results['flexible_wer']:.2%}") + m_col2.metric("유연한 CER", f"{upstage_results['flexible_cer']:.2%}") + + with st.expander("상세 텍스트 비교", expanded=True): + st.markdown(""" + + 범례 (Legend) +
일치하는 텍스트
+
삽입된 텍스트 (정답에 없음)
+
치환된 텍스트 (정답과 다름)
+ """, unsafe_allow_html=True) + st.markdown("---") + + text_col1, text_col2, text_col3 = st.columns(3) + with text_col1: + st.text("정답 (Ground Truth)") + st.code(gt_text, language=None) + with text_col2: + st.text("Paddle OCR") + html = generate_hyp_html(gt_text, paddle_text) + st.markdown(html, unsafe_allow_html=True) + with text_col3: + st.text("Upstage OCR") + html = generate_hyp_html(gt_text, upstage_text) + st.markdown(html, unsafe_allow_html=True) + + except Exception as e: + st.error(f"성능 평가 중 오류 발생: {e}") + +@st.cache_data +def generate_all_results_df(_matched_files): + """세션의 모든 파일에 대한 평가 결과를 집계하여 DataFrame으로 반환합니다.""" + all_results = [] + for basename, files in _matched_files.items(): + try: + with open(files["gt_file"], "r", encoding="utf-8") as f: + gt_data = json.load(f) + with open(files["paddle_file"], "r", encoding="utf-8") as f: + paddle_data = json.load(f) + with open(files["upstage_file"], "r", encoding="utf-8") as f: + upstage_data = json.load(f) + + gt_text = (gt_data[0] if isinstance(gt_data, list) else gt_data).get(EDIT_KEY, "") + if not gt_text: + continue + + evaluator = OCREvaluator(gt_text) + + # Paddle 모델 평가 + paddle_text = (paddle_data[0] if isinstance(paddle_data, list) else paddle_data).get(EDIT_KEY, "") + paddle_results = evaluator.evaluate(paddle_text) + paddle_results['model'] = 'paddle_ocr' + paddle_results['file'] = basename + all_results.append(paddle_results) + + # Upstage 모델 평가 + upstage_text = (upstage_data[0] if isinstance(upstage_data, list) else upstage_data).get(EDIT_KEY, "") + upstage_results = evaluator.evaluate(upstage_text) + upstage_results['model'] = 'upstage_ocr' + upstage_results['file'] = basename + all_results.append(upstage_results) + except Exception: + # 오류가 있는 파일은 건너뜀 + continue + + df = pd.DataFrame(all_results) + # 컬럼 순서 재정렬 + ordered_cols = ['file', 'model', 'strict_wer', 'strict_cer', 'flexible_wer', 'flexible_cer', 'word_hits', 'word_substitutions', 'word_deletions', 'word_insertions', 'char_hits', 'char_substitutions', 'char_deletions', 'char_insertions'] + return df[ordered_cols] + +# --- 콜백 함수 --- +def handle_nav_button(direction, total_files): + if direction == "prev" and st.session_state.eval_current_index > 0: + st.session_state.eval_current_index -= 1 + elif direction == "next" and st.session_state.eval_current_index < total_files - 1: + st.session_state.eval_current_index += 1 + +def handle_selectbox_change(): + st.session_state.eval_current_index = st.session_state.eval_selectbox_key + +# --- 메인 UI 로직 --- +def main(): + st.set_page_config(layout="wide", page_title="OCR 성능 평가 도구") + st.title("OCR 성능 평가 도구") + + if "eval_current_index" not in st.session_state: + st.session_state.eval_current_index = 0 + + st.sidebar.header("세션 선택") + sessions = get_evaluable_sessions() + + if not sessions: + st.info("평가 가능한 세션이 없습니다. 먼저 '정답셋 생성 도구'에서 정답셋을 생성해주세요.") + return + + seed = st.sidebar.selectbox("평가할 세션을 선택하세요.", sessions) + + if not seed: + st.info("사이드바에서 평가할 세션을 선택하세요.") + return + + matched_files = match_evaluation_files(seed) + + if matched_files is None: + st.error(f"'{seed}'에 해당하는 세션을 찾을 수 없거나, 필요한 폴더(docs, groundtruth 등)가 없습니다.") + return + if not matched_files: + st.warning("해당 세션에 평가할 파일(정답셋이 생성된 파일)이 없습니다.") + return + + sorted_basenames = sorted(list(matched_files.keys())) + + if st.session_state.eval_current_index >= len(sorted_basenames): + st.session_state.eval_current_index = 0 + + st.sidebar.header("파일 선택") + st.sidebar.selectbox( + "평가할 파일을 선택하세요.", + options=range(len(sorted_basenames)), + format_func=lambda x: f"{x+1}. {sorted_basenames[x]}", + index=st.session_state.eval_current_index, + key="eval_selectbox_key", + on_change=handle_selectbox_change, + ) + + st.sidebar.header("보기 옵션") + hide_document = st.sidebar.checkbox("원본 문서 숨기기", value=False) + + st.sidebar.header("내보내기") + results_df = generate_all_results_df(matched_files) + if not results_df.empty: + csv = results_df.to_csv(index=False).encode('utf-8') + st.sidebar.download_button( + label="전체 결과 CSV 다운로드", + data=csv, + file_name=f"evaluation_results_{seed}.csv", + mime="text/csv", + ) + else: + st.sidebar.write("다운로드할 결과가 없습니다.") + + + current_basename = sorted_basenames[st.session_state.eval_current_index] + + nav_cols = st.columns([1, 5, 1]) + nav_cols[0].button( + "◀ 이전", + on_click=handle_nav_button, + args=("prev", len(sorted_basenames)), + use_container_width=True, + ) + nav_cols[1].markdown( + f"

{current_basename} ({st.session_state.eval_current_index + 1}/{len(sorted_basenames)})

", + unsafe_allow_html=True, + ) + nav_cols[2].button( + "다음 ▶", + on_click=handle_nav_button, + args=("next", len(sorted_basenames)), + use_container_width=True, + ) + st.markdown("---") + + files_to_evaluate = matched_files[current_basename] + + if hide_document: + display_evaluation_for_file(files_to_evaluate) + else: + col1, col2 = st.columns([1, 1]) + with col1: + st.header("📄 원본 문서") + doc_file = files_to_evaluate["doc_file"] + if doc_file.suffix.lower() == ".pdf": + display_pdf(doc_file) + else: + st.image(str(doc_file), use_container_width=True) + with col2: + display_evaluation_for_file(files_to_evaluate) + +if __name__ == "__main__": + main() diff --git a/workspace/ocr_eval_engine.py b/workspace/ocr_eval_engine.py new file mode 100644 index 0000000..1cce2ab --- /dev/null +++ b/workspace/ocr_eval_engine.py @@ -0,0 +1,92 @@ +import jiwer +from fuzzywuzzy import fuzz + + +class OCREvaluator: + """ + 정답(GT) 텍스트와 하나 이상의 예측(Hypothesis) 텍스트를 비교하여 + 다양한 문자 오류율(CER) 및 단어 오류율(WER) 지표를 계산하는 클래스. + """ + + def __init__(self, ground_truth_text: str): + """ + 평가기 인스턴스를 초기화합니다. + + :param ground_truth_text: 비교의 기준이 되는 정답 텍스트. + """ + self.ground_truth = ground_truth_text + + def evaluate(self, hypothesis_text: str) -> dict: + """ + 주어진 예측 텍스트에 대한 모든 평가 지표를 계산합니다. + + :param hypothesis_text: 평가할 OCR 예측 텍스트. + :return: 평가 결과를 담은 딕셔너리. + """ + cer_results = self._calculate_strict_cer(self.ground_truth, hypothesis_text) + wer_results = self._calculate_strict_wer(self.ground_truth, hypothesis_text) + flexible_cer = self._calculate_flexible_cer(self.ground_truth, hypothesis_text) + flexible_wer = self._calculate_flexible_wer(self.ground_truth, hypothesis_text) + + results = { + # Strict CER + "strict_cer": cer_results["cer"], + "char_substitutions": cer_results["S"], + "char_deletions": cer_results["D"], + "char_insertions": cer_results["I"], + "char_hits": cer_results["H"], + # Strict WER + "strict_wer": wer_results["wer"], + "word_substitutions": wer_results["S"], + "word_deletions": wer_results["D"], + "word_insertions": wer_results["I"], + "word_hits": wer_results["H"], + # Flexible Metrics + "flexible_cer": flexible_cer, + "flexible_wer": flexible_wer, + } + return results + + def _calculate_strict_cer(self, ref: str, hyp: str) -> dict: + """ + jiwer를 사용하여 엄격한 순서의 CER을 계산합니다. + """ + if not ref: + return {"cer": 1.0 if hyp else 0.0, "S": 0, "D": 0, "I": len(hyp), "H": 0} + output = jiwer.process_characters(ref, hyp) + return { + "cer": output.cer, + "S": output.substitutions, + "D": output.deletions, + "I": output.insertions, + "H": output.hits, + } + + def _calculate_strict_wer(self, ref: str, hyp: str) -> dict: + """ + jiwer를 사용하여 엄격한 순서의 WER을 계산합니다. + """ + if not ref: + return {"wer": 1.0 if hyp else 0.0, "S": 0, "D": 0, "I": len(hyp.split()), "H": 0} + output = jiwer.process_words(ref, hyp) + return { + "wer": output.wer, + "S": output.substitutions, + "D": output.deletions, + "I": output.insertions, + "H": output.hits, + } + + def _calculate_flexible_cer(self, ref: str, hyp: str) -> float: + """ + fuzzywuzzy의 token_sort_ratio를 사용하여 순서에 유연한 CER을 계산합니다. + """ + similarity_ratio = fuzz.token_sort_ratio(ref, hyp) + return (100 - similarity_ratio) / 100.0 + + def _calculate_flexible_wer(self, ref: str, hyp: str) -> float: + """ + fuzzywuzzy의 token_set_ratio를 사용하여 순서에 유연한 WER을 계산합니다. + """ + similarity_ratio = fuzz.token_set_ratio(ref, hyp) + return (100 - similarity_ratio) / 100.0 \ No newline at end of file