123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- import cv2
- import numpy as np
- from typing import Union
- from flask import jsonify
- from paddleocr import PaddleOCR
- from random import randint, seed
- from time import localtime, strftime
- from concurrent.futures import ThreadPoolExecutor
- from paddleocr.tools.infer.utility import draw_box_txt_fine
- __all__ = [
- "Args", "Response", "rand_str", "current_time", "get_ext_name", "is_image_ext", "recognize", "draw_img",
- "json_all", "str_include", "rec_multi", "save_img", "rotate"
- ]
- __StrBase = "qwertyuioplkjhgfdsazxcvbnm1234567890ZXCVBNMLKJHGFDSAQWERTYUIOP"
- __StrBaseLen = len(__StrBase) - 1
- __AcceptExtNames = ["jpg", "jpeg", "bmp", "png", "rgb", "tif", "tiff", "gif", "pdf"]
- __EngineNum = 2
- __Engines = [PaddleOCR(
- use_gpu=False,
- enable_mkldnn=True,
- det_model_dir="models/det/",
- rec_model_dir="models/rec/",
- cls_model_dir="models/cls/",
- use_angle_cls=True
- ) for _ in range(__EngineNum)]
- class Args:
- def __init__(self, **kwargs):
- self.__update(
- use_gpu=False, use_xpu=False, use_npu=False, ir_optim=True, use_tensorrt=False,
- min_subgraph_size=15, precision="fp32", gpu_mem=500, image_dir=None, page_num=0,
- det_algorithm="DB", det_model_dir="models/det/", det_limit_side_len=960, det_limit_type="max",
- det_box_type="quad", det_db_thresh=0.3, det_db_box_thresh=0.6, det_db_unclip_ratio=1.5,
- max_batch_size=10, use_dilation=False, det_db_score_mode="fast", det_east_score_thresh=0.8,
- det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_sast_score_thresh=0.5,
- det_sast_nms_thresh=0.2, det_pse_thresh=0, det_pse_box_thresh=0.85, det_pse_min_area=16,
- det_pse_scale=1, scales=[8, 16, 32], alpha=1.0, beta=1.0, fourier_degree=5,
- rec_algorithm="SVTR_LCNet", rec_model_dir="models/rec/", rec_image_inverse=True,
- rec_image_shape="3, 48, 320", rec_batch_num=6, max_text_length=25,
- rec_char_dict_path="venv/lib/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt",
- use_space_char=True, vis_font_path="static/simfang.ttf", drop_score=0.5,
- e2e_algorithm="PGNet", e2e_model_dir=None, e2e_limit_side_len=768, e2e_limit_type="max",
- e2e_pgnet_score_thresh=0.5, e2e_char_dict_path="./ppocr/utils/ic15_dict.txt",
- e2e_pgnet_valid_set="totaltext", e2e_pgnet_mode="fast", use_angle_cls=True,
- cls_model_dir="models/cls/", cls_image_shape="3, 48, 192", label_list=["0", "180"],
- cls_batch_num=6, cls_thresh=0.9, enable_mkldnn=True, cpu_threads=10, use_pdserving=False,
- sr_model_dir=None, sr_image_shape="3, 32, 128", sr_batch_num=1,
- draw_img_save_dir="static/rec_res/", save_crop_res=False, crop_res_save_dir="./output",
- use_mp=False, benchmark=False, save_log_path="./log_output/",
- show_log=True, use_onnx=False, output="./output", table_max_len=488, table_algorithm="TableAttn",
- table_model_dir=None, merge_no_span_structure=True, table_char_dict_path=None,
- layout_model_dir=None, layout_dict_path=None, layout_score_threshold=0.5,
- layout_nms_threshold=0.5, kie_algorithm="LayoutXLM", ser_model_dir=None, re_model_dir=None,
- use_visual_backbone=True, ser_dict_path="../train_data/XFUND/class_list_xfun.txt",
- ocr_order_method=None, mode="structure", image_orientation=False, layout=True, table=True,
- ocr=True, recovery=False, use_pdf2docx_api=False, lang="ch", det=True, rec=True, type="ocr",
- ocr_version="PP-OCRv3", structure_version="PP-StructureV2"
- )
- self.__update(**kwargs)
- def __update(self, **kwargs):
- for k, v in kwargs:
- self.__dict__[k] = v
- def __setattr__(self, key: "str", value):
- self.__dict__[key] = value
- def __getattribute__(self, key: "str"):
- assert key in self.__dict__.keys()
- return self.__dict__[key]
- def rand_str(size: "int" = 8) -> "str":
- return "".join([__StrBase[randint(0, __StrBaseLen)] for _ in range(size)])
- def current_time() -> "str":
- return strftime("%Y-%m-%d_%H-%M-%S", localtime())
- def get_ext_name(name: "str") -> "str":
- return name.split(".")[-1].lower()
- def is_image_ext(ext: "str") -> bool:
- return ext in __AcceptExtNames
- def Response(message: "str" = None, data=None):
- if message is None:
- return jsonify(success=True, message="操作成功", data=data)
- return jsonify(success=False, message=message, data=data)
- def _rec(img, which: "int" = 0) -> "tuple[list, tuple]":
- return __Engines[which % __EngineNum].ocr(img)[0], img.shape
- def rec_multi(images: "list[np.ndarray]") -> "list[tuple]": # list[_rec]
- pool = ThreadPoolExecutor(__EngineNum)
- tasks = [pool.submit(_rec, one, i) for i, one in enumerate(images)]
- return [task.result() for task in tasks]
- def recognize(content: "str") -> "tuple[list, tuple]":
- img = cv2.imdecode(np.fromstring(content, np.uint8), 1) # noqa
- return _rec(img)
- def draw_img(shape: "tuple", data: "list[dict]", path: "str", drop: "float" = 0.5):
- img = np.ones(shape, dtype=np.uint8) * 255
- seed(0)
- for one in data:
- if one["rate"] < drop:
- continue
- color = (randint(0, 255), randint(0, 255), randint(0, 255))
- text = draw_box_txt_fine((shape[1], shape[0]), one["pos"], one["word"], font_path="static/ppocr/simfang.ttf")
- pts = np.array(one["pos"], np.int32).reshape((-1, 1, 2))
- cv2.polylines(text, [pts], True, color, 1) # noqa
- img = cv2.bitwise_and(img, text) # noqa
- cv2.imwrite(path, np.array(img)) # noqa
- def json_all(data: "dict or list") -> "bool":
- if isinstance(data, list):
- for item in data:
- if isinstance(item, str) and not item:
- return False
- elif isinstance(item, (list, dict)) and not json_all(item):
- return False
- return True
- elif isinstance(data, dict):
- for value in data.values():
- if isinstance(value, str) and not value:
- return False
- elif isinstance(value, (list, dict)) and not json_all(value):
- return False
- return True
- raise TypeError(f"except node type are: [list, dict], but got a {type(data)} instead.")
- def str_include(str_long: "str", str_short: "str") -> "bool":
- for it in str_short:
- if it not in str_long:
- return False
- return True
- def save_img(filename: "str", content: "Union[bytes, np.ndarray]"):
- if isinstance(content, np.ndarray):
- return cv2.imwrite(filename, content) # noqa
- with open(filename, "wb") as fp:
- fp.write(content)
- fp.close()
- def rotate(content: "str") -> "list[np.ndarray]":
- img = cv2.imdecode(np.fromstring(content, np.uint8), 1) # noqa
- return [img, np.rot90(img), np.rot90(img, 2), np.rot90(img, 3)]
|