First commit

This commit is contained in:
kyy
2025-11-05 15:20:21 +09:00
commit fe8601ac63
22 changed files with 3414 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
# Cache directories
.cache/
__pycache__/

39
Dockerfile Normal file
View File

@@ -0,0 +1,39 @@
# PyTorch 2.6.0 + CUDA 12.6 + cuDNN9
FROM pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
# 기본 환경 변수 설정
ENV DEBIAN_FRONTEND=noninteractive \
HF_HOME=/workspace/.cache/huggingface \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} \
PIP_DISABLE_PIP_VERSION_CHECK=1 \
PYTHONUNBUFFERED=1 \
TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
TORCH_CUDA_ARCH_LIST="8.0"
WORKDIR /workspace
# 필수 빌드 도구 설치
RUN apt-get update && apt-get install -y --no-install-recommends \
git build-essential ninja-build \
&& rm -rf /var/lib/apt/lists/*
# pip 업그레이드
RUN python -m pip install -U pip setuptools wheel
# 기존 라이브러리 제거 및 특정 버전 재설치
RUN pip uninstall -y vllm torch torchvision torchaudio triton flash-attn || true
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
# 프로젝트 의존성 설치
COPY requirements.txt /tmp/requirements.txt
RUN pip install -r /tmp/requirements.txt
# vLLM 특정 버전 설치
RUN pip install vllm==0.8.5
# FlashAttention 소스에서 빌드하여 설치
RUN pip cache purge && \
pip install --no-cache-dir --no-build-isolation --no-binary=flash-attn flash-attn==2.7.3
WORKDIR /workspace

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

163
README.md Normal file
View File

@@ -0,0 +1,163 @@
# DeepSeek-OCR
## 소개
DeepSeek-OCR은 DeepSeek AI에서 개발한 OCR(광학 문자 인식)을 활용하여 컨텍스트 압축을 수행하는 혁신적인 시스템입니다. 이 시스템은 방대한 텍스트 정보를 처리하는 방식을 혁신적으로 바꾸는 것을 목표로 합니다. 텍스트가 많은 문서를 이미지로 취급하여 시각적 표현을 통해 정보를 압축함으로써, 기존 텍스트 인코딩 방식보다 훨씬 적은 토큰으로 텍스트를 표현할 수 있습니다.
## 주요 기능
- **높은 압축률**: 10개의 텍스트 토큰을 단일 비전 토큰으로 압축할 수 있습니다.
- **높은 정확도**: 10배 압축률에서 97%의 OCR 정밀도를 달성합니다.
- **빠른 처리 속도**: 단일 A100-40G GPU에서 하루에 200,000페이지 이상을 처리할 수 있습니다.
- **구조화된 콘텐츠 처리**: 표, 수식, 다이어그램과 같은 구조화된 콘텐츠를 효과적으로 분석합니다.
## 프로젝트 구조
```
.
├── docker-compose.yml
├── Dockerfile.hf
├── Dockerfile.vllm
├── LICENSE
├── README.md
├── requirements.txt
└── DeepSeek-OCR-master
├── DeepSeek-OCR-hf
│ └── run_dpsk_ocr.py
└── DeepSeek-OCR-vllm
├── config.py
├── deepseek_ocr.py
├── run_dpsk_ocr_eval_batch.py
├── run_dpsk_ocr_image.py
├── run_dpsk_ocr_pdf.py
├── deepencoder
│ ├── build_linear.py
│ ├── clip_sdpa.py
│ └── sam_vary_sdpa.py
└── process
├── image_process.py
└── ngram_norepeat.py
```
## 파일 설명
### 최상위 디렉토리
- `docker-compose.yml`: Docker 서비스를 정의하고 관리합니다 (`hf``vllm` 버전).
- `Dockerfile.hf`: Hugging Face 버전의 Docker 이미지를 빌드하기 위한 설정 파일입니다.
- `Dockerfile.vllm`: vLLM 버전의 Docker 이미지를 빌드하기 위한 설정 파일입니다.
- `requirements.txt`: 프로젝트에 필요한 Python 패키지 목록입니다.
### `DeepSeek-OCR-master/DeepSeek-OCR-hf`
- `run_dpsk_ocr.py`: Hugging Face `transformers` 라이브러리를 사용하여 단일 이미지 또는 PDF 파일에 대해 OCR을 실행하는 스크립트입니다. 스크립트 내에서 직접 경로와 같은 변수를 수정해야 합니다.
### `DeepSeek-OCR-master/DeepSeek-OCR-vllm`
- `config.py`: vLLM 버전 스크립트의 모든 설정을 관리하는 파일입니다. 모델, 입/출력 경로, 프롬프트, 처리 옵션 등을 여기서 지정합니다.
- `deepseek_ocr.py`: vLLM 프레임워크와 호환되도록 DeepSeek-OCR 모델의 아키텍처를 정의하는 핵심 파일입니다.
- `run_dpsk_ocr_image.py`: `config.py` 설정에 따라 단일 이미지 파일에 대해 OCR을 수행합니다.
- `run_dpsk_ocr_pdf.py`: `config.py` 설정에 따라 단일 PDF 파일에 대해 OCR을 수행합니다.
- `run_dpsk_ocr_eval_batch.py`: `config.py`에 지정된 디렉토리 내의 모든 이미지에 대해 일괄적으로 OCR을 수행하고 결과를 저장하여 평가에 사용됩니다.
- `deepencoder/`: SAM, CLIP 등과 같은 비전 인코더 모델을 빌드하는 스크립트가 포함된 디렉토리입니다.
- `process/`: 이미지 전처리 및 n-gram 반복 방지 로직과 같은 유틸리티 스크립트가 포함된 디렉토리입니다.
## 설치
1. 저장소를 복제합니다.
```bash
git clone https://gitea.hmac.kr/kyy/deepseek_ocr.git
cd deepseek_ocr
```
2. 필요한 패키지를 설치합니다.
```bash
pip install -r requirements.txt
```
## 사용법
### Hugging Face 버전
1. `DeepSeek-OCR-master/DeepSeek-OCR-hf/run_dpsk_ocr.py` 파일을 열고 `image_file` 및 `output_path` 변수를 수정합니다.
2. 스크립트를 실행합니다: `python DeepSeek-OCR-master/DeepSeek-OCR-hf/run_dpsk_ocr.py`
### vLLM 버전
1. `DeepSeek-OCR-master/DeepSeek-OCR-vllm/config.py` 파일을 열고 `INPUT_PATH`, `OUTPUT_PATH` 및 기타 설정을 수정합니다.
2. 실행할 스크립트를 선택하여 실행합니다.
- 이미지: `python DeepSeek-OCR-master/DeepSeek-OCR-vllm/run_dpsk_ocr_image.py`
- PDF: `python DeepSeek-OCR-master/DeepSeek-OCR-vllm/run_dpsk_ocr_pdf.py`
## Docker 사용법
Docker를 사용하면 모든 종속성이 포함된 격리된 환경에서 스크립트를 쉽게 실행할 수 있습니다.
### 1. (선택 사항) 입/출력 디렉토리 준비
호스트 머신에 입력 파일을 저장하고 출력 결과를 받을 디렉토리를 만듭니다.
```bash
mkdir -p ./data/input
mkdir -p ./data/output
# 입력 파일을 ./data/input 에 복사합니다.
cp /path/to/your/file.pdf ./data/input/
```
### 2. Docker Compose 설정 수정
`docker-compose.yml` 파일을 열고, 방금 만든 디렉토리를 컨테이너에 마운트하도록 `volumes` 섹션을 수정합니다. 이렇게 하면 컨테이너가 호스트의 파일에 접근할 수 있습니다.
```yaml
services:
deepseek_ocr_vllm:
build:
context: .
dockerfile: Dockerfile.vllm
# ... (기타 설정)
volumes:
- ./DeepSeek-OCR-master/DeepSeek-OCR-vllm:/workspace
- ./data/input:/workspace/input # 입력 디렉토리 마운트
- ./data/output:/workspace/output # 출력 디렉토리 마운트
# ... (기타 설정)
```
### 3. Docker 컨테이너 빌드 및 시작
`docker-compose.yml`이 있는 프로젝트의 최상위 디렉토리에서 다음 명령을 실행합니다.
```bash
# vLLM 서비스 빌드 및 시작
docker-compose build
docker-compose up -d
```
Hugging Face 버전을 사용하려면 `docker-compose.yml`에서 `deepseek_ocr_hf` 서비스의 주석을 해제하면 됩니다.
### 4. 컨테이너에서 스크립트 실행
1. 실행 중인 컨테이너에 접속합니다.
```bash
docker exec -it deepseek_ocr_vllm /bin/bash
```
2. 컨테이너 내부에서 `config.py`를 수정하여 마운트된 디렉토리의 파일을 가리키도록 합니다.
```python
# /workspace/config.py
INPUT_PATH = '/workspace/input/file.pdf' # 마운트된 입력 파일
OUTPUT_PATH = '/workspace/output/' # 마운트된 출력 디렉토리
```
3. 원하는 스크립트를 실행합니다.
```bash
python run_dpsk_ocr_pdf.py
```
결과는 호스트 머신의 `./data/output` 디렉토리에 저장됩니다.
## 라이선스
이 프로젝트는 [LICENSE](LICENSE) 파일에 명시된 라이선스를 따릅니다.

22
api.py Normal file
View File

@@ -0,0 +1,22 @@
import logging
from fastapi import FastAPI
from router import deepseek_router
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)
app = FastAPI(
title="DeepSeek OCR Service",
description="OCR Gateway로부터 파일 요청을 받아 텍스트를 추출하는 API",
docs_url="/docs",
)
@app.get("/health/API", include_in_schema=False)
async def health_check():
return {"status": "API ok"}
app.include_router(deepseek_router)

0
config/__init__.py Normal file
View File

12
config/env_setup.py Normal file
View File

@@ -0,0 +1,12 @@
import os
import torch
def setup_environment():
"""
OCR 모델 실행에 필요한 환경 변수를 설정합니다.
"""
if torch.version.cuda == "11.8":
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ["VLLM_USE_V1"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

42
config/model_settings.py Normal file
View File

@@ -0,0 +1,42 @@
# TODO: change modes
# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
BASE_SIZE = 1024
IMAGE_SIZE = 640
CROP_MODE = True
MIN_CROPS = 2
MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count.
NUM_WORKERS = 64 # image pre-process (resize/padding) workers
PRINT_NUM_VIS_TOKENS = False
SKIP_REPEAT = False
MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
# TODO: change INPUT_PATH
# .pdf: run_dpsk_ocr_pdf.py;
# .jpg, .png, .jpeg: run_dpsk_ocr_image.py;
# Omnidocbench images path: run_dpsk_ocr_eval_batch.py
INPUT_PATH = "/workspace/input"
OUTPUT_PATH = "/workspace/output"
# PROMPT = f"{PROMPT_TEXT.strip()}"
PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
# PROMPT = '<image>\nFree OCR.'
# TODO commonly used prompts
# document: <image>\n<|grounding|>Convert the document to markdown.
# other image: <image>\n<|grounding|>OCR this image.
# without layouts: <idmage>\nFree OCR.
# figures in document: <image>\nParse the figure.
# general: <image>\nDescribe this image in detail.
# rec: <image>\nLocate <|ref|>xxxx<|/ref|> in the image.
# .......
from transformers import AutoTokenizer
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

21
docker-compose.yml Normal file
View File

@@ -0,0 +1,21 @@
services:
deepseek_ocr_vllm:
build:
context: .
dockerfile: Dockerfile
image: deepseek-ocr-vllm:cu126
container_name: deepseek_ocr_vllm
working_dir: /workspace
volumes:
- ./:/workspace
gpus: all
shm_size: "8g"
ipc: "host"
environment:
- HF_HOME=/workspace/.cache/huggingface
- CUDA_HOME=/usr/local/cuda
- LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH}
- PIP_DISABLE_PIP_VERSION_CHECK=1
- PYTHONUNBUFFERED=1
tty: true
entrypoint: ["/bin/bash"]

13
requirements.txt Normal file
View File

@@ -0,0 +1,13 @@
transformers
tokenizers
PyMuPDF
img2pdf
einops
easydict
addict
Pillow
numpy
matplotlib
fastapi
uvicorn[standard]
python-multipart

3
router/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .deepseek_router import router as deepseek_router
__all__ = ["deepseek_router"]

41
router/deepseek_router.py Normal file
View File

@@ -0,0 +1,41 @@
import logging
from fastapi import APIRouter, File, HTTPException, UploadFile
from services.ocr_engine import process_document
router = APIRouter(prefix="/ocr", tags=["OCR"])
logger = logging.getLogger(__name__)
@router.post("", description="요청된 파일에서 Deepseek OCR을 수행하고 텍스트를 추출합니다.")
async def perform_ocr(document: UploadFile = File(..., description="OCR을 수행할 PDF 또는 이미지 파일")):
"""
클라이언트로부터 받은 파일을 OCR 엔진에 전달하고, 추출된 텍스트를 반환합니다.
- **document**: `multipart/form-data` 형식으로 전송된 파일.
"""
logger.info(f"'{document.filename}' 파일에 대한 OCR 요청 수신 (Content-Type: {document.content_type})")
try:
file_content = await document.read()
if not file_content:
raise HTTPException(status_code=400, detail="파일 내용이 비어있습니다.")
# 모든 처리 로직을 OCR 엔진에 위임
result = process_document(
file_bytes=file_content,
content_type=document.content_type,
filename=document.filename,
)
return result
except (ValueError, TypeError) as e:
# 엔진에서 발생한 예상된 오류 (e.g., 지원하지 않는 파일 형식)
logger.warning(f"요청 처리 중 오류 발생: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# 예상치 못한 서버 내부 오류
logger.exception(f"OCR 처리 중 예상치 못한 오류 발생: {e}")
raise HTTPException(status_code=500, detail=f"서버 내부 오류가 발생했습니다: {e}")
finally:
await document.close()

View File

View File

@@ -0,0 +1,174 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
import copy
class MlpProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.projector_type == "identity":
modules = nn.Identity()
elif cfg.projector_type == "linear":
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
elif cfg.projector_type == "mlp_gelu":
mlp_depth = cfg.get("depth", 1)
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
mlp_ratio = cfg.get("mlp_ratio", 1)
modules = [
nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "downsample_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
mlp_ratio = cfg.get("mlp_ratio", 1)
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
channel_div = cfg.get("channel_div", 0.5)
self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "low_high_split_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
modules = nn.Sequential(*modules)
self.high_layers = nn.Sequential(*modules)
self.low_layers = copy.deepcopy(modules)
else:
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
if cfg.get("token_pooling", False):
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
if cfg.get("conv_fusion_high_low_features", False):
self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
self.layers = modules
def forward(self, x):
if self.cfg.get("token_pooling", False):
batch_size, wxh, channels = x.shape
w = h = int(wxh**0.5)
x = x.view(batch_size, w, h, channels)
x = x.permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
# 在通道维度上拼接
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
# 通过线性层
patches = patches.permute(0, 2, 1, 3).contiguous()
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
x = self.token_pooling_layer(patches)
if self.cfg.get("conv_fusion_high_low_features", False):
x = self.fusion_layer(x[:, 0]) + x[:, 1]
if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
high_x, low_x = x[0], x[1]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
high_x = x[...,:self.cfg.input_dim[0]]
low_x = x[...,self.cfg.input_dim[0]:]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.cfg.projector_type == 'low_high_split_mlp_gelu':
high_x, low_x = x[0], x[1]
high_x = self.high_layers(high_x)
low_x = self.low_layers(low_x)
x = torch.concat([high_x, low_x], dim=-1)
return x
if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.cfg.downsample_ratio:
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
return self.layers(x)
@staticmethod
def get_flops_per_sample(cfg):
if cfg.projector_type == "linear":
fwd = 2 * cfg.input_dim * cfg.n_embed
elif "mlp_gelu" in cfg.projector_type :
mlp_depth = cfg.get("depth", 1)
downsample_ratio = cfg.get("downsample_ratio", 1)
input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
input_dim = input_dim * downsample_ratio * downsample_ratio
fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
else:
fwd = 0
return fwd * 3

View File

@@ -0,0 +1,504 @@
from contextlib import nullcontext
import math
from typing import Optional, Tuple
# from megatron.model import LayerNorm
from easydict import EasyDict as adict
import torch
from torch.nn import functional as F
from torch import nn
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# from optimus import flash_attn_func
# from megatron.core import tensor_parallel
# from megatron.core import parallel_state as mpu
# from megatron.core.utils import make_viewless_tensor, divide
# from megatron.model.fused_rms_norm import RMSNorm
# from megatron.model.transformer import (
# FlashSelfAttention,
# NoopTransformerLayer,
# _cfg_to_kwargs,
# )
# from megatron.model.enums import AttnMaskType, AttnType
# from megatron.model.fused_softmax import FusedScaleMaskSoftmax
# from megatron.model.utils import attention_mask_func
# from megatron.model.module import MegatronModule
# try:
# from einops import rearrange
# except ImportError:
# rearrange = None
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# try:
# # flash attention 2.x
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# except ImportError:
# try:
# # flash attention 1.x
# from flash_attn.flash_attn_interface import flash_attn_unpadded_func
# except ImportError:
# flash_attn_unpadded_func = None
# try:
# from flash_attn.flash_attn_interface import flash_attn_unpadded_relative_attention_bias_func
# except ImportError:
# flash_attn_unpadded_relative_attention_bias_func = None
# try:
# from flash_attn.flash_attn_interface import mask_flash_attn_unpadded_func
# except ImportError:
# mask_flash_attn_unpadded_func = None
class LayerNormfp32(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
# print(tgt_size)
# print(abs_pos.shape)
# exit()
dim = abs_pos.size(-1)
# print(dim)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
2).contiguous()
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
return vision_pos_embed
else:
return abs_pos
@torch.jit.script
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
super().__init__()
self.embed_dim = hidden_size
self.image_size = image_size
self.patch_size = patch_size
self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = torch.nn.Conv2d(
in_channels=num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids", torch.arange(self.num_positions).expand((1, -1))
)
def forward(self, pixel_values, patch_embeds):
batch_size = pixel_values.shape[0]
# patch_embeds = self.patch_embedding(
# pixel_values
# ) # shape = [*, width, grid, grid]
if patch_embeds is not None:
patch_embeds = patch_embeds
# print(patch_embeds.shape)
else:
patch_embeds = self.patch_embedding(pixel_values)
# print(111111)
# shape = [*, width, grid, grid]
# patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
# x = torch.cat([cls_token, x], dim=1)
embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
# embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class NoTPFeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
):
super().__init__()
self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
def forward(self, x):
output = self.fc2(quick_gelu(self.fc1(x)))
return output
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
# class NoTPAttention(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# self.num_heads = cfg.num_attention_heads
# self.n_local_heads = cfg.num_attention_heads
# self.head_dim = cfg.hidden_size // cfg.num_attention_heads
# self.max_seq_len = cfg.seq_length
# self.use_flash_attention = cfg.use_flash_attn
# self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
# self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
# # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
# self.attn_drop = cfg.attention_dropout
# def forward(
# self,
# x: torch.Tensor,
# ):
# bsz, seqlen, _ = x.shape
# xqkv = self.qkv_proj(x)
# xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
# if self.use_flash_attention:
# output = flash_attn_qkvpacked_func(xqkv)
# output = output.view(bsz, seqlen, -1)
# else:
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
# xq = xq.squeeze(2)
# xk = xk.squeeze(2)
# xv = xv.squeeze(2)
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# # B, num_head, S, head_size)
# xq = xq.permute(0, 2, 1, 3)
# xk = xk.permute(0, 2, 1, 3)
# xv = xv.permute(0, 2, 1, 3)
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
# utput = output.permute(0, 2, 1, 3).view(bsz, seqlen, -1)
# output = self.out_proj(output)
# return output
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
class NoTPAttention(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_heads = cfg.num_attention_heads
self.n_local_heads = cfg.num_attention_heads
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.max_seq_len = cfg.seq_length
self.use_flash_attention = cfg.use_flash_attn
self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
# self.core_attention = CoreAttention(cfg, AttnType.self_attn)
self.attn_drop = cfg.attention_dropout
def forward(
self,
x: torch.Tensor,
):
bsz, seqlen, _ = x.shape
xqkv = self.qkv_proj(x)
xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
if self.use_flash_attention:
output = flash_attn_qkvpacked_func(xqkv)
output = output.view(bsz, seqlen, -1)
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
# xq = xq.squeeze(2)
# xk = xk.squeeze(2)
# xv = xv.squeeze(2)
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# # B, num_head, S, head_size)
# xq = xq.permute(0, 2, 1, 3)
# xk = xk.permute(0, 2, 1, 3)
# xv = xv.permute(0, 2, 1, 3)
# # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
# output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
# output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
else:
# output = flash_attn_qkvpacked_func(xqkv)
xq, xk, xv = torch.split(xqkv, 1, dim=2)
xq = xq.squeeze(2)
xk = xk.squeeze(2)
xv = xv.squeeze(2)
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# B, num_head, S, head_size)
xq = xq.permute(0, 2, 1, 3)
xk = xk.permute(0, 2, 1, 3)
xv = xv.permute(0, 2, 1, 3)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
output = self.out_proj(output)
return output
class NoTPTransformerBlock(nn.Module):
def __init__(self, cfg, layer_id: int, multiple_of=256):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.dim = cfg.hidden_size
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.self_attn = NoTPAttention(cfg)
self.mlp = NoTPFeedForward(
cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
)
self.layer_id = layer_id
self.layer_norm1 = torch.nn.LayerNorm(
cfg.hidden_size, eps=cfg.layernorm_epsilon
)
self.layer_norm2 = torch.nn.LayerNorm(
cfg.hidden_size, eps=cfg.layernorm_epsilon
)
def forward(self, x: torch.Tensor):
residual = self.self_attn.forward(self.layer_norm1(x))
h = x + residual
out = h + self.mlp.forward(self.layer_norm2(h))
return out
class NoTPTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# self.recompute_list = self.cfg.get("recompute_list", [])
self.num_layers = cfg.num_layers # _get_num_layers(cfg)
self.layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
self.layers.append(
NoTPTransformerBlock(
cfg,
layer_id + 1,
)
)
def forward(
self,
hidden_states,
):
for lid, layer in enumerate(self.layers):
# if lid in self.recompute_list:
# def custom(layer_id):
# def custom_forward(*args, **kwargs):
# x_ = self.layers[layer_id](*args, **kwargs)
# return x_
# return custom_forward
# assert hidden_states.requires_grad == True, logger.warning(
# "When using recalculation, the input must have grad fn"
# )
# hidden_states = tensor_parallel.checkpoint(
# custom(lid),
# False,
# hidden_states.contiguous()
# )
# else:
hidden_states = layer(hidden_states)
return hidden_states
# from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
class VitModel(nn.Module):
def __init__(
self,
cfg,
freeze_embed=False,
freeze_pre_norm=False
) -> None:
super().__init__()
self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
if freeze_embed:
for name, param in self.embeddings.named_parameters():
param.requires_grad = False
self.transformer = NoTPTransformer(cfg=cfg)
if cfg.get("fp32norm", False):
logger.info("Load fp32 layernorm for ViT.")
self.pre_layrnorm = LayerNormfp32(
cfg.hidden_size,
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
else:
self.pre_layrnorm = torch.nn.LayerNorm(
cfg.hidden_size,
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
# self.pre_layrnorm = RMSNorm(
# cfg.hidden_size,
# eps=cfg.get("pre_layernorm_epsilon", 1e-5),
# sequence_parallel=False,
# use_fp32=True,
# use_optimus=True,
# )
if freeze_pre_norm:
for name, param in self.pre_layrnorm.named_parameters():
param.requires_grad = False
for p in self.parameters():
p.micro_dp = True
def set_input_tensor(self, input_tensor):
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.transformer.set_input_tensor(input_tensor[0])
def __str__(self) -> str:
return "open_clip"
def forward(
self,
x,
patch_embeds
):
x = self.embeddings(x, patch_embeds)
hidden_states = self.pre_layrnorm(x)
# hidden_states, dis = local_dp_scatter(hidden_states)
output = self.transformer(hidden_states)
# output = local_dp_reduce(output, dis)
return output
vit_model_cfg = adict(
num_layers=24,
hidden_size=1024,
num_heads = 16,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout = 0.0,
attention_dropout = 0.0,
no_persist_layer_norm = False,
layernorm_epsilon = 1e-5,
pre_layernorm_epsilon = 1e-5,
image_size = 224,
patch_size = 14,
recompute_list = []
)
def build_clip_l():
return VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
if __name__ == '__main__':
from mmgpt.model.vision_encoder.sam_b import build_sam_vit_b
vit_model_cfg = adict(
num_layers=24,
hidden_size=1024,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout = 0.0,
attention_dropout = 0.0,
no_persist_layer_norm = False,
layernorm_epsilon = 1e-5,
pre_layernorm_epsilon = 1e-5,
image_size = 224,
patch_size = 14,
recompute_list = []
)
sam_model = build_sam_vit_b()
vision_model = VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
# model = VitModel(1344)
# x = torch.zeros(2, 3, 224, 224)
x = torch.zeros(2, 3, 1024, 1024)
with torch.no_grad():
# y = vision_model(x)
patch_embed = sam_model(x)
print(patch_embed.shape)
y = vision_model(x, patch_embed)
print(y.shape)
image_feature = torch.add(y[:, 1:], patch_embed.flatten(2).permute(0, 2, 1))
print(image_feature.shape)

View File

@@ -0,0 +1,528 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from functools import partial
from flash_attn import flash_attn_qkvpacked_func
# from .common import LayerNorm2d, MLPBlock
# from mmgpt.model.vision_encoder.flash_4 import _attention_rel_h_rel_w
def get_abs_pos(abs_pos, tgt_size):
dtype = abs_pos.dtype
src_size = abs_pos.size(1)
if src_size != tgt_size:
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
return new_pos_embed
else:
return abs_pos
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
# x = x + self.pos_embed
x = x + get_abs_pos(self.pos_embed, x.size(1))
for blk in self.blocks:
x = blk(x)
neck_output = self.neck(x.permute(0, 3, 1, 2))
conv2_output = self.net_2(neck_output)
# print(f"conv2_output shape: {conv2_output.shape}")
conv3_output = self.net_3(conv2_output)
return conv3_output
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
if self.use_rel_pos:
rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
# x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
# qkv = torch.stack([q, k, v], dim=1).transpose(1, 3).reshape(B, H * W, 3, self.num_heads, -1)
# x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=False).transpose(1, 2)
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
dtype = rel_pos.dtype
rel_pos = rel_pos.to(torch.float32)
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
).to(dtype)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
return rel_h, rel_w
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
if checkpoint is not None:
# with open(checkpoint, "rb") as f:
state_dict = torch.load(checkpoint)
# print(state_dict.keys())
# for key in state_dict:
# image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
# ocr-anyting
# image_encoder.load_state_dict(state_dict, strict=True)
# tob
image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
print(checkpoint)
return image_encoder

614
services/deepseek_ocr.py Normal file
View File

@@ -0,0 +1,614 @@
"""Inference-only Deepseek-OCR model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Set, Tuple, Union
import torch
import torch.nn as nn
from addict import Dict
# import time
from config import BASE_SIZE, CROP_MODE, IMAGE_SIZE, PRINT_NUM_VIS_TOKENS, PROMPT
from deepencoder.build_linear import MlpProjector
from deepencoder.clip_sdpa import build_clip_l
from deepencoder.sam_vary_sdpa import build_sam_vit_b
from process.image_process import DeepseekOCRProcessor, count_tiles
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata
# from vllm.utils import is_list_of
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsPP,
)
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
merge_multimodal_embeddings,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargs,
NestedTensors,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
# The image token id may be various
_IMAGE_TOKEN = "<image>"
class DeepseekOCRProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_num_image_tokens(
self, *, image_width: int, image_height: int, cropping: bool = True
) -> int:
hf_processor = self.get_hf_processor()
# image_size = hf_processor.image_size
# patch_size = hf_processor.patch_size
# downsample_ratio = hf_processor.downsample_ratio
image_size = IMAGE_SIZE
base_size = BASE_SIZE
patch_size = 16
downsample_ratio = 4
if CROP_MODE:
if image_width <= 640 and image_height <= 640:
crop_ratio = [1, 1]
else:
# images_crop_raw, crop_ratio = hf_processor.dynamic_preprocess(image)
# find the closest aspect ratio to the target
crop_ratio = count_tiles(
image_width, image_height, image_size=IMAGE_SIZE
)
# print('===========')
# print('crop_ratio ', crop_ratio)
# print('============')
num_width_tiles, num_height_tiles = crop_ratio
else:
num_width_tiles = num_height_tiles = 1
h = w = math.ceil((base_size // patch_size) / downsample_ratio)
h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)
global_views_tokens = h * (w + 1)
if num_width_tiles > 1 or num_height_tiles > 1:
local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
else:
local_views_tokens = 0
return global_views_tokens + local_views_tokens + 1
def get_image_size_with_most_features(self) -> ImageSize:
if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
return ImageSize(width=1024 * 2, height=1024 * 2)
return ImageSize(width=640 * 2, height=640 * 2)
class DeepseekOCRDummyInputsBuilder(BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
max_image_size = self.info.get_image_size_with_most_features()
if "<image>" in PROMPT:
return {
"image": DeepseekOCRProcessor().tokenize_with_images(
images=self._get_dummy_images(
width=max_image_size.width,
height=max_image_size.height,
num_images=num_images,
),
bos=True,
eos=True,
cropping=CROP_MODE,
)
}
else:
return {"image": []}
class DeepseekOCRMultiModalProcessor(
BaseMultiModalProcessor[DeepseekOCRProcessingInfo]
):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# print(mm_data)
if mm_data:
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
mm_kwargs,
)
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(
prompt, add_special_tokens=True, return_tensors="pt"
)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
# image_embeds=MultiModalFieldConfig.batched("image2"),
images_crop=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = hf_processor.image_token_id
assert isinstance(image_token_id, int)
def get_replacement_deepseek_vl2(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
width = images[0][-1][0][0]
height = images[0][-1][0][1]
num_image_tokens = self.info.get_num_image_tokens(
image_width=width,
image_height=height,
# flag = True,
cropping=CROP_MODE,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_deepseek_vl2,
)
]
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor,
info=DeepseekOCRProcessingInfo,
dummy_inputs=DeepseekOCRDummyInputsBuilder,
)
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language.": "language_model.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# config.model_type ='deepseek_vl_v2'
self.config = config
self.multimodal_config = multimodal_config
self.vision_config = config.vision_config
self.projector_config = config.projector_config
self.text_config = config.text_config
model_config = vllm_config.model_config
tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
self.sam_model = build_sam_vit_b()
self.vision_model = build_clip_l()
n_embed = 1280
self.projector = MlpProjector(
Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)
)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# self.sam_model = torch.compile(self.sam_model, mode="reduce-overhead")
# self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead")
# self.projector = torch.compile(self.projector, mode="max-autotune")
# special token for image token sequence format
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
)
if self.text_config.topk_method == "noaux_tc":
architectures = ["DeepseekV3ForCausalLM"]
elif not self.text_config.use_mla:
architectures = ["DeepseekForCausalLM"]
else:
architectures = ["DeepseekV2ForCausalLM"]
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.text_config,
prefix=maybe_prefix(prefix, "language"),
architectures=architectures,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
if pixel_values is None or torch.sum(pixel_values).item() == 0:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}"
)
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image crop. " f"Got type: {type(images_crop)}"
)
return [pixel_values, images_crop, images_spatial_crop]
raise AssertionError("This line should be unreachable.")
def _pixel_values_to_embedding(
self,
pixel_values: torch.Tensor,
images_crop: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
# Pixel_values (global view): [n_image, batch_size, 3, height, width]
# images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]]
# images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w]
# split the pixel and image_crop, all batch_size = 1
images_in_this_batch = []
# print(type(images_crop))
# print(pixel_values.shape)
with torch.no_grad():
for jdx in range(images_spatial_crop.size(0)):
# with torch.set_grad_enabled(False):
patches = images_crop[jdx][0].to(torch.bfloat16) # batch_size = 1
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
if torch.sum(patches).item() != 0: # if all values = 0, no crop
# P, C, H, W = patches.shape
# crop_flag = 1
local_features_1 = self.sam_model(patches)
# TODO del patches
# torch.compiler.cudagraph_mark_step_begin()
local_features_2 = self.vision_model(patches, local_features_1)
local_features = torch.cat(
(
local_features_2[:, 1:],
local_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
local_features = self.projector(local_features)
global_features_1 = self.sam_model(image_ori)
global_features_2 = self.vision_model(image_ori, global_features_1)
global_features = torch.cat(
(
global_features_2[:, 1:],
global_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
global_features = self.projector(global_features)
if PRINT_NUM_VIS_TOKENS:
print("=====================")
print("BASE: ", global_features.shape)
print("PATCHES: ", local_features.shape)
print("=====================")
_, hw, n_dim = global_features.shape
h = w = int(hw**0.5)
_2, hw2, n_dim2 = local_features.shape
h2 = w2 = int(hw2**0.5)
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[
global_features,
self.image_newline[None, None, :].expand(h, 1, n_dim),
],
dim=1,
)
global_features = global_features.view(-1, n_dim)
local_features = (
local_features.view(
height_crop_num, width_crop_num, h2, w2, n_dim2
)
.permute(0, 2, 1, 3, 4)
.reshape(height_crop_num * h2, width_crop_num * w2, n_dim2)
)
local_features = torch.cat(
[
local_features,
self.image_newline[None, None, :].expand(
height_crop_num * h2, 1, n_dim2
),
],
dim=1,
)
local_features = local_features.view(-1, n_dim2)
global_local_features = torch.cat(
[local_features, global_features, self.view_seperator[None, :]],
dim=0,
)
else:
global_features_1 = self.sam_model(image_ori)
global_features_2 = self.vision_model(image_ori, global_features_1)
global_features = torch.cat(
(
global_features_2[:, 1:],
global_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
global_features = self.projector(global_features)
if PRINT_NUM_VIS_TOKENS:
print("=====================")
print("BASE: ", global_features.shape)
print("NO PATCHES")
print("=====================")
_, hw, n_dim = global_features.shape
h = w = int(hw**0.5)
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[
global_features,
self.image_newline[None, None, :].expand(h, 1, n_dim),
],
dim=1,
)
global_features = global_features.view(-1, n_dim)
global_local_features = torch.cat(
[global_features, self.view_seperator[None, :]], dim=0
)
images_in_this_batch.append(global_local_features)
return images_in_this_batch
def _process_image_input(self, image_input) -> torch.Tensor:
# image_input: [pixel_values, images_crop, images_spatial_crop]
pixel_values = image_input[0].to(torch.bfloat16)
# print(image_input[1][0].shape)
# print(type(image_input[1]))
# exit()
# images_crop = image_input[1].to(torch.bfloat16)
images_crop = image_input[1]
# images_crop = image_input[1]
images_spatial_crop = image_input[2].to(dtype=torch.long)
# local_start = time.time()
vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
)
# local_total_time = time.time() - local_start
# print('encoder_time: ', local_total_time)
# exit()
return vision_features
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object
) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id
)
# print(len(multimodal_embeddings))
# print(input_ids.shape)
# print(type(inputs_embeds))
# print(inputs_embeds.shape)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings)
input_ids = None
hidden_states = self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
processed_weights = []
for name, tensor in weights:
if (
"sam_model" in name
or "vision_model" in name
or "projector" in name
or "image_newline" in name
or "view_seperator" in name
):
new_name = name.replace("model.", "", 1)
else:
new_name = "language." + name
processed_weights.append((new_name, tensor))
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(
processed_weights, mapper=self.hf_to_vllm_mapper
)
return autoloaded_weights

View File

View File

@@ -0,0 +1,502 @@
import math
from typing import List, Tuple
import torch
import torchvision.transforms as T
from PIL import Image, ImageOps
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
from transformers.processing_utils import ProcessorMixin
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, MIN_CROPS, MAX_CROPS, PROMPT, TOKENIZER
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def count_tiles(orig_width, orig_height, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False):
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
return target_aspect_ratio
def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# print(target_aspect_ratio)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class ImageTransform:
def __init__(self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True):
self.mean = mean
self.std = std
self.normalize = normalize
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class DeepseekOCRProcessor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast = TOKENIZER,
candidate_resolutions: Tuple[Tuple[int, int]] = [[1024, 1024]],
patch_size: int = 16,
downsample_ratio: int = 4,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<▁pad▁>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
# self.candidate_resolutions = candidate_resolutions # placeholder no use
self.image_size = IMAGE_SIZE
self.base_size = BASE_SIZE
# self.patch_size = patch_size
self.patch_size = 16
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
# self.downsample_ratio = downsample_ratio
self.downsample_ratio = 4
self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize)
self.tokenizer = tokenizer
# self.tokenizer = add_special_token(tokenizer)
self.tokenizer.padding_side = 'left' # must set thispadding side with make a difference in batch inference
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': pad_token})
# add image token
# image_token_id = self.tokenizer.vocab.get(image_token)
# if image_token_id is None:
# special_tokens = [image_token]
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token_id = self.tokenizer.vocab.get(image_token)
# add five special tokens for grounding-related tasks
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
# special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>']
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# special_tokens = ['<image>','<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>', '<td>', '</td>', '<tr>', '</tr>']
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
# # add special tokens for SFT data
# special_tokens = ["<|User|>", "<|Assistant|>"]
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
# def select_best_resolution(self, image_size):
# # used for cropping
# original_width, original_height = image_size
# best_fit = None
# max_effective_resolution = 0
# min_wasted_resolution = float("inf")
# for width, height in self.candidate_resolutions:
# scale = min(width / original_width, height / original_height)
# downscaled_width, downscaled_height = int(
# original_width * scale), int(original_height * scale)
# effective_resolution = min(downscaled_width * downscaled_height,
# original_width * original_height)
# wasted_resolution = (width * height) - effective_resolution
# if effective_resolution > max_effective_resolution or (
# effective_resolution == max_effective_resolution
# and wasted_resolution < min_wasted_resolution):
# max_effective_resolution = effective_resolution
# min_wasted_resolution = wasted_resolution
# best_fit = (width, height)
# return best_fit
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str,
images: List,
inference_mode: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert (prompt is not None and images is not None
), "prompt and images must be used at the same time."
sft_format = prompt
input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, _ = images[0]
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"images_crop": images_crop,
"images_seq_mask": images_seq_mask,
"images_spatial_crop": images_spatial_crop,
"num_image_tokens": num_image_tokens,
}
# prepare = BatchFeature(
# data=dict(
# input_ids=input_ids,
# pixel_values=pixel_values,
# images_crop = images_crop,
# images_seq_mask=images_seq_mask,
# images_spatial_crop=images_spatial_crop,
# num_image_tokens=num_image_tokens,
# ),
# tensor_type="pt",
# )
# return prepare
def __call__(
self,
*,
prompt: str,
images: List,
inference_mode: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(
prompt=prompt,
images=images,
inference_mode=inference_mode,
)
return prepare
def tokenize_with_images(
self,
# conversation: str,
images: List[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
):
"""Tokenize text with <image> tags."""
# print(conversation)
conversation = PROMPT
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = [], [], [], []
image_shapes = []
num_image_tokens = []
tokenized_str = []
# print('image: ', len(images))
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""select best resolution for anyres"""
# if cropping:
# best_width, best_height = self.select_best_resolution(image.size)
# else:
# best_width, best_height = self.image_size, self.image_size
image_shapes.append(image.size)
if image.size[0] <= 640 and image.size[1] <= 640:
crop_ratio = [1, 1]
else:
if cropping:
# print('image-size: ', image.size)
# best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
# print('image ', image.size)
# print('open_size:', image.size)
images_crop_raw, crop_ratio = dynamic_preprocess(image, image_size=IMAGE_SIZE)
# print('crop_ratio: ', crop_ratio)
else:
# best_width, best_height = self.image_size, self.image_size
crop_ratio = [1, 1]
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
# print(crop_ratio)
"""process the global view"""
# if cropping
if self.image_size <= 640 and not cropping:
# print('directly resize')
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(image, (self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean))
images_list.append(self.image_transform(global_view))
"""record height / width crop num"""
# width_crop_num, height_crop_num = best_width // self.image_size, best_height // self.image_size
num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1:
"""process the local views"""
# local_view = ImageOps.pad(image, (best_width, best_height),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# for i in range(0, best_height, self.image_size):
# for j in range(0, best_width, self.image_size):
# images_crop_list.append(
# self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
for i in range(len(images_crop_raw)):
images_crop_list.append(self.image_transform(images_crop_raw[i]))
# """process the global view"""
# global_view = ImageOps.pad(image, (self.image_size, self.image_size),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# images_list.append(self.image_transform(global_view))
# """process the local views"""
# local_view = ImageOps.pad(image, (best_width, best_height),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# for i in range(0, best_height, self.image_size):
# for j in range(0, best_width, self.image_size):
# images_list.append(
# self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
# """add image tokens"""
"""add image tokens"""
num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += ([self.image_token_id] * (num_queries * num_width_tiles) + [self.image_token_id]) * (
num_queries * num_height_tiles)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(
images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \
(f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal")
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) |
(input_ids == self.image_token_id)] = self.ignore_id
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return [[input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, image_shapes]]
AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCRProcessor)

View File

@@ -0,0 +1,40 @@
import torch
from transformers import LogitsProcessor
from transformers.generation.logits_process import _calc_banned_ngram_tokens
from typing import List, Set
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __init__(self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(f"`window_size` has to be a strictly positive integer, but is {window_size}")
self.ngram_size = ngram_size
self.window_size = window_size
self.whitelist_token_ids = whitelist_token_ids or set()
def __call__(self, input_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor:
if len(input_ids) < self.ngram_size:
return scores
current_prefix = tuple(input_ids[-(self.ngram_size - 1):])
search_start = max(0, len(input_ids) - self.window_size)
search_end = len(input_ids) - self.ngram_size + 1
banned_tokens = set()
for i in range(search_start, search_end):
ngram = tuple(input_ids[i:i + self.ngram_size])
if ngram[:-1] == current_prefix:
banned_tokens.add(ngram[-1])
banned_tokens = banned_tokens - self.whitelist_token_ids
if banned_tokens:
scores = scores.clone()
for token in banned_tokens:
scores[token] = -float("inf")
return scores

View File

@@ -0,0 +1,320 @@
import asyncio
import os
import re
from config import env_setup
env_setup.setup_environment()
import time
import numpy as np
from config.model_settings import CROP_MODE, INPUT_PATH, MODEL_PATH, OUTPUT_PATH, PROMPT
from PIL import Image, ImageDraw, ImageFont, ImageOps
from process.image_process import DeepseekOCRProcessor
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from tqdm import tqdm
from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.registry import ModelRegistry
from deepseek_ocr import DeepseekOCRForCausalLM
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
def load_image(image_path):
try:
image = Image.open(image_path)
corrected_image = ImageOps.exif_transpose(image)
return corrected_image
except Exception as e:
print(f"error: {e}")
try:
return Image.open(image_path)
except:
return None
def re_match(text):
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if "<|ref|>image<|/ref|>" in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
except Exception as e:
print(e)
return None
return (label_type, cor_list)
def draw_bounding_boxes(image, refs):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = points
x1 = int(x1 / 999 * image_width)
y1 = int(y1 / 999 * image_height)
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == "image":
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg")
except Exception as e:
print(e)
pass
img_idx += 1
try:
if label_type == "title":
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle(
[text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
except:
continue
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts):
result_image = draw_bounding_boxes(image, ref_texts)
return result_image
async def stream_generate(image=None, prompt=""):
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
max_model_len=8192,
enforce_eager=False,
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=30, window_size=90, whitelist_token_ids={128821, 128822}
)
] # whitelist: <td>, </td>
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
# ignore_eos=False,
)
request_id = f"request-{int(time.time())}"
printed_length = 0
if image and "<image>" in prompt:
request = {"prompt": prompt, "multi_modal_data": {"image": image}}
elif prompt:
request = {"prompt": prompt}
else:
assert False, "prompt is none!!!"
async for request_output in engine.generate(request, sampling_params, request_id):
if request_output.outputs:
full_text = request_output.outputs[0].text
new_text = full_text[printed_length:]
print(new_text, end="", flush=True)
printed_length = len(full_text)
final_output = full_text
print("\n")
return final_output
if __name__ == "__main__":
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
image = load_image(INPUT_PATH).convert("RGB")
if "<image>" in PROMPT:
image_features = DeepseekOCRProcessor().tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE
)
else:
image_features = ""
prompt = PROMPT
result_out = asyncio.run(stream_generate(image_features, prompt))
save_results = 1
if save_results and "<image>" in prompt:
print("=" * 15 + "save results:" + "=" * 15)
image_draw = image.copy()
outputs = result_out
with open(f"{OUTPUT_PATH}/result_ori.mmd", "w", encoding="utf-8") as afile:
afile.write(outputs)
matches_ref, matches_images, mathes_other = re_match(outputs)
# print(matches_ref)
result = process_image_with_refs(image_draw, matches_ref)
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
outputs = outputs.replace(
a_match_image, "![](images/" + str(idx) + ".jpg)\n"
)
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
outputs = (
outputs.replace(a_match_other, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
)
# if 'structural formula' in conversation[0]['content']:
# outputs = '<smiles>' + outputs + '</smiles>'
with open(f"{OUTPUT_PATH}/result.mmd", "w", encoding="utf-8") as afile:
afile.write(outputs)
if "line_type" in outputs:
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
lines = eval(outputs)["Line"]["line"]
line_type = eval(outputs)["Line"]["line_type"]
# print(lines)
endpoints = eval(outputs)["Line"]["line_endpoint"]
fig, ax = plt.subplots(figsize=(3, 3), dpi=200)
ax.set_xlim(-15, 15)
ax.set_ylim(-15, 15)
for idx, line in enumerate(lines):
try:
p0 = eval(line.split(" -- ")[0])
p1 = eval(line.split(" -- ")[-1])
if line_type[idx] == "--":
ax.plot(
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
)
else:
ax.plot(
[p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color="k"
)
ax.scatter(p0[0], p0[1], s=5, color="k")
ax.scatter(p1[0], p1[1], s=5, color="k")
except:
pass
for endpoint in endpoints:
label = endpoint.split(": ")[0]
(x, y) = eval(endpoint.split(": ")[1])
ax.annotate(
label,
(x, y),
xytext=(1, 1),
textcoords="offset points",
fontsize=5,
fontweight="light",
)
try:
if "Circle" in eval(outputs).keys():
circle_centers = eval(outputs)["Circle"]["circle_center"]
radius = eval(outputs)["Circle"]["radius"]
for center, r in zip(circle_centers, radius):
center = eval(center.split(": ")[1])
circle = Circle(
center,
radius=r,
fill=False,
edgecolor="black",
linewidth=0.8,
)
ax.add_patch(circle)
except:
pass
plt.savefig(f"{OUTPUT_PATH}/geo.jpg")
plt.close()
result.save(f"{OUTPUT_PATH}/result_with_boxes.jpg")

View File

@@ -0,0 +1,352 @@
import io
import os
import re
from concurrent.futures import ThreadPoolExecutor
from config import env_setup
env_setup.setup_environment()
import fitz
import img2pdf
import numpy as np
from config.model_settings import (
CROP_MODE,
INPUT_PATH,
MAX_CONCURRENCY,
MODEL_PATH,
NUM_WORKERS,
OUTPUT_PATH,
PROMPT,
SKIP_REPEAT,
)
from PIL import Image, ImageDraw, ImageFont
from process.image_process import DeepseekOCRProcessor
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
from deepseek_ocr import DeepseekOCRForCausalLM
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
llm = LLM(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True,
)
logits_processors = [
NoRepeatNGramLogitsProcessor(
ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822}
)
] # window for fastwhitelist_token_ids: <td>,</td>
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
include_stop_str_in_output=True,
)
class Colors:
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
RESET = "\033[0m"
def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
"""
pdf2images
"""
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
if image_format.upper() == "PNG":
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
else:
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ("RGBA", "LA"):
background = Image.new("RGB", img.size, (255, 255, 255))
background.paste(
img, mask=img.split()[-1] if img.mode == "RGBA" else None
)
img = background
images.append(img)
pdf_document.close()
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != "RGB":
img = img.convert("RGB")
img_buffer = io.BytesIO()
img.save(img_buffer, format="JPEG", quality=95)
img_bytes = img_buffer.getvalue()
image_bytes_list.append(img_bytes)
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
f.write(pdf_bytes)
except Exception as e:
print(f"error: {e}")
def re_match(text):
pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if "<|ref|>image<|/ref|>" in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
except Exception as e:
print(e)
return None
return (label_type, cor_list)
def draw_bounding_boxes(image, refs, jdx):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (
np.random.randint(0, 200),
np.random.randint(0, 200),
np.random.randint(0, 255),
)
color_a = color + (20,)
for points in points_list:
x1, y1, x2, y2 = points
x1 = int(x1 / 999 * image_width)
y1 = int(y1 / 999 * image_height)
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == "image":
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
except Exception as e:
print(e)
pass
img_idx += 1
try:
if label_type == "title":
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle(
[x1, y1, x2, y2],
fill=color_a,
outline=(0, 0, 0, 0),
width=1,
)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle(
[text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30),
)
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
except:
continue
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts, jdx):
result_image = draw_bounding_boxes(image, ref_texts, jdx)
return result_image
def process_single_image(image):
"""single image"""
prompt_in = prompt
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {
"image": DeepseekOCRProcessor().tokenize_with_images(
images=[image], bos=True, eos=True, cropping=CROP_MODE
)
},
}
return cache_item
if __name__ == "__main__":
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f"{OUTPUT_PATH}/images", exist_ok=True)
print(f"{Colors.RED}PDF loading .....{Colors.RESET}")
images = pdf_to_images_high_quality(INPUT_PATH)
prompt = PROMPT
# batch_inputs = []
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(
tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images",
)
)
# for image in tqdm(images):
# prompt_in = prompt
# cache_list = [
# {
# "prompt": prompt_in,
# "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
# }
# ]
# batch_inputs.extend(cache_list)
outputs_list = llm.generate(batch_inputs, sampling_params=sampling_params)
output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True)
mmd_det_path = (
output_path + "/" + INPUT_PATH.split("/")[-1].replace(".pdf", "_det.mmd")
)
mmd_path = output_path + "/" + INPUT_PATH.split("/")[-1].replace("pdf", "mmd")
pdf_out_path = (
output_path + "/" + INPUT_PATH.split("/")[-1].replace(".pdf", "_layouts.pdf")
)
contents_det = ""
contents = ""
draw_images = []
jdx = 0
for output, img in zip(outputs_list, images):
content = output.outputs[0].text
if "<end▁of▁sentence>" in content: # repeat no eos
content = content.replace("<end▁of▁sentence>", "")
else:
if SKIP_REPEAT:
continue
page_num = "\n<--- Page Split --->"
contents_det += content + f"\n{page_num}\n"
image_draw = img.copy()
matches_ref, matches_images, mathes_other = re_match(content)
# print(matches_ref)
result_image = process_image_with_refs(image_draw, matches_ref, jdx)
draw_images.append(result_image)
for idx, a_match_image in enumerate(matches_images):
content = content.replace(
a_match_image, "![](images/" + str(jdx) + "_" + str(idx) + ".jpg)\n"
)
for idx, a_match_other in enumerate(mathes_other):
content = (
content.replace(a_match_other, "")
.replace("\\coloneqq", ":=")
.replace("\\eqqcolon", "=:")
.replace("\n\n\n\n", "\n\n")
.replace("\n\n\n", "\n\n")
)
contents += content + f"\n{page_num}\n"
jdx += 1
with open(mmd_det_path, "w", encoding="utf-8") as afile:
afile.write(contents_det)
with open(mmd_path, "w", encoding="utf-8") as afile:
afile.write(contents)
pil_to_pdf_img2pdf(draw_images, pdf_out_path)