util.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import cv2
  2. import numpy as np
  3. from flask import jsonify
  4. from paddleocr import PaddleOCR
  5. from random import randint, seed
  6. from time import localtime, strftime
  7. from paddleocr.tools.infer.utility import draw_box_txt_fine
  8. __all__ = [
  9. "Args", "Response", "rand_str", "current_time", "get_ext_name", "is_image_ext", "recognize", "draw_img"
  10. ]
  11. __StrBase = "qwertyuioplkjhgfdsazxcvbnm1234567890ZXCVBNMLKJHGFDSAQWERTYUIOP"
  12. __StrBaseLen = len(__StrBase) - 1
  13. __AcceptExtNames = ["jpg", "jpeg", "bmp", "png", "rgb", "tif", "tiff", "gif", "pdf"]
  14. __OcrEngine = PaddleOCR(
  15. use_gpu=True,
  16. det_model_dir="models/det/",
  17. rec_model_dir="models/rec/",
  18. cls_model_dir="models/cls/",
  19. use_angle_cls=True,
  20. use_space_char=True
  21. )
  22. class Args:
  23. def __init__(self, **kwargs):
  24. self.__update(
  25. use_gpu=False, use_xpu=False, use_npu=False, ir_optim=True, use_tensorrt=False,
  26. min_subgraph_size=15, precision="fp32", gpu_mem=500, image_dir=None, page_num=0,
  27. det_algorithm="DB", det_model_dir="models/det/", det_limit_side_len=960, det_limit_type="max",
  28. det_box_type="quad", det_db_thresh=0.3, det_db_box_thresh=0.6, det_db_unclip_ratio=1.5,
  29. max_batch_size=10, use_dilation=False, det_db_score_mode="fast", det_east_score_thresh=0.8,
  30. det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_sast_score_thresh=0.5,
  31. det_sast_nms_thresh=0.2, det_pse_thresh=0, det_pse_box_thresh=0.85, det_pse_min_area=16,
  32. det_pse_scale=1, scales=[8, 16, 32], alpha=1.0, beta=1.0, fourier_degree=5,
  33. rec_algorithm="SVTR_LCNet", rec_model_dir="models/rec/", rec_image_inverse=True,
  34. rec_image_shape="3, 48, 320", rec_batch_num=6, max_text_length=25,
  35. rec_char_dict_path="venv/lib/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt",
  36. use_space_char=True, vis_font_path="static/simfang.ttf", drop_score=0.5,
  37. e2e_algorithm="PGNet", e2e_model_dir=None, e2e_limit_side_len=768, e2e_limit_type="max",
  38. e2e_pgnet_score_thresh=0.5, e2e_char_dict_path="./ppocr/utils/ic15_dict.txt",
  39. e2e_pgnet_valid_set="totaltext", e2e_pgnet_mode="fast", use_angle_cls=True,
  40. cls_model_dir="models/cls/", cls_image_shape="3, 48, 192", label_list=["0", "180"],
  41. cls_batch_num=6, cls_thresh=0.9, enable_mkldnn=True, cpu_threads=10, use_pdserving=False,
  42. sr_model_dir=None, sr_image_shape="3, 32, 128", sr_batch_num=1,
  43. draw_img_save_dir="static/rec_res/", save_crop_res=False, crop_res_save_dir="./output",
  44. use_mp=False, benchmark=False, save_log_path="./log_output/",
  45. show_log=True, use_onnx=False, output="./output", table_max_len=488, table_algorithm="TableAttn",
  46. table_model_dir=None, merge_no_span_structure=True, table_char_dict_path=None,
  47. layout_model_dir=None, layout_dict_path=None, layout_score_threshold=0.5,
  48. layout_nms_threshold=0.5, kie_algorithm="LayoutXLM", ser_model_dir=None, re_model_dir=None,
  49. use_visual_backbone=True, ser_dict_path="../train_data/XFUND/class_list_xfun.txt",
  50. ocr_order_method=None, mode="structure", image_orientation=False, layout=True, table=True,
  51. ocr=True, recovery=False, use_pdf2docx_api=False, lang="ch", det=True, rec=True, type="ocr",
  52. ocr_version="PP-OCRv3", structure_version="PP-StructureV2"
  53. )
  54. self.__update(**kwargs)
  55. def __update(self, **kwargs):
  56. for k, v in kwargs:
  57. self.__dict__[k] = v
  58. def __setattr__(self, key: "str", value):
  59. self.__dict__[key] = value
  60. def __getattribute__(self, key: "str"):
  61. assert key in self.__dict__.keys()
  62. return self.__dict__[key]
  63. def rand_str(size: "int" = 8) -> "str":
  64. return "".join([__StrBase[randint(0, __StrBaseLen)] for _ in range(size)])
  65. def current_time() -> "str":
  66. return strftime("%Y-%m-%d_%H-%M-%S", localtime())
  67. def get_ext_name(name: "str") -> "str":
  68. return name.split(".")[-1].lower()
  69. def is_image_ext(ext: "str") -> bool:
  70. return ext in __AcceptExtNames
  71. def Response(message: "str" = None, data=None):
  72. if message is None:
  73. return jsonify(success=True, message="操作成功", data=data)
  74. return jsonify(success=False, message=message, data=data)
  75. def recognize(content: "str") -> "tuple[list, tuple]":
  76. img = cv2.imdecode(np.fromstring(content, np.uint8), 1) # noqa
  77. return __OcrEngine.ocr(img)[0], img.shape
  78. def draw_img(shape: "tuple", data: "list[dict]", path: "str", drop: "float" = 0.5):
  79. img = np.ones(shape, dtype=np.uint8) * 255
  80. seed(0)
  81. for one in data:
  82. if one["rate"] < drop:
  83. continue
  84. color = (randint(0, 255), randint(0, 255), randint(0, 255))
  85. text = draw_box_txt_fine((shape[1], shape[0]), one["pos"], one["word"], font_path="static/simfang.ttf")
  86. pts = np.array(one["pos"], np.int32).reshape((-1, 1, 2))
  87. cv2.polylines(text, [pts], True, color, 1) # noqa
  88. img = cv2.bitwise_and(img, text) # noqa
  89. cv2.imwrite(path, np.array(img)) # noqa