Skip to content

Commit f4c88d5

Browse files
committed
Try exposing hiddens
1 parent 15baa83 commit f4c88d5

4 files changed

Lines changed: 1017 additions & 143 deletions

File tree

elk/rwkv_lm/rwkv_hf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
2+
import gc
23
import torch
3-
from rwkv.model import RWKV
4+
# from rwkv.model import RWKV
5+
from .rwkv_hiddens import RWKV
46
from huggingface_hub import hf_hub_download
57
from transformers import AutoTokenizer, GPT2TokenizerFast, PreTrainedModel, PretrainedConfig
68
from transformers.modeling_outputs import CausalLMOutput
@@ -12,7 +14,7 @@ class RWKVConfig(PretrainedConfig):
1214
def __init__(self, **kwargs):
1315
super().__init__(**kwargs)
1416
self.hidden_size = 2048
15-
self.num_hidden_layers = 120
17+
self.num_hidden_layers = 25
1618
self.is_encoder_decoder = False
1719
self.architectures = ["RWKV-LM"]
1820

@@ -21,7 +23,7 @@ def __init__(self):
2123
super().__init__(RWKVConfig())
2224
weights_path = "/home/kyle/HF-MODEL/rwkv-4-pile-1b5/models--BlinkDL--rwkv-4-pile-1b5/snapshots/6ea995eaa87a17af560c9b41ce1a3d92355c5a49/RWKV-4-Pile-1B5-20220903-8040.pth"
2325
# weights_path = "/home/kyle/HF-MODEL/rwkv-4-pile-14b/models--BlinkDL--rwkv-4-pile-14b/snapshots/939b6851f96122b7b49bd00d446b3b49481214dd/RWKV-4-Pile-14B-20230213-8019.pth"
24-
self.model = RWKV(model=weights_path, strategy='cuda bf16')
26+
self.model = RWKV(model=weights_path, strategy='cuda fp16')
2527

2628
def forward(
2729
self,
@@ -37,7 +39,7 @@ def forward(
3739
token, states = self.model.forward(inputs, None)
3840
mock_embedding_state = states[0].clone()
3941
output_states = [mock_embedding_state] + states
40-
response = CausalLMOutput(logits=token, hidden_states=output_states)
42+
response = CausalLMOutput(logits=token.detach().clone(), hidden_states=[state.detach() for state in output_states])
4143
return response
4244

4345
# @staticmethod

0 commit comments

Comments
 (0)