实验:Unet道路分割
2024-10-22 14:52:45

实验:Unet道路分割

Unet网络模型介绍:

UNet — Line by Line Explanation. Example UNet Implementation | by Jeremy  Zhang | Towards Data Science

Unet(U-Net)是一种用于图像分割任务的深度学习模型架构,它在医学图像分割等领域广泛应用。这个架构的名称“U-Net”是因为它的网络结构外观类似字母“U”。

U-Net 最初是为了解决生物医学图像分割问题而提出的,尤其是对细胞图像进行精确分割。它的独特之处在于它将卷积神经网络(CNN)的编码器(捕捉图像特征)和解码器(生成分割结果)结合在一起,形成了一个对称的结构。这种结构使 U-Net 在保留图像上下文信息的同时,能够准确地捕捉不同尺度的特征。

U-Net 的主要组成部分包括:

  1. 编码器(Encoder):编码器部分通常由一系列的卷积层和池化层组成,用来逐步提取图像中的特征。这些特征在不同的层级上表示不同的抽象程度,从低级特征(如边缘)到高级特征(如纹理和形状)。
  2. 跳跃连接(Skip Connections):这是 U-Net 的一个关键特点。在编码器的每一层之后,会添加一个连接,将相应分辨率的特征图与解码器的对应层连接起来。这样做有助于传递更详细的信息给解码器,帮助它更好地还原细节。
  3. 解码器(Decoder):解码器部分也由一系列的卷积层和上采样(反池化)层组成,用来将编码器提取的特征重新映射到原始图像尺寸,并生成分割结果。跳跃连接帮助解码器在生成分割时融合不同层级的信息。
  4. 最后的卷积层:解码器的最后一层使用卷积层来生成最终的分割图像,通常使用适当的激活函数(如 sigmoid 或 softmax)来产生像素级的预测。

数据集划分

将数据集分为训练集、验证集和测试集是在机器学习和深度学习中常见的做法,其主要目的是评估模型的性能并进行泛化能力的估计。这种分割有助于模型的开发和优化过程,以及避免过拟合(在训练数据上表现良好,但在新数据上表现糟糕)的问题。

以下是每个集合的主要目的:

  1. 训练集(Training Set): 训练集是模型用来学习和调整参数的数据集。模型在训练集上进行多轮迭代,逐渐调整自己的权重和偏差,以最小化损失函数。模型在训练集上的表现会逐步提升,但这并不一定代表它在未见过的数据上也会表现良好。
  2. 验证集(Validation Set): 验证集用于调整模型的超参数(如学习率、正则化参数等),以优化模型的性能。在训练过程中,通过在验证集上进行评估,可以监控模型在未见过数据上的表现。如果模型在训练集上表现得很好,但在验证集上表现较差,可能出现了过拟合的情况。根据验证集的表现,可以进行超参数的调整,以达到更好的泛化性能。
  3. 测试集(Test Set): 测试集是用来评估模型在真实世界数据上的性能的数据集。测试集是模型完全没有见过的数据,用于最终评估模型的泛化能力。测试集的结果可以提供关于模型在真实情况下的性能指标,帮助判断模型是否足够好,是否适合部署到实际应用中。

通过将数据集分为训练集、验证集和测试集,可以更好地监控模型的表现、避免过拟合,并获得关于模型泛化性能的可靠估计。分割数据集还有助于在模型的开发过程中进行迭代和改进,以构建更准确、鲁棒的机器学习模型。

换个说法就是:

当我们训练一个模型时,为了确保它在不同情况下都能表现得好,我们通常把数据分成三份:训练集、验证集和测试集。这就好像是在学习时分成练习、考试前复习和最终考试三个阶段。

  1. 训练集:就像练习题一样,模型通过在训练集上学习,逐渐调整自己的能力。它会试着找到规律,让自己在练习上做得越来越好。
  2. 验证集:想象一下在考试前的复习。我们用验证集来调整模型的“策略”,比如要不要在解题中使用哪些方法,或者要不要调整学习的速度。这样,我们可以更好地准备模型应对真正的考试,也就是测试集。
  3. 测试集:就是最终考试。测试集包含了模型完全没有见过的问题,这样我们就可以看看模型在真实情况下的表现如何。这个阶段能告诉我们模型是否真的学得很好,能不能应对新的问题。

所以,分成这三部分有助于我们监控模型的学习过程,防止它只是死记硬背了训练集上的题目。同时,它也能让我们调整模型,确保它在各种情况下都能有好的表现。最后,通过测试集,我们能判断模型是否准备好面对真实世界中的挑战。

道路分割实验:

1、数据处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import cv2

# 构建新数据集的文件夹
os.makedirs('./UAS_NEW/train/LABEL', exist_ok=True)
os.makedirs('./UAS_NEW/train/SIGHT', exist_ok=True)
os.makedirs('./UAS_NEW/train/LABEL_IMAGE', exist_ok=True)
os.makedirs('./UAS_NEW/test/LABEL', exist_ok=True)
os.makedirs('./UAS_NEW/test/SIGHT', exist_ok=True)
os.makedirs('./UAS_NEW/test/LABEL_IMAGE', exist_ok=True)
os.makedirs('./UAS_NEW/predict/LABEL', exist_ok=True)
os.makedirs('./UAS_NEW/predict/SIGHT', exist_ok=True)
os.makedirs('./UAS_NEW/predict/LABEL_IMAGE', exist_ok=True)

IMAGE_SIZE = (224, 224)

#遍历原始数据对图片进行压缩并分类
for dirname1 in os.listdir('./UAS'):
image_path = './UAS/'+dirname1
for dirname2 in os.listdir(image_path):
if dirname2 == 'train':
image_path2 = image_path +'/train'
for filename in os.listdir(image_path2):
if 'jpg' in filename:
if int(filename.split('.')[0].split('t')[-1])%9!=1:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/train/SIGHT/'+filename,pic)
else:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/predict/SIGHT/'+filename,pic)
elif 'Graph' in filename:
if int(filename.split('.')[0].split('h')[-1])%9!=1:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/train/LABEL_IMAGE/'+filename,pic)
else:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/predict/LABEL_IMAGE/'+filename,pic)
else:
if int(filename.split('.')[0].split('l')[-1])%9!=1:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/train/LABEL/'+filename,pic)
else:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/predict/LABEL/'+filename,pic)
else:
image_path2 = image_path +'/test'
for filename in os.listdir(image_path2):
if 'jpg' in filename:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/test/SIGHT/'+filename,pic)
elif 'Graph' in filename:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/test/LABEL_IMAGE/'+filename,pic)
else:
sight_path = image_path2+'/'+filename
pic = cv2.imread(sight_path)
pic = cv2.resize(pic, IMAGE_SIZE)
cv2.imwrite('./UAS_NEW/test/LABEL/'+filename,pic)

def CollectImgName(filepath, output):
images = sorted(os.listdir(filepath + '/SIGHT'))
labels = sorted(os.listdir(filepath + '/LABEL'))
label_images = sorted(os.listdir(filepath + '/LABEL_IMAGE'))

with open(output, 'w') as f:
for i in range(len(images)):
f.write(filepath + '/SIGHT/' + images[i] + '\t' + filepath + '/LABEL/' + labels[i] + '\t' + filepath + '/LABEL_IMAGE/' + label_images[i] + '\n')

CollectImgName('UAS_NEW/train', './train.txt')
CollectImgName('UAS_NEW/test', './test.txt')
CollectImgName('UAS_NEW/predict', './predict.txt')

2、抽样检查

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image as PilImage

with open('./train.txt', 'r') as f:
i = 0
for line in f.readlines():
image_path, label_path, label_image_path = line.strip().split('\t')
image = np.array(PilImage.open(image_path))
label = np.array(PilImage.open(label_path))
label_image = np.array(PilImage.open(label_image_path))

if i > 2:
break

# 进行图片的展示
plt.figure()

plt.subplot(1, 3, 1),
plt.title('Train Image')
plt.imshow(image.astype('uint8'))
plt.axis('off')

plt.subplot(1, 3, 2),
plt.title('Label')
plt.imshow(label.astype('uint8'), cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3),
plt.title('Label Image')
plt.imshow(label_image.astype('uint8'))
plt.axis('off')

plt.show()
i = i + 1

3、构建Unet模型网络-编码器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import paddle

# 编码器
class Encoder(paddle.nn.Layer):
def __init__(self, in_channels, out_channels):
super(Encoder, self).__init__()

self.relus = paddle.nn.LayerList(
[paddle.nn.ReLU() for i in range(2)])
self.separable_conv_01 = paddle.nn.Conv2D(in_channels,
out_channels,
kernel_size=3,
padding='same')
self.bns = paddle.nn.LayerList(
[paddle.nn.BatchNorm2D(out_channels) for i in range(2)])

self.separable_conv_02 = paddle.nn.Conv2D(out_channels,
out_channels,
kernel_size=3,
padding='same')
self.pool = paddle.nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.residual_conv = paddle.nn.Conv2D(in_channels,
out_channels,
kernel_size=1,
stride=2,
padding='same')

def forward(self, inputs):
previous_block_activation = inputs

y = self.relus[0](inputs)
y = self.separable_conv_01(y)
y = self.bns[0](y)
y = self.relus[1](y)
y = self.separable_conv_02(y)
y = self.bns[1](y)
y = self.pool(y)

residual = self.residual_conv(previous_block_activation)
y = paddle.add(y, residual)

return y

4、构建Unet模型网络-解码器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import paddle

# 解码器
class Decoder(paddle.nn.Layer):
def __init__(self, in_channels, out_channels):
super(Decoder, self).__init__()

self.relus = paddle.nn.LayerList(
[paddle.nn.ReLU() for i in range(2)])
self.conv_transpose_01 = paddle.nn.Conv2DTranspose(in_channels,
out_channels,
kernel_size=3,
padding=1)
self.conv_transpose_02 = paddle.nn.Conv2DTranspose(out_channels,
out_channels,
kernel_size=3,
padding=1)
self.bns = paddle.nn.LayerList(
[paddle.nn.BatchNorm2D(out_channels) for i in range(2)]
)
self.upsamples = paddle.nn.LayerList(
[paddle.nn.Upsample(scale_factor=2.0) for i in range(2)]
)
self.residual_conv = paddle.nn.Conv2D(in_channels,
out_channels,
kernel_size=1,
padding='same')

def forward(self, inputs):
previous_block_activation = inputs

y = self.relus[0](inputs)
y = self.conv_transpose_01(y)
y = self.bns[0](y)
y = self.relus[1](y)
y = self.conv_transpose_02(y)
y = self.bns[1](y)
y = self.upsamples[0](y)

residual = self.upsamples[1](previous_block_activation)
residual = self.residual_conv(residual)

y = paddle.add(y, residual)

return y

6、构建Unet模型网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import paddle
from Library.Encoder import Encoder
from Library.Decoder import Decoder

# 网络
class RoadNet(paddle.nn.Layer):
def __init__(self, num_classes):
super(RoadNet, self).__init__()

self.conv_1 = paddle.nn.Conv2D(3, 32,
kernel_size=3,
stride=2,
padding='same')
self.bn = paddle.nn.BatchNorm2D(32)
self.relu = paddle.nn.ReLU()

in_channels = 32
self.encoders = []
self.encoder_list = [64, 128, 256]
self.decoder_list = [256, 128, 64, 32]

# 根据下采样个数和配置循环定义子Layer,避免重复写一样的程序
for out_channels in self.encoder_list:
block = self.add_sublayer('encoder_{}'.format(out_channels), Encoder(in_channels, out_channels))
self.encoders.append(block)
in_channels = out_channels

self.decoders = []

# 根据上采样个数和配置循环定义子Layer,避免重复写一样的程序
for out_channels in self.decoder_list:
block = self.add_sublayer('decoder_{}'.format(out_channels), Decoder(in_channels, out_channels))
self.decoders.append(block)
in_channels = out_channels

self.output_conv = paddle.nn.Conv2D(in_channels,
num_classes,
kernel_size=3,
padding='same')

def forward(self, inputs):
y = self.conv_1(inputs)
y = self.bn(y)
y = self.relu(y)
for encoder in self.encoders:
y = encoder(y)
for decoder in self.decoders:
y = decoder(y)
y = self.output_conv(y)
return y

7、训练集读取器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from paddle.io import Dataset
from paddle.vision.transforms import transforms as T
from PIL import Image as PilImage
import numpy as np
import io

IMAGE_SIZE = (224, 224)

class PredictDataset(Dataset):
# 数据集定义
def __init__(self, mode='train'):
# 构造函数
self.image_size = IMAGE_SIZE
self.mode = mode.lower()
self.train_images = []

with open('./{}.txt'.format(self.mode), 'r') as f:
for line in f.readlines():
image = line.strip().split('\t')[0]
self.train_images.append(image)

def _load_img(self, path, color_mode='rgb', transforms=[]):
# 统一的图像处理接口封装,用于规整图像大小和通道
with open(path, 'rb') as f:
img = PilImage.open(io.BytesIO(f.read()))
if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')

return T.Compose([
T.Resize(self.image_size)
] + transforms)(img)

def __getitem__(self, idx):
# 返回 image
train_image = self._load_img(self.train_images[idx],
transforms=[
T.Transpose(),
T.Normalize(mean=127.5, std=127.5)
])

train_image = np.array(train_image, dtype='float32')
return train_image

def __len__(self):
return len(self.train_images)

9、训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import paddle
from Library.RoadDataset import RoadDataset
from Library.RoadNet import RoadNet

num_classes = 2
IMAGE_SIZE = (224, 224)
network = RoadNet(num_classes)
model = paddle.Model(network)

train_dataset = RoadDataset(mode='train') # 训练数据集
val_dataset = RoadDataset(mode='test') # 验证数据集

optim = paddle.optimizer.RMSProp(learning_rate=0.001,
rho=0.9,
momentum=0.0,
epsilon=1e-07,
centered=False,
parameters=model.parameters())

model.prepare(optim, paddle.nn.CrossEntropyLoss(axis=1))

paddle.set_device('gpu')
model.fit(train_dataset, val_dataset, epochs=15, batch_size=32, verbose=1, save_dir='./work')

10、推理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from Library.RoadDataset import RoadDataset
from Library.RoadNet import RoadNet
import paddle
import pickle

num_classes = 2
IMAGE_SIZE = (224, 224)
network = RoadNet(num_classes)
model = paddle.Model(network)
model.load("./work/final")
optim = paddle.optimizer.RMSProp(learning_rate=0.001,
rho=0.9,
momentum=0.0,
epsilon=1e-07,
centered=False,
parameters=model.parameters())

model.prepare(optim, paddle.nn.CrossEntropyLoss(axis=1))
paddle.set_device('gpu')

predict_dataset = RoadDataset(mode='predict')
predict_results = model.predict(predict_dataset)

with open("./predict_results.txt", 'wb') as f:
result = pickle.dumps(predict_results)
f.write(result)

11、推理结果可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image as PilImage
from paddle.vision.transforms import transforms as T
import pickle

with open('./predict_results.txt', 'rb') as f:
content = f.read()
predict_results = pickle.loads(content)

plt.figure(figsize=(10, 10))

i = 0
mask_idx = 0
IMAGE_SIZE = (224, 224)
with open('./predict.txt', 'r') as f:
for line in f.readlines():
image_path, label_path, label_image_path = line.strip().split('\t')
resize_t = T.Compose([
T.Resize(IMAGE_SIZE)
])
image = resize_t(PilImage.open(image_path))
label = resize_t(PilImage.open(label_image_path))

image = np.array(image).astype('uint8')
label = np.array(label).astype('uint8')

if i > 8:
break
plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title('Input Image')
plt.axis("off")

plt.subplot(3, 3, i + 2)
plt.imshow(label, cmap='gray')
plt.title('Label')
plt.axis("off")

# 模型只有一个输出,所以我们通过predict_results[0]来取出预测的结果
# 映射原始图片的index来取出预测结果,提取mask进行展示
data = predict_results[0][mask_idx][0].transpose((1, 2, 0))
mask = np.argmax(data, axis=-1)

plt.subplot(3, 3, i + 3)
plt.imshow(mask.astype('uint8'), cmap='gray')
plt.title('Predict')
plt.axis("off")
i += 3
mask_idx += 1
plt.show()

12、模型评估(参考)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
__all__ = ['SegmentationMetric']

import numpy as np
import cv2
from paddle.vision.transforms import transforms as T
from PIL import Image as PilImage
import pickle

class SegmentationMetric(object):
def __init__(self, numClass):
self.numClass = numClass
self.confusionMatrix = np.zeros((self.numClass,) * 2) # 混淆矩阵(空)

def pixelAccuracy(self):
# return all class overall pixel accuracy 正确的像素占总像素的比例
# PA = acc = (TP + TN) / (TP + TN + FP + TN)
acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
return acc

def classPixelAccuracy(self):
# return each category pixel accuracy(A more accurate way to call it precision)
# acc = (TP) / TP + FP
classAcc = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
return classAcc # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率

def meanPixelAccuracy(self):
"""
Mean Pixel Accuracy(MPA,均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。
:return:
"""
classAcc = self.classPixelAccuracy()
meanAcc = np.nanmean(classAcc) # np.nanmean 求平均值,nan表示遇到Nan类型,其值取为0
return meanAcc # 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89

def IntersectionOverUnion(self):
# Intersection = TP Union = TP + FP + FN
# IoU = TP / (TP + FP + FN)
intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表
union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表
IoU = intersection / union # 返回列表,其值为各个类别的IoU
return IoU

def meanIntersectionOverUnion(self):
mIoU = np.nanmean(self.IntersectionOverUnion()) # 求各类别IoU的平均
return mIoU

def genConfusionMatrix(self, imgPredict, imgLabel): #
"""
同FCN中score.py的fast_hist()函数,计算混淆矩阵
:param imgPredict:
:param imgLabel:
:return: 混淆矩阵
"""
# remove classes from unlabeled pixels in gt image and predict
mask = (imgLabel >= 0) & (imgLabel < self.numClass)
label = self.numClass * imgLabel[mask] + imgPredict[mask]
count = np.bincount(label, minlength=self.numClass ** 2)
confusionMatrix = count.reshape(self.numClass, self.numClass)
# print(confusionMatrix)
return confusionMatrix

def Frequency_Weighted_Intersection_over_Union(self):
"""
FWIoU,频权交并比:为MIoU的一种提升,这种方法根据每个类出现的频率为其设置权重。
FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
"""
freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
iu = np.diag(self.confusion_matrix) / (
np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
np.diag(self.confusion_matrix))
FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
return FWIoU

def addBatch(self, imgPredict, imgLabel):
assert imgPredict.shape == imgLabel.shape
self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) # 得到混淆矩阵
return self.confusionMatrix

def reset(self):
self.confusionMatrix = np.zeros((self.numClass, self.numClass))

# 图片颜色转换
def green2black(img):
for i in range(len(img)):
for j in range(len(img[i])):
if (img[i][j][0] == 255) and (img[i][j][1] == 255) and (img[i][j][2] == 255):
img[i][j][0] = 0
img[i][j][1] = 0
img[i][j][2] = 0
elif (img[i][j][0] == 0) and (img[i][j][1] == 255) and (img[i][j][2] == 255):
img[i][j][0] = 255
return img

def create_img(arr):
img = np.ones((224,224,3), dtype=np.float32)
for i in range(len(arr)):
for j in range(len(arr[i])):
if arr[i][j] == 0:
img[i][j][0] = 0
img[i][j][1] = 0
img[i][j][2] = 0
else:
img[i][j][0] = 255
img[i][j][1] = 255
img[i][j][2] = 255
return img


with open('./predict_results.txt', 'rb') as f:
content = f.read()
predict_results = pickle.loads(content)

# 测试内容
mask_idx = 0
metric, hist, pa, cpa, mpa, IoU, mIoU = [], [], [], [], [], [], []
IMAGE_SIZE = (224, 224)
with open('./predict.txt', 'r') as f:
for line in f.readlines():
image_path, label_path, label_image_path = line.strip().split('\t')
resize_t = T.Compose([
T.Resize(IMAGE_SIZE)
])
label = resize_t(PilImage.open(label_image_path))
label = np.array(label).astype('uint8')
# 模型只有一个输出,所以我们通过predict_results[0]来取出预测的结果
# 映射原始图片的index来取出预测结果,提取mask进行展示
data = predict_results[0][mask_idx][0].transpose((1, 2, 0))
mask = np.argmax(data, axis=-1)
mask_idx += 1

imgPredict = create_img(mask.astype('uint8'))
imgLabel = green2black(label)
imgPredict = np.array(cv2.cvtColor(imgPredict, cv2.COLOR_BGR2GRAY) / 255., dtype=np.uint8)
imgLabel = np.array(cv2.cvtColor(imgLabel, cv2.COLOR_BGR2GRAY) / 255., dtype=np.uint8)
# imgPredict = np.array([0, 0, 1, 1, 2, 2]) # 可直接换成预测图片
# imgLabel = np.array([0, 0, 1, 1, 2, 2]) # 可直接换成标注图片

metric = SegmentationMetric(2) # 2表示有2个分类,有几个分类就填几
hist.append(metric.addBatch(imgPredict, imgLabel))
pa.append(metric.pixelAccuracy())
cpa.append(metric.classPixelAccuracy())
mpa.append(metric.meanPixelAccuracy())
IoU.append(metric.IntersectionOverUnion())
mIoU.append(metric.meanIntersectionOverUnion())

print('hist is :\n', np.mean(hist))
print('PA is : %f' % np.mean(pa))
print('IoU is : ', np.mean(IoU))