Files
railway-client/tools/sam3_receipt_ocr.py
minsung ccba1266b5 프로젝트 분리 이동
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-20 14:28:27 +09:00

220 lines
7.8 KiB
Python

import argparse
import sys
import os
import time
import json
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
# Add server to path so we can import sam3 locally
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
models_path = os.path.join(server_path, "app", "models")
if server_path not in sys.path:
sys.path.insert(0, server_path)
if models_path not in sys.path:
sys.path.insert(0, models_path)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# PaddleOCR will be imported inside the function to avoid errors if not installed
def run_paddleocr(image_np):
# Disable PIR API and oneDNN to avoid NotImplementedError on PaddlePaddle 3.x Windows
import os as _os
_os.environ["FLAGS_use_mkldnn"] = "0"
_os.environ["PADDLE_WITH_MKLDNN"] = "0"
_os.environ["FLAGS_enable_pir_api"] = "0" # Disable the new PIR API which causes this error
from paddleocr import PaddleOCR
# Force use_gpu=False if oneDNN is failing on CPU, or let it detect.
# On Windows, sometimes CPU + oneDNN is the default and it fails.
ocr = PaddleOCR(use_textline_orientation=True, lang='korean', use_gpu=torch.cuda.is_available())
result = ocr.ocr(image_np, cls=True) # Use .ocr() instead of .predict() for better compatibility
return result
def build_point_grid(n_per_side: int) -> np.ndarray:
offset = 1.0 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
return grid
def mask_iou(mask1, mask2):
inter = np.logical_and(mask1, mask2).sum()
union = np.logical_or(mask1, mask2).sum()
if union == 0: return 0
return inter / union
def get_receipt_mask(image_bgr, model_path, points_per_side=32):
print("Loading SAM3 Model for receipt detection...")
device = "cuda" if torch.cuda.is_available() else "cpu"
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
model = build_sam3_image_model(
bpe_path=bpe_path,
device=device,
checkpoint_path=model_path,
)
processor = Sam3Processor(model, confidence_threshold=0.7, device=device)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image_rgb)
state = processor.set_image(pil_image)
grid_points = build_point_grid(points_per_side)
masks = []
scores = []
print(f"Sampling {len(grid_points)} points...")
for i, (nx, ny) in enumerate(grid_points):
processor.reset_all_prompts(state)
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
if "masks" in state and len(state["masks"]) > 0:
best_idx = torch.argmax(state["scores"])
mask = state["masks"][best_idx].cpu().numpy()
score = state["scores"][best_idx].item()
if score > 0.8:
masks.append(mask)
scores.append(score)
if not masks:
return None
# Pick the largest mask that covers a significant area (receipt is usually big)
areas = [m.sum() for m in masks]
# Simple heuristic: largest mask
best_mask_idx = np.argmax(areas)
return masks[best_mask_idx]
def crop_and_warp(image, mask):
# Ensure mask is 2D
mask = np.squeeze(mask)
if mask.ndim != 2:
print(f"Warning: Mask has unexpected dimensions {mask.shape}, trying to flatten...")
if mask.ndim == 3: mask = mask[0]
# Find contours
mask_uint8 = (mask > 0).astype(np.uint8) * 255
if np.sum(mask_uint8) == 0:
print("Warning: Mask is empty.")
return image
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return image
cnt = max(contours, key=cv2.contourArea)
# Get bounding box
x, y, w, h = cv2.boundingRect(cnt)
# Try to find 4 corners for perspective transform
epsilon = 0.02 * cv2.arcLength(cnt, True)
approx = cv2.approxPolyDP(cnt, epsilon, True)
if len(approx) == 4:
print("Found 4 corners, applying perspective transform...")
pts = approx.reshape(4, 2)
# Sort points: top-left, top-right, bottom-right, bottom-left
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
(tl, tr, br, bl) = rect
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
else:
print("Could not find 4 clear corners, just cropping to bounding box.")
# Create a mask image to black out background
masked_img = cv2.bitwise_and(image, image, mask=mask_uint8)
crop = masked_img[y:y+h, x:x+w]
return crop
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True, help="Input image path")
parser.add_argument("--output_dir", default="output/ocr", help="Output directory")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Read image
buf = np.fromfile(args.input, dtype=np.uint8)
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if image is None:
print(f"Failed to read {args.input}")
return
model_path = os.path.join(server_path, "sam3.pt")
# 1. SAM3 Masking
mask = get_receipt_mask(image, model_path)
if mask is None:
print("No receipt-like object found.")
return
# 2. Cropping / Warping
processed_img = crop_and_warp(image, mask)
# Save processed image for debugging
processed_path = os.path.join(args.output_dir, "processed_receipt.jpg")
cv2.imwrite(processed_path, processed_img)
print(f"Saved processed image to {processed_path}")
# 3. PaddleOCR
print("Running PaddleOCR...")
ocr_results = run_paddleocr(processed_img)
# 4. Save results
output_json = os.path.join(args.output_dir, "ocr_results.json")
with open(output_json, "w", encoding="utf-8") as f:
json.dump(ocr_results, f, ensure_ascii=False, indent=2)
print(f"OCR results saved to {output_json}")
# Print summary
if ocr_results:
print("\n--- OCR Extracted Text ---")
for page in ocr_results:
if page is None:
continue
# New v3.x format: list of dicts with 'rec_text' key
if isinstance(page, dict):
text = page.get('rec_text', '')
score = page.get('rec_score', 0)
if text:
print(f"{text} (conf: {score:.2f})")
elif isinstance(page, list):
for line in page:
if line and isinstance(line, list) and len(line) >= 2:
print(line[1][0])
print("--------------------------")
if __name__ == "__main__":
main()