detector.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import numpy as np
  2. from .utils import *
  3. from hmOCR.argument import Args
  4. class Detector:
  5. def __init__(self, args: "Args"):
  6. self.det_algorithm = args.det_algorithm # DB
  7. self.det_box_type = args.det_box_type
  8. pre_process_list = [{
  9. "DetResizeForTest": {
  10. "limit_side_len": args.det_limit_side_len,
  11. "limit_type": args.det_limit_type,
  12. }
  13. }, {
  14. "NormalizeImage": {
  15. "std": [0.229, 0.224, 0.225],
  16. "mean": [0.485, 0.456, 0.406],
  17. "scale": 1 / 255,
  18. "order": "hwc"
  19. }
  20. }, {
  21. "ToCHWImage": None
  22. }, {
  23. "KeepKeys": {
  24. "keep_keys": ["image", "shape"]
  25. }
  26. }]
  27. postprocess_params = {
  28. "name": "DBPostProcess", "thresh": args.det_db_thresh,
  29. "box_thresh": args.det_db_box_thresh, "max_candidates": 1000,
  30. "unclip_ratio": args.det_db_unclip_ratio, "use_dilation": args.det_use_dilation,
  31. "score_mode": args.det_db_score_mode, "box_type": args.det_box_type
  32. }
  33. self.pre_operator = create_operators(pre_process_list)
  34. self.post_operator = build_post_process(postprocess_params)
  35. self.predictor, self.input_tensor, self.output_tensors = create_predictor(args, "det")
  36. @staticmethod
  37. def order_points_clockwise(pts):
  38. rect = np.zeros((4, 2), dtype="float32")
  39. s = pts.sum(axis=1)
  40. rect[0] = pts[np.argmin(s)]
  41. rect[2] = pts[np.argmax(s)]
  42. tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
  43. diff = np.diff(np.array(tmp), axis=1)
  44. rect[1] = tmp[np.argmin(diff)]
  45. rect[3] = tmp[np.argmax(diff)]
  46. return rect
  47. @staticmethod
  48. def clip_det_res(points, img_height, img_width):
  49. for pno in range(points.shape[0]):
  50. points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
  51. points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
  52. return points
  53. def filter_tag_det_res(self, dt_boxes, image_shape):
  54. img_height, img_width = image_shape[0:2]
  55. dt_boxes_new = []
  56. for box in dt_boxes:
  57. if type(box) is list:
  58. box = np.array(box)
  59. box = self.order_points_clockwise(box)
  60. box = self.clip_det_res(box, img_height, img_width)
  61. rect_width = int(np.linalg.norm(box[0] - box[1]))
  62. rect_height = int(np.linalg.norm(box[0] - box[3]))
  63. if rect_width <= 3 or rect_height <= 3:
  64. continue
  65. dt_boxes_new.append(box)
  66. dt_boxes = np.array(dt_boxes_new)
  67. return dt_boxes
  68. def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
  69. img_height, img_width = image_shape[0:2]
  70. dt_boxes_new = []
  71. for box in dt_boxes:
  72. if type(box) is list:
  73. box = np.array(box)
  74. box = self.clip_det_res(box, img_height, img_width)
  75. dt_boxes_new.append(box)
  76. dt_boxes = np.array(dt_boxes_new)
  77. return dt_boxes
  78. def __call__(self, img):
  79. ori_im = img.copy()
  80. data = {"image": img}
  81. data = transform(data, self.pre_operator)
  82. img, shape_list = data
  83. if img is None:
  84. return None
  85. img = np.expand_dims(img, axis=0)
  86. shape_list = np.expand_dims(shape_list, axis=0)
  87. img = img.copy()
  88. self.input_tensor.copy_from_cpu(img)
  89. self.predictor.run()
  90. outputs = [out.copy_to_cpu() for out in self.output_tensors]
  91. preds = {"maps": outputs[0]}
  92. post_result = self.post_operator(preds, shape_list)
  93. dt_boxes = post_result[0]["points"]
  94. if self.det_box_type == "poly":
  95. dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
  96. else:
  97. dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
  98. return dt_boxes