‘LinearClsHead‘ object has no attribute ‘simple

    科技2022-07-12  128

    https://github.com/open-mmlab/mmclassification

    目前官方还未提供infer代码,自己写了一个但是报了一个错误'LinearClsHead' object has no attribute 'simple_test',需要改动源码

    在linear_head.py 增加一个函数即可:

    def simple_test(self, x): cls_score = self.fc(x) return cls_score

    然后前向代码infer.py

    # -*- coding: utf-8 -*- # @Time : 2020/10/3 下午8:35 # @Author : zxq # @File : model_infer.py # @Software: PyCharm import os import cv2 import mmcv import numpy as np import torch from mmcv import Config from mmcls.models import build_classifier if __name__ == '__main__': cfg = Config.fromfile('../../configs/imagenet/ciga_call_cfg.py') data_path = '/home/zxq/PycharmProjects/data/ciga_call/test2' weight_path = '../../work_dir/epoch_100.pth' model = build_classifier(cfg.model) model.eval() save_path = os.path.join(os.path.dirname(cfg.data.test.data_prefix), 'test_result') mmcv.mkdir_or_exist(save_path) mean_value = None std_value = None for step_ in cfg.test_pipeline: if step_['type'] is 'Normalize': mean_value = np.array(step_['mean']) std_value = np.array(step_['std']) img_name_list = os.listdir(data_path) for img_name in img_name_list: img_dir = os.path.join(data_path, img_name) print(img_dir) img = cv2.imread(img_dir) # 1, resize img_resized = mmcv.imresize(img, (256, 256)) # 2, Normalize img_normalized = mmcv.imnormalize(img_resized, mean_value, std_value) # 3, switch dim and to tensor input_data = torch.Tensor(np.transpose(img_normalized, [2, 0, 1])) # 4, add batch dim batch_data = torch.unsqueeze(input_data, 0) # 4, infer model.load_state_dict(torch.load(weight_path, map_location='cpu')['state_dict']) model_output = model(batch_data, return_loss=False).detach().numpy() cls_output = np.argmax(model_output, axis=1) print(cls_output)

     

    Processed: 0.010, SQL: 8