인퍼런스 레이어 구축
This commit is contained in:
161
model_services/deepseek_ocr/run_dpsk_ocr_eval_batch.py
Normal file
161
model_services/deepseek_ocr/run_dpsk_ocr_eval_batch.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import os
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
if torch.version.cuda == '11.8':
|
||||
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
|
||||
os.environ['VLLM_USE_V1'] = '0'
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
||||
|
||||
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, MAX_CONCURRENCY, CROP_MODE, NUM_WORKERS
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import glob
|
||||
from PIL import Image
|
||||
from deepseek_ocr import DeepseekOCRForCausalLM
|
||||
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
|
||||
from process.image_process import DeepseekOCRProcessor
|
||||
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
|
||||
|
||||
|
||||
llm = LLM(
|
||||
model=MODEL_PATH,
|
||||
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
|
||||
block_size=256,
|
||||
enforce_eager=False,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
swap_space=0,
|
||||
max_num_seqs = MAX_CONCURRENCY,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.9,
|
||||
)
|
||||
|
||||
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=40, window_size=90, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=8192,
|
||||
logits_processors=logits_processors,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
class Colors:
|
||||
RED = '\033[31m'
|
||||
GREEN = '\033[32m'
|
||||
YELLOW = '\033[33m'
|
||||
BLUE = '\033[34m'
|
||||
RESET = '\033[0m'
|
||||
|
||||
def clean_formula(text):
|
||||
|
||||
formula_pattern = r'\\\[(.*?)\\\]'
|
||||
|
||||
def process_formula(match):
|
||||
formula = match.group(1)
|
||||
|
||||
formula = re.sub(r'\\quad\s*\([^)]*\)', '', formula)
|
||||
|
||||
formula = formula.strip()
|
||||
|
||||
return r'\[' + formula + r'\]'
|
||||
|
||||
cleaned_text = re.sub(formula_pattern, process_formula, text)
|
||||
|
||||
return cleaned_text
|
||||
|
||||
def re_match(text):
|
||||
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
|
||||
# mathes_image = []
|
||||
mathes_other = []
|
||||
for a_match in matches:
|
||||
mathes_other.append(a_match[0])
|
||||
return matches, mathes_other
|
||||
|
||||
def process_single_image(image):
|
||||
"""single image"""
|
||||
prompt_in = prompt
|
||||
cache_item = {
|
||||
"prompt": prompt_in,
|
||||
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
|
||||
}
|
||||
return cache_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# INPUT_PATH = OmniDocBench images path
|
||||
|
||||
os.makedirs(OUTPUT_PATH, exist_ok=True)
|
||||
|
||||
# print('image processing until processing prompts.....')
|
||||
|
||||
print(f'{Colors.RED}glob images.....{Colors.RESET}')
|
||||
|
||||
images_path = glob.glob(f'{INPUT_PATH}/*')
|
||||
|
||||
images = []
|
||||
|
||||
for image_path in images_path:
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
images.append(image)
|
||||
|
||||
prompt = PROMPT
|
||||
|
||||
# batch_inputs = []
|
||||
|
||||
|
||||
# for image in tqdm(images):
|
||||
|
||||
# prompt_in = prompt
|
||||
# cache_list = [
|
||||
# {
|
||||
# "prompt": prompt_in,
|
||||
# "multi_modal_data": {"image": Image.open(image).convert('RGB')},
|
||||
# }
|
||||
# ]
|
||||
# batch_inputs.extend(cache_list)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
|
||||
batch_inputs = list(tqdm(
|
||||
executor.map(process_single_image, images),
|
||||
total=len(images),
|
||||
desc="Pre-processed images"
|
||||
))
|
||||
|
||||
|
||||
|
||||
|
||||
outputs_list = llm.generate(
|
||||
batch_inputs,
|
||||
sampling_params=sampling_params
|
||||
)
|
||||
|
||||
|
||||
output_path = OUTPUT_PATH
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
for output, image in zip(outputs_list, images_path):
|
||||
|
||||
content = output.outputs[0].text
|
||||
mmd_det_path = output_path + image.split('/')[-1].replace('.jpg', '_det.md')
|
||||
|
||||
with open(mmd_det_path, 'w', encoding='utf-8') as afile:
|
||||
afile.write(content)
|
||||
|
||||
content = clean_formula(content)
|
||||
matches_ref, mathes_other = re_match(content)
|
||||
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
|
||||
content = content.replace(a_match_other, '').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n').replace('<center>', '').replace('</center>', '')
|
||||
|
||||
mmd_path = output_path + image.split('/')[-1].replace('.jpg', '.md')
|
||||
|
||||
with open(mmd_path, 'w', encoding='utf-8') as afile:
|
||||
afile.write(content)
|
||||
Reference in New Issue
Block a user