diff --git a/projects/LSKNet/README.md b/projects/LSKNet/README.md new file mode 100644 index 000000000..ada0a67e0 --- /dev/null +++ b/projects/LSKNet/README.md @@ -0,0 +1,124 @@ +# LSKNet + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/large-selective-kernel-network-for-remote/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=large-selective-kernel-network-for-remote) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/large-selective-kernel-network-for-remote/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=large-selective-kernel-network-for-remote) + +## Abstract + +Recent research on remote sensing object detection has largely focused on improving the representation of oriented bounding boxes but has overlooked the unique prior knowledge presented in remote sensing scenarios. Such prior knowledge can be useful because tiny remote sensing objects may be mistakenly detected without referencing a sufficiently long-range context, and the long-range context required by different types of objects can vary. In this paper, we take these priors into account and propose the Large Selective Kernel Network (LSKNet). LSKNet can dynamically adjust its large spatial receptive field to better model the ranging context of various objects in remote sensing scenarios. To the best of our knowledge, this is the first time that large and selective kernel mechanisms have been explored in the field of remote sensing object detection. Without bells and whistles, LSKNet sets new state-of-the-art scores on standard benchmarks, i.e., HRSC2016 (98.46% mAP), DOTA-v1.0 (81.85% mAP) and FAIR1M-v1.0 (47.87% mAP). Based on a similar technique, we rank 2nd place in 2022 the Greater Bay Area International Algorithm Competition + +## Description + +Author: @Yuxuan Li. +This project is an implementation of "Large Selective Kernel Network for Remote Sensing Object Detection" at: [https://arxiv.org/pdf/2303.09030.pdf](https://arxiv.org/pdf/2303.09030.pdf) + +## Usage + +### Training commands + +In MMRotate's root directory, run the following command to train the model: + +```bash +python tools/train.py projects/LSKNet/configs/lsk_t_fpn_1x_dota_le90.py +``` + +### Testing commands + +In MMRotate's root directory, run the following command to test the model: + +```bash +python tools/test.py projects/LSKNet/configs/lsk_t_fpn_1x_dota_le90.py ${CHECKPOINT_PATH} +``` + +## Results + +Imagenet 300-epoch pre-trained LSKNet-T backbone: [Download](https://download.openmmlab.com/mmrotate/v1.0/lsknet/backbones/lsk_t_backbone-2ef8a593.pth) + +Imagenet 300-epoch pre-trained LSKNet-S backbone: [Download](https://download.openmmlab.com/mmrotate/v1.0/lsknet/backbones/lsk_s_backbone-e9d2e551.pth) + +DOTA1.0 + +| Model | mAP | Angle | lr schd | Batch Size | Configs | Download | note | +| :--------------------------------------------------------: | :---: | :---: | :-----: | :--------: | :--------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------: | +| [RTMDet-l](https://arxiv.org/abs/2212.07784) (1024,1024,-) | 81.33 | - | 3x-ema | 8 | - | - | Prev. Best | +| LSKNet_T (1024,1024,200) | 81.37 | le90 | 1x | 2\*8 | [lsk_t_fpn_1x_dota_le90](./configs/lsknet/lsk_t_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v1.0/lsknet/lsk_t_fpn_1x_dota_le90/lsk_t_fpn_1x_dota_le90_20230206-3ccee254.pth) \| [log](https://download.openmmlab.com/mmrotate/v1.0/lsknet/lsk_t_fpn_1x_dota_le90/lsk_t_fpn_1x_dota_le90_20230206.log) | | +| LSKNet_S (1024,1024,200) | 81.64 | le90 | 1x | 1\*8 | [lsk_s_fpn_1x_dota_le90](./configs/lsknet/lsk_s_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v1.0/lsknet/lsk_s_fpn_1x_dota_le90/lsk_s_fpn_1x_dota_le90_20230116-99749191.pth) \| [log](https://download.openmmlab.com/mmrotate/v1.0/lsknet/lsk_s_fpn_1x_dota_le90/lsk_s_fpn_1x_dota_le90_20230116.log) | | +| LSKNet_S\* (1024,1024,200) | 81.85 | le90 | 1x | 1\*8 | [lsk_s_ema_fpn_1x_dota_le90](./configs/lsknet/lsk_s_ema_fpn_1x_dota_le90.py) | [model](https://download.openmmlab.com/mmrotate/v1.0/lsknet/lsk_s_ema_fpn_1x_dota_le90/lsk_s_ema_fpn_1x_dota_le90_20230212-30ed4041.pth) \| [log](https://download.openmmlab.com/mmrotate/v1.0/lsknet/lsk_s_ema_fpn_1x_dota_le90/lsk_s_ema_fpn_1x_dota_le90_20230212.log) | EMA Finetune | + + + +HRSC2016 + +| Model | mAP(07) | mAP(12) | Angle | lr schd | Batch Size | Configs | Download | note | +| :------------------------------------------: | :-----: | :-----: | :---: | :-----: | :--------: | :-------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------: | +| [RTMDet-l](https://arxiv.org/abs/2212.07784) | 90.60 | 97.10 | le90 | 3x | - | - | - | Prev. Best | +| [ReDet](https://arxiv.org/abs/2103.07733) | 90.46 | 97.63 | le90 | 3x | 2\*4 | [redet_re50_refpn_3x_hrsc_le90](./configs/redet/redet_re50_refpn_3x_hrsc_le90.py) | - | Prev. Best | +| LSKNet_S | 90.65 | 98.46 | le90 | 3x | 1\*8 | [lsk_s_fpn_3x_hrsc_le90](./configs/lsknet/lsk_s_fpn_3x_hrsc_le90.py) | [model](https://download.openmmlab.com/mmrotate/v1.0/lsknet/lsk_s_fpn_3x_hrsc_le90/lsk_s_fpn_3x_hrsc_le90_20230205-4a4a39ce.pth) \| [log](https://download.openmmlab.com/mmrotate/v1.0/lsknet/lsk_s_fpn_3x_hrsc_le90/lsk_s_fpn_3x_hrsc_le90_20230205-4a4a39ce.pth) | | + +## Citation + +If you use this toolbox or benchmark in your research, please cite this project. + +```bibtex +@article{li2023large, + title = {Large Selective Kernel Network for Remote Sensing Object Detection}, + author = {Li, Yuxuan and Hou, Qibin and Zheng, Zhaohui and Cheng, Mingming and Yang, Jian and Li, Xiang}, + journal={ArXiv}, + year={2023} +} +``` + +## Checklist + + + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + + + - [x] Basic docstrings & proper citation + + + + - [x] Test-time correctness + + + + - [x] A full README + + + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training-time correctness + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + + + - [ ] Unit tests + + + + - [ ] Code polishing + + + + - [ ] Metafile.yml + + + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + + + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/LSKNet/configs/lsk_s_ema_fpn_1x_dota_le90.py b/projects/LSKNet/configs/lsk_s_ema_fpn_1x_dota_le90.py new file mode 100644 index 000000000..bb24e8ee9 --- /dev/null +++ b/projects/LSKNet/configs/lsk_s_ema_fpn_1x_dota_le90.py @@ -0,0 +1,153 @@ +_base_ = [ + 'mmrotate::_base_/datasets/dota_ms.py', + 'mmrotate::_base_/schedules/schedule_1x.py', + 'mmrotate::_base_/default_runtime.py' +] + +custom_imports = dict(imports=['projects.LSKNet.lsknet']) + +angle_version = 'le90' +model = dict( + type='mmdet.FasterRCNN', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + boxtype2tensor=False), + backbone=dict( + type='LSKNet', + embed_dims=[64, 128, 320, 512], + drop_rate=0.1, + drop_path_rate=0.1, + depths=[2, 2, 4, 2], + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmrotate/v1.0/lsknet/\ +backbones/lsk_s_backbone-e9d2e551.pth'), + norm_cfg=dict(type='SyncBN', requires_grad=True)), + neck=dict( + type='FPN', + in_channels=[64, 128, 320, 512], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='OrientedRPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='mmdet.AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + use_box_type=True), + bbox_coder=dict( + type='MidpointOffsetCoder', + angle_version=angle_version, + target_means=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0, 0.5, 0.5]), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', + beta=0.1111111111111111, + loss_weight=1.0)), + roi_head=dict( + type='mmdet.StandardRoIHead', + bbox_roi_extractor=dict( + type='RotatedSingleRoIExtractor', + roi_layer=dict( + type='RoIAlignRotated', + out_size=7, + sample_num=2, + clockwise=True), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='mmdet.Shared2FCBBoxHead', + predict_box_type='rbox', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=15, + reg_predictor_cfg=dict(type='mmdet.Linear'), + cls_predictor_cfg=dict(type='mmdet.Linear'), + bbox_coder=dict( + type='DeltaXYWHTRBBoxCoder', + angle_version=angle_version, + norm_factor=None, + edge_swap=True, + proj_xy=True, + target_means=(.0, .0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)), + reg_class_agnostic=True, + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', beta=1.0, loss_weight=1.0))), + train_cfg=dict( + rpn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBbox2HBboxOverlaps2D')), + sampler=dict( + type='mmdet.RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + iou_calculator=dict(type='RBboxOverlaps2D'), + ignore_iof_thr=-1), + sampler=dict( + type='mmdet.RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000))) + +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type='AdamW', + lr=0.0002, + betas=(0.9, 0.999), + weight_decay=0.05)) + +custom_hooks = [dict(type='EMAHook')] diff --git a/projects/LSKNet/configs/lsk_s_fpn_1x_dota_le90.py b/projects/LSKNet/configs/lsk_s_fpn_1x_dota_le90.py new file mode 100644 index 000000000..2330c096d --- /dev/null +++ b/projects/LSKNet/configs/lsk_s_fpn_1x_dota_le90.py @@ -0,0 +1,151 @@ +_base_ = [ + 'mmrotate::_base_/datasets/dota_ms.py', + 'mmrotate::_base_/schedules/schedule_1x.py', + 'mmrotate::_base_/default_runtime.py' +] + +custom_imports = dict(imports=['projects.LSKNet.lsknet']) + +angle_version = 'le90' +model = dict( + type='mmdet.FasterRCNN', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + boxtype2tensor=False), + backbone=dict( + type='LSKNet', + embed_dims=[64, 128, 320, 512], + drop_rate=0.1, + drop_path_rate=0.1, + depths=[2, 2, 4, 2], + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmrotate/v1.0/lsknet/\ +backbones/lsk_s_backbone-e9d2e551.pth'), + norm_cfg=dict(type='SyncBN', requires_grad=True)), + neck=dict( + type='FPN', + in_channels=[64, 128, 320, 512], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='OrientedRPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='mmdet.AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + use_box_type=True), + bbox_coder=dict( + type='MidpointOffsetCoder', + angle_version=angle_version, + target_means=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0, 0.5, 0.5]), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', + beta=0.1111111111111111, + loss_weight=1.0)), + roi_head=dict( + type='mmdet.StandardRoIHead', + bbox_roi_extractor=dict( + type='RotatedSingleRoIExtractor', + roi_layer=dict( + type='RoIAlignRotated', + out_size=7, + sample_num=2, + clockwise=True), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='mmdet.Shared2FCBBoxHead', + predict_box_type='rbox', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=15, + reg_predictor_cfg=dict(type='mmdet.Linear'), + cls_predictor_cfg=dict(type='mmdet.Linear'), + bbox_coder=dict( + type='DeltaXYWHTRBBoxCoder', + angle_version=angle_version, + norm_factor=None, + edge_swap=True, + proj_xy=True, + target_means=(.0, .0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)), + reg_class_agnostic=True, + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', beta=1.0, loss_weight=1.0))), + train_cfg=dict( + rpn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBbox2HBboxOverlaps2D')), + sampler=dict( + type='mmdet.RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + iou_calculator=dict(type='RBboxOverlaps2D'), + ignore_iof_thr=-1), + sampler=dict( + type='mmdet.RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000))) + +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type='AdamW', + lr=0.0002, + betas=(0.9, 0.999), + weight_decay=0.05)) diff --git a/projects/LSKNet/configs/lsk_s_fpn_3x_hrsc_le90.py b/projects/LSKNet/configs/lsk_s_fpn_3x_hrsc_le90.py new file mode 100644 index 000000000..e818d4aa9 --- /dev/null +++ b/projects/LSKNet/configs/lsk_s_fpn_3x_hrsc_le90.py @@ -0,0 +1,151 @@ +_base_ = [ + 'mmrotate::_base_/datasets/hrsc.py', + 'mmrotate::_base_/schedules/schedule_3x.py', + 'mmrotate::_base_/default_runtime.py' +] + +custom_imports = dict(imports=['projects.LSKNet.lsknet']) + +angle_version = 'le90' +model = dict( + type='mmdet.FasterRCNN', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + boxtype2tensor=False), + backbone=dict( + type='LSKNet', + embed_dims=[64, 128, 320, 512], + drop_rate=0.1, + drop_path_rate=0.1, + depths=[2, 2, 4, 2], + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmrotate/v1.0/lsknet/\ +backbones/lsk_s_backbone-e9d2e551.pth'), + norm_cfg=dict(type='SyncBN', requires_grad=True)), + neck=dict( + type='FPN', + in_channels=[64, 128, 320, 512], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='OrientedRPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='mmdet.AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + use_box_type=True), + bbox_coder=dict( + type='MidpointOffsetCoder', + angle_version=angle_version, + target_means=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0, 0.5, 0.5]), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', + beta=0.1111111111111111, + loss_weight=1.0)), + roi_head=dict( + type='mmdet.StandardRoIHead', + bbox_roi_extractor=dict( + type='RotatedSingleRoIExtractor', + roi_layer=dict( + type='RoIAlignRotated', + out_size=7, + sample_num=2, + clockwise=True), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='mmdet.Shared2FCBBoxHead', + predict_box_type='rbox', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + reg_predictor_cfg=dict(type='mmdet.Linear'), + cls_predictor_cfg=dict(type='mmdet.Linear'), + bbox_coder=dict( + type='DeltaXYWHTRBBoxCoder', + angle_version=angle_version, + norm_factor=None, + edge_swap=True, + proj_xy=True, + target_means=(.0, .0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)), + reg_class_agnostic=True, + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', beta=1.0, loss_weight=1.0))), + train_cfg=dict( + rpn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBbox2HBboxOverlaps2D')), + sampler=dict( + type='mmdet.RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + iou_calculator=dict(type='RBboxOverlaps2D'), + ignore_iof_thr=-1), + sampler=dict( + type='mmdet.RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000))) + +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type='AdamW', + lr=0.0002, + betas=(0.9, 0.999), + weight_decay=0.05)) diff --git a/projects/LSKNet/configs/lsk_t_fpn_1x_dota_le90.py b/projects/LSKNet/configs/lsk_t_fpn_1x_dota_le90.py new file mode 100644 index 000000000..5291971aa --- /dev/null +++ b/projects/LSKNet/configs/lsk_t_fpn_1x_dota_le90.py @@ -0,0 +1,151 @@ +_base_ = [ + 'mmrotate::_base_/datasets/dota_ms.py', + 'mmrotate::_base_/schedules/schedule_1x.py', + 'mmrotate::_base_/default_runtime.py' +] + +custom_imports = dict(imports=['projects.LSKNet.lsknet']) + +angle_version = 'le90' +model = dict( + type='mmdet.FasterRCNN', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + boxtype2tensor=False), + backbone=dict( + type='LSKNet', + embed_dims=[32, 64, 160, 256], + drop_rate=0.1, + drop_path_rate=0.1, + depths=[3, 3, 5, 2], + init_cfg=dict( + type='Pretrained', + checkpoint='https://download.openmmlab.com/mmrotate/v1.0/lsknet/\ +backbones/lsk_t_backbone-2ef8a593.pth'), + norm_cfg=dict(type='SyncBN', requires_grad=True)), + neck=dict( + type='FPN', + in_channels=[64, 128, 320, 512], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='OrientedRPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='mmdet.AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + use_box_type=True), + bbox_coder=dict( + type='MidpointOffsetCoder', + angle_version=angle_version, + target_means=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0, 0.5, 0.5]), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', + beta=0.1111111111111111, + loss_weight=1.0)), + roi_head=dict( + type='mmdet.StandardRoIHead', + bbox_roi_extractor=dict( + type='RotatedSingleRoIExtractor', + roi_layer=dict( + type='RoIAlignRotated', + out_size=7, + sample_num=2, + clockwise=True), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='mmdet.Shared2FCBBoxHead', + predict_box_type='rbox', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=15, + reg_predictor_cfg=dict(type='mmdet.Linear'), + cls_predictor_cfg=dict(type='mmdet.Linear'), + bbox_coder=dict( + type='DeltaXYWHTRBBoxCoder', + angle_version=angle_version, + norm_factor=None, + edge_swap=True, + proj_xy=True, + target_means=(.0, .0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)), + reg_class_agnostic=True, + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', beta=1.0, loss_weight=1.0))), + train_cfg=dict( + rpn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBbox2HBboxOverlaps2D')), + sampler=dict( + type='mmdet.RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + iou_calculator=dict(type='RBboxOverlaps2D'), + ignore_iof_thr=-1), + sampler=dict( + type='mmdet.RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000))) + +optim_wrapper = dict( + optimizer=dict( + _delete_=True, + type='AdamW', + lr=0.0002, + betas=(0.9, 0.999), + weight_decay=0.05)) diff --git a/projects/LSKNet/lsknet/__init__.py b/projects/LSKNet/lsknet/__init__.py new file mode 100644 index 000000000..f69267803 --- /dev/null +++ b/projects/LSKNet/lsknet/__init__.py @@ -0,0 +1,3 @@ +from .lsknet import LSKNet + +__all__ = ['LSKNet'] diff --git a/projects/LSKNet/lsknet/lsknet.py b/projects/LSKNet/lsknet/lsknet.py new file mode 100644 index 000000000..655b98b28 --- /dev/null +++ b/projects/LSKNet/lsknet/lsknet.py @@ -0,0 +1,377 @@ +import math +import warnings +from functools import partial + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.utils.weight_init import (constant_init, normal_init, + trunc_normal_init) +from mmcv.runner import BaseModule +from torch.nn.modules.utils import _pair as to_2tuple + +from mmrotate.registry import MODELS + + +class Mlp(BaseModule): + """An implementation of Mlp of LSKNet. + + Refer to + mmclassification/mmcls/models/backbones/van.py. + Args: + in_features (int): The feature dimension. Same as + `MultiheadAttention`. + hidden_features (int): The hidden dimension of Mlps. + act_cfg (dict, optional): The activation config for Mlps. + Default: dict(type='GELU'). + drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0., + init_cfg=None): + super(Mlp, self).__init__(init_cfg=init_cfg) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LSKmodule(BaseModule): + """LSK module(LSK) of LSKNet. + + Args: + dim (int): Number of input channels. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, dim, init_cfg=None): + super(LSKmodule, self).__init__(init_cfg=init_cfg) + self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + self.conv_spatial = nn.Conv2d( + dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) + self.conv1 = nn.Conv2d(dim, dim // 2, 1) + self.conv2 = nn.Conv2d(dim, dim // 2, 1) + self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3) + self.conv = nn.Conv2d(dim // 2, dim, 1) + + def forward(self, x): + attn1 = self.conv0(x) + attn2 = self.conv_spatial(attn1) + + attn1 = self.conv1(attn1) + attn2 = self.conv2(attn2) + + attn = torch.cat([attn1, attn2], dim=1) + avg_attn = torch.mean(attn, dim=1, keepdim=True) + max_attn, _ = torch.max(attn, dim=1, keepdim=True) + agg = torch.cat([avg_attn, max_attn], dim=1) + sig = self.conv_squeeze(agg).sigmoid() + attn = attn1 * sig[:, 0, :, :].unsqueeze( + 1) + attn2 * sig[:, 1, :, :].unsqueeze(1) + attn = self.conv(attn) + return x * attn + + +class Attention(BaseModule): + """Basic attention module in LSKblock. + + Args: + d_model (int): Number of input channels. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, d_model, init_cfg=None): + super(Attention, self).__init__(init_cfg=init_cfg) + + self.proj_1 = nn.Conv2d(d_model, d_model, 1) + self.activation = nn.GELU() + self.spatial_gating_unit = LSKmodule(d_model) + self.proj_2 = nn.Conv2d(d_model, d_model, 1) + + def forward(self, x): + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class Block(BaseModule): + """A block of LSK. + + Args: + dim (int): Number of input channels. + mlp_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop (float): Dropout rate after embedding. Defaults to 0. + drop_path (float): Stochastic depth rate. Defaults to 0.1. + act_layer (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (obj:`mmcv.ConfigDict`): The Config for normalization. + Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + dim, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_cfg=None, + init_cfg=None): + super(Block, self).__init__(init_cfg=init_cfg) + if norm_cfg: + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + else: + self.norm1 = nn.BatchNorm2d(dim) + self.norm2 = nn.BatchNorm2d(dim) + self.attn = Attention(dim) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.attn(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + return x + + +class OverlapPatchEmbed(BaseModule): + """Image to Patch Embedding of LSK. + + Args: + patch_size (int): OverlapPatchEmbed patch size. Defaults to 7 + stride (int): OverlapPatchEmbed stride. Defaults to 4 + in_chans (int): Number of input channels. Defaults to 3. + embed_dim (int): The hidden dimension of OverlapPatchEmbed. + norm_cfg (obj:`mmcv.ConfigDict`): The Config for normalization. + Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + patch_size=7, + stride=4, + in_chans=3, + embed_dim=768, + norm_cfg=None, + init_cfg=None): + super(OverlapPatchEmbed, self).__init__(init_cfg=init_cfg) + patch_size = to_2tuple(patch_size) + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2)) + if norm_cfg: + self.norm = build_norm_layer(norm_cfg, embed_dim)[1] + else: + self.norm = nn.BatchNorm2d(embed_dim) + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - H (list[int]): Height shape of x + - W (list[int]): Weight shape of x + """ + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + return x, H, W + + +@MODELS.register_module() +class LSKNet(BaseModule): + """Large Selective Kernel Network. + + A PyTorch implement of : `Large Selective Kernel Network for + Remote Sensing Object Detection.` + PDF: https://arxiv.org/pdf/2303.09030.pdf + Inspiration from + https://github.com/zcablii/Large-Selective-Kernel-Network + Args: + in_chans (int): The num of input channels. Defaults to 3. + embed_dims (List[int]): Embedding channels of each LSK block. + Defaults to [64, 128, 256, 512] + mlp_ratios (List[int]): Mlp ratios. Defaults to [8, 8, 4, 4] + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + depths (List[int]): Number of LSK block in each stage. + Defaults to [3, 4, 6, 3] + num_stages (int): Number of stages. Defaults to 4 + pretrained (bool): If the model weight is pretrained. Defaults to None, + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to None. + """ + + def __init__(self, + in_chans=3, + embed_dims=[64, 128, 256, 512], + mlp_ratios=[8, 8, 4, 4], + drop_rate=0., + drop_path_rate=0., + depths=[3, 4, 6, 3], + num_stages=4, + pretrained=None, + init_cfg=None, + norm_cfg=None): + super(LSKNet, self).__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + self.depths = depths + self.num_stages = num_stages + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for i in range(num_stages): + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_chans=in_chans if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i], + norm_cfg=norm_cfg) + + block = nn.ModuleList([ + Block( + dim=embed_dims[i], + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[cur + j], + norm_cfg=norm_cfg) for j in range(depths[i]) + ]) + norm_layer = partial(nn.LayerNorm, eps=1e-6) + norm = norm_layer(embed_dims[i]) + cur += depths[i] + + setattr(self, f'patch_embed{i + 1}', patch_embed) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + def init_weights(self): + print('init cfg', self.init_cfg) + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super(LSKNet, self).init_weights() + + def freeze_patch_emb(self): + self.patch_embed1.requires_grad = False + + @torch.jit.ignore + def no_weight_decay(self): + return { + 'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token' + } + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear( + self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + outs = [] + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + block = getattr(self, f'block{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, H, W = patch_embed(x) + for blk in block: + x = blk(x) + x = x.flatten(2).transpose(1, 2) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + return outs + + def forward(self, x): + x = self.forward_features(x) + return x + + +class DWConv(nn.Module): + """Depth-wise convolution + Args: + dim (int): In/out channel of the Depth-wise convolution. + """ + + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x): + x = self.dwconv(x) + return x