Skip to content

Commit 445e1bd

Browse files
add prefix injection
1 parent ca0c25a commit 445e1bd

4 files changed

Lines changed: 18 additions & 3 deletions

File tree

api/_output_handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def generate_output_stream(
2727
verbose=True,
2828
random_state=None,
2929
as_json=True,
30+
target="",
3031
):
3132
rng = np.random.default_rng(random_state)
3233
TERMINATOR = tokenizer.eos_token
@@ -47,7 +48,7 @@ def generate_output_stream(
4748
]
4849
)
4950

50-
output = ""
51+
output = target
5152
output_ids = torch.tensor([], dtype=torch.long)
5253

5354
else:
@@ -71,7 +72,7 @@ def generate_output_stream(
7172

7273
prompt = [
7374
{"role": "user", "content": init_prompt},
74-
{"role": "assistant", "content": ""},
75+
{"role": "assistant", "content": target},
7576
]
7677
prompt_ids = tokenizer.apply_chat_template(prompt, return_tensors="pt")[0][
7778
:-1

api/api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def generate(
8080
random_state: Optional[int] = Form(None),
8181
data: Optional[list] = Form(None),
8282
safenudge: bool = Form(False),
83+
target: Optional[str] = Form(None),
8384
):
8485
"""
8586
Generate an output stream based on a prompt, using a LLM (Llama 3.2 1B Instruct).
@@ -98,6 +99,7 @@ async def generate(
9899
verbose=verbose,
99100
random_state=random_state,
100101
data=data,
102+
target=target or "",
101103
)
102104
return StreamingResponse(data, media_type="application/json")
103105
else:
@@ -112,7 +114,7 @@ async def generate(
112114
).generate_moderated(
113115
prompt=init_prompt,
114116
clf=ml_models["SAFENUDGE_MODEL"],
115-
target="",
117+
target=target or "",
116118
tau=tau,
117119
max_tokens=max_new_tokens,
118120
verbose=verbose,

client/src/App.jsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const DEFAULT_SETTINGS = {
1313
maxNewTokens: 300,
1414
k: 10,
1515
T: 1,
16+
target: "",
1617
};
1718

1819
function resolveRandomSeed(input) {
@@ -131,6 +132,7 @@ export default function App() {
131132
verbose: false,
132133
random_state: seed,
133134
sleep_time: settings.sleepTime,
135+
target: settings.target || undefined,
134136
};
135137
runStream(streamGenerate, params);
136138
},

client/src/components/SettingsPanel.jsx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ export default function SettingsPanel({ settings, onChange }) {
9191
Advanced
9292
</summary>
9393
<div className="p-2 flex flex-col gap-3">
94+
<div className="flex flex-col gap-1">
95+
<span>Assistant prefix</span>
96+
<textarea
97+
rows={2}
98+
placeholder="Optional assistant prefix..."
99+
value={settings.target}
100+
onChange={setField("target")}
101+
className="w-full bg-transparent border border-border px-2 py-1 text-fg resize-none"
102+
/>
103+
</div>
94104
<Row label="Top-k (k)">
95105
<input
96106
type="number"

0 commit comments

Comments
 (0)