infer.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import argparse
  2. import random
  3. import numpy as np
  4. import torch
  5. from models import build_model
  6. import torchvision
  7. from torchvision.ops.boxes import batched_nms
  8. import cv2
  9. def get_args_parser():
  10. parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
  11. parser.add_argument('--lr', default=1e-4, type=float)
  12. parser.add_argument('--lr_backbone', default=1e-5, type=float)
  13. parser.add_argument('--batch_size', default=2, type=int)
  14. parser.add_argument('--weight_decay', default=1e-4, type=float)
  15. parser.add_argument('--epochs', default=300, type=int)
  16. parser.add_argument('--lr_drop', default=200, type=int)
  17. parser.add_argument('--clip_max_norm', default=0.1, type=float,
  18. help='gradient clipping max norm')
  19. # Model parameters
  20. parser.add_argument('--frozen_weights', type=str, default=None,
  21. help="Path to the pretrained model. If set, only the mask head will be trained")
  22. # * Backbone
  23. # 如果设置为resnet101,后面的权重文件路径也需要修改一下
  24. parser.add_argument('--backbone', default='resnet50', type=str,
  25. help="Name of the convolutional backbone to use")
  26. parser.add_argument('--dilation', action='store_true',
  27. help="If true, we replace stride with dilation in the last convolutional block (DC5)")
  28. parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
  29. help="Type of positional embedding to use on top of the image features")
  30. # * Transformer
  31. parser.add_argument('--enc_layers', default=6, type=int,
  32. help="Number of encoding layers in the transformer")
  33. parser.add_argument('--dec_layers', default=6, type=int,
  34. help="Number of decoding layers in the transformer")
  35. parser.add_argument('--dim_feedforward', default=2048, type=int,
  36. help="Intermediate size of the feedforward layers in the transformer blocks")
  37. parser.add_argument('--hidden_dim', default=256, type=int,
  38. help="Size of the embeddings (dimension of the transformer)")
  39. parser.add_argument('--dropout', default=0.1, type=float,
  40. help="Dropout applied in the transformer")
  41. parser.add_argument('--nheads', default=8, type=int,
  42. help="Number of attention heads inside the transformer's attentions")
  43. parser.add_argument('--num_queries', default=100, type=int,
  44. help="Number of query slots")
  45. parser.add_argument('--pre_norm', action='store_true')
  46. # * Segmentation
  47. parser.add_argument('--masks', action='store_true',
  48. help="Train segmentation head if the flag is provided")
  49. # Loss
  50. parser.add_argument('--no_aux_loss', dest='aux_loss', default='False',
  51. help="Disables auxiliary decoding losses (loss at each layer)")
  52. # * Matcher
  53. parser.add_argument('--set_cost_class', default=1, type=float,
  54. help="Class coefficient in the matching cost")
  55. parser.add_argument('--set_cost_bbox', default=5, type=float,
  56. help="L1 box coefficient in the matching cost")
  57. parser.add_argument('--set_cost_giou', default=2, type=float,
  58. help="giou box coefficient in the matching cost")
  59. # * Loss coefficients
  60. parser.add_argument('--mask_loss_coef', default=1, type=float)
  61. parser.add_argument('--dice_loss_coef', default=1, type=float)
  62. parser.add_argument('--bbox_loss_coef', default=5, type=float)
  63. parser.add_argument('--giou_loss_coef', default=2, type=float)
  64. parser.add_argument('--eos_coef', default=0.1, type=float,
  65. help="Relative classification weight of the no-object class")
  66. # dataset parameters
  67. parser.add_argument('--dataset_file', default='coco')
  68. parser.add_argument('--coco_path', type=str, default="coco")
  69. parser.add_argument('--coco_panoptic_path', type=str)
  70. parser.add_argument('--remove_difficult', action='store_true')
  71. # 检测的图像路径
  72. parser.add_argument('--source_dir', default='demo/images',
  73. help='path where to save, empty for no saving')
  74. # 检测结果保存路径
  75. parser.add_argument('--output_dir', default='demo/outputs',
  76. help='path where to save, empty for no saving')
  77. parser.add_argument('--device', default='cpu',
  78. help='device to use for training / testing')
  79. parser.add_argument('--seed', default=42, type=int)
  80. # resnet50对应的权重文件
  81. parser.add_argument('--resume', default='demo/weights/detr-r50-e632da11.pth',
  82. help='resume from checkpoint')
  83. parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
  84. help='start epoch')
  85. parser.add_argument('--eval', default="True")
  86. parser.add_argument('--num_workers', default=2, type=int)
  87. # distributed training parameters
  88. parser.add_argument('--world_size', default=1, type=int,
  89. help='number of distributed processes')
  90. parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
  91. return parser
  92. def init():
  93. args, _ = get_args_parser().parse_known_args()
  94. device = torch.device(args.device)
  95. model, _, _ = build_model(args)
  96. Checkpoint = torch.load(args.resume, map_location="cpu")
  97. model.load_state_dict(Checkpoint["model"], False)
  98. model.to(device)
  99. return model, device
  100. # COCO classes
  101. Classes = [
  102. 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
  103. 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
  104. 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
  105. 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
  106. 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
  107. 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
  108. 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
  109. 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
  110. 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
  111. 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
  112. 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
  113. 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
  114. 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
  115. 'toothbrush'
  116. ]
  117. Model, Device = init()
  118. ToTensor = torchvision.transforms.ToTensor()
  119. def box_cxcywh_to_xyxy(x):
  120. x_c, y_c, w, h = x.unbind(1)
  121. b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
  122. (x_c + 0.5 * w), (y_c + 0.5 * h)]
  123. return torch.stack(b, dim=1)
  124. def rescale_bboxes(out_bbox, size):
  125. img_w, img_h = size
  126. b = box_cxcywh_to_xyxy(out_bbox)
  127. b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
  128. return b
  129. def filter_boxes(scores, boxes, confidence=0.7, apply_nms=True, iou=0.5):
  130. keep = scores.max(-1).values > confidence
  131. scores, boxes = scores[keep], boxes[keep]
  132. if apply_nms:
  133. top_scores, labels = scores.max(-1)
  134. keep = batched_nms(boxes, top_scores, labels, iou)
  135. scores, boxes = scores[keep], boxes[keep]
  136. return scores, boxes
  137. def plot_one_box(x, img, color=None, label=None, line_thickness=1):
  138. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  139. color = color or [random.randint(0, 255) for _ in range(3)]
  140. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  141. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  142. if label:
  143. tf = max(tl - 1, 1) # font thickness
  144. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  145. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  146. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  147. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  148. def inference(img: "np.ndarray") -> "np.ndarray":
  149. # 加载图片数据到cpu、gpu中
  150. imgTensor = ToTensor(img)
  151. imgTensor = torch.reshape(imgTensor, [-1, imgTensor.shape[0], imgTensor.shape[1], imgTensor.shape[2]])
  152. imgTensor.to(Device)
  153. # Model(image)即可检测图片里的对象数据
  154. inferenceResult = Model(imgTensor)
  155. # 检测对象的得分
  156. scores = inferenceResult["pred_logits"].softmax(-1)[0, :, :-1].cpu()
  157. # 检测对象的位置数据
  158. boxes = rescale_bboxes(inferenceResult["pred_boxes"][0,].cpu(), (imgTensor.shape[3], imgTensor.shape[2]))
  159. scores, boxes = filter_boxes(scores, boxes)
  160. scores, boxes = scores.data.numpy(), boxes.data.numpy()
  161. # 在图片中标记对象
  162. for i in range(boxes.shape[0]):
  163. class_id = scores[i].argmax()
  164. label = Classes[class_id]
  165. confidence = scores[i].max()
  166. text = f"{label} {confidence:.3f}"
  167. plot_one_box(boxes[i], img, label=text)
  168. # 返回标记后的图像数据
  169. return img
  170. def main():
  171. img = cv2.imread("./123.png")
  172. out = inference(img)
  173. cv2.imwrite("./out.jpg", out)
  174. if __name__ == "__main__":
  175. main()