Failed reason:[f 0908 20:15:00.939677 12 pyjt_jit_op_maker.cc:22828] Not a valid keyword: dtype

  1. 错误信息
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_4502/1038824653.py in <module>
     80     print("output_edge.dtype=========",output_edge.dtype)
     81     print("edge.dtype============",edge.dtype)
---> 82     loss_edge = loss_Focalloss(output_edge, edge) * exp_args.edgeRatio
     83 
     84 

/tmp/ipykernel_4502/1038824653.py in loss_Focalloss(pred, label, gamma)
     65     one = jt.array([1.], dtype="float32")
     66     fg_label = jt.greater_equal(label, one)
---> 67     fg_num = jt.sum(jt.cast(fg_label, dtype="float32"))
     68     loss_focal = sigmoid_focal_loss(pred, label, weight=fg_num)
     69     return loss_focal

RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.ops.unary)).

Types of your inputs are:
 self	= module,
 args	= (Var, ),
 kwargs	= {dtype=str, },

The function declarations are:
 VarHolder* unary(VarHolder* x,  NanoString op)

Failed reason:[f 0908 20:15:00.939677 12 pyjt_jit_op_maker.cc:22828] Not a valid keyword: dtype
  1. 错误代码
def loss_Focalloss(pred, label, gamma=2.0):
    N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
    pred = pred.reshape([N, C, -1]) # N,C,H,W -> N,C, H*W
    pred = pred.transpose((0, 2, 1)) # N,C,H*W -> N, H*W, C
    pred = pred.reshape([-1, C]) # N,H*W,C -> N*H*W, C
    label = label.reshape([-1, 1]) # N,H,W -> N*H*W, 1
    
    label = jt.squeeze(label, dim=1) # 去除多余维度
    label = jt.array(label, dtype="int64")
    print("label.dtype=************",label.dtype)
    label = jt.nn.one_hot(label, num_classes=2)
    label = jt.array(label, dtype="float32")
    
    one = jt.array([1.], dtype="float32")
    fg_label = jt.greater_equal(label, one)
    fg_num = jt.sum(jt.cast(fg_label, dtype="float32"))
    loss_focal = sigmoid_focal_loss(pred, label, weight=fg_num)
    return loss_focal
  1. 错误原因
    是不是因为cast这个函数还没有写好,我看到文档有,但是不能用。
    image

jittor.cast 的函数声明里没有关键词参数,直接使用 jt.cast(var, jt.float32) 即可

1 个赞