core.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from .parts import *
  2. from .utility import *
  3. from time import time
  4. from numpy import ndarray
  5. from copy import deepcopy
  6. from .argument import ArgType
  7. from utils.logger import Logger
  8. from threading import Lock
  9. from concurrent.futures import ThreadPoolExecutor
  10. __all__ = ["Engine", "HuiMvOCR"]
  11. class Engine:
  12. __worker_count = 1
  13. def __init__(self, args: "ArgType"):
  14. self.det = Detector(args)
  15. self.cls = Classifier(args)
  16. self.rec = Recognizer(args)
  17. self.det_box_type = args.det_box_type
  18. self.drop_score = args.drop_score
  19. self.crop_image_res_index = 0
  20. def __call__(self, img, cls: "bool" = True, use_space: "bool" = True):
  21. ori_im = img.copy()
  22. dt_boxes = self.det(img)
  23. if dt_boxes is None:
  24. return None
  25. dt_boxes = sorted_boxes(dt_boxes)
  26. size = len(dt_boxes)
  27. img_crop_list = [...] * size
  28. for i in range(size):
  29. box = deepcopy(dt_boxes[i])
  30. if self.det_box_type == "quad":
  31. img_crop_list[i] = get_rotate_crop_image(ori_im, box)
  32. else:
  33. img_crop_list[i] = get_min_area_rect_crop(ori_im, box)
  34. if cls:
  35. img_crop_list = self.cls(img_crop_list)
  36. rec_res = self.rec(img_crop_list, use_space=use_space)
  37. filter_rec_res = []
  38. for text, score in rec_res:
  39. if score > self.drop_score:
  40. filter_rec_res.append((text, score))
  41. return filter_rec_res
  42. class Handler:
  43. def __init__(self, args: "ArgType", eid: "int" = -1):
  44. start = time()
  45. self.engine = Engine(args=args)
  46. self.tasks = 0
  47. Logger.info(f"Engine[{eid}] initialized in {time() - start}s")
  48. self.__lock = Lock()
  49. self.__eid = eid
  50. def __call__(self, img: "ndarray", cls: "bool" = False, use_space: "bool" = False):
  51. start = time()
  52. self.tasks += 1
  53. self.__lock.acquire()
  54. res = self.engine(img, cls, use_space)
  55. self.tasks -= 1
  56. self.__lock.release()
  57. Logger.info(f"Engine[{self.__eid}] finished a task in {time() - start}s")
  58. return res
  59. class HuiMvOCR:
  60. def __init__(self, args: "ArgType"):
  61. self.loop = range(args.workers)
  62. self.loop_except = range(1, args.workers)
  63. self.handlers = [Handler(args, i) for i in self.loop]
  64. self.pool = ThreadPoolExecutor(args.workers)
  65. def rec_one(self, img: "ndarray", cls: "bool" = True, use_space: "bool" = True):
  66. index = 0
  67. for cur in self.loop_except:
  68. if self.handlers[index].tasks == 0:
  69. break
  70. if self.handlers[cur].tasks < self.handlers[index].tasks:
  71. index = cur
  72. return self.handlers[index](img, cls, use_space)
  73. def rec_multi(self, images: "list[ndarray]", cls: "bool" = False, use_space: "bool" = False):
  74. tasks = [
  75. self.pool.submit(self.rec_one, img, cls, use_space)
  76. for img in images
  77. ]
  78. return [task.result() for task in tasks]