使用jdet里面的orient-rcnn时,sampler里面的采样出现异常

        num_expected_pos = int(self.num * self.pos_fraction) 
        pos_inds = self._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
        # We found that sampled indices have duplicated items occasionally.
        # (may be a bug of PyTorch)
        pos_inds = pos_inds.unique()  # 这里会报错 
        num_sampled_pos = pos_inds.numel()
        num_expected_neg = self.num - num_sampled_pos

观察发现由于pos_inds调用的 _sample_pos() 函数有问题

def _sample_pos(self, assign_result, num_expected, **kwargs):
        """Randomly sample some positive samples."""
        pos_inds = jt.nonzero(assign_result.gt_inds > 0)  # 此函数偶尔返回错误inds
        print(pos_inds.size())
        if pos_inds.numel() != 0:
            pos_inds = pos_inds.squeeze(1)
        if pos_inds.numel() <= num_expected:
            return pos_inds
        else:
            return self.random_choice(pos_inds, num_expected)

下面是返回的pos_inds的size().

···
[16,1,]
[20,1,]
[45,1,]
[6,1,]
[11,1,]
[15,1,]
[15,1,]
[4,1,]
[11,1,]
[27,1,]
[290,1,]
[8,1,]
[65,1,]
[18,1,]
[5,1,]
[23,1,]
[2,1,]
[5,1,]
[50,1,]
[2,1,]
[31,1,]
[1,1,]
[0,1,]   # 由于这个返回的结果有问题
Traceback (most recent call last):
  File "tools/run_net.py", line 56, in <module>
    main()
  File "tools/run_net.py", line 47, in main
    runner.run()
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/runner/runner.py", line 84, in run
    self.train()
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/runner/runner.py", line 126, in train
    losses = self.model(images,targets)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jittor/__init__.py", line 951, in __call__
    return self.execute(*args, **kw)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/models/networks/rcnn.py", line 39, in execute
    proposals_list, rpn_losses = self.rpn(features,targets)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jittor/__init__.py", line 951, in __call__
    return self.execute(*args, **kw)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/models/roi_heads/oriented_rpn_head.py", line 487, in execute
    losses = self.loss(*outs,targets)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/models/roi_heads/oriented_rpn_head.py", line 457, in loss
    labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,num_total_pos, num_total_neg = self.get_targets(anchor_list, valid_flag_list, targets)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/models/roi_heads/oriented_rpn_head.py", line 376, in get_targets
    all_bbox_weights, pos_inds_list, neg_inds_list, sampling_results_list) = multi_apply(self._get_targets_single, anchor_list, valid_flag_list, targets)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/utils/general.py", line 53, in multi_apply
    return tuple(map(list, zip(*map_results)))
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/models/roi_heads/oriented_rpn_head.py", line 308, in _get_targets_single
    sampling_result = self.sampler.sample(assign_result, anchors, target_bboxes)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jdet-0.2.0.0-py3.8.egg/jdet/models/boxes/sampler.py", line 96, in sample
    pos_inds = pos_inds.unique()
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jittor/misc.py", line 573, in unique
    input_flatten = input_flatten.view(orig_shape[0], -1)
  File "/home/yuxk/.conda/envs/pytorch/lib/python3.8/site-packages/jittor/__init__.py", line 553, in reshape
    return origin_reshape(x, shape)
RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.ops.reshape)).

Types of your inputs are:
 self   = module,
 args   = (Var, tuple, ),

The function declarations are:
 VarHolder* reshape(VarHolder* x,  NanoVector shape)

Failed reason:[f 0919 21:23:05.270914 56 reshape_op.cc:47] Check failed: y_items != 0 && x_items % y_items == 0  reshape shape is invalid for input of size  0

可以看出由于 pos_inds 的返回出问题,导致后续的问题,至于为啥会出来一个size是 (0,1) 的,我搞不懂。为啥第一维会是0.

非常感谢您的反馈。
之前版本的jittor中unique算子在处理大小为0的张量时会出错,现在我们在pip上更新了jittor,
执行python -m pip install jittor -U 即可更新。

感谢回复,采样到空数据bug已解决,是阈值设置问题。