-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
executable file
·67 lines (46 loc) · 1.63 KB
/
Copy pathapp.py
File metadata and controls
executable file
·67 lines (46 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import streamlit as st
import logging
from src.cfg import MODEL_NAME
from src.model import MyModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np
classes = ["C", "Java", "Javascript", "Python", "R"]
@st.cache(allow_output_mutation=True)
def load_model(path):
checkpoint = torch.load(path, map_location="cpu")
config = {"model_name": MODEL_NAME}
logging.info("Loading model")
model = MyModel(config, pretrained=False)
model.load_state_dict(checkpoint)
return model, AutoTokenizer.from_pretrained(MODEL_NAME)
model, tokenizer = load_model("outputs/model_1.pt")
model.eval()
def predict(s):
feed = [" ".join(s.split())]
length = [len(s.split())]
data = tokenizer(feed, return_tensors="pt", padding="max_length", max_length=256)
print(data)
print(length)
return model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
lengths=torch.tensor(length),
)
"""
# What language should your project be written in?
"""
form = st.form(key="input")
text = form.text_area("Input project description (English ASCII descriptions only)")
# x = st.text_area("Input project description (English ASCII descriptions only)")
x = form.form_submit_button("Submit")
if x:
logits = predict(text).numpy()[0]
choice = np.argmax(logits)
probs = np.exp(logits) / np.exp(logits).sum()
fig, ax = plt.subplots()
ax.barh(classes, probs)
plt.title("Class probability percentages")
st.write(f"Your project should be written in {classes[choice]}.")
st.pyplot(fig)