-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
154 lines (132 loc) · 5 KB
/
main.py
File metadata and controls
154 lines (132 loc) · 5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
torch.manual_seed(32)
"""
# ! 进行数据清洗,提炼出符合条件的数据充足的用户
from data_featuring_scripts import TravelDataFilter
data_filter = TravelDataFilter(
need_filter=False,
dataset_path="datasets/example_sets/travel_behavior/dataset_TIST2015_Checkins_withPOIs_example.parquet")
data_filter.filter_user_and_checkins()
# """
"""
# ! 进行自然语言化处理
from data_featuring_scripts import LanguageConverter
language_converter = LanguageConverter()
language_converter.form_diary()
# """
"""
# ! 进行自然语言化处理
from data_featuring_scripts import LanguageConverter
language_converter = LanguageConverter(dataset_path="datasets/example_sets/travel_behavior/dataset_TIST2015_Checkins_withPOIs_example_valid.parquet")
language_converter.form_diary()
# """
"""
# ! 数据模型的 dataset 测试
from data_featuring_scripts import TravelBCPDatasetLanguage
from utils import ProgressBar
from pprint import pprint
from torch.utils.data import Subset, DataLoader
bcp_dataset_language = TravelBCPDatasetLanguage(
context_window_size=7,
dataset_path="datasets/example_sets/travel_behavior/dataset_TIST2015_Checkins_withPOIs_example_valid.parquet",
diary_path="datasets/example_sets/travel_behavior/dataset_TIST2015_Checkins_withPOIs_example_valid_diaries.parquet"
)
train_indices = bcp_dataset_language.train_indices
val_indices = bcp_dataset_language.val_indices
test_indices = bcp_dataset_language.test_indices
train_set = Subset(bcp_dataset_language, train_indices)
val_set = Subset(bcp_dataset_language, val_indices)
test_set = Subset(bcp_dataset_language, test_indices)
print(len(train_set), len(val_set), len(test_set))
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=8)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=8)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=8)
print(len(train_loader), len(val_loader), len(test_loader))
for batch in ProgressBar(test_loader, description="test travel bcp dataset discrete train loader"):
batch_shape = {key: len(value) for key, value in batch.items()}
# print(batch['context Day'])
# input()
# """
"""
# ! 对于 api 的测试
from apis import YunwuAPI
import asyncio
yunwu_api = YunwuAPI()
message_list = [
[{"role": 'user', "content": "Hello, who are you?"}],
[{"role": 'user', "content": "What can you do?"}],
[{"role": 'user', "content": "Tell me a joke."}]
]
result = asyncio.run(yunwu_api.chat_completion_async_batch(message_list))
print(result)
# """
"""
# llmbaseline 在media的example数据集上的尝试
# import pandas as pd
# from apis import YunwuAPI
# import asyncio
# import json
# from data_featuring_scripts import MediaBCPDatasetLanguage, MediaLanguageConverter
# media_dataset=MediaBCPDatasetLanguage()
# length = len(media_dataset)
# system_prompt =(
# "you are a human media behavior analyst. Please generate the predicted rating and comment based on the user's current state, the context of the business, and the user's historical behavior."
# "comment should be written as if you are the user."
# "the output should be a JSON object with keys 'rating' and 'comment'."
# )
# message_list = []
# user_id_list = []
# for i in range(length):
# info_dict = media_dataset[i]
# llm_text = (
# info_dict.get('context current', '') +
# info_dict.get('context business', '') +
# info_dict.get('context user', '') +
# info_dict.get('context diaries', '')
# )
# message = [
# {"role": "system", "content": system_prompt},
# {"role": "user", "content": llm_text}
# ]
# # print(llm_text)
# # print()
# # input("press the ENTER")
# message_list.append(message)
# user_id_list.append(info_dict.get('user_id', f'idx_{i}'))
# # 实例化 YunwuAPI 并批量调用
# yunwu_api = YunwuAPI()
# result = asyncio.run(yunwu_api.chat_completion_async_batch(message_list))
# # 结果
# ratings = []
# comments = []
# for res in result:
# try:
# data = json.loads(res)
# ratings.append(data.get('rating', None))
# comments.append(data.get('comment', None))
# except Exception as e:
# ratings.append(None)
# comments.append(None)
# # 保存为DataFrame
# result_df = pd.DataFrame({
# 'user_id': user_id_list,
# 'rating': ratings,
# 'comment': comments
# })
# print(result_df.head())
# result_df.to_csv('yunwuapi_results.csv', index=False)
from baseline_tester import toy_baseline
toy_baseline()
# """
"""
from baseline_tester import TravelBehaviorDiscreteTrainer, TravelBehaviorDiscreteTester
travel_behavior_discrete_trainer = TravelBehaviorDiscreteTrainer()
travel_behavior_discrete_tester = TravelBehaviorDiscreteTester()
travel_behavior_discrete_trainer.train()
travel_behavior_discrete_tester.test()
# """
"""
from baseline_tester import TravelBehaviorLanguageTrainer, TravelBehaviorLanguageTester
travel_behavior_language_tester = TravelBehaviorLanguageTester()
travel_behavior_language_tester.test()
# """