11import os
2+ import gc
23import torch
3- from rwkv .model import RWKV
4+ # from rwkv.model import RWKV
5+ from .rwkv_hiddens import RWKV
46from huggingface_hub import hf_hub_download
57from transformers import AutoTokenizer , GPT2TokenizerFast , PreTrainedModel , PretrainedConfig
68from transformers .modeling_outputs import CausalLMOutput
@@ -12,7 +14,7 @@ class RWKVConfig(PretrainedConfig):
1214 def __init__ (self , ** kwargs ):
1315 super ().__init__ (** kwargs )
1416 self .hidden_size = 2048
15- self .num_hidden_layers = 120
17+ self .num_hidden_layers = 25
1618 self .is_encoder_decoder = False
1719 self .architectures = ["RWKV-LM" ]
1820
@@ -21,7 +23,7 @@ def __init__(self):
2123 super ().__init__ (RWKVConfig ())
2224 weights_path = "/home/kyle/HF-MODEL/rwkv-4-pile-1b5/models--BlinkDL--rwkv-4-pile-1b5/snapshots/6ea995eaa87a17af560c9b41ce1a3d92355c5a49/RWKV-4-Pile-1B5-20220903-8040.pth"
2325 # weights_path = "/home/kyle/HF-MODEL/rwkv-4-pile-14b/models--BlinkDL--rwkv-4-pile-14b/snapshots/939b6851f96122b7b49bd00d446b3b49481214dd/RWKV-4-Pile-14B-20230213-8019.pth"
24- self .model = RWKV (model = weights_path , strategy = 'cuda bf16 ' )
26+ self .model = RWKV (model = weights_path , strategy = 'cuda fp16 ' )
2527
2628 def forward (
2729 self ,
@@ -37,7 +39,7 @@ def forward(
3739 token , states = self .model .forward (inputs , None )
3840 mock_embedding_state = states [0 ].clone ()
3941 output_states = [mock_embedding_state ] + states
40- response = CausalLMOutput (logits = token , hidden_states = output_states )
42+ response = CausalLMOutput (logits = token . detach (). clone () , hidden_states = [ state . detach () for state in output_states ] )
4143 return response
4244
4345 # @staticmethod
0 commit comments