util.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import cv2
  2. import numpy as np
  3. from typing import Union
  4. from flask import jsonify
  5. from paddleocr import PaddleOCR
  6. from random import randint, seed
  7. from time import localtime, strftime
  8. from concurrent.futures import ThreadPoolExecutor
  9. from paddleocr.tools.infer.utility import draw_box_txt_fine
  10. __all__ = [
  11. "Args", "Response", "rand_str", "current_time", "get_ext_name", "is_image_ext", "recognize", "draw_img",
  12. "json_all", "str_include", "rec_multi", "save_img", "rotate"
  13. ]
  14. __StrBase = "qwertyuioplkjhgfdsazxcvbnm1234567890ZXCVBNMLKJHGFDSAQWERTYUIOP"
  15. __StrBaseLen = len(__StrBase) - 1
  16. __AcceptExtNames = ["jpg", "jpeg", "bmp", "png", "rgb", "tif", "tiff", "gif", "pdf"]
  17. __EngineNum = 2
  18. __Engines = [PaddleOCR(
  19. use_gpu=False,
  20. enable_mkldnn=True,
  21. det_model_dir="models/det/",
  22. rec_model_dir="models/rec/",
  23. cls_model_dir="models/cls/",
  24. use_angle_cls=True
  25. ) for _ in range(__EngineNum)]
  26. class Args:
  27. def __init__(self, **kwargs):
  28. self.__update(
  29. use_gpu=False, use_xpu=False, use_npu=False, ir_optim=True, use_tensorrt=False,
  30. min_subgraph_size=15, precision="fp32", gpu_mem=500, image_dir=None, page_num=0,
  31. det_algorithm="DB", det_model_dir="models/det/", det_limit_side_len=960, det_limit_type="max",
  32. det_box_type="quad", det_db_thresh=0.3, det_db_box_thresh=0.6, det_db_unclip_ratio=1.5,
  33. max_batch_size=10, use_dilation=False, det_db_score_mode="fast", det_east_score_thresh=0.8,
  34. det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_sast_score_thresh=0.5,
  35. det_sast_nms_thresh=0.2, det_pse_thresh=0, det_pse_box_thresh=0.85, det_pse_min_area=16,
  36. det_pse_scale=1, scales=[8, 16, 32], alpha=1.0, beta=1.0, fourier_degree=5,
  37. rec_algorithm="SVTR_LCNet", rec_model_dir="models/rec/", rec_image_inverse=True,
  38. rec_image_shape="3, 48, 320", rec_batch_num=6, max_text_length=25,
  39. rec_char_dict_path="venv/lib/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt",
  40. use_space_char=True, vis_font_path="static/simfang.ttf", drop_score=0.5,
  41. e2e_algorithm="PGNet", e2e_model_dir=None, e2e_limit_side_len=768, e2e_limit_type="max",
  42. e2e_pgnet_score_thresh=0.5, e2e_char_dict_path="./ppocr/utils/ic15_dict.txt",
  43. e2e_pgnet_valid_set="totaltext", e2e_pgnet_mode="fast", use_angle_cls=True,
  44. cls_model_dir="models/cls/", cls_image_shape="3, 48, 192", label_list=["0", "180"],
  45. cls_batch_num=6, cls_thresh=0.9, enable_mkldnn=True, cpu_threads=10, use_pdserving=False,
  46. sr_model_dir=None, sr_image_shape="3, 32, 128", sr_batch_num=1,
  47. draw_img_save_dir="static/rec_res/", save_crop_res=False, crop_res_save_dir="./output",
  48. use_mp=False, benchmark=False, save_log_path="./log_output/",
  49. show_log=True, use_onnx=False, output="./output", table_max_len=488, table_algorithm="TableAttn",
  50. table_model_dir=None, merge_no_span_structure=True, table_char_dict_path=None,
  51. layout_model_dir=None, layout_dict_path=None, layout_score_threshold=0.5,
  52. layout_nms_threshold=0.5, kie_algorithm="LayoutXLM", ser_model_dir=None, re_model_dir=None,
  53. use_visual_backbone=True, ser_dict_path="../train_data/XFUND/class_list_xfun.txt",
  54. ocr_order_method=None, mode="structure", image_orientation=False, layout=True, table=True,
  55. ocr=True, recovery=False, use_pdf2docx_api=False, lang="ch", det=True, rec=True, type="ocr",
  56. ocr_version="PP-OCRv3", structure_version="PP-StructureV2"
  57. )
  58. self.__update(**kwargs)
  59. def __update(self, **kwargs):
  60. for k, v in kwargs:
  61. self.__dict__[k] = v
  62. def __setattr__(self, key: "str", value):
  63. self.__dict__[key] = value
  64. def __getattribute__(self, key: "str"):
  65. assert key in self.__dict__.keys()
  66. return self.__dict__[key]
  67. def rand_str(size: "int" = 8) -> "str":
  68. return "".join([__StrBase[randint(0, __StrBaseLen)] for _ in range(size)])
  69. def current_time() -> "str":
  70. return strftime("%Y-%m-%d_%H-%M-%S", localtime())
  71. def get_ext_name(name: "str") -> "str":
  72. return name.split(".")[-1].lower()
  73. def is_image_ext(ext: "str") -> bool:
  74. return ext in __AcceptExtNames
  75. def Response(message: "str" = None, data=None):
  76. if message is None:
  77. return jsonify(success=True, message="操作成功", data=data)
  78. return jsonify(success=False, message=message, data=data)
  79. def _rec(img, which: "int" = 0) -> "tuple[list, tuple]":
  80. return __Engines[which % __EngineNum].ocr(img)[0], img.shape
  81. def rec_multi(images: "list[np.ndarray]") -> "list[tuple]": # list[_rec]
  82. pool = ThreadPoolExecutor(__EngineNum)
  83. tasks = [pool.submit(_rec, one, i) for i, one in enumerate(images)]
  84. return [task.result() for task in tasks]
  85. def recognize(content: "str") -> "tuple[list, tuple]":
  86. img = cv2.imdecode(np.fromstring(content, np.uint8), 1) # noqa
  87. return _rec(img)
  88. def draw_img(shape: "tuple", data: "list[dict]", path: "str", drop: "float" = 0.5):
  89. img = np.ones(shape, dtype=np.uint8) * 255
  90. seed(0)
  91. for one in data:
  92. if one["rate"] < drop:
  93. continue
  94. color = (randint(0, 255), randint(0, 255), randint(0, 255))
  95. text = draw_box_txt_fine((shape[1], shape[0]), one["pos"], one["word"], font_path="static/ppocr/simfang.ttf")
  96. pts = np.array(one["pos"], np.int32).reshape((-1, 1, 2))
  97. cv2.polylines(text, [pts], True, color, 1) # noqa
  98. img = cv2.bitwise_and(img, text) # noqa
  99. cv2.imwrite(path, np.array(img)) # noqa
  100. def json_all(data: "dict or list") -> "bool":
  101. if isinstance(data, list):
  102. for item in data:
  103. if isinstance(item, str) and not item:
  104. return False
  105. elif isinstance(item, (list, dict)) and not json_all(item):
  106. return False
  107. return True
  108. elif isinstance(data, dict):
  109. for value in data.values():
  110. if isinstance(value, str) and not value:
  111. return False
  112. elif isinstance(value, (list, dict)) and not json_all(value):
  113. return False
  114. return True
  115. raise TypeError(f"except node type are: [list, dict], but got a {type(data)} instead.")
  116. def str_include(str_long: "str", str_short: "str") -> "bool":
  117. for it in str_short:
  118. if it not in str_long:
  119. return False
  120. return True
  121. def save_img(filename: "str", content: "Union[bytes, np.ndarray]"):
  122. if isinstance(content, np.ndarray):
  123. return cv2.imwrite(filename, content) # noqa
  124. with open(filename, "wb") as fp:
  125. fp.write(content)
  126. fp.close()
  127. def rotate(content: "str") -> "list[np.ndarray]":
  128. img = cv2.imdecode(np.fromstring(content, np.uint8), 1) # noqa
  129. return [img, np.rot90(img), np.rot90(img, 2), np.rot90(img, 3)]