Browse Source

idc ocr v2.0: support any angle for identity card. 4 thread will be arranged to computed in four aspects, and only the ture result will be saved.

Tinger 2 years ago
parent
commit
ffb282a949
3 changed files with 83 additions and 62 deletions
  1. 1 3
      blues/com.py
  2. 52 55
      blues/idc.py
  3. 30 4
      utils/util.py

+ 1 - 3
blues/com.py

@@ -24,9 +24,7 @@ class ComView(views.MethodView):
         cur, rnd = current_time(), rand_str()
         cur, rnd = current_time(), rand_str()
         raw_path = f"static/images/{cur}_{rnd}.{ext}"
         raw_path = f"static/images/{cur}_{rnd}.{ext}"
         rec_path = f"static/images/{cur}_{rnd}-rec.{ext}"
         rec_path = f"static/images/{cur}_{rnd}-rec.{ext}"
-        with open(raw_path, "wb") as fp:
-            fp.write(content)
-            fp.close()
+        save_img(raw_path, content)
 
 
         ocr_res, img_shape = recognize(content)
         ocr_res, img_shape = recognize(content)
         kind = request.form.get("type")
         kind = request.form.get("type")

+ 52 - 55
blues/idc.py

@@ -5,7 +5,7 @@ from utils.conf import MAX_CONTENT_LENGTH
 
 
 idc = Blueprint("idc", __name__, url_prefix="/idc")
 idc = Blueprint("idc", __name__, url_prefix="/idc")
 
 
-__CN = "中国CHINA"
+__exclude = "中国CHINA *#★☆"
 __face_ptn = r"^姓名(?P<name>.+)性别(?P<gender>男|女)民族(?P<nation>.+)" \
 __face_ptn = r"^姓名(?P<name>.+)性别(?P<gender>男|女)民族(?P<nation>.+)" \
              r"出生(?P<year>\d{4})年(?P<month>\d\d)月(?P<day>\d\d)日" \
              r"出生(?P<year>\d{4})年(?P<month>\d\d)月(?P<day>\d\d)日" \
              r"住址(?P<addr>.+)公民身份号码(?P<idn>\d{17}\d|x|X)$"
              r"住址(?P<addr>.+)公民身份号码(?P<idn>\d{17}\d|x|X)$"
@@ -14,20 +14,12 @@ __icon_ptn = r"^中华人民共和国居民身份证签发机关(?P<agent>.+)" \
              r"[^\d]+(?P<to_year>\d{4})\.(?P<to_month>\d{2})\.(?P<to_day>\d{2})$"
              r"[^\d]+(?P<to_year>\d{4})\.(?P<to_month>\d{2})\.(?P<to_day>\d{2})$"
 
 
 
 
-# 需要图片在PC上看着是:横长竖宽
 def get_face_info(data: "list[str]") -> "tuple[dict, str, bool]":
 def get_face_info(data: "list[str]") -> "tuple[dict, str, bool]":
     res = {"name": "", "gender": "", "nation": "", "birth": {"year": "", "month": "", "day": ""}, "addr": "", "idn": ""}
     res = {"name": "", "gender": "", "nation": "", "birth": {"year": "", "month": "", "day": ""}, "addr": "", "idn": ""}
-
     if len(data) < 5:  # 最少 5 个识别结果
     if len(data) < 5:  # 最少 5 个识别结果
         return res, "请使用正确的身份证人像面照片", False
         return res, "请使用正确的身份证人像面照片", False
-    deal = [item.replace(" ", "") for item in data if not str_include(__CN, item)]
-    if not deal[0].startswith("姓名"):  # 非正,逆序后尝试
-        deal.reverse()
-    if not deal[0].startswith("姓名"):
-        return res, "请确保照片为:横长竖宽,正面朝上", False
-
-    str_all = "".join(deal)
-    print(str_all)
+
+    str_all = "".join([item for item in data if not str_include(__exclude, item)])
     if match := re.match(__face_ptn, str_all):
     if match := re.match(__face_ptn, str_all):
         res["name"] = match.group("name")
         res["name"] = match.group("name")
         res["gender"] = match.group("gender")
         res["gender"] = match.group("gender")
@@ -39,24 +31,17 @@ def get_face_info(data: "list[str]") -> "tuple[dict, str, bool]":
         }
         }
         res["addr"] = match.group("addr")
         res["addr"] = match.group("addr")
         res["idn"] = match.group("idn")
         res["idn"] = match.group("idn")
-        return res, "", True
+        return res, str_all, True
 
 
     return res, "识别失败,请重新选择", False
     return res, "识别失败,请重新选择", False
 
 
 
 
 def get_icon_info(data: "list[str]"):
 def get_icon_info(data: "list[str]"):
     res = {"agent": "", "from": {"year": "", "month": "", "day": ""}, "to": {"year": "", "month": "", "day": ""}}
     res = {"agent": "", "from": {"year": "", "month": "", "day": ""}, "to": {"year": "", "month": "", "day": ""}}
-
     if len(data) < 4:  # 最少 4 个识别结果
     if len(data) < 4:  # 最少 4 个识别结果
         return res, "请使用正确的身份证国徽面照片", False
         return res, "请使用正确的身份证国徽面照片", False
-    deal = [item.replace(" ", "") for item in data if not str_include(__CN, item)]
-    if not deal[0].startswith("中华"):  # 非正,逆序后尝试
-        deal.reverse()
-    if not deal[0].startswith("中华"):
-        return res, "请确保照片为:横长竖宽,正面朝上", False
-
-    str_all = "".join(deal)
-    print(str_all)
+
+    str_all = "".join([item for item in data if not str_include(__exclude, item)])
     if match := re.match(__icon_ptn, str_all):
     if match := re.match(__icon_ptn, str_all):
         res["agent"] = match.group("agent")
         res["agent"] = match.group("agent")
         res["from"] = {
         res["from"] = {
@@ -69,7 +54,7 @@ def get_icon_info(data: "list[str]"):
             "month": match.group("to_month"),
             "month": match.group("to_month"),
             "day": match.group("to_day"),
             "day": match.group("to_day"),
         }
         }
-        return res, "", True
+        return res, str_all, True
     return res, "识别失败,请重新选择", False
     return res, "识别失败,请重新选择", False
 
 
 
 
@@ -80,6 +65,11 @@ class IdcView(views.MethodView):
 
 
     @staticmethod
     @staticmethod
     def post():
     def post():
+        which = request.form.get("which")
+        if which is not None:
+            which = which.lower()
+        if which not in ["face", "icon"]:
+            return Response(f"not recognized arg <which>: '{which}'")
         pic = request.files.get("picture")
         pic = request.files.get("picture")
         if pic is None:
         if pic is None:
             return Response("empty body")
             return Response("empty body")
@@ -89,26 +79,25 @@ class IdcView(views.MethodView):
         content = pic.read()
         content = pic.read()
         if len(content) > MAX_CONTENT_LENGTH:
         if len(content) > MAX_CONTENT_LENGTH:
             return Response("文件过大,请重新选择")
             return Response("文件过大,请重新选择")
-        raw_path = f"static/images/{current_time()}_{rand_str()}.{ext}"
-        with open(raw_path, "wb") as fp:
-            fp.write(content)
-            fp.close()
-
-        which = request.form.get("which")
-        if which is not None:
-            which = which.lower()
-        if which not in ["face", "icon"]:
-            return Response(f"not recognized arg <which>: '{which}'")
 
 
-        ocr_res, _ = recognize(content)
-        words = [it[1][0] for it in ocr_res]
-        if which == "face":
-            info, msg, sta = get_face_info(words)
+        images = rotate(content)
+        rec = rec_multi(images)
+        info, msg, sta, idx = {}, "识别失败,请重新选择", False, 0
+        for idx, (ocr_res, _) in enumerate(rec):
+            words = [it[1][0].replace(" ", "") for it in ocr_res]
+            if which == "face":
+                if not words[0].startswith("姓名"):
+                    continue
+                info, msg, sta = get_face_info(words)
+            else:
+                if not words[0].startswith("中华"):
+                    continue
+                info, msg, sta = get_icon_info(words)
             if sta:
             if sta:
-                return Response(data=info)
-            return Response(msg, info)
-        info, msg, sta = get_icon_info(words)
+                break
         if sta:
         if sta:
+            raw_path = f"static/images/{current_time()}_{rand_str()}.{ext}"
+            save_img(raw_path, images[idx])
             return Response(data=info)
             return Response(data=info)
         return Response(msg, info)
         return Response(msg, info)
 
 
@@ -116,6 +105,11 @@ class IdcView(views.MethodView):
 class IdcHtmlView(views.MethodView):
 class IdcHtmlView(views.MethodView):
     @staticmethod
     @staticmethod
     def post():
     def post():
+        which = request.form.get("which")
+        if which is not None:
+            which = which.lower()
+        if which not in ["face", "icon"]:
+            return Response(f"not recognized arg <which>: '{which}'")
         pic = request.files.get("picture")
         pic = request.files.get("picture")
         if pic is None:
         if pic is None:
             return Response("empty body")
             return Response("empty body")
@@ -125,26 +119,29 @@ class IdcHtmlView(views.MethodView):
         content = pic.read()
         content = pic.read()
         if len(content) > MAX_CONTENT_LENGTH:
         if len(content) > MAX_CONTENT_LENGTH:
             return Response("文件过大,请重新选择")
             return Response("文件过大,请重新选择")
+
+        images = rotate(content)
+        rec = rec_multi(images)
+        info, msg, sta, idx = {}, "识别失败,请重新选择", False, 0
+        for idx, (ocr_res, _) in enumerate(rec):
+            words = [it[1][0].replace(" ", "") for it in ocr_res]
+            if which == "face":
+                if not words[0].startswith("姓名"):
+                    continue
+                info, msg, sta = get_face_info(words)
+            else:
+                if not words[0].startswith("中华"):
+                    continue
+                info, msg, sta = get_icon_info(words)
+            if sta:
+                break
+
         cut, rnd = current_time(), rand_str()
         cut, rnd = current_time(), rand_str()
         raw_path = f"static/images/{cut}_{rnd}.{ext}"
         raw_path = f"static/images/{cut}_{rnd}.{ext}"
         rec_path = f"static/images/{cut}_{rnd}_rec.{ext}"
         rec_path = f"static/images/{cut}_{rnd}_rec.{ext}"
-        with open(raw_path, "wb") as fp:
-            fp.write(content)
-            fp.close()
-
-        which = request.form.get("which")
-        if which is not None:
-            which = which.lower()
-        if which not in ["face", "icon"]:
-            return Response(f"not recognized arg <which>: '{which}'")
+        save_img(raw_path, images[idx])
+        draw_img(rec[idx][1], [{"pos": it[0], "word": it[1][0], "rate": it[1][1]} for it in rec[idx][0]], rec_path)
 
 
-        ocr_res, img_shape = recognize(content)
-        words = [it[1][0] for it in ocr_res]
-        draw_img(img_shape, [{"pos": it[0], "word": it[1][0], "rate": it[1][1]} for it in ocr_res], rec_path)
-        if which == "face":
-            info, msg, sta = get_face_info(words)
-        else:
-            info, msg, sta = get_icon_info(words)
         info["SUCCESS"] = str(sta).upper()
         info["SUCCESS"] = str(sta).upper()
         info["MESSAGE"] = msg
         info["MESSAGE"] = msg
         return render_template("k-v_result.html", raw=raw_path, rec=rec_path, data=info)
         return render_template("k-v_result.html", raw=raw_path, rec=rec_path, data=info)

+ 30 - 4
utils/util.py

@@ -1,20 +1,22 @@
 import cv2
 import cv2
 import numpy as np
 import numpy as np
+from typing import Union
 from flask import jsonify
 from flask import jsonify
 from paddleocr import PaddleOCR
 from paddleocr import PaddleOCR
 from random import randint, seed
 from random import randint, seed
 from time import localtime, strftime
 from time import localtime, strftime
+from concurrent.futures import ThreadPoolExecutor
 from paddleocr.tools.infer.utility import draw_box_txt_fine
 from paddleocr.tools.infer.utility import draw_box_txt_fine
 
 
 __all__ = [
 __all__ = [
     "Args", "Response", "rand_str", "current_time", "get_ext_name", "is_image_ext", "recognize", "draw_img",
     "Args", "Response", "rand_str", "current_time", "get_ext_name", "is_image_ext", "recognize", "draw_img",
-    "json_all", "str_include"
+    "json_all", "str_include", "rec_multi", "save_img", "rotate"
 ]
 ]
 
 
 __StrBase = "qwertyuioplkjhgfdsazxcvbnm1234567890ZXCVBNMLKJHGFDSAQWERTYUIOP"
 __StrBase = "qwertyuioplkjhgfdsazxcvbnm1234567890ZXCVBNMLKJHGFDSAQWERTYUIOP"
 __StrBaseLen = len(__StrBase) - 1
 __StrBaseLen = len(__StrBase) - 1
 __AcceptExtNames = ["jpg", "jpeg", "bmp", "png", "rgb", "tif", "tiff", "gif", "pdf"]
 __AcceptExtNames = ["jpg", "jpeg", "bmp", "png", "rgb", "tif", "tiff", "gif", "pdf"]
-__OcrEngine = PaddleOCR(
+__Engines = [PaddleOCR(
     use_gpu=False,
     use_gpu=False,
     enable_mkldnn=True,
     enable_mkldnn=True,
     det_model_dir="models/det/",
     det_model_dir="models/det/",
@@ -22,7 +24,7 @@ __OcrEngine = PaddleOCR(
     cls_model_dir="models/cls/",
     cls_model_dir="models/cls/",
     use_angle_cls=True,
     use_angle_cls=True,
     use_space_char=True
     use_space_char=True
-)
+) for _ in range(4)]
 
 
 
 
 class Args:
 class Args:
@@ -94,10 +96,21 @@ def Response(message: "str" = None, data=None):
     return jsonify(success=False, message=message, data=data)
     return jsonify(success=False, message=message, data=data)
 
 
 
 
+def _rec(img, which: "int" = 0) -> "tuple[list, tuple]":
+    return __Engines[which % 4].ocr(img)[0], img.shape
+
+
+def rec_multi(images: "list[np.ndarray]") -> "list[tuple]":  # list[_rec]
+    pool = ThreadPoolExecutor(4)
+    tasks = [pool.submit(_rec, one, i) for i, one in enumerate(images)]
+
+    return [task.result() for task in tasks]
+
+
 def recognize(content: "str") -> "tuple[list, tuple]":
 def recognize(content: "str") -> "tuple[list, tuple]":
     img = cv2.imdecode(np.fromstring(content, np.uint8), 1)  # noqa
     img = cv2.imdecode(np.fromstring(content, np.uint8), 1)  # noqa
 
 
-    return __OcrEngine.ocr(img)[0], img.shape
+    return _rec(img)
 
 
 
 
 def draw_img(shape: "tuple", data: "list[dict]", path: "str", drop: "float" = 0.5):
 def draw_img(shape: "tuple", data: "list[dict]", path: "str", drop: "float" = 0.5):
@@ -139,3 +152,16 @@ def str_include(str_long: "str", str_short: "str") -> "bool":
         if it not in str_long:
         if it not in str_long:
             return False
             return False
     return True
     return True
+
+
+def save_img(filename: "str", content: "Union[bytes, np.ndarray]"):
+    if isinstance(content, np.ndarray):
+        return cv2.imwrite(filename, content)  # noqa
+    with open(filename, "wb") as fp:
+        fp.write(content)
+        fp.close()
+
+
+def rotate(content: "str") -> "list[np.ndarray]":
+    img = cv2.imdecode(np.fromstring(content, np.uint8), 1)  # noqa
+    return [img, np.rot90(img), np.rot90(img, 2), np.rot90(img, 3)]