-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
103 lines (93 loc) · 3.47 KB
/
app.py
File metadata and controls
103 lines (93 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# app.py
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LongformerTokenizer, LongformerModel
from torch.utils.data import DataLoader, IterableDataset
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
import argparse
import gradio as gr
import higher
import pandas as pd
from tqdm import tqdm
import diskcache as dc
from datasets import load_dataset
import zlib
import numpy as np
from PIL import Image
from torchvision import transforms
from typing import Any, Dict, List, Optional
import torch.utils.checkpoint as checkpoint_utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# (Include all model class definitions here as in train.py)
# For brevity, it's assumed that the model classes are defined in a separate module, e.g., model.py
# If not, you can copy the class definitions from train.py into this script as well.
# Import model and tokenizer from train.py or define them here
# Here, we'll assume they are defined in train.py and imported accordingly
from train import XlinxChatModel, LiquidFoundationTokenizer, generate_response_api
def generate_response_gradio(user_text):
assistant_reply = generate_response_api(
model,
tokenizer,
user_text,
session_id="gradio_session",
max_new_tokens=100,
temperature=0.7,
top_k=50,
top_p=0.9
)
return assistant_reply
def load_model_and_tokenizer(checkpoint_path: str, device: torch.device):
tokenizer = LiquidFoundationTokenizer(adapt_dim=64).to(device)
model = XlinxChatModel(
token_dim=256,
channel_dim=256,
expert_dim=128,
adapt_dim=64,
num_experts=4,
num_layers=2,
hidden_dim=32,
num_heads=4,
semantic_hidden_dim=128,
semantic_num_heads=4,
semantic_num_layers=1,
dropout_rate=0.1,
max_drop_prob=0.05,
layerdrop_prob=0.05,
dropblock_block_size=7,
dropblock_prob=0.05,
combination_activation='gelu',
combination_norm_type='batchnorm',
norm_type='batchnorm',
dynamic_layer_threshold=0.4
).to(device)
if os.path.exists(checkpoint_path):
model.load_model(checkpoint_path)
print(f"Loaded model checkpoint from '{checkpoint_path}'.")
else:
raise FileNotFoundError(f"No checkpoint found at '{checkpoint_path}'.")
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Serve XlinxChatModel with Gradio")
parser.add_argument('--checkpoint', type=str, default='checkpoint.pth.tar', help="Path to the model checkpoint")
args = parser.parse_args()
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(args.checkpoint, device)
# Define Gradio Interface
iface = gr.Interface(
fn=generate_response_gradio,
inputs=[
gr.inputs.Textbox(lines=2, placeholder="Введите ваше сообщение здесь...")
],
outputs="text",
title="XlinxChatModel Chatbot",
description="Чат-бот с возможностями AGI, продвинутым рассуждением и саморегуляцией.",
examples=[
["Привет, как дела?"],
["Расскажи мне историю об искусственном интеллекте."]
],
live=False
)
iface.launch()