Compare commits
4 Commits
d14e3ea64a
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 732e7c8cc0 | |||
| 7749492ae7 | |||
| 8a6f2ae2d8 | |||
| f5ab36737a |
29
.env
Normal file
29
.env
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# --------------------------------------------------------------------------
|
||||||
|
# vLLM Engine Configuration
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
# 이 파일의 주석(#)을 제거하고 값을 수정하여 기본 엔진 설정을 재정의할 수 있습니다.
|
||||||
|
# 설정 가능한 변수 목록은 config/engine_settings.py 파일을 참고하세요.
|
||||||
|
|
||||||
|
# Hugging Face 모델 경로 또는 로컬 경로
|
||||||
|
# MODEL_PATH="deepseek-ai/DeepSeek-OCR"
|
||||||
|
|
||||||
|
# 텐서 병렬 처리 크기 (Multi-GPU 환경에서 사용)
|
||||||
|
# TENSOR_PARALLEL_SIZE=1
|
||||||
|
|
||||||
|
# 최대 GPU 메모리 사용률 (0.0 ~ 1.0)
|
||||||
|
# GPU_MEMORY_UTILIZATION=0.15
|
||||||
|
|
||||||
|
# KV 캐시 블록 크기
|
||||||
|
# BLOCK_SIZE=256
|
||||||
|
|
||||||
|
# 최대 모델 길이
|
||||||
|
# MAX_MODEL_LEN=8192
|
||||||
|
|
||||||
|
# Eager 모드 강제 실행 여부 (True / False)
|
||||||
|
# ENFORCE_EAGER=False
|
||||||
|
|
||||||
|
# 원격 코드 신뢰 여부 (True / False)
|
||||||
|
# TRUST_REMOTE_CODE=True
|
||||||
|
|
||||||
|
# 사용자 정의 모델 아키텍처 (쉼표로 구분)
|
||||||
|
# ARCHITECTURES="DeepseekOCRForCausalLM"
|
||||||
35
config/engine_settings.py
Normal file
35
config/engine_settings.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def _str_to_bool(value: str) -> bool:
|
||||||
|
"""환경 변수(문자열)를 boolean 값으로 변환합니다."""
|
||||||
|
return value.lower() in ("true", "1", "t")
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
# vLLM Engine Configuration
|
||||||
|
# .env 파일에 동일한 이름의 환경 변수를 설정하여 아래 기본값을 재정의할 수 있습니다.
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# 사용자 정의 모델 아키텍처
|
||||||
|
# 여러 개일 경우 쉼표로 구분: "Arch1,Arch2"
|
||||||
|
_architectures_str = os.getenv("ARCHITECTURES", "DeepseekOCRForCausalLM")
|
||||||
|
ARCHITECTURES = [arch.strip() for arch in _architectures_str.split(",")]
|
||||||
|
|
||||||
|
# KV 캐시 블록 크기
|
||||||
|
BLOCK_SIZE = int(os.getenv("BLOCK_SIZE", "256"))
|
||||||
|
|
||||||
|
# 최대 모델 길이
|
||||||
|
MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
|
||||||
|
|
||||||
|
# Eager 모드 강제 실행 여부
|
||||||
|
ENFORCE_EAGER = _str_to_bool(os.getenv("ENFORCE_EAGER", "False"))
|
||||||
|
|
||||||
|
# 원격 코드 신뢰 여부
|
||||||
|
TRUST_REMOTE_CODE = _str_to_bool(os.getenv("TRUST_REMOTE_CODE", "True"))
|
||||||
|
|
||||||
|
# 텐서 병렬 처리 크기
|
||||||
|
TENSOR_PARALLEL_SIZE = int(os.getenv("TENSOR_PARALLEL_SIZE", "1"))
|
||||||
|
|
||||||
|
# GPU 메모리 사용률
|
||||||
|
GPU_MEMORY_UTILIZATION = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.15"))
|
||||||
@@ -11,3 +11,4 @@ matplotlib
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
python-multipart
|
python-multipart
|
||||||
|
python-dotenv
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ async def perform_ocr(
|
|||||||
raise HTTPException(status_code=400, detail="파일 내용이 비어있습니다.")
|
raise HTTPException(status_code=400, detail="파일 내용이 비어있습니다.")
|
||||||
|
|
||||||
# 모든 처리 로직을 OCR 엔진에 위임
|
# 모든 처리 로직을 OCR 엔진에 위임
|
||||||
result = process_document(
|
result = await process_document(
|
||||||
file_bytes=file_content,
|
file_bytes=file_content,
|
||||||
content_type=document.content_type,
|
content_type=document.content_type,
|
||||||
filename=document.filename,
|
filename=document.filename,
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
import fitz
|
import fitz
|
||||||
from config.model_settings import CROP_MODE, MODEL_PATH, PROMPT
|
from config import engine_settings, model_settings
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
from vllm import AsyncLLMEngine, SamplingParams
|
from vllm import AsyncLLMEngine, SamplingParams
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.model_executor.models.registry import ModelRegistry
|
from vllm.model_executor.models.registry import ModelRegistry
|
||||||
@@ -29,14 +30,14 @@ async def init_engine():
|
|||||||
return
|
return
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs(
|
engine_args = AsyncEngineArgs(
|
||||||
model=MODEL_PATH,
|
model=model_settings.MODEL_PATH,
|
||||||
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
hf_overrides={"architectures": engine_settings.ARCHITECTURES},
|
||||||
block_size=256,
|
block_size=engine_settings.BLOCK_SIZE,
|
||||||
max_model_len=8192,
|
max_model_len=engine_settings.MAX_MODEL_LEN,
|
||||||
enforce_eager=False,
|
enforce_eager=engine_settings.ENFORCE_EAGER,
|
||||||
trust_remote_code=True,
|
trust_remote_code=engine_settings.TRUST_REMOTE_CODE,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=engine_settings.TENSOR_PARALLEL_SIZE,
|
||||||
gpu_memory_utilization=0.75,
|
gpu_memory_utilization=engine_settings.GPU_MEMORY_UTILIZATION,
|
||||||
)
|
)
|
||||||
_engine = AsyncLLMEngine.from_engine_args(engine_args)
|
_engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
@@ -52,34 +53,89 @@ sampling_params = SamplingParams(
|
|||||||
processor = DeepseekOCRProcessor()
|
processor = DeepseekOCRProcessor()
|
||||||
|
|
||||||
# --------------------------------------------------------------------------
|
# --------------------------------------------------------------------------
|
||||||
# 2. 핵심 처리 함수
|
# 2. 후처리 함수 (공식 코드 기반)
|
||||||
# --------------------------------------------------------------------------
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def _process_single_image(image: Image.Image) -> str:
|
def _postprocess_text(text: str, page_num: int = 0) -> str:
|
||||||
"""단일 PIL 이미지를 받아 OCR을 수행하고 텍스트를 반환합니다."""
|
"""
|
||||||
|
모델의 원본 출력에서 태그를 제거/변경하고 텍스트를 정리합니다.
|
||||||
|
(test/test.py의 후처리 로직 기반으로 수정)
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 1. 정규식으로 모든 ref/det 태그 블록을 찾음
|
||||||
|
# 패턴은 (전체 매치, ref 내용, det 내용)을 캡처
|
||||||
|
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
# 2. 전체 매치된 문자열을 이미지 태그와 기타 태그로 분리
|
||||||
|
matches_images = []
|
||||||
|
matches_other = []
|
||||||
|
for match_tuple in matches:
|
||||||
|
full_match_str = match_tuple[0] # 전체 매치된 부분
|
||||||
|
ref_content = match_tuple[1] # <|ref|> 안의 내용
|
||||||
|
|
||||||
|
if "image" in ref_content:
|
||||||
|
matches_images.append(full_match_str)
|
||||||
|
else:
|
||||||
|
matches_other.append(full_match_str)
|
||||||
|
|
||||||
|
processed_text = text
|
||||||
|
|
||||||
|
# 3. 이미지 태그는 마크다운 링크로 대체
|
||||||
|
for idx, img_tag in enumerate(matches_images):
|
||||||
|
img_link = f"\n"
|
||||||
|
processed_text = processed_text.replace(img_tag, img_link)
|
||||||
|
|
||||||
|
# 4. 이미지가 아닌 다른 모든 태그는 제거
|
||||||
|
for other_tag in matches_other:
|
||||||
|
processed_text = processed_text.replace(other_tag, "")
|
||||||
|
|
||||||
|
# 5. 특수 문자, 불필요한 토큰 및 추가 공백 정리
|
||||||
|
processed_text = (
|
||||||
|
processed_text.replace("<|end of sentence|>", "")
|
||||||
|
.replace("\\coloneqq", ":=")
|
||||||
|
.replace("\\eqqcolon", "=:")
|
||||||
|
.replace("\n\n\n", "\n\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_text
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
# 3. 핵심 처리 함수
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_single_image(image: Image.Image, page_num: int = 0) -> str:
|
||||||
|
"""단일 PIL 이미지를 받아 OCR을 수행하고 후처리된 텍스트를 반환합니다."""
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
raise RuntimeError("vLLM engine not initialized yet")
|
raise RuntimeError("vLLM engine not initialized yet")
|
||||||
if "<image>" not in PROMPT:
|
if "<image>" not in model_settings.PROMPT:
|
||||||
raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.")
|
raise ValueError("프롬프트에 '<image>' 토큰이 없어 OCR을 수행할 수 없습니다.")
|
||||||
|
|
||||||
image_features = processor.tokenize_with_images(
|
image_features = processor.tokenize_with_images(
|
||||||
images=[image], bos=True, eos=True, cropping=CROP_MODE
|
images=[image], bos=True, eos=True, cropping=model_settings.CROP_MODE
|
||||||
)
|
)
|
||||||
|
|
||||||
request = {"prompt": PROMPT, "multi_modal_data": {"image": image_features}}
|
request = {"prompt": model_settings.PROMPT, "multi_modal_data": {"image": image_features}}
|
||||||
request_id = f"request-{asyncio.get_running_loop().time()}"
|
request_id = f"request-{asyncio.get_running_loop().time()}"
|
||||||
|
|
||||||
final_output = ""
|
raw_output = ""
|
||||||
async for request_output in _engine.generate(request, sampling_params, request_id):
|
async for request_output in _engine.generate(request, sampling_params, request_id):
|
||||||
if request_output.outputs:
|
if request_output.outputs:
|
||||||
final_output = request_output.outputs[0].text
|
raw_output = request_output.outputs[0].text
|
||||||
|
|
||||||
return final_output
|
# 후처리 적용 (페이지 번호 전달)
|
||||||
|
clean_text = _postprocess_text(raw_output, page_num)
|
||||||
|
return clean_text
|
||||||
|
|
||||||
|
|
||||||
def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
|
def _pdf_to_images_high_quality(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
|
||||||
"""PDF 바이트를 받아 페이지별 PIL 이미지 리스트를 반환합니다."""
|
"""PDF 바이트를 받아 페이지별 고품질 PIL 이미지 리스트를 반환합니다."""
|
||||||
images = []
|
images = []
|
||||||
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
|
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
|
||||||
zoom = dpi / 72.0
|
zoom = dpi / 72.0
|
||||||
@@ -88,8 +144,14 @@ def _pdf_to_images(pdf_bytes: bytes, dpi=144) -> list[Image.Image]:
|
|||||||
for page_num in range(pdf_document.page_count):
|
for page_num in range(pdf_document.page_count):
|
||||||
page = pdf_document[page_num]
|
page = pdf_document[page_num]
|
||||||
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
|
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
|
||||||
|
Image.MAX_IMAGE_PIXELS = None
|
||||||
img_data = pixmap.tobytes("png")
|
img_data = pixmap.tobytes("png")
|
||||||
img = Image.open(io.BytesIO(img_data))
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
|
||||||
|
if img.mode in ("RGBA", "LA"):
|
||||||
|
background = Image.new("RGB", img.size, (255, 255, 255))
|
||||||
|
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
|
||||||
|
img = background
|
||||||
images.append(img)
|
images.append(img)
|
||||||
|
|
||||||
pdf_document.close()
|
pdf_document.close()
|
||||||
@@ -100,26 +162,38 @@ async def process_document(file_bytes: bytes, content_type: str, filename: str)
|
|||||||
"""
|
"""
|
||||||
업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다.
|
업로드된 파일(이미지 또는 PDF)을 처리하여 OCR 결과를 반환합니다.
|
||||||
"""
|
"""
|
||||||
if content_type.startswith("image/"):
|
# Content-Type이 generic할 경우, 파일 확장자로 타입을 유추
|
||||||
|
inferred_content_type = content_type
|
||||||
|
if content_type == "application/octet-stream":
|
||||||
|
if filename.lower().endswith(".pdf"):
|
||||||
|
inferred_content_type = "application/pdf"
|
||||||
|
elif any(
|
||||||
|
filename.lower().endswith(ext)
|
||||||
|
for ext in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]
|
||||||
|
):
|
||||||
|
inferred_content_type = "image/jpeg" # 구체적인 타입은 중요하지 않음
|
||||||
|
|
||||||
|
if inferred_content_type.startswith("image/"):
|
||||||
try:
|
try:
|
||||||
image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
|
image = Image.open(io.BytesIO(file_bytes))
|
||||||
|
image = ImageOps.exif_transpose(image).convert("RGB")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}")
|
raise ValueError(f"이미지 파일을 여는 데 실패했습니다: {e}")
|
||||||
|
|
||||||
result_text = await _process_single_image(image)
|
# 단일 이미지는 페이지 번호를 0으로 간주
|
||||||
|
result_text = await _process_single_image(image, page_num=0)
|
||||||
return {"filename": filename, "text": result_text}
|
return {"filename": filename, "text": result_text}
|
||||||
|
|
||||||
elif content_type == "application/pdf":
|
elif inferred_content_type == "application/pdf":
|
||||||
try:
|
try:
|
||||||
images = _pdf_to_images(file_bytes)
|
images = _pdf_to_images_high_quality(file_bytes)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}")
|
raise ValueError(f"PDF 파일을 처리하는 데 실패했습니다: {e}")
|
||||||
|
|
||||||
# 각 페이지를 비동기적으로 처리
|
# 각 페이지를 비동기적으로 처리 (페이지 번호 전달)
|
||||||
tasks = [_process_single_image(img) for img in images]
|
tasks = [_process_single_image(img, page_num=i) for i, img in enumerate(images)]
|
||||||
page_results = await asyncio.gather(*tasks)
|
page_results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# 페이지 구분자를 넣어 전체 텍스트 합치기
|
|
||||||
full_text = "\n<--- Page Split --->\n".join(page_results)
|
full_text = "\n<--- Page Split --->\n".join(page_results)
|
||||||
return {"filename": filename, "text": full_text, "page_count": len(images)}
|
return {"filename": filename, "text": full_text, "page_count": len(images)}
|
||||||
|
|
||||||
|
|||||||
12
test/test.py
12
test/test.py
@@ -228,10 +228,10 @@ def process_pdf(llm, sampling_params, pdf_path):
|
|||||||
contents += content + page_num_separator
|
contents += content + page_num_separator
|
||||||
|
|
||||||
# Save results
|
# Save results
|
||||||
json_path = os.path.join(
|
result_json_path = os.path.join(
|
||||||
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
|
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
|
||||||
)
|
)
|
||||||
pdf_out_path = os.path.join(
|
result_pdf_path = os.path.join(
|
||||||
config.OUTPUT_PATH, f"{file_name_without_ext}_layouts.pdf"
|
config.OUTPUT_PATH, f"{file_name_without_ext}_layouts.pdf"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -247,10 +247,10 @@ def process_pdf(llm, sampling_params, pdf_path):
|
|||||||
"parsed": contents,
|
"parsed": contents,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
with open(result_json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
pil_to_pdf_img2pdf(draw_images, pdf_out_path)
|
pil_to_pdf_img2pdf(draw_images, result_pdf_path)
|
||||||
print(
|
print(
|
||||||
f"{Colors.GREEN}Finished processing {pdf_path}. Results saved in {config.OUTPUT_PATH}{Colors.RESET}"
|
f"{Colors.GREEN}Finished processing {pdf_path}. Results saved in {config.OUTPUT_PATH}{Colors.RESET}"
|
||||||
)
|
)
|
||||||
@@ -283,7 +283,7 @@ def process_image(llm, sampling_params, image_path):
|
|||||||
print(result_out)
|
print(result_out)
|
||||||
|
|
||||||
# Save results
|
# Save results
|
||||||
output_json_path = os.path.join(
|
result_json_path = os.path.join(
|
||||||
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
|
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
|
||||||
)
|
)
|
||||||
result_image_path = os.path.join(
|
result_image_path = os.path.join(
|
||||||
@@ -316,7 +316,7 @@ def process_image(llm, sampling_params, image_path):
|
|||||||
"parsed": processed_text,
|
"parsed": processed_text,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(output_json_path, "w", encoding="utf-8") as f:
|
with open(result_json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
json.dump(output_data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
result_image.save(result_image_path)
|
result_image.save(result_image_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user