220 lines
7.8 KiB
Python
220 lines
7.8 KiB
Python
import argparse
|
|
import sys
|
|
import os
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
|
|
# Add server to path so we can import sam3 locally
|
|
server_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "X-AnyLabeling-Server"))
|
|
models_path = os.path.join(server_path, "app", "models")
|
|
if server_path not in sys.path:
|
|
sys.path.insert(0, server_path)
|
|
if models_path not in sys.path:
|
|
sys.path.insert(0, models_path)
|
|
|
|
from sam3.model_builder import build_sam3_image_model
|
|
from sam3.model.sam3_image_processor import Sam3Processor
|
|
|
|
# PaddleOCR will be imported inside the function to avoid errors if not installed
|
|
def run_paddleocr(image_np):
|
|
# Disable PIR API and oneDNN to avoid NotImplementedError on PaddlePaddle 3.x Windows
|
|
import os as _os
|
|
_os.environ["FLAGS_use_mkldnn"] = "0"
|
|
_os.environ["PADDLE_WITH_MKLDNN"] = "0"
|
|
_os.environ["FLAGS_enable_pir_api"] = "0" # Disable the new PIR API which causes this error
|
|
|
|
from paddleocr import PaddleOCR
|
|
# Force use_gpu=False if oneDNN is failing on CPU, or let it detect.
|
|
# On Windows, sometimes CPU + oneDNN is the default and it fails.
|
|
ocr = PaddleOCR(use_textline_orientation=True, lang='korean', use_gpu=torch.cuda.is_available())
|
|
result = ocr.ocr(image_np, cls=True) # Use .ocr() instead of .predict() for better compatibility
|
|
return result
|
|
|
|
def build_point_grid(n_per_side: int) -> np.ndarray:
|
|
offset = 1.0 / (2 * n_per_side)
|
|
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
|
pts_x, pts_y = np.meshgrid(points_one_side, points_one_side)
|
|
grid = np.stack([pts_x.flatten(), pts_y.flatten()], axis=1)
|
|
return grid
|
|
|
|
def mask_iou(mask1, mask2):
|
|
inter = np.logical_and(mask1, mask2).sum()
|
|
union = np.logical_or(mask1, mask2).sum()
|
|
if union == 0: return 0
|
|
return inter / union
|
|
|
|
def get_receipt_mask(image_bgr, model_path, points_per_side=32):
|
|
print("Loading SAM3 Model for receipt detection...")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
bpe_path = os.path.join(server_path, "bpe_simple_vocab_16e6.txt.gz")
|
|
|
|
model = build_sam3_image_model(
|
|
bpe_path=bpe_path,
|
|
device=device,
|
|
checkpoint_path=model_path,
|
|
)
|
|
|
|
processor = Sam3Processor(model, confidence_threshold=0.7, device=device)
|
|
|
|
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
|
pil_image = Image.fromarray(image_rgb)
|
|
|
|
state = processor.set_image(pil_image)
|
|
grid_points = build_point_grid(points_per_side)
|
|
|
|
masks = []
|
|
scores = []
|
|
|
|
print(f"Sampling {len(grid_points)} points...")
|
|
for i, (nx, ny) in enumerate(grid_points):
|
|
processor.reset_all_prompts(state)
|
|
state = processor.add_point_prompt(point=[nx, ny], label=True, state=state)
|
|
|
|
if "masks" in state and len(state["masks"]) > 0:
|
|
best_idx = torch.argmax(state["scores"])
|
|
mask = state["masks"][best_idx].cpu().numpy()
|
|
score = state["scores"][best_idx].item()
|
|
if score > 0.8:
|
|
masks.append(mask)
|
|
scores.append(score)
|
|
|
|
if not masks:
|
|
return None
|
|
|
|
# Pick the largest mask that covers a significant area (receipt is usually big)
|
|
areas = [m.sum() for m in masks]
|
|
# Simple heuristic: largest mask
|
|
best_mask_idx = np.argmax(areas)
|
|
return masks[best_mask_idx]
|
|
|
|
def crop_and_warp(image, mask):
|
|
# Ensure mask is 2D
|
|
mask = np.squeeze(mask)
|
|
if mask.ndim != 2:
|
|
print(f"Warning: Mask has unexpected dimensions {mask.shape}, trying to flatten...")
|
|
if mask.ndim == 3: mask = mask[0]
|
|
|
|
# Find contours
|
|
mask_uint8 = (mask > 0).astype(np.uint8) * 255
|
|
if np.sum(mask_uint8) == 0:
|
|
print("Warning: Mask is empty.")
|
|
return image
|
|
|
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
if not contours:
|
|
return image
|
|
|
|
cnt = max(contours, key=cv2.contourArea)
|
|
|
|
# Get bounding box
|
|
x, y, w, h = cv2.boundingRect(cnt)
|
|
|
|
# Try to find 4 corners for perspective transform
|
|
epsilon = 0.02 * cv2.arcLength(cnt, True)
|
|
approx = cv2.approxPolyDP(cnt, epsilon, True)
|
|
|
|
if len(approx) == 4:
|
|
print("Found 4 corners, applying perspective transform...")
|
|
pts = approx.reshape(4, 2)
|
|
# Sort points: top-left, top-right, bottom-right, bottom-left
|
|
rect = np.zeros((4, 2), dtype="float32")
|
|
s = pts.sum(axis=1)
|
|
rect[0] = pts[np.argmin(s)]
|
|
rect[2] = pts[np.argmax(s)]
|
|
diff = np.diff(pts, axis=1)
|
|
rect[1] = pts[np.argmin(diff)]
|
|
rect[3] = pts[np.argmax(diff)]
|
|
|
|
(tl, tr, br, bl) = rect
|
|
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
|
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
|
maxWidth = max(int(widthA), int(widthB))
|
|
|
|
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
|
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
|
maxHeight = max(int(heightA), int(heightB))
|
|
|
|
dst = np.array([
|
|
[0, 0],
|
|
[maxWidth - 1, 0],
|
|
[maxWidth - 1, maxHeight - 1],
|
|
[0, maxHeight - 1]], dtype="float32")
|
|
|
|
M = cv2.getPerspectiveTransform(rect, dst)
|
|
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
|
|
return warped
|
|
else:
|
|
print("Could not find 4 clear corners, just cropping to bounding box.")
|
|
# Create a mask image to black out background
|
|
masked_img = cv2.bitwise_and(image, image, mask=mask_uint8)
|
|
crop = masked_img[y:y+h, x:x+w]
|
|
return crop
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input", required=True, help="Input image path")
|
|
parser.add_argument("--output_dir", default="output/ocr", help="Output directory")
|
|
args = parser.parse_args()
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Read image
|
|
buf = np.fromfile(args.input, dtype=np.uint8)
|
|
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
|
if image is None:
|
|
print(f"Failed to read {args.input}")
|
|
return
|
|
|
|
model_path = os.path.join(server_path, "sam3.pt")
|
|
|
|
# 1. SAM3 Masking
|
|
mask = get_receipt_mask(image, model_path)
|
|
if mask is None:
|
|
print("No receipt-like object found.")
|
|
return
|
|
|
|
# 2. Cropping / Warping
|
|
processed_img = crop_and_warp(image, mask)
|
|
|
|
# Save processed image for debugging
|
|
processed_path = os.path.join(args.output_dir, "processed_receipt.jpg")
|
|
cv2.imwrite(processed_path, processed_img)
|
|
print(f"Saved processed image to {processed_path}")
|
|
|
|
# 3. PaddleOCR
|
|
print("Running PaddleOCR...")
|
|
ocr_results = run_paddleocr(processed_img)
|
|
|
|
# 4. Save results
|
|
output_json = os.path.join(args.output_dir, "ocr_results.json")
|
|
with open(output_json, "w", encoding="utf-8") as f:
|
|
json.dump(ocr_results, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"OCR results saved to {output_json}")
|
|
|
|
# Print summary
|
|
if ocr_results:
|
|
print("\n--- OCR Extracted Text ---")
|
|
for page in ocr_results:
|
|
if page is None:
|
|
continue
|
|
# New v3.x format: list of dicts with 'rec_text' key
|
|
if isinstance(page, dict):
|
|
text = page.get('rec_text', '')
|
|
score = page.get('rec_score', 0)
|
|
if text:
|
|
print(f"{text} (conf: {score:.2f})")
|
|
elif isinstance(page, list):
|
|
for line in page:
|
|
if line and isinstance(line, list) and len(line) >= 2:
|
|
print(line[1][0])
|
|
print("--------------------------")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|