operator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. import re
  2. import cv2
  3. import pyclipper
  4. import numpy as np
  5. from PIL import Image
  6. from paddle import Tensor
  7. from shapely.geometry import Polygon
  8. __all__ = [
  9. "DetResizeForTest", "NormalizeImage", "ToCHWImage", "KeepKeys",
  10. "DBPostProcess", "ClsPostProcess", "CTCLabelDecode"
  11. ]
  12. class DetResizeForTest:
  13. def __init__(self, **kwargs):
  14. if "limit_side_len" in kwargs:
  15. self.limit_side_len = kwargs["limit_side_len"]
  16. self.limit_type = kwargs.get("limit_type", "min")
  17. else:
  18. self.limit_side_len = 736
  19. self.limit_type = "min"
  20. def __call__(self, data):
  21. img = data["image"]
  22. src_h, src_w, _ = img.shape
  23. if sum([src_h, src_w]) < 64:
  24. img = self.image_padding(img)
  25. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  26. data["image"] = img
  27. data["shape"] = np.array([src_h, src_w, ratio_h, ratio_w])
  28. return data
  29. @staticmethod
  30. def image_padding(im, value=0):
  31. h, w, c = im.shape
  32. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  33. im_pad[:h, :w, :] = im
  34. return im_pad
  35. def resize_image_type0(self, img):
  36. """
  37. resize image to a size multiple of 32 which is required by the network
  38. args:
  39. img(array): array with shape [h, w, c]
  40. return(tuple):
  41. img, (ratio_h, ratio_w)
  42. """
  43. limit_side_len = self.limit_side_len
  44. h, w, c = img.shape
  45. # limit the max side
  46. if self.limit_type == "max":
  47. if max(h, w) > limit_side_len:
  48. if h > w:
  49. ratio = float(limit_side_len) / h
  50. else:
  51. ratio = float(limit_side_len) / w
  52. else:
  53. ratio = 1.
  54. elif self.limit_type == "min":
  55. if min(h, w) < limit_side_len:
  56. if h < w:
  57. ratio = float(limit_side_len) / h
  58. else:
  59. ratio = float(limit_side_len) / w
  60. else:
  61. ratio = 1.
  62. elif self.limit_type == "resize_long":
  63. ratio = float(limit_side_len) / max(h, w)
  64. else:
  65. raise Exception("not support limit type, image ")
  66. resize_h = int(h * ratio)
  67. resize_w = int(w * ratio)
  68. resize_h = max(int(round(resize_h / 32) * 32), 32)
  69. resize_w = max(int(round(resize_w / 32) * 32), 32)
  70. try:
  71. if int(resize_w) <= 0 or int(resize_h) <= 0:
  72. return None, (None, None)
  73. img = cv2.resize(img, (int(resize_w), int(resize_h))) # noqa
  74. except Exception as e:
  75. print(img.shape, resize_w, resize_h, e)
  76. exit(0)
  77. ratio_h = resize_h / float(h)
  78. ratio_w = resize_w / float(w)
  79. return img, [ratio_h, ratio_w]
  80. class NormalizeImage:
  81. def __init__(self, scale, mean, std, order="chw"):
  82. self.scale = np.float32(scale)
  83. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  84. self.mean = np.array(mean).reshape(shape).astype("float32")
  85. self.std = np.array(std).reshape(shape).astype("float32")
  86. def __call__(self, data):
  87. img = data["image"]
  88. if isinstance(img, Image.Image):
  89. img = np.array(img) # noqa
  90. assert isinstance(img, np.ndarray), "invalid input img in NormalizeImage"
  91. data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
  92. return data
  93. class ToCHWImage:
  94. def __call__(self, data):
  95. img = data["image"]
  96. if isinstance(img, Image.Image):
  97. img = np.array(img) # noqa
  98. data["image"] = img.transpose((2, 0, 1))
  99. return data
  100. class KeepKeys:
  101. def __init__(self, keep_keys):
  102. self.keep_keys = keep_keys
  103. def __call__(self, data):
  104. return [data[key] for key in self.keep_keys]
  105. class DBPostProcess:
  106. def __init__(
  107. self, thresh=0.3, box_thresh=0.7, max_candidates=1000, unclip_ratio=2.0,
  108. use_dilation=False, score_mode="fast", box_type="quad"
  109. ):
  110. self.thresh = thresh
  111. self.box_thresh = box_thresh
  112. self.max_candidates = max_candidates
  113. self.unclip_ratio = unclip_ratio
  114. self.min_size = 3
  115. self.score_mode = score_mode
  116. self.box_type = box_type
  117. assert score_mode in [
  118. "slow", "fast"
  119. ], f"Score mode must be in [slow, fast] but got: {score_mode}"
  120. self.dilation_kernel = None if not use_dilation else np.array(
  121. [[1, 1], [1, 1]])
  122. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  123. """
  124. _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}
  125. """
  126. bitmap = _bitmap
  127. height, width = bitmap.shape
  128. boxes, scores = [], []
  129. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) # noqa
  130. for contour in contours[:self.max_candidates]:
  131. epsilon = 0.002 * cv2.arcLength(contour, True) # noqa
  132. approx = cv2.approxPolyDP(contour, epsilon, True) # noqa
  133. points = approx.reshape((-1, 2))
  134. if points.shape[0] < 4:
  135. continue
  136. score = self.box_score_fast(pred, points.reshape(-1, 2))
  137. if self.box_thresh > score:
  138. continue
  139. if points.shape[0] > 2:
  140. box = self.unclip(points, self.unclip_ratio)
  141. if len(box) > 1:
  142. continue
  143. else:
  144. continue
  145. box = box.reshape(-1, 2)
  146. _, s_side = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  147. if s_side < self.min_size + 2:
  148. continue
  149. box = np.array(box)
  150. box[:, 0] = np.clip(
  151. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  152. box[:, 1] = np.clip(
  153. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  154. boxes.append(box.tolist())
  155. scores.append(score)
  156. return boxes, scores
  157. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  158. """
  159. _bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}
  160. """
  161. bitmap, contours = _bitmap, None
  162. height, width = bitmap.shape
  163. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) # noqa
  164. if len(outs) == 3:
  165. img, contours, _ = outs[0], outs[1], outs[2]
  166. elif len(outs) == 2:
  167. contours, _ = outs[0], outs[1]
  168. num_contours = min(len(contours), self.max_candidates)
  169. boxes = []
  170. scores = []
  171. for index in range(num_contours):
  172. contour = contours[index]
  173. points, s_side = self.get_mini_boxes(contour)
  174. if s_side < self.min_size:
  175. continue
  176. points = np.array(points)
  177. if self.score_mode == "fast":
  178. score = self.box_score_fast(pred, points.reshape(-1, 2))
  179. else:
  180. score = self.box_score_slow(pred, contour)
  181. if self.box_thresh > score:
  182. continue
  183. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2) # noqa
  184. box, s_side = self.get_mini_boxes(box)
  185. if s_side < self.min_size + 2:
  186. continue
  187. box = np.array(box)
  188. box[:, 0] = np.clip(
  189. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  190. box[:, 1] = np.clip(
  191. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  192. boxes.append(box.astype("int32"))
  193. scores.append(score)
  194. return np.array(boxes, dtype="int32"), scores
  195. @staticmethod
  196. def unclip(box, unclip_ratio):
  197. poly = Polygon(box)
  198. distance = poly.area * unclip_ratio / poly.length
  199. offset = pyclipper.PyclipperOffset()
  200. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  201. expanded = np.array(offset.Execute(distance))
  202. return expanded
  203. @staticmethod
  204. def get_mini_boxes(contour):
  205. bounding_box = cv2.minAreaRect(contour) # noqa
  206. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) # noqa
  207. if points[1][1] > points[0][1]:
  208. index_1 = 0
  209. index_4 = 1
  210. else:
  211. index_1 = 1
  212. index_4 = 0
  213. if points[3][1] > points[2][1]:
  214. index_2 = 2
  215. index_3 = 3
  216. else:
  217. index_2 = 3
  218. index_3 = 2
  219. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  220. return box, min(bounding_box[1])
  221. @staticmethod
  222. def box_score_fast(bitmap, _box):
  223. """
  224. box_score_fast: use bbox mean score as the mean score
  225. """
  226. h, w = bitmap.shape[:2]
  227. box = _box.copy()
  228. x_min = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
  229. x_max = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
  230. y_min = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
  231. y_max = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
  232. mask = np.zeros((y_max - y_min + 1, x_max - x_min + 1), dtype=np.uint8)
  233. box[:, 0] = box[:, 0] - x_min
  234. box[:, 1] = box[:, 1] - y_min
  235. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) # noqa
  236. return cv2.mean(bitmap[y_min:y_max + 1, x_min:x_max + 1], mask)[0] # noqa
  237. @staticmethod
  238. def box_score_slow(bitmap, contour):
  239. """
  240. box_score_slow: use polyon mean score as the mean score
  241. """
  242. h, w = bitmap.shape[:2]
  243. contour = contour.copy()
  244. contour = np.reshape(contour, (-1, 2))
  245. x_min = np.clip(np.min(contour[:, 0]), 0, w - 1)
  246. x_max = np.clip(np.max(contour[:, 0]), 0, w - 1)
  247. y_min = np.clip(np.min(contour[:, 1]), 0, h - 1)
  248. y_max = np.clip(np.max(contour[:, 1]), 0, h - 1)
  249. mask = np.zeros((y_max - y_min + 1, x_max - x_min + 1), dtype=np.uint8)
  250. contour[:, 0] = contour[:, 0] - x_min
  251. contour[:, 1] = contour[:, 1] - y_min
  252. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1) # noqa
  253. return cv2.mean(bitmap[y_min:y_max + 1, x_min:x_max + 1], mask)[0] # noqa
  254. def __call__(self, outs_dict, shape_list):
  255. pred = outs_dict["maps"]
  256. if isinstance(pred, Tensor):
  257. pred = pred.numpy()
  258. pred = pred[:, 0, :, :]
  259. segmentation = pred > self.thresh
  260. boxes_batch = []
  261. for batch_index in range(pred.shape[0]):
  262. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  263. if self.dilation_kernel is not None:
  264. mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel) # noqa
  265. else:
  266. mask = segmentation[batch_index]
  267. if self.box_type == "poly":
  268. boxes, scores = self.polygons_from_bitmap(pred[batch_index],
  269. mask, src_w, src_h)
  270. elif self.box_type == "quad":
  271. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
  272. src_w, src_h)
  273. else:
  274. raise ValueError("box_type can only be one of ['quad', 'poly']")
  275. boxes_batch.append({"points": boxes})
  276. return boxes_batch
  277. class ClsPostProcess:
  278. """ Convert between text-label and text-index """
  279. def __init__(self, label_list=None):
  280. self.label_list = label_list
  281. def __call__(self, preds, label=None, *args, **kwargs):
  282. label_list = self.label_list
  283. if label_list is None:
  284. label_list = {idx: idx for idx in range(preds.shape[-1])}
  285. if isinstance(preds, Tensor):
  286. preds = preds.numpy()
  287. pred_ids = preds.argmax(axis=1)
  288. decode_out = [(label_list[idx], preds[i, idx]) for i, idx in enumerate(pred_ids)]
  289. if label is None:
  290. return decode_out
  291. label = [(label_list[idx], 1.0) for idx in label]
  292. return decode_out, label
  293. class __BaseRecDecoder:
  294. """ Convert between text-label and text-index """
  295. def __init__(self, character_dict_path=None):
  296. self.beg_str = "sos"
  297. self.end_str = "eos"
  298. self.reverse = False
  299. self.character_str = []
  300. if character_dict_path is None:
  301. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz "
  302. dict_character = list(self.character_str)
  303. else:
  304. with open(character_dict_path, "rb") as fin:
  305. lines = fin.readlines()
  306. for line in lines:
  307. line = line.decode("utf-8").strip("\n").strip("\r\n")
  308. self.character_str.append(line)
  309. self.character_str.append(" ")
  310. dict_character = list(self.character_str)
  311. dict_character = self.add_special_char(dict_character)
  312. self.max_index = len(dict_character) - 1
  313. self.dict = {}
  314. for i, char in enumerate(dict_character):
  315. self.dict[char] = i
  316. self.character = dict_character
  317. @staticmethod
  318. def pred_reverse(pred):
  319. pred_re = []
  320. c_current = ""
  321. for c in pred:
  322. if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
  323. if c_current != "":
  324. pred_re.append(c_current)
  325. pred_re.append(c)
  326. c_current = ""
  327. else:
  328. c_current += c
  329. if c_current != "":
  330. pred_re.append(c_current)
  331. return "".join(pred_re[::-1])
  332. def add_special_char(self, dict_character):
  333. return dict_character
  334. def decode(self, text_index, text_prob=None, is_remove_duplicate=False, use_space=False):
  335. """ convert text-index into text-label. """
  336. result_list = []
  337. ignored_tokens = self.get_ignored_tokens()
  338. batch_size = len(text_index)
  339. for batch_idx in range(batch_size):
  340. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  341. if is_remove_duplicate:
  342. selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
  343. for ignored_token in ignored_tokens:
  344. selection &= text_index[batch_idx] != ignored_token
  345. char_list = []
  346. for index in text_index[batch_idx][selection]:
  347. if index == self.max_index and not use_space:
  348. continue
  349. char_list.append(self.character[index])
  350. if text_prob is not None:
  351. conf_list = text_prob[batch_idx][selection]
  352. else:
  353. conf_list = [1] * len(selection)
  354. if len(conf_list) == 0:
  355. conf_list = [0]
  356. text = "".join(char_list)
  357. result_list.append((text, np.mean(conf_list).tolist()))
  358. return result_list
  359. @staticmethod
  360. def get_ignored_tokens():
  361. return [0] # for ctc blank
  362. class CTCLabelDecode(__BaseRecDecoder):
  363. """ Convert between text-label and text-index """
  364. def __init__(self, character_dict_path=None):
  365. super(CTCLabelDecode, self).__init__(character_dict_path)
  366. def __call__(self, preds, use_space=False, *args, **kwargs):
  367. if isinstance(preds, tuple) or isinstance(preds, list):
  368. preds = preds[-1]
  369. if isinstance(preds, Tensor):
  370. preds = preds.numpy()
  371. preds_idx = preds.argmax(axis=2)
  372. preds_prob = preds.max(axis=2)
  373. return self.decode(preds_idx, preds_prob, is_remove_duplicate=True, use_space=use_space)
  374. def add_special_char(self, dict_character):
  375. dict_character = ["blank"] + dict_character
  376. return dict_character