Переглянути джерело

v3.5: 使用消息通知机制取代轮询,从而修复轮询下无任务时的空消耗

Tinger 2 роки тому
батько
коміт
c03c477a0e
4 змінених файлів з 54 додано та 52 видалено
  1. 1 1
      hmOCR/argument.py
  2. 40 44
      hmOCR/core.py
  3. 7 7
      hmOCR/parts/utils.py
  4. 6 0
      utils/logger.py

+ 1 - 1
hmOCR/argument.py

@@ -21,7 +21,7 @@ class Args:
             rec_image_shape="3, 48, 320", rec_batch_num=8, max_text_length=25,
             rec_char_dict_path="hmOCR/static/key-set.txt", use_space_char=False,
             # OCR
-            drop_score=0.5, workers=5, interval=0.1,
+            drop_score=0.5, workers=5
         )
         self.__update(**kwargs)
 

+ 40 - 44
hmOCR/core.py

@@ -1,15 +1,14 @@
-import logging
 from .parts import *
 from .utility import *
-from time import sleep
+from time import time
 from numpy import ndarray
 from copy import deepcopy
 from .argument import ArgType
-from threading import Thread, Lock
+from utils.logger import Logger
+from threading import Lock
+from concurrent.futures import ThreadPoolExecutor
 
 __all__ = ["Engine", "HuiMvOCR"]
-logger = logging.getLogger("hm-ocr")
-logger.setLevel(logging.INFO)
 
 
 class Engine:
@@ -54,49 +53,46 @@ class Engine:
         return filter_rec_res
 
 
-class HuiMvOCR:
-    __lock = Lock()
-    __tasks = []  # item: [img: "ndarray", ocr_args: "dict", callback: "fn", callback_args: "dict"]
+class Handler:
+    def __init__(self, args: "ArgType", eid: "int" = -1):
+        start = time()
+        self.engine = Engine(args=args)
+        self.tasks = 0
+        Logger.info(f"Engine[{eid}] initialized in {time() - start}s")
+        self.__lock = Lock()
+        self.__eid = eid
+
+    def __call__(self, img: "ndarray", cls: "bool" = False, use_space: "bool" = False):
+        start = time()
+        self.tasks += 1
+        self.__lock.acquire()
+        res = self.engine(img, cls, use_space)
+        self.tasks -= 1
+        self.__lock.release()
+        Logger.info(f"Engine[{self.__eid}] finished a task in {time() - start}s")
+        return res
+
 
+class HuiMvOCR:
     def __init__(self, args: "ArgType"):
-        self.interval = args.interval
-
-        for i in range(args.workers):
-            Thread(target=self.__processor, args=(Engine(args), i), daemon=True).start()
-
-    @staticmethod
-    def __processor(ocr: "Engine", eid: "int"):
-        logger.info(f"================ Engine[{eid}] initialized ================")
-        while True:
-            if HuiMvOCR.__tasks:
-                HuiMvOCR.__lock.acquire()
-                img, ocr_args, callback, callback_args = HuiMvOCR.__tasks.pop(0)
-                HuiMvOCR.__lock.release()
-                res = ocr(img)
-                callback(res, **callback_args)
-            sleep(0.1)
+        self.loop = range(args.workers)
+        self.loop_except = range(1, args.workers)
+        self.handlers = [Handler(args, i) for i in self.loop]
+        self.pool = ThreadPoolExecutor(args.workers)
 
     def rec_one(self, img: "ndarray", cls: "bool" = True, use_space: "bool" = True):
-        def callback(res):
-            foo[1] = res
-            foo[0] = 1
-
-        foo = [0, None]  # finish_count, result
-        args = {"cls": cls, "use_space": use_space}
-        HuiMvOCR.__tasks.append([img, args, callback, {}])
-        while foo[0] < 1:
-            sleep(self.interval)
-        return foo[1]
+        index = 0
+        for cur in self.loop_except:
+            if self.handlers[index].tasks == 0:
+                break
+            if self.handlers[cur].tasks < self.handlers[index].tasks:
+                index = cur
+        return self.handlers[index](img, cls, use_space)
 
     def rec_multi(self, images: "list[ndarray]", cls: "bool" = False, use_space: "bool" = False):
-        def callback(res, index):
-            foo[1][index] = res
-            foo[0] += 1
+        tasks = [
+            self.pool.submit(self.rec_one, img, cls, use_space)
+            for img in images
+        ]
 
-        size, args = len(images), {"cls": cls, "use_space": use_space}
-        foo = [0, [...] * size]  # finish_count, result
-        for i in range(size):
-            HuiMvOCR.__tasks.append([images[i], args, callback, {"index": i}])
-        while foo[0] < size:
-            sleep(self.interval)
-        return foo[1]
+        return [task.result() for task in tasks]

+ 7 - 7
hmOCR/parts/utils.py

@@ -2,6 +2,7 @@ from .operator import *  # noqa
 from copy import deepcopy
 from os import path, popen
 from platform import system
+from utils.logger import Logger
 from paddle import inference, fluid
 
 __all__ = ["create_operators", "build_post_process", "create_predictor", "transform"]
@@ -77,7 +78,7 @@ def create_predictor(args, mode):
         model_dir = args.rec_model_dir
 
     if model_dir is None:
-        print("no model_dir defined in args")
+        Logger.error("no model_dir defined in args")
         exit(0)
 
     file_names, model_path, param_path = ["model", "inference"], None, None
@@ -106,9 +107,8 @@ def create_predictor(args, mode):
     if args.use_gpu:
         gpu_id = __get_gpu_id()
         if gpu_id is None:
-            print(
-                "WARING:",
-                "GPU is not found in current device by nvidia-smi.",
+            Logger.warning(
+                "GPU is not found in current device by nvidia-smi. "
                 "Please check your device or ignore it if run on jetson."
             )
         config.enable_use_gpu(args.gpu_mem, 0)
@@ -126,12 +126,12 @@ def create_predictor(args, mode):
 
             if not path.exists(trt_shape_f):
                 config.collect_shape_range_info(trt_shape_f)
-                print(f"collect dynamic shape info into : {trt_shape_f}")
+                Logger.warning(f"collect dynamic shape info into : {trt_shape_f}")
             try:
                 config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
             except Exception as E:
-                print(E)
-                print("Please keep your paddlepaddle-gpu >= 2.3.0!")
+                Logger.error(E)
+                Logger.error("Please keep your paddlepaddle-gpu >= 2.3.0!")
     else:
         config.disable_gpu()
         if args.enable_mkldnn:

+ 6 - 0
utils/logger.py

@@ -0,0 +1,6 @@
+import logging
+
+__all__ = ["Logger"]
+
+Logger = logging.getLogger("HuiMvOCR")
+Logger.setLevel(logging.INFO)