util.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import cv2
  2. import numpy as np
  3. from flask import jsonify
  4. from paddleocr import PaddleOCR
  5. from random import randint, seed
  6. from time import localtime, strftime
  7. from paddleocr.tools.infer.utility import draw_box_txt_fine
  8. __all__ = [
  9. "Args", "Response", "rand_str", "current_time", "get_ext_name", "is_image_ext", "recognize", "draw_img"
  10. ]
  11. __StrBase = "qwertyuioplkjhgfdsazxcvbnm1234567890ZXCVBNMLKJHGFDSAQWERTYUIOP"
  12. __StrBaseLen = len(__StrBase) - 1
  13. __AcceptExtNames = ["jpg", "jpeg", "bmp", "png", "rgb", "tif", "tiff", "gif", "pdf"]
  14. __OcrEngine = PaddleOCR(
  15. use_gpu=False,
  16. enable_mkldnn=True,
  17. det_model_dir="models/det/",
  18. rec_model_dir="models/rec/",
  19. cls_model_dir="models/cls/",
  20. use_angle_cls=True,
  21. use_space_char=True
  22. )
  23. class Args:
  24. def __init__(self, **kwargs):
  25. self.__update(
  26. use_gpu=False, use_xpu=False, use_npu=False, ir_optim=True, use_tensorrt=False,
  27. min_subgraph_size=15, precision="fp32", gpu_mem=500, image_dir=None, page_num=0,
  28. det_algorithm="DB", det_model_dir="models/det/", det_limit_side_len=960, det_limit_type="max",
  29. det_box_type="quad", det_db_thresh=0.3, det_db_box_thresh=0.6, det_db_unclip_ratio=1.5,
  30. max_batch_size=10, use_dilation=False, det_db_score_mode="fast", det_east_score_thresh=0.8,
  31. det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_sast_score_thresh=0.5,
  32. det_sast_nms_thresh=0.2, det_pse_thresh=0, det_pse_box_thresh=0.85, det_pse_min_area=16,
  33. det_pse_scale=1, scales=[8, 16, 32], alpha=1.0, beta=1.0, fourier_degree=5,
  34. rec_algorithm="SVTR_LCNet", rec_model_dir="models/rec/", rec_image_inverse=True,
  35. rec_image_shape="3, 48, 320", rec_batch_num=6, max_text_length=25,
  36. rec_char_dict_path="venv/lib/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt",
  37. use_space_char=True, vis_font_path="static/simfang.ttf", drop_score=0.5,
  38. e2e_algorithm="PGNet", e2e_model_dir=None, e2e_limit_side_len=768, e2e_limit_type="max",
  39. e2e_pgnet_score_thresh=0.5, e2e_char_dict_path="./ppocr/utils/ic15_dict.txt",
  40. e2e_pgnet_valid_set="totaltext", e2e_pgnet_mode="fast", use_angle_cls=True,
  41. cls_model_dir="models/cls/", cls_image_shape="3, 48, 192", label_list=["0", "180"],
  42. cls_batch_num=6, cls_thresh=0.9, enable_mkldnn=True, cpu_threads=10, use_pdserving=False,
  43. sr_model_dir=None, sr_image_shape="3, 32, 128", sr_batch_num=1,
  44. draw_img_save_dir="static/rec_res/", save_crop_res=False, crop_res_save_dir="./output",
  45. use_mp=False, benchmark=False, save_log_path="./log_output/",
  46. show_log=True, use_onnx=False, output="./output", table_max_len=488, table_algorithm="TableAttn",
  47. table_model_dir=None, merge_no_span_structure=True, table_char_dict_path=None,
  48. layout_model_dir=None, layout_dict_path=None, layout_score_threshold=0.5,
  49. layout_nms_threshold=0.5, kie_algorithm="LayoutXLM", ser_model_dir=None, re_model_dir=None,
  50. use_visual_backbone=True, ser_dict_path="../train_data/XFUND/class_list_xfun.txt",
  51. ocr_order_method=None, mode="structure", image_orientation=False, layout=True, table=True,
  52. ocr=True, recovery=False, use_pdf2docx_api=False, lang="ch", det=True, rec=True, type="ocr",
  53. ocr_version="PP-OCRv3", structure_version="PP-StructureV2"
  54. )
  55. self.__update(**kwargs)
  56. def __update(self, **kwargs):
  57. for k, v in kwargs:
  58. self.__dict__[k] = v
  59. def __setattr__(self, key: "str", value):
  60. self.__dict__[key] = value
  61. def __getattribute__(self, key: "str"):
  62. assert key in self.__dict__.keys()
  63. return self.__dict__[key]
  64. def rand_str(size: "int" = 8) -> "str":
  65. return "".join([__StrBase[randint(0, __StrBaseLen)] for _ in range(size)])
  66. def current_time() -> "str":
  67. return strftime("%Y-%m-%d_%H-%M-%S", localtime())
  68. def get_ext_name(name: "str") -> "str":
  69. return name.split(".")[-1].lower()
  70. def is_image_ext(ext: "str") -> bool:
  71. return ext in __AcceptExtNames
  72. def Response(message: "str" = None, data=None):
  73. if message is None:
  74. return jsonify(success=True, message="操作成功", data=data)
  75. return jsonify(success=False, message=message, data=data)
  76. def recognize(content: "str") -> "tuple[list, tuple]":
  77. img = cv2.imdecode(np.fromstring(content, np.uint8), 1) # noqa
  78. return __OcrEngine.ocr(img)[0], img.shape
  79. def draw_img(shape: "tuple", data: "list[dict]", path: "str", drop: "float" = 0.5):
  80. img = np.ones(shape, dtype=np.uint8) * 255
  81. seed(0)
  82. for one in data:
  83. if one["rate"] < drop:
  84. continue
  85. color = (randint(0, 255), randint(0, 255), randint(0, 255))
  86. text = draw_box_txt_fine((shape[1], shape[0]), one["pos"], one["word"], font_path="static/simfang.ttf")
  87. pts = np.array(one["pos"], np.int32).reshape((-1, 1, 2))
  88. cv2.polylines(text, [pts], True, color, 1) # noqa
  89. img = cv2.bitwise_and(img, text) # noqa
  90. cv2.imwrite(path, np.array(img)) # noqa