Yumi's Blog

Multidimensional indexing with tensorflow

In this blog post, I present how to use multi-dimensional index arrays for indexing tensors.

First, I will show multi-dimensional array indexing with numpy. Then I will show two approaches in tensorflow:

  • tf.while_loop
  • tf.gather_nd

I will conclude that tf.gather_nd is much more effective than tf.while_loop.

Reference

numpy indexing

In [1]:
import tensorflow as tf
import numpy as np
import time 
def generate_sample(height=20, width = 5):
    nparr = np.arange(width*height).reshape(height,width)
    print("nparr")
    print(nparr)
    pos = np.random.randint(width,size=height)
    print("pos")
    print(pos)
    return(nparr, pos, height)
nparr, pos, N = generate_sample(height=20, width = 5)
nparr
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]
 [25 26 27 28 29]
 [30 31 32 33 34]
 [35 36 37 38 39]
 [40 41 42 43 44]
 [45 46 47 48 49]
 [50 51 52 53 54]
 [55 56 57 58 59]
 [60 61 62 63 64]
 [65 66 67 68 69]
 [70 71 72 73 74]
 [75 76 77 78 79]
 [80 81 82 83 84]
 [85 86 87 88 89]
 [90 91 92 93 94]
 [95 96 97 98 99]]
pos
[1 4 4 3 2 1 2 4 3 0 4 4 2 0 1 3 3 1 3 2]

Indexing

In [2]:
npindex=nparr[range(nparr.shape[0]),pos]
npindex
Out[2]:
array([ 1,  9, 14, 18, 22, 26, 32, 39, 43, 45, 54, 59, 62, 65, 71, 78, 83,
       86, 93, 97])

tf.while_loop approach

In [3]:
def slice_while_loop(tfarr,tfpos,N):  
    '''
    tfarr : tf.array (N, P)
    tfpos : tf.array (P,)
    N     : integer
    '''
    def body(i,x):    
        _x = tf.reshape(tfarr[i,tfpos[i]],(1,-1))
        x  = tf.cond(tf.equal(i,0),
                     lambda : _x,
                     lambda : tf.concat([x,_x],axis=1))
        i += 1
        return i, x
    def condition(i,x):
        return i < N

    i   = tf.constant(0,dtype="int32")
    out = tf.ones((0,0),dtype="float32")

    _n, tfindex = tf.while_loop(condition,
                      body,
                      loop_vars = [i, out],
                      shape_invariants=[i.get_shape(),
                                       tf.TensorShape([None,None])])
    return tfindex


tfarr = tf.constant(nparr,dtype="float32")
tfpos = tf.constant(pos,dtype="int32")

tfindex = slice_while_loop(tfarr,tfpos,N)
with tf.Session() as sess:
    print(sess.run(tfindex)[0])
[ 1.  9. 14. 18. 22. 26. 32. 39. 43. 45. 54. 59. 62. 65. 71. 78. 83. 86.
 93. 97.]

tf.gather_nd approach

In [4]:
def slice_gather_nd(tfarr,tfpos,N):
    indices = tf.transpose(tf.stack([tf.range(N),tfpos]))
    print("indices",indices)
    tfindex = tf.gather_nd(params=tfarr,indices=indices)
    print("tfindex",tfindex)
    return(tfindex)

tfarr = tf.constant(nparr,dtype="float32")
tfpos = tf.constant(pos,dtype="int32")
N = nparr.shape[0]

tfindex = slice_gather_nd(tfarr,tfpos,N)
with tf.Session() as sess:
    print(sess.run(tfindex))
indices Tensor("transpose:0", shape=(20, 2), dtype=int32)
tfindex Tensor("GatherNd:0", shape=(20,), dtype=float32)
[ 1.  9. 14. 18. 22. 26. 32. 39. 43. 45. 54. 59. 62. 65. 71. 78. 83. 86.
 93. 97.]

Time comparison with large array.

In [5]:
nparr, pos, N = generate_sample(height=100000, width = 5)
tfarr = tf.constant(nparr,dtype="float32")
tfpos = tf.constant(pos,dtype="int32")

tfindex_loop = slice_while_loop(tfarr,tfpos,N)
tfindex_gather_nd = slice_gather_nd(tfarr,tfpos,N)
nparr
[[     0      1      2      3      4]
 [     5      6      7      8      9]
 [    10     11     12     13     14]
 ...
 [499985 499986 499987 499988 499989]
 [499990 499991 499992 499993 499994]
 [499995 499996 499997 499998 499999]]
pos
[0 3 2 ... 3 4 2]
indices Tensor("transpose_1:0", shape=(100000, 2), dtype=int32)
tfindex Tensor("GatherNd_1:0", shape=(100000,), dtype=float32)

Evaluate the calculation speed using timeit. Clear winning goes to tf_gahter_nd!

In [6]:
%%timeit -n 3 -r 2
with tf.Session() as sess:
    sess.run(tfindex_loop)
5.65 s ± 44.9 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)
In [7]:
%%timeit -n 3 -r 2 
with tf.Session() as sess:
    sess.run(tfindex_gather_nd)
99.4 ms ± 1.95 ms per loop (mean ± std. dev. of 2 runs, 3 loops each)

Clear winnter is tf.gather_nd!

Comments