Skip to content

Commit b87dcdc

Browse files
committed
reset name to uu and ll
1 parent b8fcbe6 commit b87dcdc

2 files changed

Lines changed: 15 additions & 13 deletions

File tree

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,8 @@ def forward(
926926
rr = rr * exclude_mask[:, :, None]
927927

928928
# nfnl x nt_i x 3: direction vectors
929+
# nt_i = nnei
930+
# nt_j = nnei
929931
rr_i = rr[:, :, 1:]
930932
# nfnl x nt_j x 3
931933
rr_j = rr[:, :, 1:]

deepmd/utils/tabulate.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,36 +198,36 @@ def build(
198198
)
199199
idx += 1
200200
elif self.descrpt_type == "T_TEBD":
201-
# 1. Find the global range [global_lower, global_upper] of cos(theta) across all types
202-
global_upper = np.max(upper)
203-
global_lower = np.min(lower)
201+
# 1. Find the global range [ll, uu] of cos(theta) across all types
202+
uu = np.max(upper)
203+
ll = np.min(lower)
204204

205205
# 2. Create a unique input grid xx for this shared geometric network based on the global range
206206
xx = np.arange(
207-
extrapolate * global_lower, global_lower, stride1, dtype=self.data_type
207+
extrapolate * ll, ll, stride1, dtype=self.data_type
208208
)
209209
xx = np.append(
210210
xx,
211-
np.arange(global_lower, global_upper, stride0, dtype=self.data_type),
211+
np.arange(ll, uu, stride0, dtype=self.data_type),
212212
)
213213
xx = np.append(
214214
xx,
215215
np.arange(
216-
global_upper,
217-
extrapolate * global_upper,
216+
uu,
217+
extrapolate * uu,
218218
stride1,
219219
dtype=self.data_type,
220220
),
221221
)
222222
xx = np.append(
223-
xx, np.array([extrapolate * global_upper], dtype=self.data_type)
223+
xx, np.array([extrapolate * uu], dtype=self.data_type)
224224
)
225225

226226
# 3. Calculate the number of spline points
227227
nspline = (
228-
(global_upper - global_lower) / stride0
229-
+ ((extrapolate * global_upper - global_upper) / stride1)
230-
+ ((global_lower - extrapolate * global_lower) / stride1)
228+
(uu - ll) / stride0
229+
+ ((extrapolate * uu - uu) / stride1)
230+
+ ((ll - extrapolate * ll) / stride1)
231231
).astype(int)
232232

233233
# 4. Call _generate_spline_table only once to generate the table for this shared network
@@ -236,8 +236,8 @@ def build(
236236
geometric_net_name,
237237
xx,
238238
0,
239-
global_upper,
240-
global_lower,
239+
uu,
240+
ll,
241241
stride0,
242242
stride1,
243243
extrapolate,

0 commit comments

Comments
 (0)