recognizer.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import cv2
  2. import math
  3. import numpy as np
  4. from .utils import *
  5. from hmOCR.argument import Args
  6. from .operator import CTCLabelDecode
  7. class Recognizer:
  8. def __init__(self, args: "Args"):
  9. self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
  10. self.rec_batch_num = args.rec_batch_num
  11. # rec_algorithm: only "SVTR_LCNet" now
  12. # self.rec_algorithm = args.rec_algorithm
  13. self.post_op = CTCLabelDecode(args.rec_char_dict_path)
  14. self.predictor, self.input_tensor, self.output_tensors = create_predictor(args, "rec")
  15. def resize_norm_img(self, img, max_wh_ratio):
  16. imgC, imgH, imgW = self.rec_image_shape
  17. assert imgC == img.shape[2]
  18. imgW = int((imgH * max_wh_ratio))
  19. h, w = img.shape[:2]
  20. ratio = w / float(h)
  21. if math.ceil(imgH * ratio) > imgW:
  22. resized_w = imgW
  23. else:
  24. resized_w = int(math.ceil(imgH * ratio))
  25. resized_image = cv2.resize(img, (resized_w, imgH)) # noqa
  26. resized_image = resized_image.astype("float32")
  27. resized_image = resized_image.transpose((2, 0, 1)) / 255
  28. resized_image -= 0.5
  29. resized_image /= 0.5
  30. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  31. padding_im[:, :, 0:resized_w] = resized_image
  32. return padding_im
  33. def __call__(self, img_list, use_space=False):
  34. img_num = len(img_list)
  35. width_list = []
  36. for img in img_list:
  37. width_list.append(img.shape[1] / float(img.shape[0]))
  38. # Sorting can speed up the recognition process
  39. indices = np.argsort(np.array(width_list))
  40. rec_res = [["", 0.0]] * img_num
  41. batch_num = self.rec_batch_num
  42. for beg_img_no in range(0, img_num, batch_num):
  43. end_img_no = min(img_num, beg_img_no + batch_num)
  44. norm_img_batch = []
  45. imgC, imgH, imgW = self.rec_image_shape[:3]
  46. max_wh_ratio = imgW / imgH
  47. for ino in range(beg_img_no, end_img_no):
  48. h, w = img_list[indices[ino]].shape[0:2]
  49. wh_ratio = w * 1.0 / h
  50. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  51. for ino in range(beg_img_no, end_img_no):
  52. norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
  53. norm_img = norm_img[np.newaxis, :]
  54. norm_img_batch.append(norm_img)
  55. norm_img_batch = np.concatenate(norm_img_batch)
  56. norm_img_batch = norm_img_batch.copy()
  57. self.input_tensor.copy_from_cpu(norm_img_batch)
  58. self.predictor.run()
  59. outputs = []
  60. for output_tensor in self.output_tensors:
  61. output = output_tensor.copy_to_cpu()
  62. outputs.append(output)
  63. if len(outputs) != 1:
  64. preds = outputs
  65. else:
  66. preds = outputs[0]
  67. rec_result = self.post_op(preds, use_space=use_space)
  68. for rno in range(len(rec_result)):
  69. rec_res[indices[beg_img_no + rno]] = rec_result[rno]
  70. return rec_res