core.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import logging
  2. from .parts import *
  3. from .utility import *
  4. from time import sleep
  5. from numpy import ndarray
  6. from copy import deepcopy
  7. from .argument import ArgType
  8. from threading import Thread, Lock
  9. __all__ = ["Engine", "HuiMvOCR"]
  10. logger = logging.getLogger("hm-ocr")
  11. logger.setLevel(logging.INFO)
  12. class Engine:
  13. __worker_count = 1
  14. def __init__(self, args: "ArgType"):
  15. self.det = Detector(args)
  16. self.cls = Classifier(args)
  17. self.rec = Recognizer(args)
  18. self.det_box_type = args.det_box_type
  19. self.drop_score = args.drop_score
  20. self.crop_image_res_index = 0
  21. def __call__(self, img, cls: "bool" = True, use_space: "bool" = True):
  22. ori_im = img.copy()
  23. dt_boxes = self.det(img)
  24. if dt_boxes is None:
  25. return None
  26. dt_boxes = sorted_boxes(dt_boxes)
  27. size = len(dt_boxes)
  28. img_crop_list = [...] * size
  29. for i in range(size):
  30. box = deepcopy(dt_boxes[i])
  31. if self.det_box_type == "quad":
  32. img_crop_list[i] = get_rotate_crop_image(ori_im, box)
  33. else:
  34. img_crop_list[i] = get_min_area_rect_crop(ori_im, box)
  35. if cls:
  36. img_crop_list = self.cls(img_crop_list)
  37. rec_res = self.rec(img_crop_list, use_space=use_space)
  38. filter_rec_res = []
  39. for text, score in rec_res:
  40. if score > self.drop_score:
  41. filter_rec_res.append((text, score))
  42. return filter_rec_res
  43. class HuiMvOCR:
  44. __lock = Lock()
  45. __tasks = [] # item: [img: "ndarray", ocr_args: "dict", callback: "fn", callback_args: "dict"]
  46. def __init__(self, args: "ArgType"):
  47. self.interval = args.interval
  48. for i in range(args.workers):
  49. Thread(target=self.__processor, args=(Engine(args), i), daemon=True).start()
  50. @staticmethod
  51. def __processor(ocr: "Engine", eid: "int"):
  52. logger.info(f"================ Engine[{eid}] initialized ================")
  53. while True:
  54. if HuiMvOCR.__tasks:
  55. HuiMvOCR.__lock.acquire()
  56. img, ocr_args, callback, callback_args = HuiMvOCR.__tasks.pop(0)
  57. HuiMvOCR.__lock.release()
  58. res = ocr(img)
  59. callback(res, **callback_args)
  60. sleep(0.1)
  61. def rec_one(self, img: "ndarray", cls: "bool" = True, use_space: "bool" = True):
  62. def callback(res):
  63. foo[1] = res
  64. foo[0] = 1
  65. foo = [0, None] # finish_count, result
  66. args = {"cls": cls, "use_space": use_space}
  67. HuiMvOCR.__tasks.append([img, args, callback, {}])
  68. while foo[0] < 1:
  69. sleep(self.interval)
  70. return foo[1]
  71. def rec_multi(self, images: "list[ndarray]", cls: "bool" = False, use_space: "bool" = False):
  72. def callback(res, index):
  73. foo[1][index] = res
  74. foo[0] += 1
  75. size, args = len(images), {"cls": cls, "use_space": use_space}
  76. foo = [0, [...] * size] # finish_count, result
  77. for i in range(size):
  78. HuiMvOCR.__tasks.append([images[i], args, callback, {"index": i}])
  79. while foo[0] < size:
  80. sleep(self.interval)
  81. return foo[1]