프로젝트 분리 이동

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
minsung
2026-05-20 14:28:27 +09:00
commit ccba1266b5
24 changed files with 7900 additions and 0 deletions

219
tools/sam3_receipt_ocr.py Normal file
View File

@@ -0,0 +1,219 @@
import argparse
import sys
import os
import time
import json
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
# Add server to path so we can import sam3 locally
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
models_path = os.path.join(server_path, "app", "models")
if server_path not in sys.path:
sys.path.insert(0, server_path)
if models_path not in sys.path:
sys.path.insert(0, models_path)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# PaddleOCR will be imported inside the function to avoid errors if not installed
def run_paddleocr(image_np):
# Disable PIR API and oneDNN to avoid NotImplementedError on PaddlePaddle 3.x Windows
import os as _os
_os.environ["FLAGS_use_mkldnn"] = "0"
_os.environ["PADDLE_WITH_MKLDNN"] = "0"
_os.environ["FLAGS_enable_pir_api"] = "0" # Disable the new PIR API which causes this error
from paddleocr import PaddleOCR
# Force use_gpu=False if oneDNN is failing on CPU, or let it detect.
# On Windows, sometimes CPU + oneDNN is the default and it fails.
ocr = PaddleOCR(use_textline_orientation=True, lang='korean', use_gpu=torch.cuda.is_available())
result = ocr.ocr(image_np, cls=True) # Use .ocr() instead of .predict() for better compatibility
return result
def build_point_grid(n_per_side: int) -> np.ndarray:
offset = 1.0 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
return grid
def mask_iou(mask1, mask2):
inter = np.logical_and(mask1, mask2).sum()
union = np.logical_or(mask1, mask2).sum()
if union == 0: return 0
return inter / union
def get_receipt_mask(image_bgr, model_path, points_per_side=32):
print("Loading SAM3 Model for receipt detection...")
device = "cuda" if torch.cuda.is_available() else "cpu"
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
model = build_sam3_image_model(
bpe_path=bpe_path,
device=device,
checkpoint_path=model_path,
)
processor = Sam3Processor(model, confidence_threshold=0.7, device=device)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image_rgb)
state = processor.set_image(pil_image)
grid_points = build_point_grid(points_per_side)
masks = []
scores = []
print(f"Sampling {len(grid_points)} points...")
for i, (nx, ny) in enumerate(grid_points):
processor.reset_all_prompts(state)
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
if "masks" in state and len(state["masks"]) > 0:
best_idx = torch.argmax(state["scores"])
mask = state["masks"][best_idx].cpu().numpy()
score = state["scores"][best_idx].item()
if score > 0.8:
masks.append(mask)
scores.append(score)
if not masks:
return None
# Pick the largest mask that covers a significant area (receipt is usually big)
areas = [m.sum() for m in masks]
# Simple heuristic: largest mask
best_mask_idx = np.argmax(areas)
return masks[best_mask_idx]
def crop_and_warp(image, mask):
# Ensure mask is 2D
mask = np.squeeze(mask)
if mask.ndim != 2:
print(f"Warning: Mask has unexpected dimensions {mask.shape}, trying to flatten...")
if mask.ndim == 3: mask = mask[0]
# Find contours
mask_uint8 = (mask > 0).astype(np.uint8) * 255
if np.sum(mask_uint8) == 0:
print("Warning: Mask is empty.")
return image
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return image
cnt = max(contours, key=cv2.contourArea)
# Get bounding box
x, y, w, h = cv2.boundingRect(cnt)
# Try to find 4 corners for perspective transform
epsilon = 0.02 * cv2.arcLength(cnt, True)
approx = cv2.approxPolyDP(cnt, epsilon, True)
if len(approx) == 4:
print("Found 4 corners, applying perspective transform...")
pts = approx.reshape(4, 2)
# Sort points: top-left, top-right, bottom-right, bottom-left
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
(tl, tr, br, bl) = rect
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
else:
print("Could not find 4 clear corners, just cropping to bounding box.")
# Create a mask image to black out background
masked_img = cv2.bitwise_and(image, image, mask=mask_uint8)
crop = masked_img[y:y+h, x:x+w]
return crop
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True, help="Input image path")
parser.add_argument("--output_dir", default="output/ocr", help="Output directory")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Read image
buf = np.fromfile(args.input, dtype=np.uint8)
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image is None:
print(f"Failed to read {args.input}")
return
model_path = os.path.join(server_path, "sam3.pt")
# 1. SAM3 Masking
mask = get_receipt_mask(image, model_path)
if mask is None:
print("No receipt-like object found.")
return
# 2. Cropping / Warping
processed_img = crop_and_warp(image, mask)
# Save processed image for debugging
processed_path = os.path.join(args.output_dir, "processed_receipt.jpg")
cv2.imwrite(processed_path, processed_img)
print(f"Saved processed image to {processed_path}")
# 3. PaddleOCR
print("Running PaddleOCR...")
ocr_results = run_paddleocr(processed_img)
# 4. Save results
output_json = os.path.join(args.output_dir, "ocr_results.json")
with open(output_json, "w", encoding="utf-8") as f:
json.dump(ocr_results, f, ensure_ascii=False, indent=2)
print(f"OCR results saved to {output_json}")
# Print summary
if ocr_results:
print("\n--- OCR Extracted Text ---")
for page in ocr_results:
if page is None:
continue
# New v3.x format: list of dicts with 'rec_text' key
if isinstance(page, dict):
text = page.get('rec_text', '')
score = page.get('rec_score', 0)
if text:
print(f"{text} (conf: {score:.2f})")
elif isinstance(page, list):
for line in page:
if line and isinstance(line, list) and len(line) >= 2:
print(line[1][0])
print("--------------------------")
if __name__ == "__main__":
main()