123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- import numpy as np
- from .utils import *
- from hmOCR.argument import Args
- class Detector:
- def __init__(self, args: "Args"):
- self.det_algorithm = args.det_algorithm # DB
- self.det_box_type = args.det_box_type
- pre_process_list = [{
- "DetResizeForTest": {
- "limit_side_len": args.det_limit_side_len,
- "limit_type": args.det_limit_type,
- }
- }, {
- "NormalizeImage": {
- "std": [0.229, 0.224, 0.225],
- "mean": [0.485, 0.456, 0.406],
- "scale": 1 / 255,
- "order": "hwc"
- }
- }, {
- "ToCHWImage": None
- }, {
- "KeepKeys": {
- "keep_keys": ["image", "shape"]
- }
- }]
- postprocess_params = {
- "name": "DBPostProcess", "thresh": args.det_db_thresh,
- "box_thresh": args.det_db_box_thresh, "max_candidates": 1000,
- "unclip_ratio": args.det_db_unclip_ratio, "use_dilation": args.det_use_dilation,
- "score_mode": args.det_db_score_mode, "box_type": args.det_box_type
- }
- self.pre_operator = create_operators(pre_process_list)
- self.post_operator = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors = create_predictor(args, "det")
- @staticmethod
- def order_points_clockwise(pts):
- rect = np.zeros((4, 2), dtype="float32")
- s = pts.sum(axis=1)
- rect[0] = pts[np.argmin(s)]
- rect[2] = pts[np.argmax(s)]
- tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
- diff = np.diff(np.array(tmp), axis=1)
- rect[1] = tmp[np.argmin(diff)]
- rect[3] = tmp[np.argmax(diff)]
- return rect
- @staticmethod
- def clip_det_res(points, img_height, img_width):
- for pno in range(points.shape[0]):
- points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
- points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
- return points
- def filter_tag_det_res(self, dt_boxes, image_shape):
- img_height, img_width = image_shape[0:2]
- dt_boxes_new = []
- for box in dt_boxes:
- if type(box) is list:
- box = np.array(box)
- box = self.order_points_clockwise(box)
- box = self.clip_det_res(box, img_height, img_width)
- rect_width = int(np.linalg.norm(box[0] - box[1]))
- rect_height = int(np.linalg.norm(box[0] - box[3]))
- if rect_width <= 3 or rect_height <= 3:
- continue
- dt_boxes_new.append(box)
- dt_boxes = np.array(dt_boxes_new)
- return dt_boxes
- def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
- img_height, img_width = image_shape[0:2]
- dt_boxes_new = []
- for box in dt_boxes:
- if type(box) is list:
- box = np.array(box)
- box = self.clip_det_res(box, img_height, img_width)
- dt_boxes_new.append(box)
- dt_boxes = np.array(dt_boxes_new)
- return dt_boxes
- def __call__(self, img):
- ori_im = img.copy()
- data = {"image": img}
- data = transform(data, self.pre_operator)
- img, shape_list = data
- if img is None:
- return None
- img = np.expand_dims(img, axis=0)
- shape_list = np.expand_dims(shape_list, axis=0)
- img = img.copy()
- self.input_tensor.copy_from_cpu(img)
- self.predictor.run()
- outputs = [out.copy_to_cpu() for out in self.output_tensors]
- preds = {"maps": outputs[0]}
- post_result = self.post_operator(preds, shape_list)
- dt_boxes = post_result[0]["points"]
- if self.det_box_type == "poly":
- dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
- else:
- dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
- return dt_boxes
|