섹션 조회 추가
This commit is contained in:
307
workspace/ocr_eval_app.py
Normal file
307
workspace/ocr_eval_app.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# ocr_eval_app.py
|
||||
import base64
|
||||
import json
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
import streamlit as st
|
||||
from ocr_eval_engine import OCREvaluator
|
||||
|
||||
# --- 상수 ---
|
||||
SESSION_BASE_PATH = Path(__file__).parent / "shared_sessions"
|
||||
EDIT_KEY = "parsed"
|
||||
|
||||
# --- 헬퍼 함수 ---
|
||||
|
||||
def get_evaluable_sessions():
|
||||
""" "shared_sessions"에서 'groundtruth' 폴더를 포함하는 세션 목록을 가져옵니다. """
|
||||
if not SESSION_BASE_PATH.exists():
|
||||
return []
|
||||
evaluable = []
|
||||
for d in SESSION_BASE_PATH.iterdir():
|
||||
if d.is_dir() and (d / "groundtruth").is_dir():
|
||||
evaluable.append(d.name)
|
||||
return sorted(evaluable)
|
||||
|
||||
def get_session_path(seed):
|
||||
return SESSION_BASE_PATH / seed
|
||||
|
||||
def display_pdf(file_path):
|
||||
"""PDF 파일을 iframe으로 표시합니다."""
|
||||
bytes_data = file_path.read_bytes()
|
||||
base64_pdf = base64.b64encode(bytes_data).decode("utf-8")
|
||||
st.markdown(
|
||||
f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="800" type="application/pdf"></iframe>',
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
def generate_hyp_html(ref: str, hyp: str) -> str:
|
||||
"""
|
||||
difflib.SequenceMatcher를 사용하여 가설(hyp) 텍스트의 오류를 시각화하는 HTML을 생성합니다.
|
||||
"""
|
||||
matcher = difflib.SequenceMatcher(None, ref, hyp)
|
||||
html_out = ""
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
hyp_chunk = hyp[j1:j2]
|
||||
hyp_chunk_display = hyp_chunk.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
hyp_chunk_display = hyp_chunk_display.replace("\n", "<br>").replace(" ", " ")
|
||||
|
||||
if tag == 'equal':
|
||||
html_out += f'<span style="color: black;">{hyp_chunk_display}</span>'
|
||||
elif tag == 'replace':
|
||||
html_out += f'<span style="color: red; background-color: #ffdddd; font-weight: bold;">{hyp_chunk_display}</span>'
|
||||
elif tag == 'insert':
|
||||
html_out += f'<span style="color: green; background-color: #ddffdd; font-weight: bold;">{hyp_chunk_display}</span>'
|
||||
|
||||
return f'<div style="font-family: monospace; border: 1px solid #ddd; padding: 10px; border-radius: 5px; white-space: normal; word-break: break-all; line-height: 1.6;">{html_out}</div>'
|
||||
|
||||
def match_evaluation_files(seed):
|
||||
"""
|
||||
세션 폴더 내에서 평가에 필요한 파일들의 목록을 찾아서 매칭합니다.
|
||||
"""
|
||||
session_path = get_session_path(seed)
|
||||
doc_path = session_path / "docs"
|
||||
gt_path = session_path / "groundtruth"
|
||||
paddle_path = session_path / "jsons_paddle_ocr"
|
||||
upstage_path = session_path / "jsons_upstage"
|
||||
|
||||
if not all([p.is_dir() for p in [doc_path, gt_path, paddle_path, upstage_path]]):
|
||||
return None
|
||||
|
||||
gt_files = {f.stem for f in gt_path.glob("*.json")}
|
||||
doc_map = {f.stem: f for f in doc_path.iterdir()}
|
||||
paddle_map = {f.stem: f for f in paddle_path.glob("*.json")}
|
||||
upstage_map = {f.stem: f for f in upstage_path.glob("*.json")}
|
||||
|
||||
matched = {}
|
||||
for stem in sorted(list(gt_files)):
|
||||
if stem in doc_map and stem in paddle_map and stem in upstage_map:
|
||||
matched[stem] = {
|
||||
"doc_file": doc_map[stem],
|
||||
"gt_file": gt_path / f"{stem}.json",
|
||||
"paddle_file": paddle_map[stem],
|
||||
"upstage_file": upstage_map[stem],
|
||||
}
|
||||
return matched
|
||||
|
||||
def display_evaluation_for_file(files):
|
||||
"""선택된 파일에 대한 평가 결과를 표시합니다."""
|
||||
st.header("📊 성능 평가 결과")
|
||||
|
||||
try:
|
||||
with open(files["gt_file"], "r", encoding="utf-8") as f:
|
||||
gt_data = json.load(f)
|
||||
with open(files["paddle_file"], "r", encoding="utf-8") as f:
|
||||
paddle_data = json.load(f)
|
||||
with open(files["upstage_file"], "r", encoding="utf-8") as f:
|
||||
upstage_data = json.load(f)
|
||||
|
||||
gt_text = (gt_data[0] if isinstance(gt_data, list) else gt_data).get(EDIT_KEY, "")
|
||||
paddle_text = (paddle_data[0] if isinstance(paddle_data, list) else paddle_data).get(EDIT_KEY, "")
|
||||
upstage_text = (upstage_data[0] if isinstance(upstage_data, list) else upstage_data).get(EDIT_KEY, "")
|
||||
|
||||
if not gt_text:
|
||||
st.warning("정답 텍스트가 비어있어 평가할 수 없습니다.")
|
||||
return
|
||||
|
||||
evaluator = OCREvaluator(gt_text)
|
||||
paddle_results = evaluator.evaluate(paddle_text)
|
||||
upstage_results = evaluator.evaluate(upstage_text)
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.markdown("#### Model 1: Paddle OCR")
|
||||
m_col1, m_col2 = st.columns(2)
|
||||
m_col1.metric("엄격한 WER", f"{paddle_results['strict_wer']:.2%}")
|
||||
m_col2.metric("엄격한 CER", f"{paddle_results['strict_cer']:.2%}")
|
||||
m_col1.metric("유연한 WER", f"{paddle_results['flexible_wer']:.2%}")
|
||||
m_col2.metric("유연한 CER", f"{paddle_results['flexible_cer']:.2%}")
|
||||
|
||||
with col2:
|
||||
st.markdown("#### Model 2: Upstage OCR")
|
||||
m_col1, m_col2 = st.columns(2)
|
||||
m_col1.metric("엄격한 WER", f"{upstage_results['strict_wer']:.2%}")
|
||||
m_col2.metric("엄격한 CER", f"{upstage_results['strict_cer']:.2%}")
|
||||
m_col1.metric("유연한 WER", f"{upstage_results['flexible_wer']:.2%}")
|
||||
m_col2.metric("유연한 CER", f"{upstage_results['flexible_cer']:.2%}")
|
||||
|
||||
with st.expander("상세 텍스트 비교", expanded=True):
|
||||
st.markdown("""
|
||||
<style>.legend{display:flex;align-items:center;margin-bottom:10px;}.legend-box{width:20px;height:20px;margin-right:10px;border:1px solid #ccc;}</style>
|
||||
<b>범례 (Legend)</b>
|
||||
<div class="legend"><div class="legend-box" style="background-color:white;"></div><span>일치하는 텍스트</span></div>
|
||||
<div class="legend"><div class="legend-box" style="background-color:#ddffdd;"></div><span><b>삽입된 텍스트</b> (정답에 없음)</span></div>
|
||||
<div class="legend"><div class="legend-box" style="background-color:#ffdddd;"></div><span><b>치환된 텍스트</b> (정답과 다름)</span></div>
|
||||
""", unsafe_allow_html=True)
|
||||
st.markdown("---")
|
||||
|
||||
text_col1, text_col2, text_col3 = st.columns(3)
|
||||
with text_col1:
|
||||
st.text("정답 (Ground Truth)")
|
||||
st.code(gt_text, language=None)
|
||||
with text_col2:
|
||||
st.text("Paddle OCR")
|
||||
html = generate_hyp_html(gt_text, paddle_text)
|
||||
st.markdown(html, unsafe_allow_html=True)
|
||||
with text_col3:
|
||||
st.text("Upstage OCR")
|
||||
html = generate_hyp_html(gt_text, upstage_text)
|
||||
st.markdown(html, unsafe_allow_html=True)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"성능 평가 중 오류 발생: {e}")
|
||||
|
||||
@st.cache_data
|
||||
def generate_all_results_df(_matched_files):
|
||||
"""세션의 모든 파일에 대한 평가 결과를 집계하여 DataFrame으로 반환합니다."""
|
||||
all_results = []
|
||||
for basename, files in _matched_files.items():
|
||||
try:
|
||||
with open(files["gt_file"], "r", encoding="utf-8") as f:
|
||||
gt_data = json.load(f)
|
||||
with open(files["paddle_file"], "r", encoding="utf-8") as f:
|
||||
paddle_data = json.load(f)
|
||||
with open(files["upstage_file"], "r", encoding="utf-8") as f:
|
||||
upstage_data = json.load(f)
|
||||
|
||||
gt_text = (gt_data[0] if isinstance(gt_data, list) else gt_data).get(EDIT_KEY, "")
|
||||
if not gt_text:
|
||||
continue
|
||||
|
||||
evaluator = OCREvaluator(gt_text)
|
||||
|
||||
# Paddle 모델 평가
|
||||
paddle_text = (paddle_data[0] if isinstance(paddle_data, list) else paddle_data).get(EDIT_KEY, "")
|
||||
paddle_results = evaluator.evaluate(paddle_text)
|
||||
paddle_results['model'] = 'paddle_ocr'
|
||||
paddle_results['file'] = basename
|
||||
all_results.append(paddle_results)
|
||||
|
||||
# Upstage 모델 평가
|
||||
upstage_text = (upstage_data[0] if isinstance(upstage_data, list) else upstage_data).get(EDIT_KEY, "")
|
||||
upstage_results = evaluator.evaluate(upstage_text)
|
||||
upstage_results['model'] = 'upstage_ocr'
|
||||
upstage_results['file'] = basename
|
||||
all_results.append(upstage_results)
|
||||
except Exception:
|
||||
# 오류가 있는 파일은 건너뜀
|
||||
continue
|
||||
|
||||
df = pd.DataFrame(all_results)
|
||||
# 컬럼 순서 재정렬
|
||||
ordered_cols = ['file', 'model', 'strict_wer', 'strict_cer', 'flexible_wer', 'flexible_cer', 'word_hits', 'word_substitutions', 'word_deletions', 'word_insertions', 'char_hits', 'char_substitutions', 'char_deletions', 'char_insertions']
|
||||
return df[ordered_cols]
|
||||
|
||||
# --- 콜백 함수 ---
|
||||
def handle_nav_button(direction, total_files):
|
||||
if direction == "prev" and st.session_state.eval_current_index > 0:
|
||||
st.session_state.eval_current_index -= 1
|
||||
elif direction == "next" and st.session_state.eval_current_index < total_files - 1:
|
||||
st.session_state.eval_current_index += 1
|
||||
|
||||
def handle_selectbox_change():
|
||||
st.session_state.eval_current_index = st.session_state.eval_selectbox_key
|
||||
|
||||
# --- 메인 UI 로직 ---
|
||||
def main():
|
||||
st.set_page_config(layout="wide", page_title="OCR 성능 평가 도구")
|
||||
st.title("OCR 성능 평가 도구")
|
||||
|
||||
if "eval_current_index" not in st.session_state:
|
||||
st.session_state.eval_current_index = 0
|
||||
|
||||
st.sidebar.header("세션 선택")
|
||||
sessions = get_evaluable_sessions()
|
||||
|
||||
if not sessions:
|
||||
st.info("평가 가능한 세션이 없습니다. 먼저 '정답셋 생성 도구'에서 정답셋을 생성해주세요.")
|
||||
return
|
||||
|
||||
seed = st.sidebar.selectbox("평가할 세션을 선택하세요.", sessions)
|
||||
|
||||
if not seed:
|
||||
st.info("사이드바에서 평가할 세션을 선택하세요.")
|
||||
return
|
||||
|
||||
matched_files = match_evaluation_files(seed)
|
||||
|
||||
if matched_files is None:
|
||||
st.error(f"'{seed}'에 해당하는 세션을 찾을 수 없거나, 필요한 폴더(docs, groundtruth 등)가 없습니다.")
|
||||
return
|
||||
if not matched_files:
|
||||
st.warning("해당 세션에 평가할 파일(정답셋이 생성된 파일)이 없습니다.")
|
||||
return
|
||||
|
||||
sorted_basenames = sorted(list(matched_files.keys()))
|
||||
|
||||
if st.session_state.eval_current_index >= len(sorted_basenames):
|
||||
st.session_state.eval_current_index = 0
|
||||
|
||||
st.sidebar.header("파일 선택")
|
||||
st.sidebar.selectbox(
|
||||
"평가할 파일을 선택하세요.",
|
||||
options=range(len(sorted_basenames)),
|
||||
format_func=lambda x: f"{x+1}. {sorted_basenames[x]}",
|
||||
index=st.session_state.eval_current_index,
|
||||
key="eval_selectbox_key",
|
||||
on_change=handle_selectbox_change,
|
||||
)
|
||||
|
||||
st.sidebar.header("보기 옵션")
|
||||
hide_document = st.sidebar.checkbox("원본 문서 숨기기", value=False)
|
||||
|
||||
st.sidebar.header("내보내기")
|
||||
results_df = generate_all_results_df(matched_files)
|
||||
if not results_df.empty:
|
||||
csv = results_df.to_csv(index=False).encode('utf-8')
|
||||
st.sidebar.download_button(
|
||||
label="전체 결과 CSV 다운로드",
|
||||
data=csv,
|
||||
file_name=f"evaluation_results_{seed}.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
else:
|
||||
st.sidebar.write("다운로드할 결과가 없습니다.")
|
||||
|
||||
|
||||
current_basename = sorted_basenames[st.session_state.eval_current_index]
|
||||
|
||||
nav_cols = st.columns([1, 5, 1])
|
||||
nav_cols[0].button(
|
||||
"◀ 이전",
|
||||
on_click=handle_nav_button,
|
||||
args=("prev", len(sorted_basenames)),
|
||||
use_container_width=True,
|
||||
)
|
||||
nav_cols[1].markdown(
|
||||
f"<h4 style='text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;'>{current_basename} ({st.session_state.eval_current_index + 1}/{len(sorted_basenames)})</h4>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
nav_cols[2].button(
|
||||
"다음 ▶",
|
||||
on_click=handle_nav_button,
|
||||
args=("next", len(sorted_basenames)),
|
||||
use_container_width=True,
|
||||
)
|
||||
st.markdown("---")
|
||||
|
||||
files_to_evaluate = matched_files[current_basename]
|
||||
|
||||
if hide_document:
|
||||
display_evaluation_for_file(files_to_evaluate)
|
||||
else:
|
||||
col1, col2 = st.columns([1, 1])
|
||||
with col1:
|
||||
st.header("📄 원본 문서")
|
||||
doc_file = files_to_evaluate["doc_file"]
|
||||
if doc_file.suffix.lower() == ".pdf":
|
||||
display_pdf(doc_file)
|
||||
else:
|
||||
st.image(str(doc_file), use_container_width=True)
|
||||
with col2:
|
||||
display_evaluation_for_file(files_to_evaluate)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user