-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest.py
More file actions
73 lines (53 loc) · 2.61 KB
/
test.py
File metadata and controls
73 lines (53 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
MODEL_PATH = "./nas_model_250m" # your final SFT model
TOKENIZER_PATH = "./tokenizer/nas_tokenizer_final"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
tokenizer = GPT2TokenizerFast.from_pretrained(TOKENIZER_PATH)
model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.to(device)
model.eval()
return tokenizer, model
def generate(tokenizer, model, prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.2,
top_p=0.85,
repetition_penalty=1.1,
no_repeat_ngram_size=0,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def main():
tokenizer, model = load_model()
test_prompts = [
"### Instruction [SAMBA]:\nCreate a Samba share at /mnt/media for user pi\n\n### Response:\n",
"### Instruction [SAMBA]:\nList all active Samba connections\n\n### Response:\n",
"### Instruction [SAMBA]:\nRestart the Samba service completely\n\n### Response:\n",
"### Instruction [RAID]:\nCheck the status of mdadm RAID arrays\n\n### Response:\n",
"### Instruction [RAID]:\nExpand RAID1 array by adding disk /dev/sdc\n\n### Response:\n",
"### Instruction [NFS]:\nMount an NFS share from 192.168.1.10:/exports/data to /mnt/nfs\n\n### Response:\n",
"### Instruction [NFS]:\nExport /srv/nas to the local subnet 192.168.1.0/24 via NFS\n\n### Response:\n",
"### Instruction [PERMISSION]:\nChange ownership of /srv/nas/photos to user pi and group users\n\n### Response:\n",
"### Instruction [PERMISSION]:\nSet permissions of /srv/nas/public to 777\n\n### Response:\n",
"### Instruction [SERVICE]:\nCheck if the ssh service is active\n\n### Response:\n",
"### Instruction [SERVICE]:\nRestart the ssh service\n\n### Response:\n",
"### Instruction [SERVICE]:\nReboot the server immediately\n\n### Response:\n"
]
for prompt in test_prompts:
print("=" * 80)
print("PROMPT:\n")
print(prompt)
print("\nMODEL OUTPUT:\n")
print(generate(tokenizer, model, prompt))
print("=" * 80 + "\n")
if __name__ == "__main__":
main()