關於 Hire-MLP 的 input size

您好,我是在您們的 github 看到 hire-mlp 的 code。
我覺得寫得非常簡潔明瞭,實在是太厲害了!

但因為 Hire-MLP 論文中,圖形大小是設為 224x224x3。
而我的圖像大小是 torch.Size([1, 5000, 12]),想請問這樣的圖像大小該如何使用 Hire-MLP 模型 ?

我實際經過代碼復現:
在 train 階段沒有問題,並將每一個的 epoch 的 weight 存成 pkl 檔。
但跑至 predict 階段,就會出現 bug。

想詢問是因為一開始圖片大小的 size 設置不對,還是在 def train_all 時就發生問題了?

以上是我的問題,非常感謝您!

Hire-MLP 模型架構是完全使用您們 release 出來的 code,以下 train_all 與 pred 的 code,則是我自己撰寫的:

hire_mlp_pytorch.py

import torch
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce
from torch.utils.data import DataLoader
from torch.autograd import Variable
import os


def pair(x): return x if isinstance(x, tuple) else (x, x)


class PreNormResidual(nn.Module):
    def __init__(self, dim, fn, norm=nn.LayerNorm):
        super().__init__()
        self.fn = fn
        self.norm = norm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x


class PatchEmbedding(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, stride, padding, norm_layer=False):
        super().__init__()
        self.reduction = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size,
                      stride=stride, padding=padding),

            nn.Identity() if (not norm_layer) else nn.Sequential(
                Rearrange('b c h w -> b h w c'),
                nn.LayerNorm(dim_out),
                Rearrange('b h w c -> b c h w'),
            )
        )

    def forward(self, x):
        return self.reduction(x)


class FeedForward(nn.Module):
    def __init__(self, dim_in, hidden_dim, dim_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, hidden_dim, kernel_size=1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, dim_out, kernel_size=1),
        )

    def forward(self, x):
        return self.net(x)


class CrossRegion(nn.Module):
    def __init__(self, step=1, dim=1):
        super().__init__()
        self.step = step
        self.dim = dim

    def forward(self, x):
        return torch.roll(x, self.step, self.dim)


class InnerRegionW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b c h (w group) -> b (c w) h group', w=self.w)
        )

    def forward(self, x):
        return self.region(x)


class InnerRegionH(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h
        self.region = nn.Sequential(
            Rearrange('b c (h group) w -> b (c h) group w', h=self.h)
        )

    def forward(self, x):
        return self.region(x)


class InnerRegionRestoreW(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.region = nn.Sequential(
            Rearrange('b (c w) h group -> b c h (w group)', w=self.w)
        )

    def forward(self, x):
        return self.region(x)


class InnerRegionRestoreH(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.h = h
        self.region = nn.Sequential(
            Rearrange('b (c h) group w -> b c (h group) w', h=self.h)
        )

    def forward(self, x):
        return self.region(x)


class HireMLPBlock(nn.Module):
    def __init__(self, h, w, d_model, cross_region_step=1, cross_region_id=0, cross_region_interval=2, padding_type='constant'):
        super().__init__()

        assert (padding_type in ['constant',
                'reflect', 'replicate', 'circular'])
        self.padding_type = padding_type
        self.w = w
        self.h = h

        # cross region every cross_region_interval HireMLPBlock
        self.cross_region = (cross_region_id % cross_region_interval == 0)

        if self.cross_region:
            self.cross_regionW = CrossRegion(step=cross_region_step, dim=3)
            self.cross_regionH = CrossRegion(step=cross_region_step, dim=2)
            self.cross_region_restoreW = CrossRegion(
                step=-cross_region_step, dim=3)
            self.cross_region_restoreH = CrossRegion(
                step=-cross_region_step, dim=2)
        else:
            self.cross_regionW = nn.Identity()
            self.cross_regionH = nn.Identity()
            self.cross_region_restoreW = nn.Identity()
            self.cross_region_restoreH = nn.Identity()

        self.inner_regionW = InnerRegionW(w)
        self.inner_regionH = InnerRegionH(h)
        self.inner_region_restoreW = InnerRegionRestoreW(w)
        self.inner_region_restoreH = InnerRegionRestoreH(h)

        self.proj_h = FeedForward(h * d_model, d_model // 2, h * d_model)
        self.proj_w = FeedForward(w * d_model, d_model // 2, w * d_model)
        self.proj_c = nn.Conv2d(d_model, d_model, kernel_size=1)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)

        B, C, H, W = x.shape
        padding_num_w = W % self.w
        padding_num_h = H % self.h
        x = nn.functional.pad(
            x, (0, self.w - padding_num_w, 0, self.h - padding_num_h), self.padding_type)

        x_h = self.inner_regionH(self.cross_regionH(x))
        x_w = self.inner_regionW(self.cross_regionW(x))

        x_h = self.proj_h(x_h)
        x_w = self.proj_w(x_w)
        x_c = self.proj_c(x)

        x_h = self.cross_region_restoreH(self.inner_region_restoreH(x_h))
        x_w = self.cross_region_restoreW(self.inner_region_restoreW(x_w))

        out = x_c + x_h + x_w

        out = out[:, :, 0:H, 0:W]
        out = out.permute(0, 2, 3, 1)
        return out


class HireMLPStage(nn.Module):
    def __init__(self, h, w, d_model_in, d_model_out, depth, cross_region_step, cross_region_interval, expansion_factor=2, dropout=0., pooling=False, padding_type='constant'):
        super().__init__()

        self.pooling = pooling
        self.patch_merge = nn.Sequential(
            Rearrange('b h w c -> b c h w'),
            PatchEmbedding(d_model_in, d_model_out, kernel_size=3,
                           stride=2, padding=1, norm_layer=False),
            Rearrange('b c h w -> b h w c'),
        )

        self.model = nn.Sequential(
            *[nn.Sequential(
                PreNormResidual(d_model_in, nn.Sequential(
                    HireMLPBlock(
                        h, w, d_model_in, cross_region_step=cross_region_step, cross_region_id=i_depth + 1, cross_region_interval=cross_region_interval, padding_type=padding_type
                    )
                ), norm=nn.LayerNorm),
                PreNormResidual(d_model_in, nn.Sequential(
                    nn.Linear(d_model_in, d_model_in * expansion_factor),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model_in * expansion_factor, d_model_in),
                    nn.Dropout(dropout),
                ), norm=nn.LayerNorm),
            ) for i_depth in range(depth)]
        )

    def forward(self, x):
        x = self.model(x)
        if self.pooling:
            x = self.patch_merge(x)
        return x


class HireMLP(nn.Module):
    def __init__(
        self,
        patch_size=4,
        in_channels=1,
        num_classes=2,
        d_model=[64, 128, 320, 512],
        h=[4, 3, 3, 2],
        w=[4, 3, 3, 2],
        cross_region_step=[2, 2, 1, 1],
        cross_region_interval=2,
        depth=[4, 6, 24, 3],
        expansion_factor=2,
        patcher_norm=False,
        padding_type='constant',
    ):
        patch_size = pair(patch_size)
        super().__init__()
        self.patcher = PatchEmbedding(
            dim_in=in_channels, dim_out=d_model[0], kernel_size=7, stride=patch_size, padding=3, norm_layer=patcher_norm)

        self.layers = nn.ModuleList()
        for i_layer in range(len(depth)):
            i_depth = depth[i_layer]
            i_stage = HireMLPStage(h[i_layer], w[i_layer], d_model[i_layer], d_model_out=d_model[i_layer + 1] if (i_layer + 1 < len(depth)) else d_model[-1],
                                   depth=i_depth, cross_region_step=cross_region_step[
                                       i_layer], cross_region_interval=cross_region_interval,
                                   expansion_factor=expansion_factor, pooling=((i_layer + 1) < len(depth)), padding_type=padding_type)
            self.layers.append(i_stage)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_model[-1]),
            Reduce('b h w c -> b c', 'mean'),
            nn.Linear(d_model[-1], num_classes)
        )

    def forward(self, x):
        embedding = self.patcher(x)
        embedding = embedding.permute(0, 2, 3, 1)
        for layer in self.layers:
            embedding = layer(embedding)
        out = self.mlp_head(embedding)
        return out


def train_all(patch_size,
              in_channels,
              num_classes,
              d_model,
              h,
              w,
              cross_region_step,
              cross_region_interval,
              depth,
              expansion_factor,
              patcher_norm,
              padding_type,
              num_epochs,
              learning_rate,
              batch_size,
              model_path,
              label_leadImg,
              num_early_stop):

    ls_loss_history = []

    print('start training')
    train_dataloader = DataLoader(
        label_leadImg, shuffle=True, batch_size=batch_size)

    loss_history = []
    count = 0
    #early_stopping = earlystop.EarlyStopping(patience=num_early_stop)

    # Build models
    # model = MLPMixer(image_size=image_size, channels=channels, patch_size_h=patch_size_h,
    #                  patch_size_w=patch_size_w, dim=dim, depth=depth, num_classes=num_classes).cuda(index_gpu)

    net = HireMLP(patch_size=patch_size, in_channels=in_channels, num_classes=num_classes, d_model=d_model,
                  h=h, w=w, cross_region_step=cross_region_step, cross_region_interval=cross_region_interval,
                  depth=depth, expansion_factor=expansion_factor, patcher_norm=patcher_norm, padding_type=padding_type).cuda()

    model = torch.nn.DataParallel(net)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()

    params = model.parameters()
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    iHoldCount = 0
    iHoldCorrect = 0

    # train
    for epoch in range(num_epochs):
        for i, (images, label) in enumerate(train_dataloader):
            images = Variable(images.cuda())
            label = Variable(label.cuda())
            label = label.squeeze(-1)

            optimizer.zero_grad()

            output = model(images)
            # print(output.shape)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            prediction = output.data.max(1)[1]

            # print(prediction.cpu().data.numpy())
            # print(label.cpu().data.numpy())

            iHoldCount = iHoldCount + len(label.cpu().data.numpy())
            iHoldCorrect = iHoldCorrect + \
                (prediction.cpu().data.numpy() == label.cpu().data.numpy()).sum()

            # print(iHoldCount)
            # print(iHoldCorrect)
        print("Epoch number {}\n Current loss {}\n".format(epoch+1, loss.item()))

        # save by 10, 30, ...
        if (epoch > 5):
            # save current model parameter
            torch.save(model.state_dict(), model_path +
                       str(epoch+1) + 'hire_mlp.pkl')

        loss_history.append(loss.item())
        # early stopping
        # early_stopping(loss.item())
        # if early_stopping.early_stop:
        #    print("Early Stop!")
        #    break
        if (num_early_stop > 0):
            if (len(loss_history) > 1 and loss_history[-1] > loss_history[-2]):
                count += 1
                if count > num_early_stop:
                    break

    print('The train accuracy is: {:.6f}'.format(iHoldCorrect / iHoldCount))
    ls_loss_history.append(loss_history)


def pred(ds,
         patch_size,
         in_channels,
         num_classes,
         d_model,
         h,
         w,
         cross_region_step,
         cross_region_interval,
         depth,
         expansion_factor,
         patcher_norm,
         padding_type,
         model_path,
         model_index,
         b_label):

    # load well-trained model
    # model = MLPMixer(image_size=image_size, channels=channels, patch_size_h=patch_size_h, patch_size_w=patch_size_w, dim=dim, depth=depth, num_classes=num_classes).cuda(gpu_index)
    net = HireMLP(patch_size=patch_size, in_channels=in_channels, num_classes=num_classes, d_model=d_model,
                  h=h, w=w, cross_region_step=cross_region_step, cross_region_interval=cross_region_interval,
                  depth=depth, expansion_factor=expansion_factor, patcher_norm=patcher_norm, padding_type=padding_type).cuda()

    model = torch.nn.DataParallel(net)

    state_dict_path = torch.load(os.path.join(
        model_path, str(model_index)+'hire_mlp.pkl'))
    # state_dict_path = torch.load(os.path.join(path_resnet_model, 'densenet121.pkl'))
    model.load_state_dict(state_dict_path)

    # evaluation mode
    model.eval()

    # num = 0
    ls_y_pred = []
    ls_y_pred_prob = []

    # Loss
    criterion = nn.CrossEntropyLoss()

    # create dataloader by dataset
    # can I increase batch num? 1 may be the cause for inefficency
    dataloader = DataLoader(ds, shuffle=False, batch_size=128)

    with torch.no_grad():
        LOSS = 0
        ACC = 0
        for i, (images, label) in enumerate(dataloader):
            images = Variable(images.cuda())
            label = Variable(label.cuda())
            label = label.squeeze(-1)

            output = model(images)
            loss = criterion(output, label)
            LOSS = LOSS + loss.item()

            ls_y_pred = ls_y_pred + \
                output.cpu().detach().max(1)[1].numpy().squeeze().tolist()

            ACC = ACC + \
                output.cpu().detach().max(1)[1].eq(
                    label.cpu().detach()).sum().item()

            ls_y_pred_prob = ls_y_pred_prob + \
                output.cpu().detach().max(1)[0].numpy().squeeze().tolist()

    return ls_y_pred, ls_y_pred_prob, (LOSS/len(dataloader.dataset)), (ACC/len(dataloader.dataset))

實際上在執行時,train 階段沒問題,但 predict 階段會跑出以下的 error:

from hire_mlp_pytorch import hire_mlp_pytorch
import torch
import random
from random import shuffle
import sklearn
print(sklearn.__version__)
import ecg_functions_new
import pandas as pd
import numpy as np
import pickle
import sys

# load training data
train_list1 = pickle.load(open(r'/home/train_shuffle_0505_shuffle.pkl', 'rb'))
# load testing data
test_list = pickle.load(open(r'/home/test_shuffle_0505_shuffle.pkl', 'rb'))

# load training data
train_list1 = pickle.load(open(r'/home/u7080189/ECG_NSAID/train_shuffle_0505_shuffle.pkl', 'rb'))
# load testing data
test_list = pickle.load(open(r'/home/u7080189/ECG_NSAID/test_shuffle_0505_shuffle.pkl', 'rb'))

print(len(train_list1))  # 27288
print(len(test_list))  # 13441

# change size to [1,5000,12]


# x = train_list1[0][0]
# y = x[np.newaxis, :, :]
# print(x.shape,y.shape)

train_list2 = []
for i in range(len(train_list1)):
    z = train_list1[i][0][np.newaxis, :, :]
    a = torch.tensor(train_list1[i][1], dtype=torch.int64)
    train_list2.append((z,a))

test_list2 = []
ls_label = []
for i in range(len(test_list)):
    z = test_list[i][0][np.newaxis, :, :]
    a = torch.tensor(test_list[i][1], dtype=torch.int64)
    test_list2.append((z,a))
    ls_label.append(a)   

# training all 

num_epochs = 101
learning_rate = 0.0001
batch_size = 128

num_classes = 2

file_name = "hire_mlp_weight"
model_path = '/home/u7080189/ECG_NSAID/hire_mlp/'

iHold = 5

num_early_stop = 10

# train all
model_path = '/home/u7080189/ECG_NSAID/hire_mlp/'

ls_test_dataloader, ls_ls_y_true, ls_ls_y_pred, ls_loss_history = hire_mlp_pytorch.train_all(patch_size=4,
                                                                                             in_channels=1,
                                                                                             num_classes=num_classes,
                                                                                             d_model=[64, 128, 320, 512],
                                                                                             h=[4, 3, 3, 2],
                                                                                             w=[4, 3, 3, 2],
                                                                                             cross_region_step=[2, 2, 1, 1],
                                                                                             cross_region_interval=2,
                                                                                             depth=[4, 6, 24, 3],
                                                                                             expansion_factor=2,
                                                                                             patcher_norm=False,
                                                                                             padding_type='constant',
                                                                                             num_epochs=num_epochs,
                                                                                             learning_rate=learning_rate,
                                                                                             batch_size=batch_size,
                                                                                             model_path=model_path,
                                                                                             label_leadImg=train_list2,
                                                                                             num_early_stop=num_early_stop)


# testing all
# 為什麼取test_list2[0:13440],測試資料有13441筆,但在跑測試時最後一筆會報錯,可能是存檔的時候有損毀
model_index = 7
for i in range(27):
    model_path = r'/home/u7080189/ECG_NSAID/hire_mlp/'
    b_label = True
    ls_y_pred, ls_y_pred_prob, test_loss, test_acc = hire_mlp_pytorch.pred(ds=test_list2[0:13440],
                                                                           patch_size=4,
                                                                           in_channels=1,
                                                                           num_classes=2,
                                                                           d_model=[64, 128, 320, 512],
                                                                           h=[4, 3, 3, 2],
                                                                           w=[4, 3, 3, 2],
                                                                           cross_region_step=[2, 2, 1, 1],
                                                                           cross_region_interval=2,
                                                                           depth=[4, 6, 24, 3],
                                                                           expansion_factor=2,
                                                                           patcher_norm=True,
                                                                           padding_type='constant',
                                                                           model_path=model_path,
                                                                           model_index=model_index,
                                                                           b_label=b_label)
    print('index:', model_index, test_acc)
    model_index +=1

test 階段會出現以下 error

RuntimeError: Error(s) in loading state_dict for HireMLP:
	Missing key(s) in state_dict: "patcher.reduction.0.weight", "patcher.reduction.0.bias", "patcher.reduction.1.1.weight", "patcher.reduction.1.1.bias", "layers.0.patch_merge.1.reduction.0.weight", "layers.0.patch_merge.1.reduction.0.bias", "layers.0.model.0.0.fn.0.proj_h.net.0.weight", "layers.0.model.0.0.fn.0.proj_h.net.0.bias", "layers.0.model.0.0.fn.0.proj_h.net.2.weight", "layers.0.model.0.0.fn.0.proj_h.net.2.bias", "layers.0.model.0.0.fn.0.proj_w.net.0.weight", "layers.0.model.0.0.fn.0.proj_w.net.0.bias", "layers.0.model.0.0.fn.0.proj_w.net.2.weight", "layers.0.model.0.0.fn.0.proj_w.net.2.bias", "layers.0.model.0.0.fn.0.proj_c.weight", "layers.0.model.0.0.fn.0.proj_c.bias", "layers.0.model.0.0.norm.weight", "layers.0.model.0.0.norm.bias", "layers.0.model.0.1.fn.0.weight", "layers.0.model.0.1.fn.0.bias", "layers.0.model.0.1.fn.3.weight", "layers.0.model.0.1.fn.3.bias", "layers.0.model.0.1.norm.weight", "layers.0.model.0.1.norm.bias", "layers.0.model.1.0.fn.0.proj_h.net.0.weight", "layers.0.model.1.0.fn.0.proj_h.net.0.bias", "layers.0.model.1.0.fn.0.proj_h.net.2.weight", "layers.0.model.1.0.fn.0.proj_h.net.2.bias", "layers.0.model.1.0.fn.0.proj_w.net.0.weight", "layers.0.model.1.0.fn.0.proj_w.net.0.bias", "layers.0.model.1.0.fn.0.proj_w.net.2.weight", "layers.0.model.1.0.fn.0.proj_w.net.2.bias", "layers.0.model.1.0.fn.0.proj_c.weight", "layers.0.model.1.0.fn.0.proj_c.bias", "layers.0.model.1.0.norm.weight", "layers.0.model.1.0.norm.bias", "layers.0.model.1.1.fn.0.weight", "layers.0.model.1.1.fn.0.bias", "layers.0.model.1.1.fn.3.weight", "layers.0.model.1.1.fn.3.bias", "layers.0.model.1.1.norm.weight", "layers.0.model.1.1.norm.bias", "layers.0.model.2.0.fn.0.proj_h.net.0.weight", "layers.0.model.2.0.fn.0.proj_h.net.0.bias", "layers.0.model.2.0.fn.0.proj_h.net.2.weight", "layers.0.model.2.0.fn.0.proj_h.net.2.bias", "layers.0.model.2.0.fn.0.proj_w.net.0.weight", "layers.0.model.2.0.fn.0.proj_w.net.0.bias", "layers.0.model.2.0.fn.0.proj_w.net.2.weight", "layers.0.model.2.0.fn.0.proj_w.net.2.bias", "layers.0.model.2.0.fn.0.proj_c.weight", "layers.0.model.2.0.fn.0.proj_c.bias", "layers.0.model.2.0.norm.weight", "layers.0.model.2.0.norm.bias", "layers.0.model.2.1.fn.0.weight", "layers.0.model.2.1.fn.0.bias", "layers.0.model.2.1.fn.3.weight", "layers.0.model.2.1.fn.3.bias", "layers.0.model.2.1.norm.weight", "layers.0.model.2.1.norm.bias", "layers.0.model.3.0.fn.0.proj_h.net.0.weight", "layers.0.model.3.0.fn.0.proj_h.net.0.bias", "layers.0.model.3.0.fn.0.proj_h.net.2.weight", "layers.0.model.3.0.fn.0.proj_h.net.2.bias", "layers.0.model.3.0.fn.0.proj_w.net.0.weight", "layers.0.model.3.0.fn.0.proj_w.net.0.bias", "layers.0.model.3.0.fn.0.proj_w.net.2.weight", "layers.0.model.3.0.fn.0.proj_w.net.2.bias", "layers.0.model.3.0.fn.0.proj_c.weight", "layers.0.model.3.0.fn.0.proj_c.bias", "layers.0.model.3.0.norm.weight", "layers.0.model.3.0.norm.bias", "layers.0.model.3.1.fn.0.weight", "layers.0.model.3.1.fn.0.bias", "layers.0.model.3.1.fn.3.weight", "layers.0.model.3.1.fn.3.bias", "layers.0.model.3.1.norm.weight", "layers.0.model.3.1.norm.bias", "layers.1.patch_merge.1.reduction.0.weight", "layers.1.patch_merge.1.reduction.0.bias", "layers.1.model.0.0.fn.0.proj_h.net.0.weight", "layers.1.model.0.0.fn.0.proj_h.net.0.bias", "layers.1.model.0.0.fn.0.proj_h.net.2.weight", "layers.1.model.0.0.fn.0.proj_h.net.2.bias", "layers.1.model.0.0.fn.0.proj_w.net.0.weight", "layers.1.model.0.0.fn.0.proj_w.net.0.bias", "layers.1.model.0.0.fn.0.proj_w.net.2.weight", "layers.1.model.0.0.fn.0.proj_w.net.2.bias", "layers.1.model.0.0.fn.0.proj_c.weight", "layers.1.model.0.0.fn.0.proj_c.bias", "layers.1.model.0.0.norm.weight", "layers.1.model.0.0.norm.bias", "layers.1.model.0.1.fn.0.weight", "layers.1.model.0.1.fn.0.bias", "layers.1.model.0.1.fn.3.weight", "layers.1.model.0.1.fn.3.bias", "layers.1.model.0.1.norm.weight", "layers.1.model.0.1.norm.bias", "layers.1.model.1.0.fn.0.proj_h.net.0.weight", "layers.1.model.1.0.fn.0.proj_h.net.0.bias", "layers.1.model.1.0.fn.0.proj_h.net.2.weight", "layers.1.model.1.0.fn.0.proj_h.net.2.bias", "layers.1.model.1.0.fn.0.proj_w.net.0.weight", "layers.1.model.1.0.fn.0.proj_w.net.0.bias", "layers.1.model.1.0.fn.0.proj_w.net.2.weight", "layers.1.model.1.0.fn.0.proj_w.net.2.bias", "layers.1.model.1.0.fn.0.proj_c.weight", "layers.1.model.1.0.fn.0.proj_c.bias", "layers.1.model.1.0.norm.weight", "layers.1.model.1.0.norm.bias", "layers.1.model.1.1.fn.0.weight", "layers.1.model.1.1.fn.0.bias", "layers.1.model.1.1.fn.3.weight", "layers.1.model.1.1.fn.3.bias", "layers.1.model.1.1.norm.weight", "layers.1.model.1.1.norm.bias".
	Unexpected key(s) in state_dict: "module.patcher.reduction.0.weight", "module.patcher.reduction.0.bias", "module.layers.0.patch_merge.1.reduction.0.weight", "module.layers.0.patch_merge.1.reduction.0.bias", "module.layers.0.model.0.0.fn.0.proj_h.net.0.weight", "module.layers.0.model.0.0.fn.0.proj_h.net.0.bias", "module.layers.0.model.0.0.fn.0.proj_h.net.2.weight", "module.layers.0.model.0.0.fn.0.proj_h.net.2.bias", "module.layers.0.model.0.0.fn.0.proj_w.net.0.weight", "module.layers.0.model.0.0.fn.0.proj_w.net.0.bias", "module.layers.0.model.0.0.fn.0.proj_w.net.2.weight", "module.layers.0.model.0.0.fn.0.proj_w.net.2.bias", "module.layers.0.model.0.0.fn.0.proj_c.weight", "module.layers.0.model.0.0.fn.0.proj_c.bias", "module.layers.0.model.0.0.norm.weight", "module.layers.0.model.0.0.norm.bias", "module.layers.0.model.0.1.fn.0.weight", "module.layers.0.model.0.1.fn.0.bias", "module.layers.0.model.0.1.fn.3.weight", "module.layers.0.model.0.1.fn.3.bias", "module.layers.0.model.0.1.norm.weight", "module.layers.0.model.0.1.norm.bias", "module.layers.0.model.1.0.fn.0.proj_h.net.0.weight", "module.layers.0.model.1.0.fn.0.proj_h.net.0.bias", "module.layers.0.model.1.0.fn.0.proj_h.net.2.weight", "module.layers.0.model.1.0.fn.0.proj_h.net.2.bias", "module.layers.0.model.1.0.fn.0.proj_w.net.0.weight", "module.layers.0.model.1.0.fn.0.proj_w.net.0.bias", "module.layers.0.model.1.0.fn.0.proj_w.net.2.weight", "module.layers.0.model.1.0.fn.0.proj_w.net.2.bias", "module.layers.0.model.1.0.fn.0.proj_c.weight", "module.layers.0.model.1.0.fn.0.proj_c.bias", "module.layers.0.model.1.0.norm.weight", "module.layers.0.model.1.0.norm.bias", "module.layers.0.model.1.1.fn.0.weight", "module.layers.0.model.1.1.fn.0.bias", "module.layers.0.model.1.1.fn.3.weight", "module.layers.0.model.1.1.fn.3.bias", "module.layers.0.model.1.1.norm.weight", "module.layers.0.model.1.1.norm.bias", "module.layers.0.model.2.0.fn.0.proj_h.net.0.weight", "module.layers.0.model.2.0.fn.0.proj_h.net.0.bias", "module.layers.0.model.2.0.fn.0.proj_h.net.2.weight", "module.layers.0.model.2.0.fn.0.proj_h.net.2.bias", "module.layers.0.model.2.0.fn.0.proj_w.net.0.weight", "module.layers.0.model.2.0.fn.0.proj_w.net.0.bias", "module.layers.0.model.2.0.fn.0.proj_w.net.2.weight", "module.layers.0.model.2.0.fn.0.proj_w.net.2.bias", "module.layers.0.model.2.0.fn.0.proj_c.weight", "module.layers.0.model.2.0.fn.0.proj_c.bias", "module.layers.0.model.2.0.norm.weight", "module.layers.0.model.2.0.norm.bias", "module.layers.0.model.2.1.fn.0.weight", "module.layers.0.model.2.1.fn.0.bias", "module.layers.0.model.2.1.fn.3.weight", "module.layers.0.model.2.1.fn.3.bias", "module.layers.0.model.2.1.norm.weight", "module.layers.0.model.2.1.norm.bias", "module.layers.0.model.3.0.fn.0.proj_h.net.0.weight", "module.layers.0.model.3.0.fn.0.proj_h.net.0.bias", "module.layers.0.model.3.0.fn.0.proj_h.net.2.weight", "module.layers.0.model.3.0.fn.0.proj_h.net.2.bias", "module.layers.0.model.3.0.fn.0.proj_w.net.0.weight", "module.layers.0.model.3.0.fn.0.proj_w.net.0.bias", "module.layers.0.model.3.0.fn.0.proj_w.net.2.weight", "module.layers.0.model.3.0.fn.0.proj_w.net.2.bias", "module.layers.0.model.3.0.fn.0.proj_c.weight", "module.layers.0.model.3.0.fn.0.proj_c.bias", "module.layers.0.model.3.0.norm.weight", "module.layers.0.model.3.0.norm.bias", "module.layers.0.model.3.1.fn.0.weight", "module.layers.0.model.3.1.fn.0.bias", "module.layers.0.model.3.1.fn.3.weight", "module.layers.0.model.3.1.fn.3.bias", "module.layers.0.model.3.1.norm.weight", "module.layers.0.model.3.1.norm.bias", "module.layers.1.patch_merge.1.reduction.0.weight", "module.layers.1.patch_merge.1.reduction.0.bias", "module.layers.1.model.0.0.fn.0.proj_h.net.0.weight", "module.layers.1.model.0.0.fn.0.proj_h.net.0.bias", "module.layers.1.model.0.0.fn.0.proj_h.net.2.weight", "module.layers.1.model.0.0.fn.0.proj_h.net.2.bias", "module.layers.1.model.0.0.fn.0.proj_w.net.0.weight", "module.layers.1.model.0.0.fn.0.proj_w.net.0.bias", "module.layers.1.model.0.0.fn.0.proj_w.net.2.weight", "module.layers.1.model.0.0.fn.0.proj_w.net.2.bias", "module.layers.1.model.0.0.fn.0.proj_c.weight", "module.layers.1.model.0.0.fn.0.proj_c.bias", "module.layers.1.model.0.0.norm.weight", "module.layers.1.model.0.0.norm.bias", "module.layers.1.model.0.1.fn.0.weight", "module.layers.1.model.0.1.fn.0.bias", "module.layers.1.model.0.1.fn.3.weight", "module.layers.1.model.0.1.fn.3.bias", "module.layers.1.model.0.1.norm.weight", "module.layers.1.model.0.1.norm.bias", "module.layers.1.model.1.0.fn.0.proj_h.net.0.weight", "module.layers.1.model.1.0.fn.0.proj_h.net.0.bias", "module.layers.1.model.1.0.fn.0.proj_h.net.2.weight", "module.layers.1.model.1.0.fn.0.proj_h.net.2.bias", "module.layers.1.model.1.0.fn.0.proj_w.net.0.weight", "module.layers.1.model.1.0.fn.0.proj_w.net.0.bias", "module.layers.1.model.1.0.fn.0.proj_w.net.2.weight", "module.layers.1.model.1.0.fn.0.proj_w.net.2.bias", "module.layers.1.model.1.0.fn.0.proj_c.weight", "module.layers.1.model.1.0.fn.0.proj_c.bias", "module.layers.1.model.1.0.norm.weight", "module.layers.1.model.1.0.norm.bias", "module.layers.1.model.1.1.fn.0.weight", "module.layers.1.model.1.1.fn.0.bias", "module.layers.1.model.1.1.fn.3.weight", "module.layers.1.model.1.1.fn.3.bias", "module.layers.1.model.1.1.norm.weight", "module.layers.1.model.1.1.norm.bias".

您好,

您的代码是用 Jittor 实现的吗?看上去这个问题似乎是 PyTorch 模型导入出现了问题,可能需要咨询开源项目作者或者 PyTorch 相关论坛。