总结
本笔记改编自《深度学习经典案例解析:基于MATLAB》一书中的内容,使用ResNet网络来对X光胸片进行检测,判断是否有肺炎。
不过老实说,这个数据集数量太少。并且我自己也无法通过肉眼分辨X光胸片到底有没有肺炎。
所以本笔记更多的是介绍Matlab 如何调用ResNet、如何做数据增强等知识点。
背景
数据集
在新冠肺炎诊断过程中,除了病毒核酸检测外,X射线的胸正位图像也是其中一个重要环节。
在本节中,我们将使用开源胸片数据集,训练一个ResNet网络,用以分类一张胸片是属于健康人,还是属于新冠肺炎患者,以期帮助医生更准确地做出判断。所使用的图像数据集由Adrian Rosebrock博士发布,可以直接从网址(https://www.pyimagesearch.com/2020/03/16/detecting-covid-19-in-x-ray-images-with-keras-tensorflow-and-deep-learning/)下载。其中包含25张健康人胸片和25张新冠肺炎患者胸片,原作者还提供了Python的Keras库和VGG16网络的代码,本文只用了其中的数据集,ResNet网络程序由MATLAB实现。
ResNet网络简介
ResNet是深度残差网络(Residual Network)的简称,由微软的何凯明等人于2015年提出,一举夺得当年ImageNet大规模视觉识别竞赛图像分类和物体识别的冠军。其核心思想是在每2个网络层之间加入一个残差连接,缓解深层网络中的梯度消失问题,使得训练数百甚至数千层成为可能。
在图像处理领域,一般认为神经网络更深的网络层能够提取出更加“高级”的图像特征。为了实现先进的性能,网络往往都需要数十甚至数百层,此时,梯度消失和梯度爆炸问题成为训练深层网络的一大障碍,可能导致网络无法收敛。通过在常规网络中引入残差连接,极大地缓解了网络梯度消失的问题。除此之外,残差连接还具有实现简单、计算开销小、对原有网络结构影响小等优点。因此,无论在图像处理领域,还是在自然语言处理领域,现有的很多最先进的深度神经网络都将残差连接作为一种常用技巧。
实战笔记
加载数据集
通过资源管理器查看数据集
加载数据集
1 | % 加载图像数据 |
Label | Count | |
---|---|---|
1 | covid | 25 |
2 | normal | 25 |
查看数据集
1 | % 随机显示数据集中的部分图像 |
分割测试集和训练集
1 | % 分割测试集和训练集 |
加载预训练好的网络
1 | % 确定训练数据中新冠图片标签类别数量:2类 |
根据新数据设定新的全连接层的参数,
1 | net = setLearnRateFactor(net,"fc1000/Weights",10); |
数据增强
数据增强,目的是增加数据,增加模型鲁棒性
1 | % 数据增强的参数 |
对网络进行训练
1 | % 对训练参数进行设置 |
进行测试并查看结果
1 | classNames = categories(imds.Labels) |
计算分类准确率
1 | YTest = img_ds_Test.Labels; |
accuracy = 0.9286
创建并显示混淆矩阵
1 | figure |
知识点
-
matlab如何调用预训练的ResNet网络
-
官方提供了三个型号的resnet:
"resnet18"
、"resnet50"
、"resnet101"
-
Matlab 2024a起,推荐使用imagePretrainedNetwork来加载预训练模型,可以直接指定输出分类数目,更方便迁移学习
1
net = imagePretrainedNetwork("resnet50",NumClasses=numClasses);
-
-
matlab如何做数据增强
-
augmentedImageDatastore
1
2
3
4
5
6
7
8
9
10
11
12
13% 数据增强的参数
augmenter = imageDataAugmenter(...
'RandRotation',[-5 5],...
'RandXReflection',1,...
'RandYReflection',1,...
'RandXShear',[-0.05 0.05],...
'RandYShear',[-0.05 0.05]);
% 将批量训练图像的大小调整为与输入层的大小相同
aug_img_ds_train = augmentedImageDatastore([224 224],img_ds_Train,'DataAugmentation',augmenter,'ColorPreprocessing','gray2rgb');
% 将批量测试图像的大小调整为与输入层的大小相同
aug_img_ds_test = augmentedImageDatastore([224 224], img_ds_Test,'ColorPreprocessing','gray2rgb');
-