1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- from .parts import *
- from .utility import *
- from time import time
- from numpy import ndarray
- from copy import deepcopy
- from .argument import ArgType
- from utils.logger import Logger
- from threading import Lock
- from concurrent.futures import ThreadPoolExecutor
- __all__ = ["Engine", "HuiMvOCR"]
- class Engine:
- __worker_count = 1
- def __init__(self, args: "ArgType"):
- self.det = Detector(args)
- self.cls = Classifier(args)
- self.rec = Recognizer(args)
- self.det_box_type = args.det_box_type
- self.drop_score = args.drop_score
- self.crop_image_res_index = 0
- def __call__(self, img, cls: "bool" = True, use_space: "bool" = True):
- ori_im = img.copy()
- dt_boxes = self.det(img)
- if dt_boxes is None:
- return None
- dt_boxes = sorted_boxes(dt_boxes)
- size = len(dt_boxes)
- img_crop_list = [...] * size
- for i in range(size):
- box = deepcopy(dt_boxes[i])
- if self.det_box_type == "quad":
- img_crop_list[i] = get_rotate_crop_image(ori_im, box)
- else:
- img_crop_list[i] = get_min_area_rect_crop(ori_im, box)
- if cls:
- img_crop_list = self.cls(img_crop_list)
- rec_res = self.rec(img_crop_list, use_space=use_space)
- filter_rec_res = []
- for text, score in rec_res:
- if score > self.drop_score:
- filter_rec_res.append((text, score))
- return filter_rec_res
- class Handler:
- def __init__(self, args: "ArgType", eid: "int" = -1):
- start = time()
- self.engine = Engine(args=args)
- self.tasks = 0
- Logger.info(f"Engine[{eid}] initialized in {time() - start}s")
- self.__lock = Lock()
- self.__eid = eid
- def __call__(self, img: "ndarray", cls: "bool" = False, use_space: "bool" = False):
- start = time()
- self.tasks += 1
- self.__lock.acquire()
- res = self.engine(img, cls, use_space)
- self.tasks -= 1
- self.__lock.release()
- Logger.info(f"Engine[{self.__eid}] finished a task in {time() - start}s")
- return res
- class HuiMvOCR:
- def __init__(self, args: "ArgType"):
- self.loop = range(args.workers)
- self.loop_except = range(1, args.workers)
- self.handlers = [Handler(args, i) for i in self.loop]
- self.pool = ThreadPoolExecutor(args.workers)
- def rec_one(self, img: "ndarray", cls: "bool" = True, use_space: "bool" = True):
- index = 0
- for cur in self.loop_except:
- if self.handlers[index].tasks == 0:
- break
- if self.handlers[cur].tasks < self.handlers[index].tasks:
- index = cur
- return self.handlers[index](img, cls, use_space)
- def rec_multi(self, images: "list[ndarray]", cls: "bool" = False, use_space: "bool" = False):
- tasks = [
- self.pool.submit(self.rec_one, img, cls, use_space)
- for img in images
- ]
- return [task.result() for task in tasks]
|