Files
deepseek_ocr/test/test.py
2025-11-06 14:37:03 +09:00

398 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import io
import json
import os
import re
import time
import config.model_settings as config
import fitz
import img2pdf
import numpy as np
from config.env_setup import setup_environment
from PIL import Image, ImageDraw, ImageFont, ImageOps
from services.deepseek_ocr import DeepseekOCRForCausalLM
from services.process.image_process import DeepseekOCRProcessor
from services.process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
setup_environment()
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
class Colors:
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
RESET = "\033[0m"
# --- PDF/Image Processing Functions (from run_dpsk_ocr_*.py) ---
def pdf_to_images_high_quality(pdf_path, dpi=144):
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ("RGBA", "LA"):
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
img = background
images.append(img)
pdf_document.close()
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != "RGB":
img = img.convert("RGB")
img_buffer = io.BytesIO()
img.save(img_buffer, format="JPEG", quality=95)
image_bytes_list.append(img_buffer.getvalue())
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
f.write(pdf_bytes)
except Exception as e:
print(f"Error creating PDF: {e}")
def re_match(text):
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = [m[0] for m in matches if "<|ref|>image<|/ref|>" in m[0]]
mathes_other = [m[0] for m in matches if "<|ref|>image<|/ref|>" not in m[0]]
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
return (label_type, cor_list)
except Exception as e:
print(f"Error extracting coordinates: {e}")
return None
def draw_bounding_boxes(image, refs, jdx=None):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
result = extract_coordinates_and_label(ref, image_width, image_height)
if not result:
continue
label_type, points_list = result
color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = [
int(p / 999 * (image_width if i % 2 == 0 else image_height))
for i, p in enumerate(points)
]
if label_type == "image":
try:
cropped = image.crop((x1, y1, x2, y2))
img_filename = (
f"{jdx}_{img_idx}.jpg" if jdx is not None else f"{img_idx}.jpg"
)
cropped.save(
os.path.join(config.OUTPUT_PATH, "images", img_filename)
)
img_idx += 1
except Exception as e:
print(f"Error cropping image: {e}")
width = 4 if label_type == "title" else 2
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
draw2.rectangle(
[x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1
)
text_x, text_y = x1, max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width, text_height = (
text_bbox[2] - text_bbox[0],
text_bbox[3] - text_bbox[1],
)
draw.rectangle(
[text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color)
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts, jdx=None):
return draw_bounding_boxes(image, ref_texts, jdx)
def load_image(image_path):
try:
image = Image.open(image_path).convert("RGB")
return ImageOps.exif_transpose(image)
except Exception as e:
print(f"Error loading image {image_path}: {e}")
return None
# --- Main OCR Processing Logic ---
def process_pdf(llm, sampling_params, pdf_path):
print(f"{Colors.GREEN}Processing PDF: {pdf_path}{Colors.RESET}")
base_name = os.path.basename(pdf_path)
file_name_without_ext = os.path.splitext(base_name)[0]
images = pdf_to_images_high_quality(pdf_path)
if not images:
print(
f"{Colors.YELLOW}Could not extract images from {pdf_path}. Skipping.{Colors.RESET}"
)
return
batch_inputs = []
processor = DeepseekOCRProcessor()
for image in tqdm(images, desc="Pre-processing PDF pages"):
batch_inputs.append(
{
"prompt": config.PROMPT,
"multi_modal_data": {
"image": processor.tokenize_with_images(
images=[image], bos=True, eos=True, cropping=config.CROP_MODE
)
},
}
)
start_time = time.time()
outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params)
end_time = time.time()
contents_det = ""
contents = ""
draw_images = []
for i, (output, img) in enumerate(zip(outputs_list, images)):
content = output.outputs[0].text
if "<end of sentence>" in content:
content = content.replace("<end of sentence>", "")
elif config.SKIP_REPEAT:
continue
page_num_separator = "\n<--- Page Split --->\n"
contents_det += content + page_num_separator
matches_ref, matches_images, mathes_other = re_match(content)
result_image = process_image_with_refs(img.copy(), matches_ref, jdx=i)
draw_images.append(result_image)
for idx, match in enumerate(matches_images):
content = content.replace(match, f"![](images/{i}_{idx}.jpg)\n")
for match in mathes_other:
content = (
content.replace(match, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n", "\n\n")
)
contents += content + page_num_separator
# Save results
result_json_path = os.path.join(
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
)
result_pdf_path = os.path.join(
config.OUTPUT_PATH, f"{file_name_without_ext}_layouts.pdf"
)
duration = end_time - start_time
output_data = {
"filename": base_name,
"model": {"ocr_model": "deepseek-ocr"},
"time": {
"duration_sec": f"{duration:.2f}",
"started_at": start_time,
"ended_at": end_time,
},
"parsed": contents,
}
with open(result_json_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=4)
pil_to_pdf_img2pdf(draw_images, result_pdf_path)
print(
f"{Colors.GREEN}Finished processing {pdf_path}. Results saved in {config.OUTPUT_PATH}{Colors.RESET}"
)
def process_image(llm, sampling_params, image_path):
print(f"{Colors.GREEN}Processing Image: {image_path}{Colors.RESET}")
base_name = os.path.basename(image_path)
file_name_without_ext = os.path.splitext(base_name)[0]
image = load_image(image_path)
if image is None:
return
processor = DeepseekOCRProcessor()
image_features = processor.tokenize_with_images(
images=[image], bos=True, eos=True, cropping=config.CROP_MODE
)
request = {
"prompt": config.PROMPT,
"multi_modal_data": {"image": image_features},
}
start_time = time.time()
outputs = llm.generate([request], sampling_params)
end_time = time.time()
result_out = outputs[0].outputs[0].text
print(result_out)
# Save results
result_json_path = os.path.join(
f"{config.OUTPUT_PATH}/result", f"{file_name_without_ext}.json"
)
result_image_path = os.path.join(
config.OUTPUT_PATH, f"{file_name_without_ext}_result_with_boxes.jpg"
)
matches_ref, matches_images, mathes_other = re_match(result_out)
result_image = process_image_with_refs(image.copy(), matches_ref)
processed_text = result_out
for idx, match in enumerate(matches_images):
processed_text = processed_text.replace(match, f"![](images/{idx}.jpg)\n")
for match in mathes_other:
processed_text = (
processed_text.replace(match, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n", "\n\n")
)
duration = end_time - start_time
output_data = {
"filename": base_name,
"model": {"ocr_model": "deepseek-ocr"},
"time": {
"duration_sec": f"{duration:.2f}",
"started_at": start_time,
"ended_at": end_time,
},
"parsed": processed_text,
}
with open(result_json_path, "w", encoding="utf-8") as f:
json.dump(output_data, f, ensure_ascii=False, indent=4)
result_image.save(result_image_path)
print(
f"{Colors.GREEN}Finished processing {image_path}. Results saved in {config.OUTPUT_PATH}{Colors.RESET}"
)
def main():
# --- Model Initialization ---
print(f"{Colors.BLUE}Initializing model...{Colors.RESET}")
llm = LLM(
model=config.MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=config.MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True,
)
logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822}
)
]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
include_stop_str_in_output=True,
)
print(f"{Colors.BLUE}Model initialized successfully.{Colors.RESET}")
# --- File Processing ---
input_dir = config.INPUT_PATH
output_dir = config.OUTPUT_PATH
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "result"), exist_ok=True)
if not os.path.isdir(input_dir):
print(
f"{Colors.RED}Error: Input directory not found at '{input_dir}'{Colors.RESET}"
)
return
print(f"Scanning for files in '{input_dir}'...")
for filename in sorted(os.listdir(input_dir)):
input_path = os.path.join(input_dir, filename)
if not os.path.isfile(input_path):
continue
file_extension = os.path.splitext(filename)[1].lower()
try:
if file_extension == ".pdf":
process_pdf(llm, sampling_params, input_path)
elif file_extension in [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"]:
process_image(llm, sampling_params, input_path)
else:
print(
f"{Colors.YELLOW}Skipping unsupported file type: {filename}{Colors.RESET}"
)
except Exception as e:
print(
f"{Colors.RED}An error occurred while processing {filename}: {e}{Colors.RESET}"
)
if __name__ == "__main__":
main()