core.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from .parts import *
  2. from .utility import *
  3. from copy import deepcopy
  4. from .argument import ArgType
  5. from concurrent.futures import ThreadPoolExecutor
  6. __all__ = ["HuiMvOcr"]
  7. class HuiMvOcr:
  8. __worker_count = 1
  9. def __init__(self, args: "ArgType"):
  10. self.det = Detector(args)
  11. self.cls = Classifier(args)
  12. self.rec = Recognizer(args)
  13. self.det_box_type = args.det_box_type
  14. self.drop_score = args.drop_score
  15. self.crop_image_res_index = 0
  16. def ocr_one(self, img, cls: "bool" = False, use_space: "bool" = True):
  17. ori_im = img.copy()
  18. dt_boxes = self.det(img)
  19. if dt_boxes is None:
  20. return None
  21. dt_boxes = sorted_boxes(dt_boxes)
  22. size = len(dt_boxes)
  23. img_crop_list = [...] * size
  24. for i in range(size):
  25. box = deepcopy(dt_boxes[i])
  26. if self.det_box_type == "quad":
  27. img_crop_list[i] = get_rotate_crop_image(ori_im, box)
  28. else:
  29. img_crop_list[i] = get_min_area_rect_crop(ori_im, box)
  30. if cls:
  31. img_crop_list = self.cls(img_crop_list)
  32. rec_res = self.rec(img_crop_list, use_space=use_space)
  33. filter_rec_res = []
  34. for text, score in rec_res:
  35. if score > self.drop_score:
  36. filter_rec_res.append((text, score))
  37. return filter_rec_res
  38. def ocr_multi(self, img_list, cls: "bool" = False, use_space: "bool" = True):
  39. pool = ThreadPoolExecutor(HuiMvOcr.__worker_count)
  40. loop = range(len(img_list))
  41. tasks = [pool.submit(self.ocr_one, img_list[i], cls, use_space) for i in loop]
  42. return [tasks[i].result() for i in loop]