- 错误信息
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_4502/1038824653.py in <module>
80 print("output_edge.dtype=========",output_edge.dtype)
81 print("edge.dtype============",edge.dtype)
---> 82 loss_edge = loss_Focalloss(output_edge, edge) * exp_args.edgeRatio
83
84
/tmp/ipykernel_4502/1038824653.py in loss_Focalloss(pred, label, gamma)
65 one = jt.array([1.], dtype="float32")
66 fg_label = jt.greater_equal(label, one)
---> 67 fg_num = jt.sum(jt.cast(fg_label, dtype="float32"))
68 loss_focal = sigmoid_focal_loss(pred, label, weight=fg_num)
69 return loss_focal
RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.ops.unary)).
Types of your inputs are:
self = module,
args = (Var, ),
kwargs = {dtype=str, },
The function declarations are:
VarHolder* unary(VarHolder* x, NanoString op)
Failed reason:[f 0908 20:15:00.939677 12 pyjt_jit_op_maker.cc:22828] Not a valid keyword: dtype
- 错误代码
def loss_Focalloss(pred, label, gamma=2.0):
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
pred = pred.reshape([N, C, -1]) # N,C,H,W -> N,C, H*W
pred = pred.transpose((0, 2, 1)) # N,C,H*W -> N, H*W, C
pred = pred.reshape([-1, C]) # N,H*W,C -> N*H*W, C
label = label.reshape([-1, 1]) # N,H,W -> N*H*W, 1
label = jt.squeeze(label, dim=1) # 去除多余维度
label = jt.array(label, dtype="int64")
print("label.dtype=************",label.dtype)
label = jt.nn.one_hot(label, num_classes=2)
label = jt.array(label, dtype="float32")
one = jt.array([1.], dtype="float32")
fg_label = jt.greater_equal(label, one)
fg_num = jt.sum(jt.cast(fg_label, dtype="float32"))
loss_focal = sigmoid_focal_loss(pred, label, weight=fg_num)
return loss_focal
- 错误原因
是不是因为cast这个函数还没有写好,我看到文档有,但是不能用。