import argparse import sys import os import time import json from pathlib import Path import cv2 import numpy as np import torch from PIL import Image # Add server to path so we can import sam3 locally server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server")) models_path = os.path.join(server_path, "app", "models") if server_path not in sys.path: sys.path.insert(0, server_path) if models_path not in sys.path: sys.path.insert(0, models_path) from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor # PaddleOCR will be imported inside the function to avoid errors if not installed def run_paddleocr(image_np): # Disable PIR API and oneDNN to avoid NotImplementedError on PaddlePaddle 3.x Windows import os as _os _os.environ["FLAGS_use_mkldnn"] = "0" _os.environ["PADDLE_WITH_MKLDNN"] = "0" _os.environ["FLAGS_enable_pir_api"] = "0" # Disable the new PIR API which causes this error from paddleocr import PaddleOCR # Force use_gpu=False if oneDNN is failing on CPU, or let it detect. # On Windows, sometimes CPU + oneDNN is the default and it fails. ocr = PaddleOCR(use_textline_orientation=True, lang='korean', use_gpu=torch.cuda.is_available()) result = ocr.ocr(image_np, cls=True) # Use .ocr() instead of .predict() for better compatibility return result def build_point_grid(n_per_side: int) -> np.ndarray: offset = 1.0 / (2 * n_per_side) points_one_side = np.linspace(offset, 1 - offset, n_per_side) pts_x, pts_y = np.meshgrid(points_one_side, points_one_side) grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1) return grid def mask_iou(mask1, mask2): inter = np.logical_and(mask1, mask2).sum() union = np.logical_or(mask1, mask2).sum() if union == 0: return 0 return inter / union def get_receipt_mask(image_bgr, model_path, points_per_side=32): print("Loading SAM3 Model for receipt detection...") device = "cuda" if torch.cuda.is_available() else "cpu" bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz") model = build_sam3_image_model( bpe_path=bpe_path, device=device, checkpoint_path=model_path, ) processor = Sam3Processor(model, confidence_threshold=0.7, device=device) image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image_rgb) state = processor.set_image(pil_image) grid_points = build_point_grid(points_per_side) masks = [] scores = [] print(f"Sampling {len(grid_points)} points...") for i, (nx, ny) in enumerate(grid_points): processor.reset_all_prompts(state) state = processor.add_point_prompt(point=[nx, ny], label=True, state=state) if "masks" in state and len(state["masks"]) > 0: best_idx = torch.argmax(state["scores"]) mask = state["masks"][best_idx].cpu().numpy() score = state["scores"][best_idx].item() if score > 0.8: masks.append(mask) scores.append(score) if not masks: return None # Pick the largest mask that covers a significant area (receipt is usually big) areas = [m.sum() for m in masks] # Simple heuristic: largest mask best_mask_idx = np.argmax(areas) return masks[best_mask_idx] def crop_and_warp(image, mask): # Ensure mask is 2D mask = np.squeeze(mask) if mask.ndim != 2: print(f"Warning: Mask has unexpected dimensions {mask.shape}, trying to flatten...") if mask.ndim == 3: mask = mask[0] # Find contours mask_uint8 = (mask > 0).astype(np.uint8) * 255 if np.sum(mask_uint8) == 0: print("Warning: Mask is empty.") return image contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return image cnt = max(contours, key=cv2.contourArea) # Get bounding box x, y, w, h = cv2.boundingRect(cnt) # Try to find 4 corners for perspective transform epsilon = 0.02 * cv2.arcLength(cnt, True) approx = cv2.approxPolyDP(cnt, epsilon, True) if len(approx) == 4: print("Found 4 corners, applying perspective transform...") pts = approx.reshape(4, 2) # Sort points: top-left, top-right, bottom-right, bottom-left rect = np.zeros((4, 2), dtype="float32") s = pts.sum(axis=1) rect[0] = pts[np.argmin(s)] rect[2] = pts[np.argmax(s)] diff = np.diff(pts, axis=1) rect[1] = pts[np.argmin(diff)] rect[3] = pts[np.argmax(diff)] (tl, tr, br, bl) = rect widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) maxWidth = max(int(widthA), int(widthB)) heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) maxHeight = max(int(heightA), int(heightB)) dst = np.array([ [0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32") M = cv2.getPerspectiveTransform(rect, dst) warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) return warped else: print("Could not find 4 clear corners, just cropping to bounding box.") # Create a mask image to black out background masked_img = cv2.bitwise_and(image, image, mask=mask_uint8) crop = masked_img[y:y+h, x:x+w] return crop def main(): parser = argparse.ArgumentParser() parser.add_argument("--input", required=True, help="Input image path") parser.add_argument("--output_dir", default="output/ocr", help="Output directory") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) # Read image buf = np.fromfile(args.input, dtype=np.uint8) image = cv2.imdecode(buf, cv2.IMREAD_COLOR) if image is None: print(f"Failed to read {args.input}") return model_path = os.path.join(server_path, "sam3.pt") # 1. SAM3 Masking mask = get_receipt_mask(image, model_path) if mask is None: print("No receipt-like object found.") return # 2. Cropping / Warping processed_img = crop_and_warp(image, mask) # Save processed image for debugging processed_path = os.path.join(args.output_dir, "processed_receipt.jpg") cv2.imwrite(processed_path, processed_img) print(f"Saved processed image to {processed_path}") # 3. PaddleOCR print("Running PaddleOCR...") ocr_results = run_paddleocr(processed_img) # 4. Save results output_json = os.path.join(args.output_dir, "ocr_results.json") with open(output_json, "w", encoding="utf-8") as f: json.dump(ocr_results, f, ensure_ascii=False, indent=2) print(f"OCR results saved to {output_json}") # Print summary if ocr_results: print("\n--- OCR Extracted Text ---") for page in ocr_results: if page is None: continue # New v3.x format: list of dicts with 'rec_text' key if isinstance(page, dict): text = page.get('rec_text', '') score = page.get('rec_score', 0) if text: print(f"{text} (conf: {score:.2f})") elif isinstance(page, list): for line in page: if line and isinstance(line, list) and len(line) >= 2: print(line[1][0]) print("--------------------------") if __name__ == "__main__": main()