Yumi's Blog

Learn about Fully Convolutional Networks for semantic segmentation

In this blog post, I will learn a semantic segmentation problem and review fully convolutional networks. In an image for the semantic segmentation, each pixcel is usually labeled with the class of its enclosing object or region. For example, a pixcel might belongs to a road, car, building or a person. The semantic segmentation problem requires to make a classification at every pixel.

I will use Fully Convolutional Networks (FCN) to classify every pixcel.

To understand the semantic segmentation problem, let's look at an example data prepared by divamgupta. Note: I will use this example data rather than famous segmentation data e.g., pascal VOC2012 because it requires pre-processing.

Reference

First, I download data from:

https://drive.google.com/file/d/0B0d9ZiqAgFkiOHR1NTJhWVJMNEU/view

and save the downloaded data1 folder in the current directory.

In [1]:
dir_data = "dataset1/"
dir_seg = dir_data + "/annotations_prepped_train/"
dir_img = dir_data + "/images_prepped_train/"

Visualize a single segmentation image

In this data, there are 12 segmentation classes and the image is from a driving car.

In [2]:
import cv2, os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
## seaborn has white grid by default so I will get rid of this.
sns.set_style("whitegrid", {'axes.grid' : False})


ldseg = np.array(os.listdir(dir_seg))
## pick the first image file
fnm = ldseg[0]
print(fnm)

## read in the original image and segmentation labels
seg = cv2.imread(dir_seg + fnm ) # (360, 480, 3)
img_is = cv2.imread(dir_img + fnm )
print("seg.shape={}, img_is.shape={}".format(seg.shape,img_is.shape))

## Check the number of labels
mi, ma = np.min(seg), np.max(seg)
n_classes = ma - mi + 1
print("minimum seg = {}, maximum seg = {}, Total number of segmentation classes = {}".format(mi,ma, n_classes))

fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(1,1,1)
ax.imshow(img_is)
ax.set_title("original image")
plt.show()

fig = plt.figure(figsize=(15,10))
for k in range(mi,ma+1):
    ax = fig.add_subplot(3,n_classes/3,k+1)
    ax.imshow((seg == k)*1.0)
    ax.set_title("label = {}".format(k))


plt.show()
0016E5_01620.png
seg.shape=(360, 480, 3), img_is.shape=(360, 480, 3)
minimum seg = 0, maximum seg = 11, Total number of segmentation classes = 12

Data preprocessing: Resize image

To simplify the problem, I will reshape all the images to the same size: (224,224). Why (224,224)? This is the iamge shape used in VGG and FCN model in this blog uses a network that takes advantage of VGG structure. The FCN model becomes easier to explain when the image shape is (224,224). However, FCN does not requires the image shape to be (224,224).

Let's visualize how the resizing make the images look like. The images look fine.

In [3]:
import random
def give_color_to_seg_img(seg,n_classes):
    '''
    seg : (input_width,input_height,3)
    '''
    
    if len(seg.shape)==3:
        seg = seg[:,:,0]
    seg_img = np.zeros( (seg.shape[0],seg.shape[1],3) ).astype('float')
    colors = sns.color_palette("hls", n_classes)
    
    for c in range(n_classes):
        segc = (seg == c)
        seg_img[:,:,0] += (segc*( colors[c][0] ))
        seg_img[:,:,1] += (segc*( colors[c][1] ))
        seg_img[:,:,2] += (segc*( colors[c][2] ))

    return(seg_img)

input_height , input_width = 224 , 224
output_height , output_width = 224 , 224


ldseg = np.array(os.listdir(dir_seg))
for fnm in ldseg[np.random.choice(len(ldseg),3,replace=False)]:
    fnm = fnm.split(".")[0]
    seg = cv2.imread(dir_seg + fnm + ".png") # (360, 480, 3)
    img_is = cv2.imread(dir_img + fnm + ".png")
    seg_img = give_color_to_seg_img(seg,n_classes)

    fig = plt.figure(figsize=(20,40))
    ax = fig.add_subplot(1,4,1)
    ax.imshow(seg_img)
    
    ax = fig.add_subplot(1,4,2)
    ax.imshow(img_is/255.0)
    ax.set_title("original image {}".format(img_is.shape[:2]))
    
    ax = fig.add_subplot(1,4,3)
    ax.imshow(cv2.resize(seg_img,(input_height , input_width)))
    
    ax = fig.add_subplot(1,4,4)
    ax.imshow(cv2.resize(img_is,(output_height , output_width))/255.0)
    ax.set_title("resized to {}".format((output_height , output_width)))
    plt.show()

Resize all the images. We have 367 images in this dataset.

In [4]:
def getImageArr( path , width , height ):
        img = cv2.imread(path, 1)
        img = np.float32(cv2.resize(img, ( width , height ))) / 127.5 - 1
        return img

def getSegmentationArr( path , nClasses ,  width , height  ):

    seg_labels = np.zeros((  height , width  , nClasses ))
    img = cv2.imread(path, 1)
    img = cv2.resize(img, ( width , height ))
    img = img[:, : , 0]

    for c in range(nClasses):
        seg_labels[: , : , c ] = (img == c ).astype(int)
    ##seg_labels = np.reshape(seg_labels, ( width*height,nClasses  ))
    return seg_labels




images = os.listdir(dir_img)
images.sort()
segmentations  = os.listdir(dir_seg)
segmentations.sort()
    
X = []
Y = []
for im , seg in zip(images,segmentations) :
    X.append( getImageArr(dir_img + im , input_width , input_height )  )
    Y.append( getSegmentationArr( dir_seg + seg , n_classes , output_width , output_height )  )

X, Y = np.array(X) , np.array(Y)
print(X.shape,Y.shape)
((367, 224, 224, 3), (367, 224, 224, 12))

Import Keras and Tensorflow to develop deep learning FCN models

In [5]:
## Import usual libraries
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
import keras, sys, time, warnings
from keras.models import *
from keras.layers import *
import pandas as pd 
warnings.filterwarnings("ignore")

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.visible_device_list = "2" 
set_session(tf.Session(config=config))   

print("python {}".format(sys.version))
print("keras version {}".format(keras.__version__)); del keras
print("tensorflow version {}".format(tf.__version__))
Using TensorFlow backend.
python 2.7.13 |Anaconda 4.3.1 (64-bit)| (default, Dec 20 2016, 23:09:15) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
keras version 2.0.6
tensorflow version 1.2.1

From classifier to dense FCN

The recent successful deep learning models such as VGG are originally designed for classification task. The network stacks convolution layers together with down-sampling layers, such as max-pooling, and then finally stacks fully connected layers. Appending a fully connected layer enables the network to learn something using global information where the spatial arrangement of the input falls away.

Fully convosutional network

For the segmentation task, however, spatial infomation should be stored to make a pixcel-wise classification. FCN allows this by making all the layers of VGG to convolusional layers.

Fully convolutional indicates that the neural network is composed of convolutional layers without any fully-connected layers usually found at the end of the network. Fully Convolutional Networks for Semantic Segmentation motivates the use of fully convolutional networks by "convolutionalizing" popular CNN architectures e.g. VGG can also be viewed as FCN.

... fully connected layers can also be viewed as convolutions with kernels that cover their entire input regions. Doing so casts them into fully convolutional networks that take input of any size and output classification maps. (Section 3.1)

The model I used in this blog post is FCN8 from Fully Convolutional Networks for Semantic Segmentation. It deplicates VGG16 net by discarding the final classifier layer and convert all fully connected layers to convolutions. Fully Convolutional Networks for Semantic Segmentation appends a 1 x 1 convolution with channel dimension the same as the number of segmentation classes (in our case, this is 12) to predict scores at each of the coarse output locations, followed by upsampling deconvolution layers which brings back low resolution image to the output image size. In our example, output image size is (output_height, output_width) = (224,224).

Upsampling

The upsampling layer brings low resolution image to high resolution. There are various upsamping methods. This presentation gives a good overview. For example, one may double the image resolution by duplicating each pixcel twice. This is so-called nearest neighbor approach and implemented in Keras's UpSampling2D. Another method may be bilinear upsampling, which linearly interpolates the nearest four inputs.

These upsampling layers do not have weights/parameters so the model is not flexible. Instead, FCN8 uses upsampling procedure called backwards convolusion (sometimes called deconvolution) with some output stride. This method simply reverses the forward and backward passes of convolution and implemented in Keras's Conv2DTranspose. This deconvolusion upsampling layer is well explained in this blog post: Up-sampling with Transposed Convolution.

In FCN8, the upsampling layer is followed by several skip connections. See details at Fully Convolutional Networks for Semantic Segmentation.

I downloaded VGG16 weights from fchollet's Github This is a massive .h5 file (57MB).

In [6]:
## location of VGG weights
VGG_Weights_path = "../FacialKeypoint/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5"
In [7]:
def FCN8( nClasses ,  input_height=224, input_width=224):
    ## input_height and width must be devisible by 32 because maxpooling with filter size = (2,2) is operated 5 times,
    ## which makes the input_height and width 2^5 = 32 times smaller
    assert input_height%32 == 0
    assert input_width%32 == 0
    IMAGE_ORDERING =  "channels_last" 

    img_input = Input(shape=(input_height,input_width, 3)) ## Assume 224,224,3
    
    ## Block 1
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', data_format=IMAGE_ORDERING )(img_input)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2', data_format=IMAGE_ORDERING )(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool', data_format=IMAGE_ORDERING )(x)
    f1 = x
    
    # Block 2
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1', data_format=IMAGE_ORDERING )(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2', data_format=IMAGE_ORDERING )(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool', data_format=IMAGE_ORDERING )(x)
    f2 = x

    # Block 3
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1', data_format=IMAGE_ORDERING )(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2', data_format=IMAGE_ORDERING )(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3', data_format=IMAGE_ORDERING )(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool', data_format=IMAGE_ORDERING )(x)
    pool3 = x

    # Block 4
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1', data_format=IMAGE_ORDERING )(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2', data_format=IMAGE_ORDERING )(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3', data_format=IMAGE_ORDERING )(x)
    pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool', data_format=IMAGE_ORDERING )(x)## (None, 14, 14, 512) 

    # Block 5
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1', data_format=IMAGE_ORDERING )(pool4)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2', data_format=IMAGE_ORDERING )(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3', data_format=IMAGE_ORDERING )(x)
    pool5 = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool', data_format=IMAGE_ORDERING )(x)## (None, 7, 7, 512)

    #x = Flatten(name='flatten')(x)
    #x = Dense(4096, activation='relu', name='fc1')(x)
    # <--> o = ( Conv2D( 4096 , ( 7 , 7 ) , activation='relu' , padding='same', data_format=IMAGE_ORDERING))(o)
    # assuming that the input_height = input_width = 224 as in VGG data
    
    #x = Dense(4096, activation='relu', name='fc2')(x)
    # <--> o = ( Conv2D( 4096 , ( 1 , 1 ) , activation='relu' , padding='same', data_format=IMAGE_ORDERING))(o)   
    # assuming that the input_height = input_width = 224 as in VGG data
    
    #x = Dense(1000 , activation='softmax', name='predictions')(x)
    # <--> o = ( Conv2D( nClasses ,  ( 1 , 1 ) ,kernel_initializer='he_normal' , data_format=IMAGE_ORDERING))(o)
    # assuming that the input_height = input_width = 224 as in VGG data
    
    
    vgg  = Model(  img_input , pool5  )
    vgg.load_weights(VGG_Weights_path) ## loading VGG weights for the encoder parts of FCN8
    
    n = 4096
    o = ( Conv2D( n , ( 7 , 7 ) , activation='relu' , padding='same', name="conv6", data_format=IMAGE_ORDERING))(pool5)
    conv7 = ( Conv2D( n , ( 1 , 1 ) , activation='relu' , padding='same', name="conv7", data_format=IMAGE_ORDERING))(o)
    
    
    ## 4 times upsamping for pool4 layer
    conv7_4 = Conv2DTranspose( nClasses , kernel_size=(4,4) ,  strides=(4,4) , use_bias=False, data_format=IMAGE_ORDERING )(conv7)
    ## (None, 224, 224, 10)
    ## 2 times upsampling for pool411
    pool411 = ( Conv2D( nClasses , ( 1 , 1 ) , activation='relu' , padding='same', name="pool4_11", data_format=IMAGE_ORDERING))(pool4)
    pool411_2 = (Conv2DTranspose( nClasses , kernel_size=(2,2) ,  strides=(2,2) , use_bias=False, data_format=IMAGE_ORDERING ))(pool411)
    
    pool311 = ( Conv2D( nClasses , ( 1 , 1 ) , activation='relu' , padding='same', name="pool3_11", data_format=IMAGE_ORDERING))(pool3)
        
    o = Add(name="add")([pool411_2, pool311, conv7_4 ])
    o = Conv2DTranspose( nClasses , kernel_size=(8,8) ,  strides=(8,8) , use_bias=False, data_format=IMAGE_ORDERING )(o)
    o = (Activation('softmax'))(o)
    
    model = Model(img_input, o)

    return model

model = FCN8(nClasses     = n_classes,  
             input_height = 224, 
             input_width  = 224)
model.summary()
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 224, 224, 3)   0                                            
____________________________________________________________________________________________________
block1_conv1 (Conv2D)            (None, 224, 224, 64)  1792        input_1[0][0]                    
____________________________________________________________________________________________________
block1_conv2 (Conv2D)            (None, 224, 224, 64)  36928       block1_conv1[0][0]               
____________________________________________________________________________________________________
block1_pool (MaxPooling2D)       (None, 112, 112, 64)  0           block1_conv2[0][0]               
____________________________________________________________________________________________________
block2_conv1 (Conv2D)            (None, 112, 112, 128) 73856       block1_pool[0][0]                
____________________________________________________________________________________________________
block2_conv2 (Conv2D)            (None, 112, 112, 128) 147584      block2_conv1[0][0]               
____________________________________________________________________________________________________
block2_pool (MaxPooling2D)       (None, 56, 56, 128)   0           block2_conv2[0][0]               
____________________________________________________________________________________________________
block3_conv1 (Conv2D)            (None, 56, 56, 256)   295168      block2_pool[0][0]                
____________________________________________________________________________________________________
block3_conv2 (Conv2D)            (None, 56, 56, 256)   590080      block3_conv1[0][0]               
____________________________________________________________________________________________________
block3_conv3 (Conv2D)            (None, 56, 56, 256)   590080      block3_conv2[0][0]               
____________________________________________________________________________________________________
block3_pool (MaxPooling2D)       (None, 28, 28, 256)   0           block3_conv3[0][0]               
____________________________________________________________________________________________________
block4_conv1 (Conv2D)            (None, 28, 28, 512)   1180160     block3_pool[0][0]                
____________________________________________________________________________________________________
block4_conv2 (Conv2D)            (None, 28, 28, 512)   2359808     block4_conv1[0][0]               
____________________________________________________________________________________________________
block4_conv3 (Conv2D)            (None, 28, 28, 512)   2359808     block4_conv2[0][0]               
____________________________________________________________________________________________________
block4_pool (MaxPooling2D)       (None, 14, 14, 512)   0           block4_conv3[0][0]               
____________________________________________________________________________________________________
block5_conv1 (Conv2D)            (None, 14, 14, 512)   2359808     block4_pool[0][0]                
____________________________________________________________________________________________________
block5_conv2 (Conv2D)            (None, 14, 14, 512)   2359808     block5_conv1[0][0]               
____________________________________________________________________________________________________
block5_conv3 (Conv2D)            (None, 14, 14, 512)   2359808     block5_conv2[0][0]               
____________________________________________________________________________________________________
block5_pool (MaxPooling2D)       (None, 7, 7, 512)     0           block5_conv3[0][0]               
____________________________________________________________________________________________________
conv6 (Conv2D)                   (None, 7, 7, 4096)    102764544   block5_pool[0][0]                
____________________________________________________________________________________________________
pool4_11 (Conv2D)                (None, 14, 14, 12)    6156        block4_pool[0][0]                
____________________________________________________________________________________________________
conv7 (Conv2D)                   (None, 7, 7, 4096)    16781312    conv6[0][0]                      
____________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTransp (None, 28, 28, 12)    576         pool4_11[0][0]                   
____________________________________________________________________________________________________
pool3_11 (Conv2D)                (None, 28, 28, 12)    3084        block3_pool[0][0]                
____________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTransp (None, 28, 28, 12)    786432      conv7[0][0]                      
____________________________________________________________________________________________________
add (Add)                        (None, 28, 28, 12)    0           conv2d_transpose_2[0][0]         
                                                                   pool3_11[0][0]                   
                                                                   conv2d_transpose_1[0][0]         
____________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTransp (None, 224, 224, 12)  9216        add[0][0]                        
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 224, 224, 12)  0           conv2d_transpose_3[0][0]         
====================================================================================================
Total params: 135,066,008
Trainable params: 135,066,008
Non-trainable params: 0
____________________________________________________________________________________________________

Split between training and testing data

In [8]:
from sklearn.utils import shuffle
train_rate = 0.85
index_train = np.random.choice(X.shape[0],int(X.shape[0]*train_rate),replace=False)
index_test  = list(set(range(X.shape[0])) - set(index_train))
                            
X, Y = shuffle(X,Y)
X_train, y_train = X[index_train],Y[index_train]
X_test, y_test = X[index_test],Y[index_test]
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
((311, 224, 224, 3), (311, 224, 224, 12))
((56, 224, 224, 3), (56, 224, 224, 12))

Training starts here

In [9]:
from keras import optimizers


sgd = optimizers.SGD(lr=1E-2, decay=5**(-4), momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
              optimizer=sgd,
              metrics=['accuracy'])

hist1 = model.fit(X_train,y_train,
                  validation_data=(X_test,y_test),
                  batch_size=32,epochs=200,verbose=2)
Train on 311 samples, validate on 56 samples
Epoch 1/200
15s - loss: 2.6345 - acc: 0.0885 - val_loss: 2.4823 - val_acc: 0.0958
Epoch 2/200
8s - loss: 2.4792 - acc: 0.1003 - val_loss: 2.4750 - val_acc: 0.1053
Epoch 3/200
8s - loss: 2.4671 - acc: 0.1113 - val_loss: 2.4546 - val_acc: 0.1229
Epoch 4/200
8s - loss: 2.4317 - acc: 0.1392 - val_loss: 2.3907 - val_acc: 0.1665
Epoch 5/200
8s - loss: 2.3087 - acc: 0.2091 - val_loss: 2.2193 - val_acc: 0.2802
Epoch 6/200
8s - loss: 2.3756 - acc: 0.2297 - val_loss: 2.2791 - val_acc: 0.2472
Epoch 7/200
8s - loss: 2.1612 - acc: 0.2812 - val_loss: 2.1125 - val_acc: 0.3080
Epoch 8/200
8s - loss: 2.0254 - acc: 0.3275 - val_loss: 1.8800 - val_acc: 0.3197
Epoch 9/200
8s - loss: 1.7938 - acc: 0.3446 - val_loss: 1.7250 - val_acc: 0.3914
Epoch 10/200
8s - loss: 1.6222 - acc: 0.4469 - val_loss: 1.5345 - val_acc: 0.4825
Epoch 11/200
8s - loss: 1.4211 - acc: 0.5345 - val_loss: 1.3592 - val_acc: 0.5589
Epoch 12/200
8s - loss: 1.2757 - acc: 0.6089 - val_loss: 1.2417 - val_acc: 0.6301
Epoch 13/200
8s - loss: 1.1592 - acc: 0.6525 - val_loss: 1.1354 - val_acc: 0.6541
Epoch 14/200
8s - loss: 1.1307 - acc: 0.6640 - val_loss: 1.0794 - val_acc: 0.6612
Epoch 15/200
8s - loss: 1.0512 - acc: 0.6718 - val_loss: 1.0449 - val_acc: 0.6650
Epoch 16/200
8s - loss: 1.0254 - acc: 0.6750 - val_loss: 1.0158 - val_acc: 0.6710
Epoch 17/200
8s - loss: 0.9942 - acc: 0.6804 - val_loss: 0.9911 - val_acc: 0.6725
Epoch 18/200
8s - loss: 0.9881 - acc: 0.6805 - val_loss: 0.9929 - val_acc: 0.6754
Epoch 19/200
8s - loss: 0.9522 - acc: 0.6880 - val_loss: 0.9654 - val_acc: 0.6873
Epoch 20/200
8s - loss: 0.9418 - acc: 0.6931 - val_loss: 0.9677 - val_acc: 0.6752
Epoch 21/200
9s - loss: 0.9383 - acc: 0.6941 - val_loss: 0.9180 - val_acc: 0.6993
Epoch 22/200
8s - loss: 0.9033 - acc: 0.7078 - val_loss: 0.8939 - val_acc: 0.7059
Epoch 23/200
8s - loss: 0.9017 - acc: 0.7161 - val_loss: 0.8951 - val_acc: 0.7040
Epoch 24/200
8s - loss: 0.9023 - acc: 0.7166 - val_loss: 0.8691 - val_acc: 0.7270
Epoch 25/200
8s - loss: 0.8655 - acc: 0.7333 - val_loss: 0.8652 - val_acc: 0.7400
Epoch 26/200
8s - loss: 0.8823 - acc: 0.7300 - val_loss: 0.8429 - val_acc: 0.7408
Epoch 27/200
8s - loss: 0.8444 - acc: 0.7441 - val_loss: 0.8817 - val_acc: 0.7391
Epoch 28/200
8s - loss: 0.8403 - acc: 0.7477 - val_loss: 1.0048 - val_acc: 0.6672
Epoch 29/200
8s - loss: 0.8483 - acc: 0.7415 - val_loss: 0.8370 - val_acc: 0.7534
Epoch 30/200
9s - loss: 0.8088 - acc: 0.7586 - val_loss: 0.7951 - val_acc: 0.7606
Epoch 31/200
8s - loss: 0.7974 - acc: 0.7621 - val_loss: 0.8393 - val_acc: 0.7411
Epoch 32/200
8s - loss: 0.8218 - acc: 0.7510 - val_loss: 0.7869 - val_acc: 0.7621
Epoch 33/200
8s - loss: 0.7785 - acc: 0.7685 - val_loss: 0.7917 - val_acc: 0.7659
Epoch 34/200
8s - loss: 0.7657 - acc: 0.7732 - val_loss: 0.7743 - val_acc: 0.7665
Epoch 35/200
8s - loss: 0.7687 - acc: 0.7710 - val_loss: 0.7638 - val_acc: 0.7735
Epoch 36/200
8s - loss: 0.8037 - acc: 0.7571 - val_loss: 0.7540 - val_acc: 0.7759
Epoch 37/200
9s - loss: 0.7501 - acc: 0.7776 - val_loss: 0.7492 - val_acc: 0.7795
Epoch 38/200
9s - loss: 0.7475 - acc: 0.7782 - val_loss: 0.7511 - val_acc: 0.7778
Epoch 39/200
8s - loss: 0.7543 - acc: 0.7749 - val_loss: 0.7368 - val_acc: 0.7852
Epoch 40/200
8s - loss: 0.7326 - acc: 0.7834 - val_loss: 0.7496 - val_acc: 0.7782
Epoch 41/200
8s - loss: 0.7495 - acc: 0.7766 - val_loss: 0.7210 - val_acc: 0.7873
Epoch 42/200
8s - loss: 0.7289 - acc: 0.7830 - val_loss: 0.7379 - val_acc: 0.7840
Epoch 43/200
8s - loss: 0.7182 - acc: 0.7874 - val_loss: 0.7095 - val_acc: 0.7922
Epoch 44/200
8s - loss: 0.7171 - acc: 0.7876 - val_loss: 0.7171 - val_acc: 0.7911
Epoch 45/200
8s - loss: 0.7031 - acc: 0.7930 - val_loss: 0.7119 - val_acc: 0.7880
Epoch 46/200
9s - loss: 0.7188 - acc: 0.7866 - val_loss: 0.7189 - val_acc: 0.7826
Epoch 47/200
9s - loss: 0.6940 - acc: 0.7949 - val_loss: 0.6950 - val_acc: 0.7961
Epoch 48/200
9s - loss: 0.6936 - acc: 0.7952 - val_loss: 0.7239 - val_acc: 0.7902
Epoch 49/200
8s - loss: 0.6873 - acc: 0.7970 - val_loss: 0.6993 - val_acc: 0.7970
Epoch 50/200
8s - loss: 0.6846 - acc: 0.7980 - val_loss: 0.6805 - val_acc: 0.8023
Epoch 51/200
8s - loss: 0.6744 - acc: 0.8015 - val_loss: 0.6895 - val_acc: 0.7996
Epoch 52/200
9s - loss: 0.6755 - acc: 0.8016 - val_loss: 0.6692 - val_acc: 0.8064
Epoch 53/200
8s - loss: 0.6861 - acc: 0.7974 - val_loss: 0.6680 - val_acc: 0.8047
Epoch 54/200
8s - loss: 0.6654 - acc: 0.8041 - val_loss: 0.6657 - val_acc: 0.8046
Epoch 55/200
8s - loss: 0.6617 - acc: 0.8052 - val_loss: 0.6636 - val_acc: 0.8078
Epoch 56/200
8s - loss: 0.6483 - acc: 0.8098 - val_loss: 0.6645 - val_acc: 0.8079
Epoch 57/200
8s - loss: 0.6594 - acc: 0.8064 - val_loss: 0.6597 - val_acc: 0.8080
Epoch 58/200
8s - loss: 0.6406 - acc: 0.8127 - val_loss: 0.6548 - val_acc: 0.8131
Epoch 59/200
8s - loss: 0.6406 - acc: 0.8125 - val_loss: 0.6512 - val_acc: 0.8124
Epoch 60/200
8s - loss: 0.6574 - acc: 0.8074 - val_loss: 0.6667 - val_acc: 0.8055
Epoch 61/200
8s - loss: 0.6369 - acc: 0.8139 - val_loss: 0.6523 - val_acc: 0.8072
Epoch 62/200
8s - loss: 0.6305 - acc: 0.8154 - val_loss: 0.6471 - val_acc: 0.8136
Epoch 63/200
8s - loss: 0.6239 - acc: 0.8180 - val_loss: 0.6321 - val_acc: 0.8162
Epoch 64/200
8s - loss: 0.6331 - acc: 0.8147 - val_loss: 0.6675 - val_acc: 0.8008
Epoch 65/200
8s - loss: 0.6255 - acc: 0.8175 - val_loss: 0.6303 - val_acc: 0.8174
Epoch 66/200
9s - loss: 0.6105 - acc: 0.8226 - val_loss: 0.6577 - val_acc: 0.8082
Epoch 67/200
8s - loss: 0.6204 - acc: 0.8191 - val_loss: 0.6384 - val_acc: 0.8173
Epoch 68/200
8s - loss: 0.6063 - acc: 0.8238 - val_loss: 0.6641 - val_acc: 0.8031
Epoch 69/200
9s - loss: 0.6169 - acc: 0.8200 - val_loss: 0.6190 - val_acc: 0.8229
Epoch 70/200
8s - loss: 0.5995 - acc: 0.8258 - val_loss: 0.6584 - val_acc: 0.8073
Epoch 71/200
9s - loss: 0.6114 - acc: 0.8218 - val_loss: 0.6136 - val_acc: 0.8259
Epoch 72/200
9s - loss: 0.5907 - acc: 0.8294 - val_loss: 0.6093 - val_acc: 0.8273
Epoch 73/200
9s - loss: 0.5867 - acc: 0.8302 - val_loss: 0.6287 - val_acc: 0.8201
Epoch 74/200
8s - loss: 0.5837 - acc: 0.8319 - val_loss: 0.6119 - val_acc: 0.8234
Epoch 75/200
8s - loss: 0.5882 - acc: 0.8300 - val_loss: 0.6042 - val_acc: 0.8270
Epoch 76/200
9s - loss: 0.5782 - acc: 0.8326 - val_loss: 0.6350 - val_acc: 0.8145
Epoch 77/200
8s - loss: 0.6022 - acc: 0.8251 - val_loss: 0.6051 - val_acc: 0.8257
Epoch 78/200
8s - loss: 0.5695 - acc: 0.8358 - val_loss: 0.5962 - val_acc: 0.8315
Epoch 79/200
8s - loss: 0.5739 - acc: 0.8347 - val_loss: 0.6215 - val_acc: 0.8208
Epoch 80/200
9s - loss: 0.5800 - acc: 0.8328 - val_loss: 0.5933 - val_acc: 0.8322
Epoch 81/200
8s - loss: 0.5616 - acc: 0.8387 - val_loss: 0.5864 - val_acc: 0.8325
Epoch 82/200
8s - loss: 0.5552 - acc: 0.8405 - val_loss: 0.6287 - val_acc: 0.8190
Epoch 83/200
8s - loss: 0.5966 - acc: 0.8275 - val_loss: 0.5844 - val_acc: 0.8338
Epoch 84/200
8s - loss: 0.5517 - acc: 0.8419 - val_loss: 0.5838 - val_acc: 0.8343
Epoch 85/200
9s - loss: 0.5468 - acc: 0.8435 - val_loss: 0.5854 - val_acc: 0.8335
Epoch 86/200
8s - loss: 0.5727 - acc: 0.8341 - val_loss: 0.5845 - val_acc: 0.8349
Epoch 87/200
8s - loss: 0.5448 - acc: 0.8445 - val_loss: 0.5862 - val_acc: 0.8345
Epoch 88/200
8s - loss: 0.5627 - acc: 0.8383 - val_loss: 0.5756 - val_acc: 0.8370
Epoch 89/200
8s - loss: 0.5380 - acc: 0.8466 - val_loss: 0.5701 - val_acc: 0.8387
Epoch 90/200
9s - loss: 0.5374 - acc: 0.8465 - val_loss: 0.5705 - val_acc: 0.8384
Epoch 91/200
9s - loss: 0.5294 - acc: 0.8488 - val_loss: 0.5668 - val_acc: 0.8395
Epoch 92/200
9s - loss: 0.5707 - acc: 0.8349 - val_loss: 0.6709 - val_acc: 0.7998
Epoch 93/200
8s - loss: 0.5632 - acc: 0.8368 - val_loss: 0.5680 - val_acc: 0.8394
Epoch 94/200
8s - loss: 0.5289 - acc: 0.8487 - val_loss: 0.5656 - val_acc: 0.8393
Epoch 95/200
8s - loss: 0.5234 - acc: 0.8503 - val_loss: 0.5604 - val_acc: 0.8418
Epoch 96/200
9s - loss: 0.5292 - acc: 0.8485 - val_loss: 0.5594 - val_acc: 0.8418
Epoch 97/200
8s - loss: 0.5144 - acc: 0.8534 - val_loss: 0.5702 - val_acc: 0.8396
Epoch 98/200
8s - loss: 0.5140 - acc: 0.8533 - val_loss: 0.5863 - val_acc: 0.8350
Epoch 99/200
8s - loss: 0.5246 - acc: 0.8499 - val_loss: 0.5567 - val_acc: 0.8421
Epoch 100/200
8s - loss: 0.5066 - acc: 0.8558 - val_loss: 0.5548 - val_acc: 0.8433
Epoch 101/200
8s - loss: 0.5209 - acc: 0.8512 - val_loss: 0.5506 - val_acc: 0.8440
Epoch 102/200
8s - loss: 0.5029 - acc: 0.8571 - val_loss: 0.5689 - val_acc: 0.8388
Epoch 103/200
8s - loss: 0.5168 - acc: 0.8516 - val_loss: 0.5542 - val_acc: 0.8422
Epoch 104/200
8s - loss: 0.5036 - acc: 0.8565 - val_loss: 0.5650 - val_acc: 0.8380
Epoch 105/200
8s - loss: 0.5074 - acc: 0.8544 - val_loss: 0.5675 - val_acc: 0.8383
Epoch 106/200
8s - loss: 0.5117 - acc: 0.8540 - val_loss: 0.5447 - val_acc: 0.8457
Epoch 107/200
8s - loss: 0.4990 - acc: 0.8579 - val_loss: 0.5515 - val_acc: 0.8438
Epoch 108/200
8s - loss: 0.4947 - acc: 0.8594 - val_loss: 0.5631 - val_acc: 0.8369
Epoch 109/200
8s - loss: 0.4960 - acc: 0.8579 - val_loss: 0.5633 - val_acc: 0.8401
Epoch 110/200
8s - loss: 0.4961 - acc: 0.8579 - val_loss: 0.5500 - val_acc: 0.8432
Epoch 111/200
8s - loss: 0.4937 - acc: 0.8588 - val_loss: 0.5465 - val_acc: 0.8448
Epoch 112/200
8s - loss: 0.4907 - acc: 0.8603 - val_loss: 0.5764 - val_acc: 0.8357
Epoch 113/200
9s - loss: 0.4873 - acc: 0.8613 - val_loss: 0.5452 - val_acc: 0.8454
Epoch 114/200
8s - loss: 0.4769 - acc: 0.8646 - val_loss: 0.5358 - val_acc: 0.8477
Epoch 115/200
9s - loss: 0.5442 - acc: 0.8441 - val_loss: 0.5497 - val_acc: 0.8426
Epoch 116/200
8s - loss: 0.4905 - acc: 0.8603 - val_loss: 0.5376 - val_acc: 0.8477
Epoch 117/200
8s - loss: 0.4764 - acc: 0.8648 - val_loss: 0.5345 - val_acc: 0.8485
Epoch 118/200
8s - loss: 0.4947 - acc: 0.8581 - val_loss: 0.5379 - val_acc: 0.8469
Epoch 119/200
9s - loss: 0.4701 - acc: 0.8672 - val_loss: 0.5381 - val_acc: 0.8480
Epoch 120/200
8s - loss: 0.4740 - acc: 0.8652 - val_loss: 0.5315 - val_acc: 0.8489
Epoch 121/200
9s - loss: 0.4642 - acc: 0.8685 - val_loss: 0.5343 - val_acc: 0.8487
Epoch 122/200
8s - loss: 0.4665 - acc: 0.8674 - val_loss: 0.5459 - val_acc: 0.8451
Epoch 123/200
8s - loss: 0.4630 - acc: 0.8688 - val_loss: 0.5298 - val_acc: 0.8500
Epoch 124/200
9s - loss: 0.4556 - acc: 0.8714 - val_loss: 0.5318 - val_acc: 0.8486
Epoch 125/200
8s - loss: 0.4677 - acc: 0.8671 - val_loss: 0.5622 - val_acc: 0.8413
Epoch 126/200
8s - loss: 0.4634 - acc: 0.8683 - val_loss: 0.5258 - val_acc: 0.8508
Epoch 127/200
8s - loss: 0.4568 - acc: 0.8708 - val_loss: 0.5472 - val_acc: 0.8445
Epoch 128/200
9s - loss: 0.4681 - acc: 0.8669 - val_loss: 0.5265 - val_acc: 0.8507
Epoch 129/200
9s - loss: 0.4569 - acc: 0.8702 - val_loss: 0.5280 - val_acc: 0.8505
Epoch 130/200
9s - loss: 0.4532 - acc: 0.8714 - val_loss: 0.5234 - val_acc: 0.8517
Epoch 131/200
8s - loss: 0.4531 - acc: 0.8714 - val_loss: 0.5580 - val_acc: 0.8403
Epoch 132/200
8s - loss: 0.4724 - acc: 0.8654 - val_loss: 0.5249 - val_acc: 0.8520
Epoch 133/200
8s - loss: 0.4497 - acc: 0.8728 - val_loss: 0.5377 - val_acc: 0.8453
Epoch 134/200
9s - loss: 0.4525 - acc: 0.8708 - val_loss: 0.5580 - val_acc: 0.8396
Epoch 135/200
8s - loss: 0.4450 - acc: 0.8744 - val_loss: 0.5307 - val_acc: 0.8494
Epoch 136/200
9s - loss: 0.4604 - acc: 0.8680 - val_loss: 0.5272 - val_acc: 0.8502
Epoch 137/200
8s - loss: 0.4410 - acc: 0.8753 - val_loss: 0.5218 - val_acc: 0.8521
Epoch 138/200
9s - loss: 0.4524 - acc: 0.8719 - val_loss: 0.5386 - val_acc: 0.8466
Epoch 139/200
8s - loss: 0.4381 - acc: 0.8767 - val_loss: 0.5209 - val_acc: 0.8518
Epoch 140/200
8s - loss: 0.4336 - acc: 0.8779 - val_loss: 0.5237 - val_acc: 0.8519
Epoch 141/200
8s - loss: 0.4395 - acc: 0.8754 - val_loss: 0.5429 - val_acc: 0.8450
Epoch 142/200
8s - loss: 0.4376 - acc: 0.8760 - val_loss: 0.5281 - val_acc: 0.8500
Epoch 143/200
8s - loss: 0.4379 - acc: 0.8757 - val_loss: 0.5400 - val_acc: 0.8448
Epoch 144/200
9s - loss: 0.4380 - acc: 0.8758 - val_loss: 0.5325 - val_acc: 0.8481
Epoch 145/200
8s - loss: 0.4335 - acc: 0.8777 - val_loss: 0.5372 - val_acc: 0.8464
Epoch 146/200
9s - loss: 0.4421 - acc: 0.8745 - val_loss: 0.5167 - val_acc: 0.8524
Epoch 147/200
8s - loss: 0.4237 - acc: 0.8808 - val_loss: 0.5236 - val_acc: 0.8504
Epoch 148/200
9s - loss: 0.4263 - acc: 0.8797 - val_loss: 0.5258 - val_acc: 0.8496
Epoch 149/200
9s - loss: 0.4366 - acc: 0.8767 - val_loss: 0.5343 - val_acc: 0.8469
Epoch 150/200
8s - loss: 0.4365 - acc: 0.8767 - val_loss: 0.5168 - val_acc: 0.8528
Epoch 151/200
9s - loss: 0.4244 - acc: 0.8801 - val_loss: 0.5549 - val_acc: 0.8411
Epoch 152/200
9s - loss: 0.4284 - acc: 0.8787 - val_loss: 0.5144 - val_acc: 0.8535
Epoch 153/200
8s - loss: 0.4207 - acc: 0.8818 - val_loss: 0.5240 - val_acc: 0.8505
Epoch 154/200
8s - loss: 0.4340 - acc: 0.8767 - val_loss: 0.5102 - val_acc: 0.8543
Epoch 155/200
9s - loss: 0.4172 - acc: 0.8827 - val_loss: 0.5235 - val_acc: 0.8511
Epoch 156/200
8s - loss: 0.4235 - acc: 0.8805 - val_loss: 0.5109 - val_acc: 0.8546
Epoch 157/200
9s - loss: 0.4127 - acc: 0.8841 - val_loss: 0.5160 - val_acc: 0.8527
Epoch 158/200
8s - loss: 0.4219 - acc: 0.8803 - val_loss: 0.5179 - val_acc: 0.8531
Epoch 159/200
8s - loss: 0.4110 - acc: 0.8844 - val_loss: 0.5185 - val_acc: 0.8527
Epoch 160/200
8s - loss: 0.4288 - acc: 0.8776 - val_loss: 0.5277 - val_acc: 0.8486
Epoch 161/200
8s - loss: 0.4074 - acc: 0.8856 - val_loss: 0.5185 - val_acc: 0.8527
Epoch 162/200
9s - loss: 0.4089 - acc: 0.8852 - val_loss: 0.5154 - val_acc: 0.8536
Epoch 163/200
9s - loss: 0.4194 - acc: 0.8816 - val_loss: 0.5074 - val_acc: 0.8555
Epoch 164/200
8s - loss: 0.4028 - acc: 0.8872 - val_loss: 0.5131 - val_acc: 0.8540
Epoch 165/200
9s - loss: 0.4149 - acc: 0.8829 - val_loss: 0.5130 - val_acc: 0.8534
Epoch 166/200
9s - loss: 0.4024 - acc: 0.8870 - val_loss: 0.5213 - val_acc: 0.8528
Epoch 167/200
8s - loss: 0.4022 - acc: 0.8870 - val_loss: 0.5251 - val_acc: 0.8500
Epoch 168/200
9s - loss: 0.4081 - acc: 0.8850 - val_loss: 0.5180 - val_acc: 0.8533
Epoch 169/200
9s - loss: 0.4013 - acc: 0.8874 - val_loss: 0.5397 - val_acc: 0.8472
Epoch 170/200
8s - loss: 0.4096 - acc: 0.8842 - val_loss: 0.5206 - val_acc: 0.8523
Epoch 171/200
8s - loss: 0.3963 - acc: 0.8890 - val_loss: 0.5155 - val_acc: 0.8535
Epoch 172/200
8s - loss: 0.4020 - acc: 0.8866 - val_loss: 0.5463 - val_acc: 0.8437
Epoch 173/200
9s - loss: 0.4096 - acc: 0.8845 - val_loss: 0.5082 - val_acc: 0.8554
Epoch 174/200
8s - loss: 0.3943 - acc: 0.8897 - val_loss: 0.5112 - val_acc: 0.8550
Epoch 175/200
8s - loss: 0.3940 - acc: 0.8895 - val_loss: 0.5319 - val_acc: 0.8510
Epoch 176/200
8s - loss: 0.3995 - acc: 0.8873 - val_loss: 0.5217 - val_acc: 0.8527
Epoch 177/200
8s - loss: 0.3987 - acc: 0.8877 - val_loss: 0.5183 - val_acc: 0.8515
Epoch 178/200
8s - loss: 0.3957 - acc: 0.8883 - val_loss: 0.5058 - val_acc: 0.8560
Epoch 179/200
8s - loss: 0.3901 - acc: 0.8906 - val_loss: 0.5072 - val_acc: 0.8556
Epoch 180/200
8s - loss: 0.3956 - acc: 0.8891 - val_loss: 0.5070 - val_acc: 0.8561
Epoch 181/200
8s - loss: 0.3882 - acc: 0.8911 - val_loss: 0.5173 - val_acc: 0.8525
Epoch 182/200
8s - loss: 0.3977 - acc: 0.8871 - val_loss: 0.5101 - val_acc: 0.8549
Epoch 183/200
8s - loss: 0.3849 - acc: 0.8925 - val_loss: 0.5121 - val_acc: 0.8536
Epoch 184/200
8s - loss: 0.3819 - acc: 0.8936 - val_loss: 0.5203 - val_acc: 0.8521
Epoch 185/200
8s - loss: 0.3919 - acc: 0.8899 - val_loss: 0.5063 - val_acc: 0.8558
Epoch 186/200
8s - loss: 0.3897 - acc: 0.8908 - val_loss: 0.5236 - val_acc: 0.8498
Epoch 187/200
8s - loss: 0.3873 - acc: 0.8915 - val_loss: 0.5080 - val_acc: 0.8568
Epoch 188/200
8s - loss: 0.3859 - acc: 0.8915 - val_loss: 0.5192 - val_acc: 0.8529
Epoch 189/200
8s - loss: 0.3792 - acc: 0.8940 - val_loss: 0.5178 - val_acc: 0.8533
Epoch 190/200
8s - loss: 0.3963 - acc: 0.8874 - val_loss: 0.5385 - val_acc: 0.8456
Epoch 191/200
8s - loss: 0.3862 - acc: 0.8910 - val_loss: 0.5071 - val_acc: 0.8562
Epoch 192/200
9s - loss: 0.3838 - acc: 0.8927 - val_loss: 0.5186 - val_acc: 0.8518
Epoch 193/200
9s - loss: 0.3828 - acc: 0.8927 - val_loss: 0.5325 - val_acc: 0.8474
Epoch 194/200
9s - loss: 0.3877 - acc: 0.8913 - val_loss: 0.5066 - val_acc: 0.8558
Epoch 195/200
8s - loss: 0.3775 - acc: 0.8946 - val_loss: 0.5084 - val_acc: 0.8563
Epoch 196/200
8s - loss: 0.3758 - acc: 0.8951 - val_loss: 0.5235 - val_acc: 0.8532
Epoch 197/200
8s - loss: 0.3846 - acc: 0.8916 - val_loss: 0.5046 - val_acc: 0.8566
Epoch 198/200
8s - loss: 0.3737 - acc: 0.8957 - val_loss: 0.5085 - val_acc: 0.8562
Epoch 199/200
8s - loss: 0.3785 - acc: 0.8937 - val_loss: 0.5080 - val_acc: 0.8565
Epoch 200/200
9s - loss: 0.3731 - acc: 0.8956 - val_loss: 0.5104 - val_acc: 0.8552

Plot the change in loss over epochs

In [10]:
for key in ['loss', 'val_loss']:
    plt.plot(hist1.history[key],label=key)
plt.legend()
plt.show()

Calculate intersection over union for each segmentation class

In [11]:
y_pred = model.predict(X_test)
y_predi = np.argmax(y_pred, axis=3)
y_testi = np.argmax(y_test, axis=3)
print(y_testi.shape,y_predi.shape)
((56, 224, 224), (56, 224, 224))
In [12]:
def IoU(Yi,y_predi):
    ## mean Intersection over Union
    ## Mean IoU = TP/(FN + TP + FP)

    IoUs = []
    Nclass = int(np.max(Yi)) + 1
    for c in range(Nclass):
        TP = np.sum( (Yi == c)&(y_predi==c) )
        FP = np.sum( (Yi != c)&(y_predi==c) )
        FN = np.sum( (Yi == c)&(y_predi != c)) 
        IoU = TP/float(TP + FP + FN)
        print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU))
        IoUs.append(IoU)
    mIoU = np.mean(IoUs)
    print("_________________")
    print("Mean IoU: {:4.3f}".format(mIoU))
    
IoU(y_testi,y_predi)
class 00: #TP=396127, #FP= 44118, #FN=25057, IoU=0.851
class 01: #TP=636671, #FP=136114, #FN=50651, IoU=0.773
class 02: #TP=    11, #FP=   217, #FN=35459, IoU=0.000
class 03: #TP=840769, #FP= 33361, #FN=42747, IoU=0.917
class 04: #TP= 87924, #FP= 30018, #FN=46374, IoU=0.535
class 05: #TP=224159, #FP= 58411, #FN=48666, IoU=0.677
class 06: #TP=  1621, #FP=  3406, #FN=38029, IoU=0.038
class 07: #TP= 12823, #FP= 12332, #FN=15184, IoU=0.318
class 08: #TP=137585, #FP= 34734, #FN=40379, IoU=0.647
class 09: #TP=  1043, #FP=  3328, #FN=19466, IoU=0.044
class 10: #TP=   362, #FP=  1100, #FN= 9388, IoU=0.033
class 11: #TP= 63859, #FP= 49763, #FN=35502, IoU=0.428
_________________
Mean IoU: 0.438

Visualize the model performance

Looks reasonable!

In [13]:
shape = (224,224)
n_classes= 10

for i in range(10):
    img_is  = (X_test[i] + 1)*(255.0/2)
    seg = y_predi[i]
    segtest = y_testi[i]

    fig = plt.figure(figsize=(10,30))    
    ax = fig.add_subplot(1,3,1)
    ax.imshow(img_is/255.0)
    ax.set_title("original")
    
    ax = fig.add_subplot(1,3,2)
    ax.imshow(give_color_to_seg_img(seg,n_classes))
    ax.set_title("predicted class")
    
    ax = fig.add_subplot(1,3,3)
    ax.imshow(give_color_to_seg_img(segtest,n_classes))
    ax.set_title("true class")
    plt.show()

Comments