Przeglądaj źródła

idc ocr v3.0: HuiMvOCR, ~3s -> 1.58s

Tinger 2 lat temu
rodzic
commit
9364829782

+ 1 - 2
.gitignore

@@ -3,5 +3,4 @@ venv
 .idea
 static/images/*
 
-test.py
-hmOCR
+test.py

+ 9 - 11
blues/com.py

@@ -1,6 +1,6 @@
-from flask import Blueprint, views, render_template, request
 from utils.util import *
 from utils.conf import MAX_CONTENT_LENGTH
+from flask import Blueprint, views, render_template, request
 
 com = Blueprint("com", __name__, url_prefix="/com")
 
@@ -21,23 +21,21 @@ class ComView(views.MethodView):
         content = pic.read()
         if len(content) > MAX_CONTENT_LENGTH:
             return Response("文件过大,请重新选择")
-        cur, rnd = current_time(), rand_str()
-        raw_path = f"static/images/{cur}_{rnd}.{ext}"
-        rec_path = f"static/images/{cur}_{rnd}-rec.{ext}"
-        save_img(raw_path, content)
+        file_path = f"static/images/{current_time()}_{rand_str()}.{ext}"
+        save_img(file_path, content)
 
-        ocr_res, img_shape = recognize(content)
+        img = read_img(content)
+        ocr_res = Engine.ocr_one(img, cls=True)
         kind = request.form.get("type")
         if kind is not None:
             kind = kind.lower()
         if kind == "raw":
-            return Response(data=[{"pos": it[0], "word": it[1][0], "rate": it[1][1]} for it in ocr_res])
+            return Response(data=ocr_res)
         elif kind == "html":
-            data = [{"pos": it[0], "word": it[1][0], "rate": it[1][1]} for it in ocr_res]
-            draw_img(img_shape, data, rec_path)
-            return render_template("com_result.html", raw=raw_path, rec=rec_path, data=data)
+            data = [{"word": it[0], "rate": it[1], "index": i + 1} for i, it in enumerate(ocr_res)]
+            return render_template("com_result.html", raw=file_path, data=data)
         else:
-            return Response(data=[it[1][0] for it in ocr_res])
+            return Response(data=[it[0] for it in ocr_res])
 
 
 com.add_url_rule("/", view_func=ComView.as_view("com"))

+ 21 - 16
blues/idc.py

@@ -1,7 +1,8 @@
-from flask import Blueprint, views, render_template, request
-from utils.util import *
 import re
+from time import time
+from utils.util import *
 from utils.conf import MAX_CONTENT_LENGTH
+from flask import Blueprint, views, render_template, request
 
 idc = Blueprint("idc", __name__, url_prefix="/idc")
 
@@ -65,6 +66,7 @@ class IdcView(views.MethodView):
 
     @staticmethod
     def post():
+        start = time()
         which = request.form.get("which")
         if which is not None:
             which = which.lower()
@@ -80,11 +82,12 @@ class IdcView(views.MethodView):
         if len(content) > MAX_CONTENT_LENGTH:
             return Response("文件过大,请重新选择")
 
-        images = rotate(content)
-        rec = rec_multi(images)
+        img = read_img(content)
+        images = rot_img(img)
+        rec = Engine.ocr_multi(images, cls=True, use_space=False)
         info, msg, sta, idx = {}, "识别失败,请重新选择", False, 0
-        for idx, (ocr_res, _) in enumerate(rec):
-            words = [it[1][0].replace(" ", "") for it in ocr_res]
+        for idx, ocr_res in enumerate(rec):
+            words = [it[0].replace(" ", "") for it in ocr_res]
             if which == "face":
                 if not words or not words[0].startswith("姓名"):
                     continue
@@ -95,6 +98,8 @@ class IdcView(views.MethodView):
                 info, msg, sta = get_icon_info(words)
             if sta:
                 break
+
+        info["duration"] = time() - start
         if sta:
             raw_path = f"static/images/{current_time()}_{rand_str()}.{ext}"
             save_img(raw_path, images[idx])
@@ -105,6 +110,7 @@ class IdcView(views.MethodView):
 class IdcHtmlView(views.MethodView):
     @staticmethod
     def post():
+        start = time()
         which = request.form.get("which")
         if which is not None:
             which = which.lower()
@@ -120,11 +126,12 @@ class IdcHtmlView(views.MethodView):
         if len(content) > MAX_CONTENT_LENGTH:
             return Response("文件过大,请重新选择")
 
-        images = rotate(content)
-        rec = rec_multi(images)
+        img = read_img(content)
+        images = rot_img(img)
+        rec = Engine.ocr_multi(images, cls=True, use_space=False)
         info, msg, sta, idx = {}, "识别失败,请重新选择", False, 0
-        for idx, (ocr_res, _) in enumerate(rec):
-            words = [it[1][0].replace(" ", "") for it in ocr_res]
+        for idx, ocr_res in enumerate(rec):
+            words = [it[0].replace(" ", "") for it in ocr_res]
             if which == "face":
                 if not words or not words[0].startswith("姓名"):
                     continue
@@ -136,15 +143,13 @@ class IdcHtmlView(views.MethodView):
             if sta:
                 break
 
-        cut, rnd = current_time(), rand_str()
-        raw_path = f"static/images/{cut}_{rnd}.{ext}"
-        rec_path = f"static/images/{cut}_{rnd}_rec.{ext}"
-        save_img(raw_path, images[idx])
-        draw_img(rec[idx][1], [{"pos": it[0], "word": it[1][0], "rate": it[1][1]} for it in rec[idx][0]], rec_path)
+        file_path = f"static/images/{current_time()}_{rand_str()}.{ext}"
+        save_img(file_path, images[idx])
 
         info["SUCCESS"] = str(sta).upper()
         info["MESSAGE"] = msg
-        return render_template("k-v_result.html", raw=raw_path, rec=rec_path, data=info)
+        info["DURATION"] = time() - start  # noqa
+        return render_template("k-v_result.html", raw=file_path, data=info)
 
 
 idc.add_url_rule("/", view_func=IdcView.as_view("idc"))

+ 3 - 0
hmOCR/__init__.py

@@ -0,0 +1,3 @@
+from .core import HuiMvOcr
+from .argument import Args, ArgType
+from .utility import *

+ 41 - 0
hmOCR/argument.py

@@ -0,0 +1,41 @@
+__all__ = ["Args", "ArgType"]
+
+
+class Args:
+    def __init__(self, **kwargs):
+        self.__update(
+            use_gpu=False, precision="fp32", use_tensorrt=False,
+            # gpu
+            gpu_mem=500, max_batch_size=6, min_subgraph_size=15,
+            # cpu
+            enable_mkldnn=True, cpu_threads=16,
+            # detector
+            det_model_dir="hmOCR/static/det", det_algorithm="DB", det_limit_side_len=960,
+            det_limit_type="max", det_db_thresh=0.3, det_db_box_thresh=0.6, det_db_unclip_ratio=1.5,
+            det_use_dilation=False, det_db_score_mode="fast", det_box_type="quad",
+            # classifier
+            cls_model_dir="hmOCR/static/cls", cls_image_shape="3, 48, 192",
+            cls_batch_num=6, cls_thresh=0.9, cls_label_list=["0", "180"],
+            # recognizer
+            rec_model_dir="hmOCR/static/rec", rec_algorithm="SVTR_LCNet",
+            rec_image_shape="3, 48, 320", rec_batch_num=8, max_text_length=25,
+            rec_char_dict_path="hmOCR/static/key-set.txt", use_space_char=False,
+            # OCR
+            drop_score=0.5,
+            # test
+            image_dir="static/test_image", warmup=True
+        )
+        self.__update(**kwargs)
+
+    def __update(self, **kwargs):
+        for k, v in kwargs.items():
+            self.__dict__[k] = v
+
+    def __getattr__(self, key: "str"):
+        return self.__dict__[key]
+
+    def __setattr__(self, key: "str", value):
+        self.__dict__[key] = value
+
+
+ArgType = Args

+ 56 - 0
hmOCR/core.py

@@ -0,0 +1,56 @@
+from .parts import *
+from .utility import *
+from copy import deepcopy
+from .argument import ArgType
+from concurrent.futures import ThreadPoolExecutor
+
+__all__ = ["HuiMvOcr"]
+
+
+class HuiMvOcr:
+    __worker_count = 1
+
+    def __init__(self, args: "ArgType"):
+        self.det = Detector(args)
+        self.cls = Classifier(args)
+        self.rec = Recognizer(args)
+
+        self.det_box_type = args.det_box_type
+        self.drop_score = args.drop_score
+        self.crop_image_res_index = 0
+
+    def ocr_one(self, img, cls: "bool" = False, use_space: "bool" = True):
+        ori_im = img.copy()
+        dt_boxes = self.det(img)
+        if dt_boxes is None:
+            return None
+
+        dt_boxes = sorted_boxes(dt_boxes)
+        size = len(dt_boxes)
+        img_crop_list = [...] * size
+
+        for i in range(size):
+            box = deepcopy(dt_boxes[i])
+            if self.det_box_type == "quad":
+                img_crop_list[i] = get_rotate_crop_image(ori_im, box)
+            else:
+                img_crop_list[i] = get_min_area_rect_crop(ori_im, box)
+
+        if cls:
+            img_crop_list = self.cls(img_crop_list)
+
+        rec_res = self.rec(img_crop_list, use_space=use_space)
+
+        filter_rec_res = []
+        for text, score in rec_res:
+            if score > self.drop_score:
+                filter_rec_res.append((text, score))
+
+        return filter_rec_res
+
+    def ocr_multi(self, img_list, cls: "bool" = False, use_space: "bool" = True):
+        pool = ThreadPoolExecutor(HuiMvOcr.__worker_count)
+        loop = range(len(img_list))
+        tasks = [pool.submit(self.ocr_one, img_list[i], cls, use_space) for i in loop]
+
+        return [tasks[i].result() for i in loop]

+ 47 - 0
hmOCR/enter.py

@@ -0,0 +1,47 @@
+import cv2
+import numpy as np
+from utility import *
+from argument import *
+from core import HuiMvOcr
+from time import time
+
+
+def main(args: "ArgType"):
+    image_file_list = get_image_file_list(args.image_dir)
+    engine = HuiMvOcr(args)
+
+    # warm up 10 times
+    if args.warmup:
+        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
+        for i in range(10):
+            engine.ocr_one(img)
+    # single
+    print("single start")
+    total_time, res = 0, []
+    for image_file in image_file_list:
+        img, flag_gif = check_and_read(image_file)
+        if not flag_gif:
+            img = cv2.imread(image_file)  # noqa
+        st = time()
+        rec_res = engine.ocr_one(img, cls=True, use_space=False)
+        elapse = time() - st
+        total_time += elapse
+        res.append(rec_res)
+    print(f"single total time: {total_time}")
+    for i in range(len(res)):
+        print(f"file: {image_file_list[i]}")
+        print("res:", res[i])
+
+    # multi
+    images = [cv2.imread(file) for file in image_file_list]  # noqa
+    print("\n\nmulti start")
+    st = time()
+    res = engine.ocr_multi(images, cls=True, use_space=False)
+    print(f"multi total time {time() - st}")
+    for i in range(len(res)):
+        print(f"file: {image_file_list[i]}")
+        print("res:", res[i])
+
+
+if __name__ == "__main__":
+    main(Args())

+ 5 - 0
hmOCR/parts/__init__.py

@@ -0,0 +1,5 @@
+from .detector import Detector
+from .classifier import Classifier
+from .recognizer import Recognizer
+
+__all__ = ["Detector", "Classifier", "Recognizer"]

+ 80 - 0
hmOCR/parts/classifier.py

@@ -0,0 +1,80 @@
+import cv2
+import math
+import numpy as np
+from .utils import *
+from copy import deepcopy
+from hmOCR.argument import Args
+
+
+class Classifier:
+    def __init__(self, args: "Args"):
+        self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
+        self.cls_batch_num = args.cls_batch_num
+        self.cls_thresh = args.cls_thresh
+        postprocess_params = {
+            "name": "ClsPostProcess",
+            "label_list": args.cls_label_list
+        }
+        self.postprocess_op = build_post_process(postprocess_params)
+        self.predictor, self.input_tensor, self.output_tensors = create_predictor(args, "cls")
+
+    def resize_norm_img(self, img):
+        imgC, imgH, imgW = self.cls_image_shape
+        h = img.shape[0]
+        w = img.shape[1]
+        ratio = w / float(h)
+        if math.ceil(imgH * ratio) > imgW:
+            resized_w = imgW
+        else:
+            resized_w = int(math.ceil(imgH * ratio))
+        resized_image = cv2.resize(img, (resized_w, imgH))  # noqa
+        resized_image = resized_image.astype("float32")
+        if self.cls_image_shape[0] == 1:
+            resized_image = resized_image / 255
+            resized_image = resized_image[np.newaxis, :]
+        else:
+            resized_image = resized_image.transpose((2, 0, 1)) / 255
+        resized_image -= 0.5
+        resized_image /= 0.5
+        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+        padding_im[:, :, 0:resized_w] = resized_image
+        return padding_im
+
+    def __call__(self, img_list):
+        img_list = deepcopy(img_list)
+        img_num = len(img_list)
+        # Calculate the aspect ratio of all text bars
+        width_list = []
+        for img in img_list:
+            width_list.append(img.shape[1] / float(img.shape[0]))
+        # Sorting can speed up the cls process
+        indices = np.argsort(np.array(width_list))
+
+        batch_num = self.cls_batch_num
+
+        for beg_img_no in range(0, img_num, batch_num):
+            end_img_no = min(img_num, beg_img_no + batch_num)
+            norm_img_batch = []
+            max_wh_ratio = 0
+            for ino in range(beg_img_no, end_img_no):
+                h, w = img_list[indices[ino]].shape[0:2]
+                wh_ratio = w * 1.0 / h
+                max_wh_ratio = max(max_wh_ratio, wh_ratio)
+            for ino in range(beg_img_no, end_img_no):
+                norm_img = self.resize_norm_img(img_list[indices[ino]])
+                norm_img = norm_img[np.newaxis, :]
+                norm_img_batch.append(norm_img)
+            norm_img_batch = np.concatenate(norm_img_batch)
+            norm_img_batch = norm_img_batch.copy()
+
+            self.input_tensor.copy_from_cpu(norm_img_batch)
+            self.predictor.run()
+            prob_out = self.output_tensors[0].copy_to_cpu()
+            self.predictor.try_shrink_memory()
+
+            cls_result = self.postprocess_op(prob_out)
+            for rno in range(len(cls_result)):
+                label, score = cls_result[rno]
+                if "180" in label and score > self.cls_thresh:
+                    img_list[indices[beg_img_no + rno]] = cv2.rotate(img_list[indices[beg_img_no + rno]], 1)  # noqa
+        return img_list

+ 111 - 0
hmOCR/parts/detector.py

@@ -0,0 +1,111 @@
+import numpy as np
+from .utils import *
+from hmOCR.argument import Args
+
+
+class Detector:
+    def __init__(self, args: "Args"):
+        self.det_algorithm = args.det_algorithm  # DB
+        self.det_box_type = args.det_box_type
+        pre_process_list = [{
+            "DetResizeForTest": {
+                "limit_side_len": args.det_limit_side_len,
+                "limit_type": args.det_limit_type,
+            }
+        }, {
+            "NormalizeImage": {
+                "std": [0.229, 0.224, 0.225],
+                "mean": [0.485, 0.456, 0.406],
+                "scale": 1 / 255,
+                "order": "hwc"
+            }
+        }, {
+            "ToCHWImage": None
+        }, {
+            "KeepKeys": {
+                "keep_keys": ["image", "shape"]
+            }
+        }]
+        postprocess_params = {
+            "name": "DBPostProcess", "thresh": args.det_db_thresh,
+            "box_thresh": args.det_db_box_thresh, "max_candidates": 1000,
+            "unclip_ratio": args.det_db_unclip_ratio, "use_dilation": args.det_use_dilation,
+            "score_mode": args.det_db_score_mode, "box_type": args.det_box_type
+        }
+
+        self.pre_operator = create_operators(pre_process_list)
+        self.post_operator = build_post_process(postprocess_params)
+        self.predictor, self.input_tensor, self.output_tensors = create_predictor(args, "det")
+
+    @staticmethod
+    def order_points_clockwise(pts):
+        rect = np.zeros((4, 2), dtype="float32")
+        s = pts.sum(axis=1)
+        rect[0] = pts[np.argmin(s)]
+        rect[2] = pts[np.argmax(s)]
+        tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
+        diff = np.diff(np.array(tmp), axis=1)
+        rect[1] = tmp[np.argmin(diff)]
+        rect[3] = tmp[np.argmax(diff)]
+        return rect
+
+    @staticmethod
+    def clip_det_res(points, img_height, img_width):
+        for pno in range(points.shape[0]):
+            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
+            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
+        return points
+
+    def filter_tag_det_res(self, dt_boxes, image_shape):
+        img_height, img_width = image_shape[0:2]
+        dt_boxes_new = []
+        for box in dt_boxes:
+            if type(box) is list:
+                box = np.array(box)
+            box = self.order_points_clockwise(box)
+            box = self.clip_det_res(box, img_height, img_width)
+            rect_width = int(np.linalg.norm(box[0] - box[1]))
+            rect_height = int(np.linalg.norm(box[0] - box[3]))
+            if rect_width <= 3 or rect_height <= 3:
+                continue
+            dt_boxes_new.append(box)
+        dt_boxes = np.array(dt_boxes_new)
+        return dt_boxes
+
+    def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
+        img_height, img_width = image_shape[0:2]
+        dt_boxes_new = []
+        for box in dt_boxes:
+            if type(box) is list:
+                box = np.array(box)
+            box = self.clip_det_res(box, img_height, img_width)
+            dt_boxes_new.append(box)
+        dt_boxes = np.array(dt_boxes_new)
+        return dt_boxes
+
+    def __call__(self, img):
+        ori_im = img.copy()
+        data = {"image": img}
+
+        data = transform(data, self.pre_operator)
+        img, shape_list = data
+        if img is None:
+            return None
+        img = np.expand_dims(img, axis=0)
+        shape_list = np.expand_dims(shape_list, axis=0)
+        img = img.copy()
+
+        self.input_tensor.copy_from_cpu(img)
+        self.predictor.run()
+
+        outputs = [out.copy_to_cpu() for out in self.output_tensors]
+        preds = {"maps": outputs[0]}
+        post_result = self.post_operator(preds, shape_list)
+        dt_boxes = post_result[0]["points"]
+
+        if self.det_box_type == "poly":
+            dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
+        else:
+            dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
+
+        return dt_boxes

+ 450 - 0
hmOCR/parts/operator.py

@@ -0,0 +1,450 @@
+import re
+import cv2
+import pyclipper
+import numpy as np
+from PIL import Image
+from paddle import Tensor
+from shapely.geometry import Polygon
+
+__all__ = [
+    "DetResizeForTest", "NormalizeImage", "ToCHWImage", "KeepKeys",
+    "DBPostProcess", "ClsPostProcess", "CTCLabelDecode"
+]
+
+
+class DetResizeForTest:
+    def __init__(self, **kwargs):
+        if "limit_side_len" in kwargs:
+            self.limit_side_len = kwargs["limit_side_len"]
+            self.limit_type = kwargs.get("limit_type", "min")
+        else:
+            self.limit_side_len = 736
+            self.limit_type = "min"
+
+    def __call__(self, data):
+        img = data["image"]
+        src_h, src_w, _ = img.shape
+        if sum([src_h, src_w]) < 64:
+            img = self.image_padding(img)
+
+        img, [ratio_h, ratio_w] = self.resize_image_type0(img)
+        data["image"] = img
+        data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
+        return data
+
+    @staticmethod
+    def image_padding(im, value=0):
+        h, w, c = im.shape
+        im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
+        im_pad[:h, :w, :] = im
+        return im_pad
+
+    def resize_image_type0(self, img):
+        """
+        resize image to a size multiple of 32 which is required by the network
+        args:
+            img(array): array with shape [h, w, c]
+        return(tuple):
+            img, (ratio_h, ratio_w)
+        """
+        limit_side_len = self.limit_side_len
+        h, w, c = img.shape
+
+        # limit the max side
+        if self.limit_type == "max":
+            if max(h, w) > limit_side_len:
+                if h > w:
+                    ratio = float(limit_side_len) / h
+                else:
+                    ratio = float(limit_side_len) / w
+            else:
+                ratio = 1.
+        elif self.limit_type == "min":
+            if min(h, w) < limit_side_len:
+                if h < w:
+                    ratio = float(limit_side_len) / h
+                else:
+                    ratio = float(limit_side_len) / w
+            else:
+                ratio = 1.
+        elif self.limit_type == "resize_long":
+            ratio = float(limit_side_len) / max(h, w)
+        else:
+            raise Exception("not support limit type, image ")
+        resize_h = int(h * ratio)
+        resize_w = int(w * ratio)
+
+        resize_h = max(int(round(resize_h / 32) * 32), 32)
+        resize_w = max(int(round(resize_w / 32) * 32), 32)
+
+        try:
+            if int(resize_w) <= 0 or int(resize_h) <= 0:
+                return None, (None, None)
+            img = cv2.resize(img, (int(resize_w), int(resize_h)))  # noqa
+        except Exception as e:
+            print(img.shape, resize_w, resize_h, e)
+            exit(0)
+        ratio_h = resize_h / float(h)
+        ratio_w = resize_w / float(w)
+        return img, [ratio_h, ratio_w]
+
+
+class NormalizeImage:
+    def __init__(self, scale, mean, std, order="chw"):
+        self.scale = np.float32(scale)
+        shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+        self.mean = np.array(mean).reshape(shape).astype("float32")
+        self.std = np.array(std).reshape(shape).astype("float32")
+
+    def __call__(self, data):
+        img = data["image"]
+        if isinstance(img, Image.Image):
+            img = np.array(img)  # noqa
+        assert isinstance(img, np.ndarray), "invalid input img in NormalizeImage"
+        data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
+        return data
+
+
+class ToCHWImage:
+    def __call__(self, data):
+        img = data["image"]
+        if isinstance(img, Image.Image):
+            img = np.array(img)  # noqa
+        data["image"] = img.transpose((2, 0, 1))
+        return data
+
+
+class KeepKeys:
+    def __init__(self, keep_keys):
+        self.keep_keys = keep_keys
+
+    def __call__(self, data):
+        return [data[key] for key in self.keep_keys]
+
+
+class DBPostProcess:
+    def __init__(
+            self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=2.0,
+            use_dilation=False, score_mode="fast", box_type="quad"
+    ):
+        self.thresh = thresh
+        self.box_thresh = box_thresh
+        self.max_candidates = max_candidates
+        self.unclip_ratio = unclip_ratio
+        self.min_size = 3
+        self.score_mode = score_mode
+        self.box_type = box_type
+        assert score_mode in [
+            "slow", "fast"
+        ], f"Score mode must be in [slow, fast] but got: {score_mode}"
+
+        self.dilation_kernel = None if not use_dilation else np.array(
+            [[1, 1], [1, 1]])
+
+    def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+        """
+        _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}
+        """
+
+        bitmap = _bitmap
+        height, width = bitmap.shape
+        boxes, scores = [], []
+
+        contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)  # noqa
+
+        for contour in contours[:self.max_candidates]:
+            epsilon = 0.002 * cv2.arcLength(contour, True)  # noqa
+            approx = cv2.approxPolyDP(contour, epsilon, True)  # noqa
+            points = approx.reshape((-1, 2))
+            if points.shape[0] < 4:
+                continue
+
+            score = self.box_score_fast(pred, points.reshape(-1, 2))
+            if self.box_thresh > score:
+                continue
+
+            if points.shape[0] > 2:
+                box = self.unclip(points, self.unclip_ratio)
+                if len(box) > 1:
+                    continue
+            else:
+                continue
+            box = box.reshape(-1, 2)
+
+            _, s_side = self.get_mini_boxes(box.reshape((-1, 1, 2)))
+            if s_side < self.min_size + 2:
+                continue
+
+            box = np.array(box)
+            box[:, 0] = np.clip(
+                np.round(box[:, 0] / width * dest_width), 0, dest_width)
+            box[:, 1] = np.clip(
+                np.round(box[:, 1] / height * dest_height), 0, dest_height)
+            boxes.append(box.tolist())
+            scores.append(score)
+        return boxes, scores
+
+    def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+        """
+        _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}
+        """
+
+        bitmap, contours = _bitmap, None
+        height, width = bitmap.shape
+
+        outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)  # noqa
+        if len(outs) == 3:
+            img, contours, _ = outs[0], outs[1], outs[2]
+        elif len(outs) == 2:
+            contours, _ = outs[0], outs[1]
+
+        num_contours = min(len(contours), self.max_candidates)
+
+        boxes = []
+        scores = []
+        for index in range(num_contours):
+            contour = contours[index]
+            points, s_side = self.get_mini_boxes(contour)
+            if s_side < self.min_size:
+                continue
+            points = np.array(points)
+            if self.score_mode == "fast":
+                score = self.box_score_fast(pred, points.reshape(-1, 2))
+            else:
+                score = self.box_score_slow(pred, contour)
+            if self.box_thresh > score:
+                continue
+
+            box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)  # noqa
+            box, s_side = self.get_mini_boxes(box)
+            if s_side < self.min_size + 2:
+                continue
+            box = np.array(box)
+
+            box[:, 0] = np.clip(
+                np.round(box[:, 0] / width * dest_width), 0, dest_width)
+            box[:, 1] = np.clip(
+                np.round(box[:, 1] / height * dest_height), 0, dest_height)
+            boxes.append(box.astype("int32"))
+            scores.append(score)
+        return np.array(boxes, dtype="int32"), scores
+
+    @staticmethod
+    def unclip(box, unclip_ratio):
+        poly = Polygon(box)
+        distance = poly.area * unclip_ratio / poly.length
+        offset = pyclipper.PyclipperOffset()
+        offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+        expanded = np.array(offset.Execute(distance))
+        return expanded
+
+    @staticmethod
+    def get_mini_boxes(contour):
+        bounding_box = cv2.minAreaRect(contour)  # noqa
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])  # noqa
+
+        if points[1][1] > points[0][1]:
+            index_1 = 0
+            index_4 = 1
+        else:
+            index_1 = 1
+            index_4 = 0
+        if points[3][1] > points[2][1]:
+            index_2 = 2
+            index_3 = 3
+        else:
+            index_2 = 3
+            index_3 = 2
+
+        box = [points[index_1], points[index_2], points[index_3], points[index_4]]
+        return box, min(bounding_box[1])
+
+    @staticmethod
+    def box_score_fast(bitmap, _box):
+        """
+        box_score_fast: use bbox mean score as the mean score
+        """
+        h, w = bitmap.shape[:2]
+        box = _box.copy()
+        x_min = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
+        x_max = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
+        y_min = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
+        y_max = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
+
+        mask = np.zeros((y_max - y_min + 1, x_max - x_min + 1), dtype=np.uint8)
+        box[:, 0] = box[:, 0] - x_min
+        box[:, 1] = box[:, 1] - y_min
+        cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)  # noqa
+        return cv2.mean(bitmap[y_min:y_max + 1, x_min:x_max + 1], mask)[0]  # noqa
+
+    @staticmethod
+    def box_score_slow(bitmap, contour):
+        """
+        box_score_slow: use polyon mean score as the mean score
+        """
+        h, w = bitmap.shape[:2]
+        contour = contour.copy()
+        contour = np.reshape(contour, (-1, 2))
+
+        x_min = np.clip(np.min(contour[:, 0]), 0, w - 1)
+        x_max = np.clip(np.max(contour[:, 0]), 0, w - 1)
+        y_min = np.clip(np.min(contour[:, 1]), 0, h - 1)
+        y_max = np.clip(np.max(contour[:, 1]), 0, h - 1)
+
+        mask = np.zeros((y_max - y_min + 1, x_max - x_min + 1), dtype=np.uint8)
+
+        contour[:, 0] = contour[:, 0] - x_min
+        contour[:, 1] = contour[:, 1] - y_min
+
+        cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)  # noqa
+        return cv2.mean(bitmap[y_min:y_max + 1, x_min:x_max + 1], mask)[0]  # noqa
+
+    def __call__(self, outs_dict, shape_list):
+        pred = outs_dict["maps"]
+        if isinstance(pred, Tensor):
+            pred = pred.numpy()
+        pred = pred[:, 0, :, :]
+        segmentation = pred > self.thresh
+
+        boxes_batch = []
+        for batch_index in range(pred.shape[0]):
+            src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
+            if self.dilation_kernel is not None:
+                mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel)  # noqa
+            else:
+                mask = segmentation[batch_index]
+            if self.box_type == "poly":
+                boxes, scores = self.polygons_from_bitmap(pred[batch_index],
+                                                          mask, src_w, src_h)
+            elif self.box_type == "quad":
+                boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
+                                                       src_w, src_h)
+            else:
+                raise ValueError("box_type can only be one of ['quad', 'poly']")
+
+            boxes_batch.append({"points": boxes})
+        return boxes_batch
+
+
+class ClsPostProcess:
+    """ Convert between text-label and text-index """
+
+    def __init__(self, label_list=None):
+        self.label_list = label_list
+
+    def __call__(self, preds, label=None, *args, **kwargs):
+        label_list = self.label_list
+        if label_list is None:
+            label_list = {idx: idx for idx in range(preds.shape[-1])}
+
+        if isinstance(preds, Tensor):
+            preds = preds.numpy()
+
+        pred_ids = preds.argmax(axis=1)
+        decode_out = [(label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_ids)]
+        if label is None:
+            return decode_out
+        label = [(label_list[idx], 1.0) for idx in label]
+        return decode_out, label
+
+
+class __BaseRecDecoder:
+    """ Convert between text-label and text-index """
+
+    def __init__(self, character_dict_path=None):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        self.reverse = False
+        self.character_str = []
+
+        if character_dict_path is None:
+            self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz "
+            dict_character = list(self.character_str)
+        else:
+            with open(character_dict_path, "rb") as fin:
+                lines = fin.readlines()
+                for line in lines:
+                    line = line.decode("utf-8").strip("\n").strip("\r\n")
+                    self.character_str.append(line)
+            self.character_str.append(" ")
+            dict_character = list(self.character_str)
+
+        dict_character = self.add_special_char(dict_character)
+        self.max_index = len(dict_character) - 1
+        self.dict = {}
+        for i, char in enumerate(dict_character):
+            self.dict[char] = i
+        self.character = dict_character
+
+    @staticmethod
+    def pred_reverse(pred):
+        pred_re = []
+        c_current = ""
+        for c in pred:
+            if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
+                if c_current != "":
+                    pred_re.append(c_current)
+                pred_re.append(c)
+                c_current = ""
+            else:
+                c_current += c
+        if c_current != "":
+            pred_re.append(c_current)
+
+        return "".join(pred_re[::-1])
+
+    def add_special_char(self, dict_character):
+        return dict_character
+
+    def decode(self, text_index, text_prob=None, is_remove_duplicate=False, use_space=False):
+        """ convert text-index into text-label. """
+        result_list = []
+        ignored_tokens = self.get_ignored_tokens()
+        batch_size = len(text_index)
+        for batch_idx in range(batch_size):
+            selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+            if is_remove_duplicate:
+                selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
+            for ignored_token in ignored_tokens:
+                selection &= text_index[batch_idx] != ignored_token
+
+            char_list = []
+            for index in text_index[batch_idx][selection]:
+                if index == self.max_index and not use_space:
+                    continue
+                char_list.append(self.character[index])
+            if text_prob is not None:
+                conf_list = text_prob[batch_idx][selection]
+            else:
+                conf_list = [1] * len(selection)
+            if len(conf_list) == 0:
+                conf_list = [0]
+
+            text = "".join(char_list)
+
+            result_list.append((text, np.mean(conf_list).tolist()))
+        return result_list
+
+    @staticmethod
+    def get_ignored_tokens():
+        return [0]  # for ctc blank
+
+
+class CTCLabelDecode(__BaseRecDecoder):
+    """ Convert between text-label and text-index """
+
+    def __init__(self, character_dict_path=None):
+        super(CTCLabelDecode, self).__init__(character_dict_path)
+
+    def __call__(self, preds, use_space=False, *args, **kwargs):
+        if isinstance(preds, tuple) or isinstance(preds, list):
+            preds = preds[-1]
+        if isinstance(preds, Tensor):
+            preds = preds.numpy()
+        preds_idx = preds.argmax(axis=2)
+        preds_prob = preds.max(axis=2)
+        return self.decode(preds_idx, preds_prob, is_remove_duplicate=True, use_space=use_space)
+
+    def add_special_char(self, dict_character):
+        dict_character = ["blank"] + dict_character
+        return dict_character

+ 81 - 0
hmOCR/parts/recognizer.py

@@ -0,0 +1,81 @@
+import cv2
+import math
+import numpy as np
+from .utils import *
+from hmOCR.argument import Args
+from .operator import CTCLabelDecode
+
+
+class Recognizer:
+    def __init__(self, args: "Args"):
+        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
+        self.rec_batch_num = args.rec_batch_num
+        # rec_algorithm: only "SVTR_LCNet" now
+        # self.rec_algorithm = args.rec_algorithm
+
+        self.post_op = CTCLabelDecode(args.rec_char_dict_path)
+        self.predictor, self.input_tensor, self.output_tensors = create_predictor(args, "rec")
+
+    def resize_norm_img(self, img, max_wh_ratio):
+        imgC, imgH, imgW = self.rec_image_shape
+        assert imgC == img.shape[2]
+        imgW = int((imgH * max_wh_ratio))
+
+        h, w = img.shape[:2]
+        ratio = w / float(h)
+        if math.ceil(imgH * ratio) > imgW:
+            resized_w = imgW
+        else:
+            resized_w = int(math.ceil(imgH * ratio))
+
+        resized_image = cv2.resize(img, (resized_w, imgH))  # noqa
+        resized_image = resized_image.astype("float32")
+        resized_image = resized_image.transpose((2, 0, 1)) / 255
+        resized_image -= 0.5
+        resized_image /= 0.5
+        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+        padding_im[:, :, 0:resized_w] = resized_image
+        return padding_im
+
+    def __call__(self, img_list, use_space=False):
+        img_num = len(img_list)
+        width_list = []
+        for img in img_list:
+            width_list.append(img.shape[1] / float(img.shape[0]))
+        # Sorting can speed up the recognition process
+        indices = np.argsort(np.array(width_list))
+        rec_res = [["", 0.0]] * img_num
+        batch_num = self.rec_batch_num
+
+        for beg_img_no in range(0, img_num, batch_num):
+            end_img_no = min(img_num, beg_img_no + batch_num)
+            norm_img_batch = []
+
+            imgC, imgH, imgW = self.rec_image_shape[:3]
+            max_wh_ratio = imgW / imgH
+            for ino in range(beg_img_no, end_img_no):
+                h, w = img_list[indices[ino]].shape[0:2]
+                wh_ratio = w * 1.0 / h
+                max_wh_ratio = max(max_wh_ratio, wh_ratio)
+            for ino in range(beg_img_no, end_img_no):
+                norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
+                norm_img = norm_img[np.newaxis, :]
+                norm_img_batch.append(norm_img)
+            norm_img_batch = np.concatenate(norm_img_batch)
+            norm_img_batch = norm_img_batch.copy()
+
+            self.input_tensor.copy_from_cpu(norm_img_batch)
+            self.predictor.run()
+            outputs = []
+            for output_tensor in self.output_tensors:
+                output = output_tensor.copy_to_cpu()
+                outputs.append(output)
+            if len(outputs) != 1:
+                preds = outputs
+            else:
+                preds = outputs[0]
+            rec_result = self.post_op(preds, use_space=use_space)
+            for rno in range(len(rec_result)):
+                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
+
+        return rec_res

+ 160 - 0
hmOCR/parts/utils.py

@@ -0,0 +1,160 @@
+from .operator import *  # noqa
+from copy import deepcopy
+from os import path, popen
+from platform import system
+from paddle import inference, fluid
+
+__all__ = ["create_operators", "build_post_process", "create_predictor", "transform"]
+
+
+def create_operators(op_param_list):
+    ops = []
+    for operator in op_param_list:
+        op_name = list(operator)[0]
+        param = {} if operator[op_name] is None else operator[op_name]
+        op = eval(op_name)(**param)
+        ops.append(op)
+    return ops
+
+
+def transform(data, ops=None):
+    if ops is None:
+        ops = []
+    for op in ops:
+        data = op(data)
+        if data is None:
+            return None
+    return data
+
+
+def build_post_process(config):
+    config = deepcopy(config)
+    module_name = config.pop("name")
+
+    return eval(module_name)(**config)
+
+
+def __get_gpu_id():
+    if system() == "Windows":
+        return 0
+
+    if not fluid.core.is_compiled_with_rocm():
+        cmd = "env | grep CUDA_VISIBLE_DEVICES"
+    else:
+        cmd = "env | grep HIP_VISIBLE_DEVICES"
+    env_cuda = popen(cmd).readlines()
+    if len(env_cuda) == 0:
+        return 0
+    else:
+        gpu_id = env_cuda[0].strip().split("=")[1]
+        return int(gpu_id[0])
+
+
+def __get_output_tensors(args, mode, predictor):
+    output_names = predictor.get_output_names()
+    output_tensors = []
+    if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet"]:
+        output_name = "softmax_0.tmp_0"
+        if output_name in output_names:
+            return [predictor.get_output_handle(output_name)]
+        else:
+            for output_name in output_names:
+                output_tensor = predictor.get_output_handle(output_name)
+                output_tensors.append(output_tensor)
+    else:
+        for output_name in output_names:
+            output_tensor = predictor.get_output_handle(output_name)
+            output_tensors.append(output_tensor)
+    return output_tensors
+
+
+def create_predictor(args, mode):
+    if mode == "det":
+        model_dir = args.det_model_dir
+    elif mode == "cls":
+        model_dir = args.cls_model_dir
+    else:  # rec
+        model_dir = args.rec_model_dir
+
+    if model_dir is None:
+        print("no model_dir defined in args")
+        exit(0)
+
+    file_names, model_path, param_path = ["model", "inference"], None, None
+    for file_name in file_names:
+        model_file_path = path.join(model_dir, f"{file_name}.pdmodel")
+        params_file_path = path.join(model_dir, f"{file_name}.pdiparams")
+        if path.exists(model_file_path) and path.exists(params_file_path):
+            model_path, param_path = model_file_path, params_file_path
+            break
+    if model_path is None:
+        raise ValueError(f"not find model.pdmodel or inference.pdmodel in {model_dir}")
+    if param_path is None:
+        raise ValueError(f"not find model.pdiparams or inference.pdiparams in {model_dir}")
+
+    config = inference.Config(model_path, param_path)
+
+    precision = inference.PrecisionType.Float32
+    if hasattr(args, "precision"):
+        if args.precision == "fp16" and args.use_tensorrt:
+            precision = inference.PrecisionType.Half
+        elif args.precision == "int8":
+            precision = inference.PrecisionType.Int8
+        else:
+            precision = inference.PrecisionType.Float32
+
+    if args.use_gpu:
+        gpu_id = __get_gpu_id()
+        if gpu_id is None:
+            print(
+                "WARING:",
+                "GPU is not found in current device by nvidia-smi.",
+                "Please check your device or ignore it if run on jetson."
+            )
+        config.enable_use_gpu(args.gpu_mem, 0)
+        if args.use_tensorrt:
+            config.enable_tensorrt_engine(
+                workspace_size=1 << 30,
+                precision_mode=precision,
+                max_batch_size=args.max_batch_size,
+                min_subgraph_size=args.min_subgraph_size,  # skip the minmum trt subgraph
+                use_calib_mode=False
+            )
+
+            # collect shape
+            trt_shape_f = path.join(model_dir, f"{mode}_trt_dynamic_shape.txt")
+
+            if not path.exists(trt_shape_f):
+                config.collect_shape_range_info(trt_shape_f)
+                print(f"collect dynamic shape info into : {trt_shape_f}")
+            try:
+                config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
+            except Exception as E:
+                print(E)
+                print("Please keep your paddlepaddle-gpu >= 2.3.0!")
+    else:
+        config.disable_gpu()
+        if args.enable_mkldnn:
+            config.set_mkldnn_cache_capacity(10)
+            config.enable_mkldnn()
+            if hasattr(args, "cpu_threads"):
+                config.set_cpu_math_library_num_threads(args.cpu_threads)
+            else:
+                config.set_cpu_math_library_num_threads(10)
+
+    config.enable_memory_optim()
+    config.disable_glog_info()
+    config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
+    config.delete_pass("matmul_transpose_reshape_fuse_pass")
+
+    config.switch_use_feed_fetch_ops(False)
+    config.switch_ir_optim(True)
+
+    predictor = inference.create_predictor(config)
+    input_names = predictor.get_input_names()
+
+    input_tensor = None
+    for name in input_names:
+        input_tensor = predictor.get_input_handle(name)
+    output_tensors = __get_output_tensors(args, mode, predictor)
+    return predictor, input_tensor, output_tensors

models/cls/inference.pdiparams → hmOCR/static/cls/inference.pdiparams


models/cls/inference.pdiparams.info → hmOCR/static/cls/inference.pdiparams.info


models/cls/inference.pdmodel → hmOCR/static/cls/inference.pdmodel


models/det/inference.pdiparams → hmOCR/static/det/inference.pdiparams


models/det/inference.pdiparams.info → hmOCR/static/det/inference.pdiparams.info


models/det/inference.pdmodel → hmOCR/static/det/inference.pdmodel


static/ppocr/ppocr_keys_v1.txt → hmOCR/static/key-set.txt


models/rec/inference.pdiparams → hmOCR/static/rec/inference.pdiparams


models/rec/inference.pdiparams.info → hmOCR/static/rec/inference.pdiparams.info


models/rec/inference.pdmodel → hmOCR/static/rec/inference.pdmodel


BIN
hmOCR/static/test_image/01.jpg


BIN
hmOCR/static/test_image/02.jpg


BIN
hmOCR/static/test_image/03.jpg


BIN
hmOCR/static/test_image/04.jpg


+ 109 - 0
hmOCR/utility.py

@@ -0,0 +1,109 @@
+import cv2
+import numpy as np
+from os import path, listdir
+
+__all__ = [
+    "sorted_boxes", "get_rotate_crop_image", "get_min_area_rect_crop",
+    "get_image_file_list", "check_and_read"
+]
+
+
+def sorted_boxes(dt_boxes):
+    num_boxes = dt_boxes.shape[0]
+    _sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+    _boxes = list(_sorted_boxes)
+
+    for i in range(num_boxes - 1):
+        for j in range(i, -1, -1):
+            if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (_boxes[j + 1][0][0] < _boxes[j][0][0]):
+                tmp = _boxes[j]
+                _boxes[j] = _boxes[j + 1]  # noqa
+                _boxes[j + 1] = tmp
+            else:
+                break
+    return _boxes
+
+
+def get_rotate_crop_image(img, points):
+    """
+    img_height, img_width = img.shape[0:2]
+    left = int(np.min(points[:, 0]))
+    right = int(np.max(points[:, 0]))
+    top = int(np.min(points[:, 1]))
+    bottom = int(np.max(points[:, 1]))
+    img_crop = img[top:bottom, left:right, :].copy()
+    points[:, 0] = points[:, 0] - left
+    points[:, 1] = points[:, 1] - top
+    """
+    assert len(points) == 4, "shape of points must be 4*2"
+    img_crop_width = int(max(np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])))
+    img_crop_height = int(max(np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])))
+    pts_std = np.float32([[0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height]])
+    M = cv2.getPerspectiveTransform(points, pts_std)  # noqa
+    dst_img = cv2.warpPerspective(  # noqa
+        img, M, (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC  # noqa
+    )
+    dst_img_height, dst_img_width = dst_img.shape[0:2]
+    if dst_img_height * 1.0 / dst_img_width >= 1.5:
+        dst_img = np.rot90(dst_img)
+    return dst_img
+
+
+def get_min_area_rect_crop(img, points):
+    bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))  # noqa
+    points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])  # noqa
+
+    if points[1][1] > points[0][1]:
+        index_a = 0
+        index_d = 1
+    else:
+        index_a = 1
+        index_d = 0
+    if points[3][1] > points[2][1]:
+        index_b = 2
+        index_c = 3
+    else:
+        index_b = 3
+        index_c = 2
+
+    box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+    crop_img = get_rotate_crop_image(img, np.array(box))
+    return crop_img
+
+
+def _check_image_file(file_path):
+    img_end = ("jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif")
+    return any([file_path.lower().endswith(e) for e in img_end])
+
+
+def get_image_file_list(img_file):
+    images = []
+    if img_file is None or not path.exists(img_file):
+        raise Exception(f"not found any img file in {img_file}")
+
+    if path.isfile(img_file) and _check_image_file(img_file):
+        images.append(img_file)
+    elif path.isdir(img_file):
+        for single_file in listdir(img_file):
+            file_path = path.join(img_file, single_file)
+            if path.isfile(file_path) and _check_image_file(file_path):
+                images.append(file_path)
+    if len(images) == 0:
+        raise Exception(f"not found any img file in {img_file}")
+    images = sorted(images)
+    return images
+
+
+def check_and_read(img_path):
+    if path.basename(img_path)[-3:].lower() == "gif":
+        gif = cv2.VideoCapture(img_path)  # noqa
+        ret, frame = gif.read()
+        if not ret:
+            print(f"Cannot read {img_path}. This gif image maybe corrupted.")
+            return None, False
+        if len(frame.shape) == 2 or frame.shape[-1] == 1:
+            frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)  # noqa
+        img = frame[:, :, ::-1]
+        return img, True
+    return None, False

BIN
static/ppocr/simfang.ttf


+ 16 - 29
templates/com_result.html

@@ -13,18 +13,13 @@
         .img-line {
             width: 100%;
             box-sizing: border-box;
-            padding: 0 100px;
             display: flex;
-            justify-content: space-between;
+            justify-content: center;
             margin-top: 50px;
         }
 
-        .img-box {
-            margin: 0 20px;
-        }
-
         img {
-            width: 500px;
+            width: 600px;
             height: auto;
             box-sizing: border-box;
             padding: 5px;
@@ -40,58 +35,50 @@
         }
 
         table {
+            width: 100%;
             border: none;
             background-color: aqua;
         }
 
         .col1 {
-            width: 60%;
+            text-align: center;
+            width: 5%;
         }
 
         .col2 {
-            width: 15%;
+            width: 80%;
         }
 
         .col3 {
-            width: 25%;
+            width: 15%;
         }
 
         td, th {
             background-color: white;
         }
-
-        .center {
-            text-align: center;
-        }
     </style>
 </head>
 <body>
-<h1>识别结果展示页面</h1>
 <div class="img-line">
-    <div class="img-box">
-        <h1>原图</h1>
-        <a target="_blank" href="/{{ raw }}"><img src="/{{ raw }}" alt="raw"></a>
-    </div>
-    <div class="img-box">
-        <h1>结果</h1>
-        <a target="_blank" href="/{{ rec }}"><img src="/{{ rec }}" alt="rec"></a>
-    </div>
+    <a target="_blank" href="/{{ raw }}">
+        <img src="/{{ raw }}" alt="raw">
+    </a>
 </div>
 <div class="data-table">
     <table>
         <thead>
         <tr>
-            <th class="col1">内容</th>
-            <th class="col2">概率</th>
-            <th class="col3">位置</th>
+            <th class="col1">序号</th>
+            <th class="col2">内容</th>
+            <th class="col3">概率</th>
         </tr>
         </thead>
         <tbody>
         {% for item in data %}
             <tr>
-                <td>{{ item.word }}</td>
-                <td class="center">{{ item.rate }}</td>
-                <td>{{ item.pos }}</td>
+                <td class="col1">{{ item.index }}</td>
+                <td class="col2">{{ item.word }}</td>
+                <td class="col3">{{ item.rate }}</td>
             </tr>
         {% endfor %}
         </tbody>

+ 6 - 17
templates/k-v_result.html

@@ -1,5 +1,5 @@
 <!DOCTYPE html>
-<html lang="en">
+<html lang="zh">
 <head>
     <meta charset="UTF-8">
     <title>键值对OCR结果展示</title>
@@ -13,18 +13,13 @@
         .img-line {
             width: 100%;
             box-sizing: border-box;
-            padding: 0 100px;
             display: flex;
-            justify-content: space-between;
+            justify-content: center;
             margin-top: 50px;
         }
 
-        .img-box {
-            margin: 0 20px;
-        }
-
         img {
-            width: 500px;
+            width: 600px;
             height: auto;
             box-sizing: border-box;
             padding: 5px;
@@ -63,16 +58,10 @@
     </style>
 </head>
 <body>
-<h1>识别结果展示页面</h1>
 <div class="img-line">
-    <div class="img-box">
-        <h1>原图</h1>
-        <a target="_blank" href="/{{ raw }}"><img src="/{{ raw }}" alt="raw"></a>
-    </div>
-    <div class="img-box">
-        <h1>结果</h1>
-        <a target="_blank" href="/{{ rec }}"><img src="/{{ rec }}" alt="rec"></a>
-    </div>
+    <a target="_blank" href="/{{ raw }}">
+        <img src="/{{ raw }}" alt="raw">
+    </a>
 </div>
 <div class="data-table">
     <table>

+ 20 - 107
utils/util.py

@@ -2,76 +2,25 @@ import cv2
 import numpy as np
 from typing import Union
 from flask import jsonify
-from paddleocr import PaddleOCR
-from random import randint, seed
+from random import randint
+from hmOCR import HuiMvOcr, Args
 from time import localtime, strftime
-from concurrent.futures import ThreadPoolExecutor
-from paddleocr.tools.infer.utility import draw_box_txt_fine
 
 __all__ = [
-    "Args", "Response", "rand_str", "current_time", "get_ext_name", "is_image_ext", "recognize", "draw_img",
-    "json_all", "str_include", "rec_multi", "save_img", "rotate"
+    "Response", "rand_str", "current_time", "get_ext_name", "is_image_ext",
+    "json_all", "str_include", "read_img", "rot_img", "save_img", "Engine"
 ]
 
 __StrBase = "qwertyuioplkjhgfdsazxcvbnm1234567890ZXCVBNMLKJHGFDSAQWERTYUIOP"
 __StrBaseLen = len(__StrBase) - 1
 __AcceptExtNames = ["jpg", "jpeg", "bmp", "png", "rgb", "tif", "tiff", "gif", "pdf"]
-__EngineNum = 4
-__Engines = [PaddleOCR(
-    use_gpu=False,
-    enable_mkldnn=True,
-    det_model_dir="models/det/",
-    rec_model_dir="models/rec/",
-    cls_model_dir="models/cls/",
-    use_angle_cls=True
-) for _ in range(__EngineNum)]
-
-
-class Args:
-    def __init__(self, **kwargs):
-        self.__update(
-            use_gpu=False, use_xpu=False, use_npu=False, ir_optim=True, use_tensorrt=False,
-            min_subgraph_size=15, precision="fp32", gpu_mem=500, image_dir=None, page_num=0,
-            det_algorithm="DB", det_model_dir="models/det/", det_limit_side_len=960, det_limit_type="max",
-            det_box_type="quad", det_db_thresh=0.3, det_db_box_thresh=0.6, det_db_unclip_ratio=1.5,
-            max_batch_size=10, use_dilation=False, det_db_score_mode="fast", det_east_score_thresh=0.8,
-            det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_sast_score_thresh=0.5,
-            det_sast_nms_thresh=0.2, det_pse_thresh=0, det_pse_box_thresh=0.85, det_pse_min_area=16,
-            det_pse_scale=1, scales=[8, 16, 32], alpha=1.0, beta=1.0, fourier_degree=5,
-            rec_algorithm="SVTR_LCNet", rec_model_dir="models/rec/", rec_image_inverse=True,
-            rec_image_shape="3, 48, 320", rec_batch_num=6, max_text_length=25,
-            rec_char_dict_path="venv/lib/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt",
-            use_space_char=True, vis_font_path="static/simfang.ttf", drop_score=0.5,
-            e2e_algorithm="PGNet", e2e_model_dir=None, e2e_limit_side_len=768, e2e_limit_type="max",
-            e2e_pgnet_score_thresh=0.5, e2e_char_dict_path="./ppocr/utils/ic15_dict.txt",
-            e2e_pgnet_valid_set="totaltext", e2e_pgnet_mode="fast", use_angle_cls=True,
-            cls_model_dir="models/cls/", cls_image_shape="3, 48, 192", label_list=["0", "180"],
-            cls_batch_num=6, cls_thresh=0.9, enable_mkldnn=True, cpu_threads=10, use_pdserving=False,
-            sr_model_dir=None, sr_image_shape="3, 32, 128", sr_batch_num=1,
-            draw_img_save_dir="static/rec_res/", save_crop_res=False, crop_res_save_dir="./output",
-            use_mp=False, benchmark=False, save_log_path="./log_output/",
-            show_log=True, use_onnx=False, output="./output", table_max_len=488, table_algorithm="TableAttn",
-            table_model_dir=None, merge_no_span_structure=True, table_char_dict_path=None,
-            layout_model_dir=None, layout_dict_path=None, layout_score_threshold=0.5,
-            layout_nms_threshold=0.5, kie_algorithm="LayoutXLM", ser_model_dir=None, re_model_dir=None,
-            use_visual_backbone=True, ser_dict_path="../train_data/XFUND/class_list_xfun.txt",
-            ocr_order_method=None, mode="structure", image_orientation=False, layout=True, table=True,
-            ocr=True, recovery=False, use_pdf2docx_api=False, lang="ch", det=True, rec=True, type="ocr",
-            ocr_version="PP-OCRv3", structure_version="PP-StructureV2"
-        )
-
-        self.__update(**kwargs)
-
-    def __update(self, **kwargs):
-        for k, v in kwargs:
-            self.__dict__[k] = v
-
-    def __setattr__(self, key: "str", value):
-        self.__dict__[key] = value
-
-    def __getattribute__(self, key: "str"):
-        assert key in self.__dict__.keys()
-        return self.__dict__[key]
+Engine = HuiMvOcr(Args())
+
+
+def Response(message: "str" = None, data=None):
+    if message is None:
+        return jsonify(success=True, message="操作成功", data=data)
+    return jsonify(success=False, message=message, data=data)
 
 
 def rand_str(size: "int" = 8) -> "str":
@@ -90,46 +39,7 @@ def is_image_ext(ext: "str") -> bool:
     return ext in __AcceptExtNames
 
 
-def Response(message: "str" = None, data=None):
-    if message is None:
-        return jsonify(success=True, message="操作成功", data=data)
-    return jsonify(success=False, message=message, data=data)
-
-
-def _rec(img, which: "int" = 0) -> "tuple[list, tuple]":
-    return __Engines[which % __EngineNum].ocr(img)[0], img.shape
-
-
-def rec_multi(images: "list[np.ndarray]") -> "list[tuple]":  # list[_rec]
-    pool = ThreadPoolExecutor(__EngineNum)
-    tasks = [pool.submit(_rec, one, i) for i, one in enumerate(images)]
-
-    return [task.result() for task in tasks]
-
-
-def recognize(content: "str") -> "tuple[list, tuple]":
-    img = cv2.imdecode(np.fromstring(content, np.uint8), 1)  # noqa
-
-    return _rec(img)
-
-
-def draw_img(shape: "tuple", data: "list[dict]", path: "str", drop: "float" = 0.5):
-    img = np.ones(shape, dtype=np.uint8) * 255
-    seed(0)
-
-    for one in data:
-        if one["rate"] < drop:
-            continue
-        color = (randint(0, 255), randint(0, 255), randint(0, 255))
-        text = draw_box_txt_fine((shape[1], shape[0]), one["pos"], one["word"], font_path="static/ppocr/simfang.ttf")
-        pts = np.array(one["pos"], np.int32).reshape((-1, 1, 2))
-        cv2.polylines(text, [pts], True, color, 1)  # noqa
-        img = cv2.bitwise_and(img, text)  # noqa
-
-    cv2.imwrite(path, np.array(img))  # noqa
-
-
-def json_all(data: "dict or list") -> "bool":
+def json_all(data: "Union[list, dict]") -> "bool":
     if isinstance(data, list):
         for item in data:
             if isinstance(item, str) and not item:
@@ -154,14 +64,17 @@ def str_include(str_long: "str", str_short: "str") -> "bool":
     return True
 
 
+def read_img(content: "str") -> "np.ndarray":
+    return cv2.imdecode(np.fromstring(content, np.uint8), 1)  # noqa
+
+
+def rot_img(img: "np.ndarray") -> "list[np.ndarray]":
+    return [img, np.rot90(img), np.rot90(img, 2), np.rot90(img, 3)]
+
+
 def save_img(filename: "str", content: "Union[bytes, np.ndarray]"):
     if isinstance(content, np.ndarray):
         return cv2.imwrite(filename, content)  # noqa
     with open(filename, "wb") as fp:
         fp.write(content)
         fp.close()
-
-
-def rotate(content: "str") -> "list[np.ndarray]":
-    img = cv2.imdecode(np.fromstring(content, np.uint8), 1)  # noqa
-    return [img, np.rot90(img), np.rot90(img, 2), np.rot90(img, 3)]