class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(n_classes, n_classes)
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2))
return layers
self.model = nn.Sequential(
*block((latent_dim + n_classes), 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh())
*是python中对iterable物体的展开,用以传给函数。
举个例子:
def f(a, b, c):
return a+b+c
f(*range(3))
# >>>> = f(0,1,2)=3
vec = [3,4,5]
f(*vec)
# >>>> = f(3,4,5)=12
tup = (6,7,8)
f(*tup)
# >>>> = f(6,7,8)=21
十分感谢,下午时候已弄懂。还以为是jittor特有的格式。谢谢