- 教程流程为:自定义数据集-》加载预训练模型-》修改模型模块-》训练模型-》测试模型
- 自定义数据集,继承类包括如下三部分,其中私有函数
__getitem__的返回结果中包含后续训练中需要的字段信息,可通过roi_heads.py文件查看check_targets函数。
- 数据集应继承自标准
torch.utils.data.Dataset
类, 数据集基类; - 实现
__len__, 表示图片的数量;
- 实现
__getitem__,表示读取的原对象和返回的处理结果(数据转换或扩充的逻辑包含在此函数中);
- 加载预训练模型
- 识别预训练模型
>>> torchvision.__version__
'0.4.1'
>>> dir(torchvision.models)
['AlexNet', 'DenseNet', 'GoogLeNet', 'Inception3', 'MNASNet', 'MobileNetV2', 'ResNet', 'ShuffleNetV2', 'SqueezeNet', 'VGG', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_utils', 'alexnet', 'densenet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'detection', 'googlenet', 'inception', 'inception_v3', 'mnasnet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet', 'mobilenet_v2', 'resnet', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext50_32x4d', 'segmentation', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'shufflenetv2', 'squeezenet', 'squeezenet1_0', 'squeezenet1_1', 'utils', 'vgg', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'video', 'wide_resnet101_2', 'wide_resnet50_2']
- 检测预训练模型
>>> dir(torchvision.models.detection)
['FasterRCNN', 'KeypointRCNN', 'MaskRCNN', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_utils', 'backbone_utils', 'faster_rcnn', 'fasterrcnn_resnet50_fpn', 'generalized_rcnn', 'image_list', 'keypoint_rcnn', 'keypointrcnn_resnet50_fpn', 'mask_rcnn', 'maskrcnn_resnet50_fpn', 'roi_heads', 'rpn', 'transform']
- 检测算法框架及模块
>>> dir(torchvision.models.detection.FasterRCNN)
['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_get_name', '_load_from_state_dict', '_named_members', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_save_to_state_dict', '_slow_forward', '_tracing_name', '_version', 'add_module', 'apply', 'buffers', 'children', 'cpu', 'cuda', 'double', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'half', 'load_state_dict', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_parameter', 'requires_grad_', 'share_memory', 'state_dict', 'to', 'train', 'type', 'zero_grad']
>>> dir(torchvision.models.detection.faster_rcnn)
['AnchorGenerator', 'F', 'FastRCNNPredictor', 'FasterRCNN', 'GeneralizedRCNN', 'GeneralizedRCNNTransform', 'MultiScaleRoIAlign', 'OrderedDict', 'RPNHead', 'RegionProposalNetwork', 'RoIHeads', 'TwoMLPHead', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'fasterrcnn_resnet50_fpn', 'load_state_dict_from_url', 'misc_nn_ops', 'model_urls', 'nn', 'resnet_fpn_backbone', 'torch']
>>> dir(torchvision.models.detection.faster_rcnn.FasterRCNN)
['__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__gt__', '__hash__', '__init__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_get_name', '_load_from_state_dict', '_named_members', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_save_to_state_dict', '_slow_forward', '_tracing_name', '_version', 'add_module', 'apply', 'buffers', 'children', 'cpu', 'cuda', 'double', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'half', 'load_state_dict', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_parameter', 'requires_grad_', 'share_memory', 'state_dict', 'to', 'train', 'type', 'zero_grad']
>>> dir(torchvision.models.detection.faster_rcnn.FasterRCNN)==dir(torchvision.models.detection.FasterRCNN)
True
- 修改模型依赖模块。faster rcnn模型依赖的模块可查看该文件。torchvision提供的模块如下。
>>> dir(torchvision.ops)
['FeaturePyramidNetwork', 'MultiScaleRoIAlign', 'RoIAlign', 'RoIPool', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_utils', 'box_iou', 'boxes', 'feature_pyramid_network', 'misc', 'nms', 'poolers', 'roi_align', 'roi_pool']
- 训练模型和测试模型
ps:代码依赖文件地址详见这里。
参考文献:
TorchVision 对象检测微调教程
https://github.com/pytorch/vision/blob/master/torchvision/models/detection/faster_rcnn.py
https://github.com/pytorch/vision/tree/master/references/detection