Skip to content

Commit 168e91c

Browse files
committed
Fix LINE generator output and pin numpy for TF 2.10 CI
1 parent 9f1a812 commit 168e91c

2 files changed

Lines changed: 14 additions & 4 deletions

File tree

.github/workflows/ci.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ jobs:
5959
- name: Install dependencies
6060
run: |
6161
python -m pip install --upgrade pip
62+
if [[ "${{ matrix.tf-version }}" == "2.10.0" ]]; then
63+
python -m pip install -q "numpy<2"
64+
fi
6265
python -m pip install -q "tensorflow==${{ matrix.tf-version }}"
6366
if [[ "${{ matrix.tf-version }}" == 1.* ]]; then
6467
python -m pip install -q "protobuf==3.20.3"
@@ -67,6 +70,9 @@ jobs:
6770
python -m pip install -q "tf-keras~=2.20"
6871
fi
6972
python -m pip install -e ".[test]"
73+
if [[ "${{ matrix.tf-version }}" == "2.10.0" ]]; then
74+
python -m pip install -q "numpy<2"
75+
fi
7076
7177
- name: Test with pytest
7278
timeout-minutes: 180

ge/models/line.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,20 @@ def batch_iter(self, node2idx):
166166
cur_t = edges[shuffle_indices[i]][1]
167167
h.append(cur_h)
168168
t.append(cur_t)
169-
sign = np.ones(len(h))
169+
sign = np.ones(len(h), dtype=np.float32)
170170
else:
171-
sign = np.ones(len(h)) * -1
171+
sign = np.ones(len(h), dtype=np.float32) * -1
172172
t = []
173173
for i in range(len(h)):
174174
t.append(alias_sample(
175175
self.node_accept, self.node_alias))
176176

177+
heads = np.asarray(h, dtype=np.int32)
178+
tails = np.asarray(t, dtype=np.int32)
177179
if self.order == 'all':
178-
yield ([np.array(h), np.array(t)], [sign, sign])
180+
yield ((heads, tails), (sign, sign))
179181
else:
180-
yield ([np.array(h), np.array(t)], [sign])
182+
yield ((heads, tails), (sign,))
181183
mod += 1
182184
mod %= mod_size
183185
if mod == 0:
@@ -218,6 +220,8 @@ def train(self, batch_size=1024, epochs=1, initial_epoch=0, verbose=1, times=1):
218220
verbose=verbose,
219221
)
220222
except TypeError:
223+
if not hasattr(self.model, "fit_generator"):
224+
raise
221225
hist = self.model.fit_generator(
222226
self.batch_it,
223227
epochs=epochs,

0 commit comments

Comments
 (0)