In this blog post, I present how to use multi-dimensional index arrays for indexing tensors.
First, I will show multi-dimensional array indexing with numpy. Then I will show two approaches in tensorflow:
- tf.while_loop
- tf.gather_nd
I will conclude that tf.gather_nd is much more effective than tf.while_loop.
Reference¶
In [1]:
import tensorflow as tf
import numpy as np
import time
def generate_sample(height=20, width = 5):
nparr = np.arange(width*height).reshape(height,width)
print("nparr")
print(nparr)
pos = np.random.randint(width,size=height)
print("pos")
print(pos)
return(nparr, pos, height)
nparr, pos, N = generate_sample(height=20, width = 5)
Indexing¶
In [2]:
npindex=nparr[range(nparr.shape[0]),pos]
npindex
Out[2]:
tf.while_loop approach¶
In [3]:
def slice_while_loop(tfarr,tfpos,N):
'''
tfarr : tf.array (N, P)
tfpos : tf.array (P,)
N : integer
'''
def body(i,x):
_x = tf.reshape(tfarr[i,tfpos[i]],(1,-1))
x = tf.cond(tf.equal(i,0),
lambda : _x,
lambda : tf.concat([x,_x],axis=1))
i += 1
return i, x
def condition(i,x):
return i < N
i = tf.constant(0,dtype="int32")
out = tf.ones((0,0),dtype="float32")
_n, tfindex = tf.while_loop(condition,
body,
loop_vars = [i, out],
shape_invariants=[i.get_shape(),
tf.TensorShape([None,None])])
return tfindex
tfarr = tf.constant(nparr,dtype="float32")
tfpos = tf.constant(pos,dtype="int32")
tfindex = slice_while_loop(tfarr,tfpos,N)
with tf.Session() as sess:
print(sess.run(tfindex)[0])
tf.gather_nd approach¶
In [4]:
def slice_gather_nd(tfarr,tfpos,N):
indices = tf.transpose(tf.stack([tf.range(N),tfpos]))
print("indices",indices)
tfindex = tf.gather_nd(params=tfarr,indices=indices)
print("tfindex",tfindex)
return(tfindex)
tfarr = tf.constant(nparr,dtype="float32")
tfpos = tf.constant(pos,dtype="int32")
N = nparr.shape[0]
tfindex = slice_gather_nd(tfarr,tfpos,N)
with tf.Session() as sess:
print(sess.run(tfindex))
Time comparison with large array.¶
In [5]:
nparr, pos, N = generate_sample(height=100000, width = 5)
tfarr = tf.constant(nparr,dtype="float32")
tfpos = tf.constant(pos,dtype="int32")
tfindex_loop = slice_while_loop(tfarr,tfpos,N)
tfindex_gather_nd = slice_gather_nd(tfarr,tfpos,N)
Evaluate the calculation speed using timeit. Clear winning goes to tf_gahter_nd!
In [6]:
%%timeit -n 3 -r 2
with tf.Session() as sess:
sess.run(tfindex_loop)
In [7]:
%%timeit -n 3 -r 2
with tf.Session() as sess:
sess.run(tfindex_gather_nd)
Clear winnter is tf.gather_nd!