Skip to content

Commit 0f6808c

Browse files
committed
added specifications of experiments
1 parent c8c470a commit 0f6808c

68 files changed

Lines changed: 6442 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

New folder/cls.py

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
from dataclasses import dataclass
2+
from doctest import ELLIPSIS_MARKER
3+
from itertools import chain
4+
from typing import List, TypeVar, cast
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
from src.data_new.types import Background, JSONSerializable, PersonDocument, EncodedDocument
10+
from src.tasks.base import Task
11+
12+
T = TypeVar("T")
13+
14+
HEXACO_COLS = ['HEXACO_consc3',
15+
'HEXACO_agree7',
16+
'HEXACO_open8R',
17+
'HEXACO_agree8',
18+
'HEXACO_extra9R',
19+
'HEXACO_consc9R',
20+
'HEXACO_agree4',
21+
'HEXACO_emo7R',
22+
'HEXACO_extra8',
23+
'HEXACO_agree1R',
24+
'HEXACO_open6',
25+
'HEXACO_extra5',
26+
'HEXACO_open10R',
27+
'HEXACO_open3',
28+
'HEXACO_open5R',
29+
'HEXACO_agree9R',
30+
'HEXACO_emo8',
31+
'HEXACO_hh4',
32+
'HEXACO_agree5',
33+
'HEXACO_agree6R',
34+
'HEXACO_open9',
35+
'HEXACO_extra10R',
36+
'HEXACO_consc8R',
37+
'HEXACO_agree10',
38+
'HEXACO_consc5R',
39+
'HEXACO_extra7',
40+
'HEXACO_emo3',
41+
'HEXACO_hh3R',
42+
'HEXACO_consc2R',
43+
'HEXACO_consc10R',
44+
'HEXACO_extra2R',
45+
'HEXACO_emo1',
46+
'HEXACO_extra6',
47+
'HEXACO_agree2',
48+
'HEXACO_emo5',
49+
'HEXACO_extra4',
50+
'HEXACO_hh8',
51+
'HEXACO_consc7',
52+
'HEXACO_hh7R',
53+
'HEXACO_consc6',
54+
'HEXACO_emo10R',
55+
'HEXACO_hh2',
56+
'HEXACO_hh5R',
57+
'HEXACO_consc4R',
58+
'HEXACO_hh1R',
59+
'HEXACO_emo6',
60+
'HEXACO_hh9R',
61+
'HEXACO_emo2R',
62+
'HEXACO_extra3R',
63+
'HEXACO_open7R',
64+
'HEXACO_emo9',
65+
'HEXACO_hh6R',
66+
'HEXACO_open4',
67+
'HEXACO_agree3R',
68+
'HEXACO_hh10',
69+
'HEXACO_emo4R',
70+
'HEXACO_open2',
71+
'HEXACO_open1R',
72+
'HEXACO_extra1',
73+
'HEXACO_consc1']
74+
75+
@dataclass
76+
class CLS(Task):
77+
"""
78+
Pulls data from somewhere and uses it for classification?
79+
80+
.. todo::
81+
Describe CLS
82+
"""
83+
# CLS Specific params
84+
pooled: bool = False
85+
num_pooled_sep: int = 0
86+
87+
88+
def __post_init__(self) -> None:
89+
import warnings
90+
if self.pooled:
91+
raise NotImplementedError("Pooled version is not implemented")
92+
93+
# CLS Specific params
94+
def get_document(self, person_sentences: pd.DataFrame) -> PersonDocument:
95+
document = super().get_document(person_sentences)
96+
target = int(person_sentences.TARGET.iloc[0])
97+
document.task_info = cast(JSONSerializable, target) # makes mypy happy
98+
99+
return document
100+
101+
def encode_document(self, document: PersonDocument) -> "CLSEncodedDocument":
102+
103+
prefix_sentence = (
104+
["[CLS]"] + Background.get_sentence(document.background) + ["[SEP]"]
105+
)
106+
sentences = [prefix_sentence] + [s + ["[SEP]"] for s in document.sentences]
107+
sentence_lengths = [len(x) for x in sentences]
108+
109+
def expand(x: List[T]) -> List[T]:
110+
assert len(x) == len(sentence_lengths)
111+
return list(
112+
chain.from_iterable(
113+
length * [i] for length, i in zip(sentence_lengths, x)
114+
)
115+
)
116+
117+
abspos_expanded = expand([0] + document.abspos)
118+
age_expanded = expand([0.0] + document.age) # todo abs_age vs age?
119+
assert document.segment is not None
120+
segment_expanded = expand([1] + document.segment)
121+
122+
token2index = self.datamodule.vocabulary.token2index
123+
unk_id = token2index["[UNK]"]
124+
125+
flat_sentences = np.concatenate(sentences)
126+
token_ids = np.array([token2index.get(x, unk_id) for x in flat_sentences])
127+
128+
length = len(token_ids)
129+
130+
input_ids = np.zeros((4, self.max_length))
131+
input_ids[0, :length] = token_ids
132+
input_ids[1, :length] = abspos_expanded
133+
input_ids[2, :length] = age_expanded
134+
input_ids[3, :length] = segment_expanded
135+
136+
padding_mask = np.repeat(False, self.max_length)
137+
padding_mask[:length] = True
138+
139+
original_sequence = np.zeros(self.max_length)
140+
original_sequence[:length] = token_ids
141+
142+
target = np.array(document.task_info).astype(np.float32)
143+
144+
sequence_id = np.array(document.person_id)
145+
146+
if self.pooled:
147+
sep_pos = self.extract_sep_positions(token_ids)
148+
else:
149+
sep_pos = np.array([0])
150+
151+
152+
return CLSEncodedDocument(
153+
sequence_id=sequence_id,
154+
input_ids=input_ids,
155+
padding_mask=padding_mask,
156+
target=target,
157+
sep_pos=sep_pos,
158+
original_sequence=original_sequence,
159+
)
160+
161+
def extract_sep_positions(self, token_ids: np.ndarray) -> np.ndarray:
162+
163+
token2index = self.datamodule.vocabulary.token2index
164+
sep_id = token2index["[SEP]"]
165+
166+
MAX_LEN = self.num_pooled_sep
167+
_sep_pos = np.where(token_ids == sep_id)[0]
168+
sep_pos = np.zeros(MAX_LEN)
169+
170+
if len(_sep_pos) >= MAX_LEN:
171+
offset = len(_sep_pos) - MAX_LEN
172+
_sep_pos = _sep_pos[offset:]
173+
174+
sep_pos[: len(_sep_pos)] = _sep_pos
175+
return sep_pos
176+
177+
178+
@dataclass
179+
class CLS_HAN(CLS):
180+
181+
sentence_num: int = 312
182+
word_num: int = 9
183+
def encode_document(self, document: PersonDocument) -> "HANEncodedDocument":
184+
prefix_sentence = (
185+
["[CLS]"] + Background.get_sentence(document.background)
186+
)
187+
sentences = [prefix_sentence] + [s for s in document.sentences]
188+
189+
abspos = [0.0] + document.abspos[-self.sentence_num+1:]
190+
age = [0.0] + document.age[-self.sentence_num+1:]
191+
position = np.zeros((2, self.sentence_num))
192+
position[0, 0:len(abspos)] = abspos
193+
position[1, 0:len(age)] = age
194+
195+
196+
token2index = self.datamodule.vocabulary.token2index
197+
unk_id = token2index["[UNK]"]
198+
199+
token_ids = np.zeros((self.sentence_num, self.word_num))
200+
201+
for i,j in enumerate(sentences[-self.sentence_num:]):
202+
token_ids[i][0:len(j)] = list(
203+
map(
204+
lambda x: token2index.get(x, unk_id),
205+
j,
206+
),
207+
)
208+
209+
target = np.array(document.task_info).astype(int)
210+
sequence_id = np.array(document.person_id)
211+
212+
213+
return HANEncodedDocument(
214+
sequence_id=sequence_id,
215+
input_ids=token_ids,
216+
padding_mask=token_ids.astype(bool),
217+
target=target,
218+
position=position,
219+
original_sequence=token_ids,
220+
)
221+
222+
223+
@dataclass
224+
class PSY(CLS):
225+
# TASK
226+
def get_document(self, person_sentences: pd.DataFrame) -> PersonDocument:
227+
document = super(CLS, self).get_document(person_sentences)
228+
usecols = ['HH','EM','EX','AG','CO','OP', "SDO", "SVOa", "RISK", "CRTi", "CRTr"]
229+
usecols += [c + "_w" for c in usecols[:-1]]
230+
target = []
231+
for col in usecols:
232+
target.append(float(person_sentences[col].iloc[0]))
233+
document.task_info = cast(JSONSerializable, target) # makes mypy happy
234+
return document
235+
236+
@dataclass
237+
class HEXACO(CLS):
238+
# TASK
239+
def get_document(self, person_sentences: pd.DataFrame) -> PersonDocument:
240+
document = super(CLS, self).get_document(person_sentences)
241+
usecols = ['HH','EM','EX','AG','CO','OP', "SDO", "SVOa", "RISK", "CRTi", "CRTr"] + HEXACO_COLS
242+
target = []
243+
for col in usecols:
244+
target.append(float(person_sentences[col].iloc[0]))
245+
document.task_info = cast(JSONSerializable, target) # makes mypy happy
246+
return document
247+
248+
249+
250+
@dataclass
251+
class PSY_HAN(PSY):
252+
253+
sentence_num: int = 312
254+
word_num: int = 9
255+
def encode_document(self, document: PersonDocument) -> "HANEncodedDocument":
256+
prefix_sentence = (
257+
["[CLS]"] + Background.get_sentence(document.background)
258+
)
259+
sentences = [prefix_sentence] + [s for s in document.sentences]
260+
261+
abspos = [0.0] + document.abspos[-self.sentence_num+1:]
262+
age = [0.0] + document.age[-self.sentence_num+1:]
263+
position = np.zeros((2, self.sentence_num))
264+
position[0, 0:len(abspos)] = abspos
265+
position[1, 0:len(age)] = age
266+
267+
268+
token2index = self.datamodule.vocabulary.token2index
269+
unk_id = token2index["[UNK]"]
270+
271+
token_ids = np.zeros((self.sentence_num, self.word_num))
272+
273+
for i,j in enumerate(sentences[-self.sentence_num:]):
274+
token_ids[i][0:len(j)] = list(
275+
map(
276+
lambda x: token2index.get(x, unk_id),
277+
j,
278+
),
279+
)
280+
281+
target = np.array(document.task_info).astype(float)
282+
sequence_id = np.array(document.person_id)
283+
284+
285+
return HANEncodedDocument(
286+
sequence_id=sequence_id,
287+
input_ids=token_ids,
288+
padding_mask=token_ids.astype(bool),
289+
target=target,
290+
position=position,
291+
original_sequence=token_ids,
292+
)
293+
294+
295+
296+
297+
@dataclass
298+
class CLSEncodedDocument(EncodedDocument[CLS]):
299+
sequence_id: np.ndarray
300+
input_ids: np.ndarray
301+
padding_mask: np.ndarray
302+
target: np.ndarray
303+
sep_pos: np.ndarray
304+
original_sequence: np.ndarray
305+
306+
@dataclass
307+
class HANEncodedDocument(EncodedDocument[CLS_HAN]):
308+
sequence_id: np.ndarray
309+
input_ids: np.ndarray
310+
position: np.ndarray
311+
padding_mask: np.ndarray
312+
target: np.ndarray
313+
original_sequence: np.ndarray
314+
315+
316+

0 commit comments

Comments
 (0)