-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathoffline_structuredOutput.py
More file actions
46 lines (31 loc) · 1.38 KB
/
offline_structuredOutput.py
File metadata and controls
46 lines (31 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from vllm import LLM, SamplingParams
from pydantic import BaseModel
from enum import Enum
from vllm.sampling_params import GuidedDecodingParams
cache_dir = "/netscratch/thomas/models/" #TODO Please change to your local directory on the cluster
model_name ="google/gemma-2-9b-it" # Please change to the model you want to use
prompt = "Generate a JSON with the brand, model and car_type of the most iconic car from the 90's"
# Guided decoding by JSON using Pydantic schema
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
json_schema = CarDescription.model_json_schema()
guided_params = GuidedDecodingParams(json=json_schema)
sampling_params = SamplingParams(temperature=0.2,
max_tokens=3000,
guided_decoding= guided_params
)
llm = LLM(model=model_name) #Load your model
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=3000)
outputs = llm.generate(prompt, sampling_params=sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Generated JSON: {generated_text!r}")
print("---------------------------------------------------")