Yumi's Blog

Classification with Mahalanobis distance + full covariance using tensorflow

This blog discusses how to calculate Mahalanobis distance using tensorflow. I will consider full variance approach, i.e., each cluster has its own general covariance matrix, so I do not assume common variance accross clusters unlike the previous post. Calculation of Mahalanobis distance is important for classification when each cluster has different covariance structure. Let's take a lookt at this situation using toy data.

Bonus: This blog post goes over how to use tf.while_loop

Example Data

In the following toy data, I generate 60 samples from 2-d Gaussian mixture model with three components: 20 samples from each of a 2-d gaussian. Notice that each gaussian distribution has different variance matrix.

In [1]:
import numpy as np 
np.random.seed(0)
import matplotlib.pyplot as plt
N_CLUSTER  = 3 
N_FEATURES = 2 
_N = 20
N_SAMPLE  = _N * N_CLUSTER


S1 = np.identity(N_FEATURES)
S1[0,1] = S1[1,0] = 0
S2 = np.identity(N_FEATURES)*0.1
S3 = np.identity(N_FEATURES)*3
S3[0,1] = S3[1,0] = -2.8

mu1 = np.array([1,  2])
mu2 = np.array([-2,-2])
mu3 = np.array([-3, 2])

npMEANS  = np.array([mu1,mu2,mu3])
npSIGMAS = np.array([S1,S2,S3])

colors   = np.array(["red","green","blue"])
npX = []
plt.figure(figsize=(5,5))
for icluster in range(N_CLUSTER):
    color     = colors[icluster]
    npmean    = npMEANS[icluster]
    s_cluster = np.random.multivariate_normal(npmean,npSIGMAS[icluster],_N)
    npX.extend(s_cluster)
    plt.plot(npmean[0],npmean[1],"X",color=color)
    plt.plot(s_cluster[:,0],s_cluster[:,1],"p",alpha=0.3,color=color,label="cluster={}".format(icluster))
plt.title("Sample distribution from {} clusters".format(N_CLUSTER))
plt.legend()
plt.show()

GT = [0 for _ in range(_N)] + [1 for _ in range(_N)] + [2 for _ in range(_N)]
npX = np.array(npX)
print("npX:      Data Dimension = (N_SAMPLE,N_FEATURES)  = {}".format(npX.shape))
print("npMEANS:  Data Dimension = (N_CLUSTER,N_FEATURES) = {}".format(npMEANS.shape))
print("npSIGMAS: Data Dimension = (N_CLUSTER,N_FEATURES,N_FEATURES) = {}".format(npSIGMAS.shape))

for icluster in range(N_CLUSTER):
    print("\n***CLUSTER={}***".format(icluster))
    print(">>>MEAN")
    print(npMEANS[icluster])
    print(">>>SIGMA")
    print(npSIGMAS[icluster])
npX:      Data Dimension = (N_SAMPLE,N_FEATURES)  = (60, 2)
npMEANS:  Data Dimension = (N_CLUSTER,N_FEATURES) = (3, 2)
npSIGMAS: Data Dimension = (N_CLUSTER,N_FEATURES,N_FEATURES) = (3, 2, 2)

***CLUSTER=0***
>>>MEAN
[1 2]
>>>SIGMA
[[1. 0.]
 [0. 1.]]

***CLUSTER=1***
>>>MEAN
[-2 -2]
>>>SIGMA
[[0.1 0. ]
 [0.  0.1]]

***CLUSTER=2***
>>>MEAN
[-3  2]
>>>SIGMA
[[ 3.  -2.8]
 [-2.8  3. ]]

In the plot above, the samples are colored with ground truth clusters. Due to the difference in variance matrix, the shape of each cluster is different!

- red cluster: circle 
- blue cluster: high negatice correlation
- green cluster: circle, small variance 

If you use Euclidean distance, covariance matrix of each cluster is implicitly assumed to be the same (as an identity matrix) so many blue points in the bottom of the blue clusters are classified into the red cluster.

The script below shows this issue.

Euclidean distance based classification (numpy)

The classification accuracy is 0.9.

In [2]:
from scipy.spatial import distance
import scipy
npEuclidean = [[0 for _ in range(npMEANS.shape[0])]
            for _ in range(npX.shape[0])]
for isample in range(npX.shape[0]):
    for icluster in range(npMEANS.shape[0]):
        npEuclidean[isample][icluster] = distance.euclidean(npX[isample],npMEANS[icluster])
npEuclidean= np.array(npEuclidean)

pred = npEuclidean.argmin(axis=1)
plt.figure(figsize=(5,5))
plt.scatter(npX[:,0],npX[:,1],c=colors[pred])
plt.title("Classification using Euclidean distance acc={}".format(np.mean(GT == pred)))
plt.show()

Mahalanobis distance based classification (numpy)

If you use mahalanobis distance to calcualte the distance between samples and cluster centers, classification performance improves to 0.98.

In [3]:
from scipy.spatial import distance
import scipy
npMAHALANOBIS = [[0 for _ in range(npMEANS.shape[0])]
            for _ in range(npX.shape[0])]
for isample in range(npX.shape[0]):
    for icluster in range(npMEANS.shape[0]):
        npMAHALANOBIS[isample][icluster] = distance.mahalanobis(npX[isample],npMEANS[icluster],
                                                                VI=scipy.linalg.pinv(npSIGMAS[icluster]))
npMAHALANOBIS = np.array(npMAHALANOBIS)

pred = npMAHALANOBIS.argmin(axis=1)
plt.figure(figsize=(5,5))
plt.scatter(npX[:,0],npX[:,1],c=colors[pred])
plt.title("Classification using Mahalanobis distance, acc={:3.2f}".format(np.mean(GT == pred)))
plt.show()
print("npMAHALANOBIS")
npMAHALANOBIS
/Users/yumikondo/anaconda3/lib/python3.6/site-packages/scipy/linalg/basic.py:1321: RuntimeWarning: internal gelsd driver lwork query error, required iwork dimension not returned. This is likely the result of LAPACK bug 0038, fixed in LAPACK 3.2.2 (released July 21, 2010). Falling back to 'gelss' driver.
  x, resids, rank, s = lstsq(a, b, cond=cond, check_finite=False)
npMAHALANOBIS
Out[3]:
array([[ 1.80886884, 20.50794437,  9.87289163],
       [ 2.44530783, 23.4049362 , 11.44350799],
       [ 2.10780571, 18.11904241,  7.98912043],
       [ 0.96206913, 17.43996842,  7.73389073],
       [ 0.42337367, 16.6867373 ,  6.88705391],
       [ 1.46138974, 19.90831722,  8.88687686],
       [ 0.77070307, 17.64471867,  7.83949613],
       [ 0.55529553, 17.50455002,  7.64973936],
       [ 1.50809887, 18.60042218,  8.52829531],
       [ 0.90966528, 14.44753654,  5.6756414 ],
       [ 2.63533191, 14.78376952,  3.32954469],
       [ 1.13932386, 15.98354025,  6.72254601],
       [ 2.69573099, 18.50690896,  7.94438202],
       [ 0.19269571, 15.43185405,  6.22624559],
       [ 2.12330575, 22.46329734, 11.13544953],
       [ 0.40867542, 17.0651693 ,  7.25274973],
       [ 2.17064925,  9.24047182,  2.33160146],
       [ 0.38142871, 15.59128177,  6.10851391],
       [ 1.72027104, 21.20380045, 10.23947028],
       [ 0.49133392, 14.31748096,  5.35888769],
       [ 5.55818978,  1.76519525,  6.16364726],
       [ 4.8963245 ,  2.59169493,  4.75696051],
       [ 5.20772623,  0.67205241,  5.41809113],
       [ 5.06236096,  1.4744448 ,  5.14278905],
       [ 5.37264771,  1.62785893,  5.81281208],
       [ 5.08088469,  0.97547625,  5.17655799],
       [ 5.39642988,  1.28639583,  5.79483163],
       [ 4.89778468,  0.429258  ,  4.76177196],
       [ 4.91102675,  0.30969952,  4.78279185],
       [ 5.21292689,  0.73071586,  5.43689182],
       [ 5.21954574,  0.76254936,  5.45272851],
       [ 5.5923296 ,  1.90820819,  6.21039728],
       [ 5.06942627,  0.43921288,  5.08849233],
       [ 5.21627247,  1.69461325,  5.47728254],
       [ 5.16455747,  0.90878416,  5.35180415],
       [ 4.83168346,  0.7404118 ,  4.57555954],
       [ 5.12292219,  1.68018711,  5.07995302],
       [ 5.10217037,  0.79425667,  5.13045246],
       [ 5.31280939,  1.04563598,  5.65490645],
       [ 5.04569747,  0.31657468,  5.07777588],
       [ 2.42566847,  8.30881401,  1.47277375],
       [ 5.28772819, 15.4093536 ,  1.60526826],
       [ 6.71149929, 24.39395134,  2.41024695],
       [ 6.37024254, 21.16558702,  1.19243206],
       [ 2.37005382,  8.73988553,  1.50279061],
       [ 2.94217956, 11.70266824,  1.28721544],
       [ 4.09988693, 15.11363531,  0.99860016],
       [ 4.46137739, 15.88891474,  0.79135495],
       [ 3.50194497, 14.56219972,  1.78590136],
       [ 4.10338449, 14.15969423,  0.4215473 ],
       [ 8.12381008, 25.96984533,  2.3157528 ],
       [ 2.40604389,  8.21804927,  1.59808076],
       [ 1.95921587,  9.71997443,  2.27021625],
       [ 3.65513219,  9.81943601,  0.85426539],
       [ 7.76768041, 27.28381972,  2.42685601],
       [ 7.71646691, 26.63001481,  2.07573924],
       [ 2.11342027, 10.48309795,  2.09524648],
       [ 3.29608231, 12.04325452,  0.84602723],
       [ 5.87419015, 19.50540307,  0.95985125],
       [ 4.93863668, 17.76637109,  1.10795244]])

I hope that you are convinced that Mahalanobis distance is more preferable especially when cluster shapes are different. Now, we will calculate the Mahalanobis distance with tensorflow!

Tensorflow calculation

Previously, Calculate Mahalanobis distance with tensorflow 2.0 discussed how to utilize Euclidean distance function to compute Mahalanobis distance. The idea was to first decompose the inverse of variance matrix by Cholesky decomposition and standardize the samples. We will use codes from there.

The following code calculates the Euclidean distances between two groups.

In [4]:
import tensorflow as tf
def Euclidean(A,B):
        v   = tf.expand_dims(tf.reduce_sum(tf.square(A), 1), 1)
        p1  = tf.reshape(tf.reduce_sum(v,axis=1),(-1,1))
        v   = tf.reshape(tf.reduce_sum(tf.square(B), 1), shape=[-1, 1])
        p2  = tf.transpose(tf.reshape(tf.reduce_sum(v,axis=1),(-1,1)))
        res = tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(A, B, transpose_b=True))
        return(res)
    
## Sanity check    
npA = np.random.random((N_SAMPLE,N_FEATURES))    
npB = np.random.random((N_CLUSTER,N_FEATURES))
tfA = tf.constant(npA,dtype="float32")
tfB = tf.constant(npB,dtype="float32")
with tf.Session() as sess:
    tfE = sess.run(Euclidean(tfA,tfB))
assert tfE.shape == (N_SAMPLE,N_CLUSTER)

from scipy.spatial.distance import cdist
print("Average MSE={}".format(np.mean((tfE-cdist(npA,npB,metric="euclidean"))**2)))
/Users/yumikondo/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Average MSE=1.0233338821564651e-13

Convert numpy array to tensor objects

.

I will keep:

  • tfSinv_half: Cholesky decomposition of the inverse of the cluster-covariance matrix
  • tfMEANS: standardized cluster means. Standardization is done by the cholesky decomposition of the inverse of the cluster-covariance matrix
  • tfX: sampled ata
In [5]:
def get_mu1_Sinv_half(mu1,S1):
    Sinv      = np.linalg.pinv(S1)
    #Sinv_half = np.linalg.cholesky(Sinv + 0.01*np.identity(Sinv.shape[0])) 
    Sinv_half = np.linalg.cholesky(Sinv)
    mu1_stand = np.matmul(mu1,Sinv_half)    
    return(mu1_stand,Sinv_half)


npMEANS_stand = []
npSinv_half   = []
for icluster in range(N_CLUSTER):
    mu_stand,Sinv_half = get_mu1_Sinv_half(npMEANS[icluster], # (N_FEATURES,)
                                           npSIGMAS[icluster]) # (N_FEATURES,N_FEATURES)
    npMEANS_stand.append(mu_stand)
    npSinv_half.append(Sinv_half)
npMEANS_stand = np.array(npMEANS_stand)
npSinv_half = np.array(npSinv_half)


tfSinv_half = tf.constant(npSinv_half,dtype="float32")
print("tfSinv_half:",tfSinv_half.get_shape())


tfMEANS = tf.constant(npMEANS_stand,dtype="float32")
print("tfMEANS:",tfMEANS.get_shape())


tfX = tf.constant(npX,dtype="float32")
print("tfX:",tfX.get_shape())
tfSinv_half: (3, 2, 2)
tfMEANS: (3, 2)
tfX: (60, 2)

Example of tf.while_loop: Simple looping with indexing within the loop

I will use tf.while_loop to calcualte the Mahalanobis distance. Here I shows the example usage of the tf.while_loop.

The following code simply slice each standardized cluster mean 3 times and return the final standardized cluster mean.

In [6]:
def body(i,x):
    x = tfMEANS[i]
    i += 1
    return(i,x)

def condition(i,x):
    return i < N_CLUSTER

x = tf.ones(N_FEATURES,dtype="float32")
i = tf.constant(0,dtype="int32")

# while_loop must be instantiated AFTER all the global variables are defined
print("Looping N_CLUSTER={}".format(N_CLUSTER))
wl = tf.while_loop(condition,
                   body,
                   loop_vars=[i,x])
with tf.Session() as sess:
    print(sess.run(wl))
Looping N_CLUSTER=3
(3, array([-1.8225913,  1.1547005], dtype=float32))

Looping with Euclidean distance function

You can modify the code above to compute the Mahalanobis distance between the first sample and each cluster center. The code below returns the Mahalanobis distance between the first sample and the final cluster.

In [7]:
import sys 
    
def body(i,out):
    mu_stand  = tf.reshape(tfMEANS[i],(1,-1))           ## (n_feat)  
    Sinv_half = tfSinv_half[i] ## (n_feat, n_feat)
    X_stand   = tf.matmul(tfXi,Sinv_half) 
    out = Euclidean(X_stand,mu_stand)
    i += 1
    return(i,out)

def condition(i,x):
    return i < N_CLUSTER

INDEX_FIRST_SAMPLE = 0
tfXi = tf.reshape(tfX[INDEX_FIRST_SAMPLE],(1,-1))
print("the dimension of sample:",tfXi.get_shape())
x = tf.ones((1,1),dtype="float32")
i = tf.constant(0,dtype="int32")

# while_loop must be instantiated AFTER all the global variables are defined

wl = tf.while_loop(condition,
                   body,
                   loop_vars=[i,x])

with tf.Session() as sess:
    print(sess.run(wl))
    
print("Mahalanobis distance between the 3rd cluster center and the first cluster mean (numpy) {}".format(npMAHALANOBIS[0,-1]))
the dimension of sample: (1, 2)
(3, array([[9.872893]], dtype=float32))
Mahalanobis distance between the 3rd cluster center and the first cluster mean (numpy) 9.872891632237177

Mahalanobis distance calculation

Modify the function above to finally obtain the Mahalanobis distances between ALL samples and ALL cluster centers.

Please see the example in tf.while_loop to understand how to handle a shape invariance of the looping variables.

In [8]:
def EfficientMahalanobis(tfX,tfMEANS,tfSinv_half):
    """
    tfX         : (N_SAMPLE, N_FEATURES)
    tfSinv_half : (N_CLUSTER, N_FEATURES,N_FEATURES)
    tfMEANS     : (N_CLUSTER, N_FEATURES)
    
    Global variables need to be defined. 
        N_CLUSTER
        
    OUTPUT:
    
    Mdist : (N_SAMPLE,N_CLUSTER) 
    
    """
    def Euclidean(A,B):
        v   = tf.expand_dims(tf.reduce_sum(tf.square(A), 1), 1)
        p1  = tf.reshape(tf.reduce_sum(v,axis=1),(-1,1))
        v   = tf.reshape(tf.reduce_sum(tf.square(B), 1), shape=[-1, 1])
        p2  = tf.transpose(tf.reshape(tf.reduce_sum(v,axis=1),(-1,1)))
        res = tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(A, B, transpose_b=True))
        return(res)
    
    def body(i,out):
        mu_stand  = tf.reshape(tfMEANS[i],(1,-1))           ## (n_feat)  
        Sinv_half = tfSinv_half[i] ## (n_feat, n_feat)
        X_stand   = tf.matmul(tfX,Sinv_half) 
        euc       = Euclidean(X_stand,mu_stand) # (N_SAMPLE,1)
        out       = tf.cond(tf.equal(i,0), lambda : euc, lambda : tf.concat([out,euc],axis=1))
        i += 1
        return(i,out)

    def condition(i,x):
        return i < N_CLUSTER


    i = tf.constant(0,dtype="int32")
    out = tf.ones((0,0),dtype="float32")
    print("out.get_shape()",out)
    # while_loop must be instantiated AFTER all the global variables are defined
    _ncluster, Mdist = tf.while_loop(condition,
                       body,
                       loop_vars=[i,out],
                       shape_invariants=[i.get_shape(), 
                                         tf.TensorShape([None, # N_SAMPLE 
                                                         None])]) # N_CLUSTER
    print("_ncluster      ",_ncluster)
    print("Mdist          ",Mdist)
    return(Mdist)

Mdist = EfficientMahalanobis(tfX,tfMEANS,tfSinv_half)
with tf.Session() as sess:
     tfMAHALANOBIS = sess.run(Mdist)
print("tfMAHALANOBIS.shape",tfMAHALANOBIS.shape)    
tfMAHALANOBIS
out.get_shape() Tensor("ones_2:0", shape=(0, 0), dtype=float32)
_ncluster       Tensor("while_2/Exit:0", shape=(), dtype=int32)
Mdist           Tensor("while_2/Exit_1:0", shape=(?, ?), dtype=float32)
tfMAHALANOBIS.shape (60, 3)
Out[8]:
array([[ 1.8088686 , 20.507944  ,  9.872893  ],
       [ 2.4453077 , 23.404936  , 11.443509  ],
       [ 2.107806  , 18.119043  ,  7.98912   ],
       [ 0.96206915, 17.439968  ,  7.7338896 ],
       [ 0.42337447, 16.686737  ,  6.8870544 ],
       [ 1.4613898 , 19.908318  ,  8.886876  ],
       [ 0.7707028 , 17.644718  ,  7.839496  ],
       [ 0.55529493, 17.50455   ,  7.64974   ],
       [ 1.508099  , 18.60042   ,  8.5282955 ],
       [ 0.909665  , 14.4475355 ,  5.6756415 ],
       [ 2.6353319 , 14.78377   ,  3.3295443 ],
       [ 1.139324  , 15.98354   ,  6.7225456 ],
       [ 2.6957312 , 18.506908  ,  7.944382  ],
       [ 0.19269733, 15.431853  ,  6.226246  ],
       [ 2.1233056 , 22.463297  , 11.135449  ],
       [ 0.40867597, 17.065168  ,  7.2527504 ],
       [ 2.170649  ,  9.240472  ,  2.3316014 ],
       [ 0.38142985, 15.591281  ,  6.1085134 ],
       [ 1.7202713 , 21.203802  , 10.2394705 ],
       [ 0.49133426, 14.317481  ,  5.358888  ],
       [ 5.5581894 ,  1.7651972 ,  6.1636477 ],
       [ 4.896325  ,  2.591696  ,  4.75696   ],
       [ 5.2077265 ,  0.67204523,  5.418091  ],
       [ 5.062361  ,  1.4744425 ,  5.142789  ],
       [ 5.372648  ,  1.6278567 ,  5.8128123 ],
       [ 5.080885  ,  0.975476  ,  5.176558  ],
       [ 5.3964295 ,  1.2863964 ,  5.794832  ],
       [ 4.897785  ,  0.42926118,  4.761772  ],
       [ 4.9110265 ,  0.3097043 ,  4.782792  ],
       [ 5.212927  ,  0.7307194 ,  5.436893  ],
       [ 5.219546  ,  0.76254964,  5.4527287 ],
       [ 5.5923295 ,  1.9082062 ,  6.2103972 ],
       [ 5.069426  ,  0.43920565,  5.0884933 ],
       [ 5.2162724 ,  1.6946104 ,  5.477282  ],
       [ 5.1645575 ,  0.9087805 ,  5.3518043 ],
       [ 4.8316836 ,  0.74040693,  4.5755596 ],
       [ 5.1229224 ,  1.680187  ,  5.0799527 ],
       [ 5.10217   ,  0.7942569 ,  5.130453  ],
       [ 5.312809  ,  1.0456353 ,  5.654906  ],
       [ 5.0456977 ,  0.31659907,  5.0777755 ],
       [ 2.4256685 ,  8.308814  ,  1.4727737 ],
       [ 5.287729  , 15.409354  ,  1.6052685 ],
       [ 6.7114997 , 24.39395   ,  2.410247  ],
       [ 6.370243  , 21.165588  ,  1.1924319 ],
       [ 2.3700538 ,  8.739886  ,  1.5027907 ],
       [ 2.9421794 , 11.702669  ,  1.2872156 ],
       [ 4.099887  , 15.113636  ,  0.99860024],
       [ 4.461377  , 15.888915  ,  0.7913555 ],
       [ 3.501945  , 14.562199  ,  1.7859012 ],
       [ 4.1033845 , 14.159694  ,  0.4215471 ],
       [ 8.12381   , 25.969843  ,  2.3157535 ],
       [ 2.4060435 ,  8.218049  ,  1.5980806 ],
       [ 1.959216  ,  9.7199745 ,  2.2702162 ],
       [ 3.655132  ,  9.819436  ,  0.8542645 ],
       [ 7.76768   , 27.28382   ,  2.4268565 ],
       [ 7.716467  , 26.630013  ,  2.075739  ],
       [ 2.1134202 , 10.483098  ,  2.0952466 ],
       [ 3.2960823 , 12.043255  ,  0.84602755],
       [ 5.8741903 , 19.505404  ,  0.95985067],
       [ 4.938637  , 17.76637   ,  1.1079527 ]], dtype=float32)
In [9]:
print("MSE between the tensorflow and numpy Mahalanobis calculations",np.mean((tfMAHALANOBIS - npMAHALANOBIS)**2))
MSE between the tensorflow and numpy Mahalanobis calculations 4.83147260130798e-12

4. Check its classifiation performance.

In [10]:
pred = tfMAHALANOBIS.argmin(axis=1)
plt.figure(figsize=(5,5))
plt.scatter(npX[:,0],npX[:,1],c=colors[pred])
plt.title("Classification using Mahalanobis distance, acc={:3.2f}".format(np.mean(GT == pred)))
plt.show()

Comments