@@ -21,7 +21,7 @@ def __init__(self, hypernet, dtype=torch.float32):
2121 self .stepping_class = 'fixed'
2222 self .op1 = self .order + 1
2323
24- def step (self , f , x , t , dt , k1 = None ):
24+ def step (self , f , x , t , dt , k1 = None , args = None ):
2525 _ , x_sol , _ = super ().step (f , x , t , dt , k1 )
2626 return None , x_sol + dt ** (self .op1 ) * self .hypernet (t , x ), None
2727
@@ -32,7 +32,7 @@ def __init__(self, hypernet, dtype=torch.float32):
3232 self .stepping_class = 'fixed'
3333 self .op1 = self .order + 1
3434
35- def step (self , f , x , t , dt , k1 = None ):
35+ def step (self , f , x , t , dt , k1 = None , args = None ):
3636 _ , x_sol , _ = super ().step (f , x , t , dt , k1 )
3737 return None , x_sol + dt ** (self .op1 ) * self .hypernet (t , x ), None
3838
@@ -42,6 +42,6 @@ def __init__(self, hypernet, dtype=torch.float32):
4242 self .hypernet = hypernet
4343 self .op1 = self .order + 1
4444
45- def step (self , f , x , t , dt , k1 = None ):
45+ def step (self , f , x , t , dt , k1 = None , args = None ):
4646 _ , x_sol , _ = super ().step (f , x , t , dt , k1 )
4747 return None , x_sol + dt ** (self .op1 ) * self .hypernet (t , x ), None
0 commit comments