Deepseek-OCR 환경 설정
This commit is contained in:
@@ -1,112 +0,0 @@
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from PIL import Image
|
||||
|
||||
# vLLM 및 모델 관련 import
|
||||
from vllm import AsyncLLMEngine, SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
|
||||
# DeepSeek-OCR 관련 로컬 import
|
||||
from deepseek_ocr import DeepseekOCRForCausalLM
|
||||
from process.image_process import DeepseekOCRProcessor
|
||||
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
|
||||
|
||||
# --- Configuration ---
|
||||
# Docker 환경에서는 환경 변수를 사용하거나, Dockerfile에서 모델을 다운로드하는 것이 좋습니다.
|
||||
# 여기서는 config.py의 기본값을 사용하되, 환경 변수로 재정의할 수 있도록 합니다.
|
||||
MODEL_PATH = os.environ.get("MODEL_PATH", "deepseek-ai/deepseek-vl-7b-base")
|
||||
# 참고: 실제 `config.py`는 로컬 경로를 사용하므로, 허깅페이스 모델 ID로 대체합니다.
|
||||
# 이 모델을 사용하려면 인터넷 연결이 필요하며, 처음 실행 시 다운로드됩니다.
|
||||
|
||||
# --- Model Initialization ---
|
||||
|
||||
# 1. 커스텀 모델 등록
|
||||
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
||||
|
||||
# 2. vLLM 엔진 설정
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=MODEL_PATH,
|
||||
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
||||
max_model_len=8192,
|
||||
enforce_eager=False,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=1, # 단일 GPU 사용
|
||||
gpu_memory_utilization=0.90, # GPU 메모리 사용률
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# 3. Deepseek OCR 프로세서 초기화
|
||||
processor = DeepseekOCRProcessor()
|
||||
|
||||
# 4. FastAPI 앱 초기화
|
||||
app = FastAPI()
|
||||
|
||||
# --- Pydantic Models ---
|
||||
class InferenceRequest(BaseModel):
|
||||
# Base64로 인코딩된 이미지 문자열
|
||||
base64_image: str
|
||||
|
||||
class InferenceResponse(BaseModel):
|
||||
text: str
|
||||
|
||||
# --- API Endpoints ---
|
||||
|
||||
@app.get("/")
|
||||
def health_check():
|
||||
return {"status": "DeepSeek-OCR service is running"}
|
||||
|
||||
@app.post("/process", response_model=InferenceResponse)
|
||||
async def process_image(request: InferenceRequest):
|
||||
"""
|
||||
Base64 인코딩된 이미지를 받아 OCR 추론을 수행합니다.
|
||||
"""
|
||||
try:
|
||||
# 1. Base64 이미지 디코딩
|
||||
image_data = base64.b64decode(request.base64_image)
|
||||
image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
||||
|
||||
# 2. 이미지 전처리
|
||||
prompt = "<image>"
|
||||
image_features = processor.tokenize_with_images(
|
||||
images=[image],
|
||||
bos=True,
|
||||
eos=True,
|
||||
cropping=False # CROP_MODE 기본값 사용
|
||||
)
|
||||
|
||||
# 3. 샘플링 파라미터 설정 (기존 스크립트 참조)
|
||||
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=30, window_size=90, whitelist_token_ids={128821, 128822})]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=8192,
|
||||
logits_processors=logits_processors,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# 4. vLLM으로 추론 실행
|
||||
request_id = f"dpsk-request-{int(time.time())}"
|
||||
vllm_request = {
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {"image": image_features}
|
||||
}
|
||||
|
||||
final_output = None
|
||||
async for request_output in engine.generate(vllm_request, sampling_params, request_id):
|
||||
# 스트리밍 결과의 마지막 최종본을 사용
|
||||
final_output = request_output
|
||||
|
||||
if final_output and final_output.outputs:
|
||||
generated_text = final_output.outputs[0].text
|
||||
return InferenceResponse(text=generated_text)
|
||||
else:
|
||||
raise Exception("Model generated no output.")
|
||||
|
||||
except Exception as e:
|
||||
# 실제 운영 환경에서는 로깅을 추가하는 것이 좋습니다.
|
||||
return {"error": f"An error occurred: {str(e)}"}, 500
|
||||
Reference in New Issue
Block a user