Files
C.E.L_Slide_test2/src/block_search.py
kyeongmin 29f56187c0 Phase P~S 전체 작업물: 검증 스크립트, 블록 템플릿, 설계 문서, 코드 수정
포함 내용:
- Phase P/Q/R/S 설계 문서 (IMPROVEMENT-PHASE-*.md)
- 영역별 검증 스크립트 (scripts/verify_*.py, test_*.py)
- 블록 템플릿 추가 (cards, emphasis 변형)
- 코드 수정: block_search, content_editor, design_director, slide_measurer
- catalog.yaml 블록 목록 업데이트
- CLAUDE.md, PROGRESS.md, README.md 업데이트

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 08:38:06 +09:00

248 lines
7.9 KiB
Python

"""P2-A: FAISS 기반 블록 검색 모듈.
catalog.yaml 46개 블록 중 콘텐츠에 적합한 후보를 검색하여 반환한다.
디자인 팀장(Step B)의 프롬프트에 전체 catalog 대신 관련 블록만 전달하기 위함.
사용법:
from src.block_search import search_blocks_for_topics
candidates = search_blocks_for_topics(topics, top_k=8)
# → catalog.yaml 형식의 문자열 (팀장 프롬프트에 삽입)
fallback: 인덱스 없거나 검색 실패 시 → catalog.yaml 전문 반환 (기존 방식)
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
import numpy as np
logger = logging.getLogger(__name__)
PROJECT_ROOT = Path(__file__).parent.parent
INDEX_PATH = PROJECT_ROOT / "data" / "block_index.faiss"
META_PATH = PROJECT_ROOT / "data" / "block_metadata.json"
CATALOG_PATH = PROJECT_ROOT / "templates" / "catalog.yaml"
# Kei persona와 동일 모델
EMBEDDING_MODEL = "BAAI/bge-m3"
# 카테고리 목록 (최소 1개 보장용)
ALL_CATEGORIES = ["headers", "cards", "tables", "visuals", "emphasis", "media"]
# Lazy load
_index = None
_metadata: list[dict] | None = None
_model = None
_loaded = False
def _ensure_loaded() -> bool:
"""인덱스 + 모델을 lazy load한다. 성공 시 True."""
global _index, _metadata, _model, _loaded
if _loaded:
return _index is not None
_loaded = True # 재시도 방지
if not INDEX_PATH.exists() or not META_PATH.exists():
logger.warning(
f"블록 인덱스 없음: {INDEX_PATH}. "
f"scripts/build_block_index.py를 실행하세요. "
f"catalog 전문 fallback 사용."
)
return False
try:
import faiss
from sentence_transformers import SentenceTransformer
logger.info(f"블록 인덱스 로딩: {INDEX_PATH}")
_index = faiss.read_index(str(INDEX_PATH))
with open(META_PATH, encoding="utf-8") as f:
_metadata = json.load(f)
logger.info(f"임베딩 모델 로딩: {EMBEDDING_MODEL} (CPU)")
_model = SentenceTransformer(EMBEDDING_MODEL, device="cpu")
logger.info(
f"블록 검색 준비 완료: {_index.ntotal}개 벡터, "
f"{len(_metadata)}개 메타데이터"
)
return True
except Exception as e:
logger.warning(f"블록 인덱스 로드 실패: {e}. catalog 전문 fallback.")
_index = None
_metadata = None
_model = None
return False
def search_blocks(query: str, top_k: int = 8) -> list[dict]:
"""단일 쿼리로 관련 블록을 검색한다.
Args:
query: 검색 쿼리 (꼭지 제목+요약+역할)
top_k: 반환할 최대 블록 수
Returns:
관련 블록 메타데이터 목록 (score 포함)
"""
if not _ensure_loaded():
return []
q_embedding = _model.encode(
[query],
normalize_embeddings=True,
)
q_embedding = np.array(q_embedding, dtype=np.float32)
scores, indices = _index.search(q_embedding, min(top_k, _index.ntotal))
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0 or idx >= len(_metadata):
continue
block = dict(_metadata[idx])
block["search_score"] = float(score)
results.append(block)
return results
def search_blocks_for_topics(
topics: list[dict],
top_k_per_topic: int = 3,
total_max: int = 10,
) -> str:
"""여러 꼭지에 대해 검색하고, 중복 제거 + 카테고리 보장 후 문자열로 반환.
Args:
topics: 1단계 실장의 꼭지 분석 결과
top_k_per_topic: 꼭지당 검색 수
total_max: 최종 반환 최대 수
Returns:
catalog.yaml 형식의 문자열 (팀장 프롬프트에 삽입)
검색 실패 시 catalog.yaml 전문 반환 (fallback)
"""
if not _ensure_loaded():
return _fallback_full_catalog()
# 1. 꼭지별 검색
all_results: dict[str, dict] = {} # id → block (중복 제거)
for topic in topics:
query = _build_query(topic)
results = search_blocks(query, top_k=top_k_per_topic)
for block in results:
bid = block["id"]
if bid not in all_results:
all_results[bid] = block
else:
# 이미 있으면 더 높은 점수로 업데이트
if block["search_score"] > all_results[bid]["search_score"]:
all_results[bid] = block
# 2. 카테고리별 최소 1개 보장
found_categories = {b.get("category", "") for b in all_results.values()}
missing_categories = set(ALL_CATEGORIES) - found_categories
if missing_categories and _metadata:
for block in _metadata:
cat = block.get("category", "")
if cat in missing_categories:
if block["id"] not in all_results:
block_copy = dict(block)
block_copy["search_score"] = 0.0 # 카테고리 보장용
all_results[block["id"]] = block_copy
missing_categories.discard(cat)
if not missing_categories:
break
# 3. 점수순 정렬 + 최대 개수 제한
sorted_blocks = sorted(
all_results.values(),
key=lambda b: b.get("search_score", 0),
reverse=True,
)[:total_max]
# 4. 프롬프트용 문자열 생성
return _format_for_prompt(sorted_blocks)
def search_candidates_per_topic(
topics: list[dict],
top_k: int = 2,
) -> dict[int, list[dict]]:
"""Phase P: 각 topic별 FAISS 상위 후보를 반환한다.
Args:
topics: 1단계 꼭지 분석 결과
top_k: topic당 반환할 후보 수
Returns:
{topic_id: [블록 메타데이터 목록]} — 각 topic별 상위 top_k개
"""
if not _ensure_loaded():
return {}
result: dict[int, list[dict]] = {}
for topic in topics:
tid = topic.get("id")
if tid is None:
continue
query = _build_query(topic)
candidates = search_blocks(query, top_k=top_k + 2) # 여유분 확보 (중복 제거용)
result[tid] = candidates[:top_k]
logger.info(
f"[Phase P] topic별 FAISS 후보: "
+ ", ".join(f"t{tid}={[c['id'] for c in cs]}" for tid, cs in result.items())
)
return result
def _build_query(topic: dict) -> str:
"""꼭지 정보에서 검색 쿼리를 생성한다. (Phase M: 역할+관계+표현 추가)"""
parts = [
topic.get("title", ""),
topic.get("summary", ""),
f"역할: {topic.get('role', 'flow')}",
f"레이어: {topic.get('layer', 'core')}",
]
# Phase M: purpose, relation_type, expression_hint 추가
if topic.get("purpose"):
parts.append(f"목적: {topic['purpose']}")
if topic.get("relation_type"):
parts.append(f"관계: {topic['relation_type']}")
if topic.get("expression_hint"):
parts.append(f"표현: {topic['expression_hint']}")
if topic.get("content_type"):
parts.append(f"콘텐츠: {topic['content_type']}")
return ". ".join(p for p in parts if p)
def _format_for_prompt(blocks: list[dict]) -> str:
"""블록 목록을 팀장 프롬프트에 삽입할 문자열로 변환."""
lines = [f"# 관련 블록 후보 ({len(blocks)}개)\n"]
for block in blocks:
lines.append(f"- **{block['id']}** ({block.get('name', '')})")
lines.append(f" 시각: {block.get('visual', '')}")
lines.append(f" 사용: {block.get('when', '').strip()}")
lines.append(f" 금지: {block.get('not_for', '').strip()}")
lines.append(f" 높이: {block.get('height_cost', 'medium')}")
lines.append("")
return "\n".join(lines)
def _fallback_full_catalog() -> str:
"""검색 실패 시 catalog.yaml 전문을 반환한다 (기존 방식)."""
if CATALOG_PATH.exists():
return CATALOG_PATH.read_text(encoding="utf-8")
return "사용 가능한 블록 없음."