Skip to content
This repository was archived by the owner on Apr 19, 2026. It is now read-only.

Commit 658b7b4

Browse files
jiminhasayantan-nervana
authored andcommitted
Jiminha/gatherndindex (#445)
1 parent 45611ed commit 658b7b4

5 files changed

Lines changed: 52 additions & 5 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Once TensorFlow's dependencies are installed, clone the `ngraph-bridge` repo:
8888

8989
git clone https://github.com/tensorflow/ngraph-bridge.git
9090
cd ngraph-bridge
91-
git checkout v0.19.0-rc9
91+
git checkout v0.19.0-rc10
9292

9393
Run the following Python script to build TensorFlow, nGraph, and the bridge. Use Python 3.5:
9494

ngraph_bridge/ngraph_builder.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,7 +2239,8 @@ static Status TranslateGatherNdOp(const Node* op,
22392239

22402240
auto ng_params_shape = ng_params->get_shape();
22412241
size_t ng_params_rank = ng_params_shape.size();
2242-
size_t ng_indices_rank = ng_indices->get_shape().size();
2242+
auto ng_indices_shape = ng_indices->get_shape();
2243+
size_t ng_indices_rank = ng_indices_shape.size();
22432244

22442245
for (size_t i = 0; i < ng_params_rank; i++) {
22452246
if (ng_params_shape[i] == 0) {
@@ -2250,7 +2251,7 @@ static Status TranslateGatherNdOp(const Node* op,
22502251
}
22512252
}
22522253

2253-
if ((ng_indices_rank - 1) > ng_params_rank) {
2254+
if ((ng_indices_shape[ng_indices_rank - 1]) > ng_params_rank) {
22542255
return errors::InvalidArgument(
22552256
"The last dimension of indices can be at most the rank of params");
22562257
}

ngraph_bridge/version.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
// candidate such as v0.7.0-rc0
3333
// The code in master will always have the last released version number
3434
// with a suffix of '-master'
35-
#define NG_TF_VERSION_SUFFIX "-rc9"
35+
#define NG_TF_VERSION_SUFFIX "-rc10"
3636

3737
#define VERSION_STR_HELPER(x) #x
3838
#define VERSION_STR(x) VERSION_STR_HELPER(x)

python/setup.in.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_tag(self):
5959

6060
setup(
6161
name='ngraph_tensorflow_bridge',
62-
version='0.19.0rc9',
62+
version='0.19.0rc10',
6363
description='Intel nGraph compiler and runtime for TensorFlow',
6464
long_description=long_description,
6565
long_description_content_type="text/markdown",

test/python/test_gathernd.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# ==============================================================================
2+
# Copyright 2018-2019 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""nGraph TensorFlow bridge gather_nd operation test
17+
18+
"""
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
import pytest
24+
25+
import tensorflow as tf
26+
import os
27+
import numpy as np
28+
29+
from common import NgraphTest
30+
31+
32+
class TestGatherNDOperations(NgraphTest):
33+
34+
def test_gather_nd(self):
35+
val = tf.placeholder(tf.float32, shape=(5, 10))
36+
indices = np.zeros([1, 3, 3, 1], dtype=np.int32)
37+
out = tf.gather_nd(val, indices, batch_dims=0, name='output')
38+
39+
def run_test(sess):
40+
return sess.run((out,),
41+
feed_dict={val: np.arange(50).reshape([5, 10])})[0]
42+
43+
self.with_ngraph(run_test)
44+
45+
assert (
46+
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()

0 commit comments

Comments
 (0)