@@ -108,7 +108,7 @@ def test_consistency(
108108 err_msg = err_msg ,
109109 )
110110
111- def test_jit (
111+ def test_export (
112112 self ,
113113 ) -> None :
114114 rng = np .random .default_rng (GLOBAL_SEED )
@@ -132,8 +132,6 @@ def test_jit(
132132 [False , True ], # use_econf_tebd
133133 ):
134134 dtype = PRECISION_DICT [prec ]
135- rtol , atol = get_tols (prec )
136- err_msg = f"idt={ idt } prec={ prec } "
137135 # dpa1 new impl
138136 dd0 = DescrptDPA1 (
139137 self .rcut ,
@@ -151,6 +149,9 @@ def test_jit(
151149 )
152150 dd0 .se_atten .mean = torch .tensor (davg , dtype = dtype , device = env .DEVICE )
153151 dd0 .se_atten .dstd = torch .tensor (dstd , dtype = dtype , device = env .DEVICE )
154- # dd1 = DescrptDPA1.deserialize(dd0.serialize())
155- model = torch .jit .script (dd0 )
156- # model = torch.jit.script(dd1)
152+
153+ coord_ext = torch .tensor (self .coord_ext , dtype = dtype , device = env .DEVICE )
154+ atype_ext = torch .tensor (self .atype_ext , dtype = int , device = env .DEVICE )
155+ nlist = torch .tensor (self .nlist , dtype = int , device = env .DEVICE )
156+
157+ _ = torch .export .export (dd0 , (coord_ext , atype_ext , nlist ))
0 commit comments