Skip to content

Commit a255b82

Browse files
committed
add splitter fixed
1 parent a7ce570 commit a255b82

2 files changed

Lines changed: 67 additions & 4 deletions

File tree

wtrec/splitter.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from random import sample
2+
import numpy as np
3+
import os
4+
5+
6+
Y_COLUMN = 'label'
7+
DROP_COLUMS = ['raw', 'path']
8+
9+
10+
def _shuffle(data):
11+
return data.reindex(np.random.permutation(data.index))
12+
13+
14+
def split_random(data, test_size=0.3):
15+
test_size = int(data.shape[0] * test_size)
16+
train_size = len(data) - test_size
17+
18+
data = _shuffle(data)[:test_size + train_size]
19+
data = data.reindex().drop(DROP_COLUMS, axis=1)
20+
21+
test = data[:test_size]
22+
train = data[test_size:]
23+
24+
test_Y = test[Y_COLUMN]
25+
test_X = test.drop(Y_COLUMN, axis=1)
26+
27+
train_Y = train[Y_COLUMN]
28+
train_X = train.drop(Y_COLUMN, axis=1)
29+
30+
return train_X, train_Y, test_X, test_Y
31+
32+
33+
def _train_test_indices(num_samples, idx_file_path, test_size=0.3):
34+
if not os.path.isfile(idx_file_path):
35+
test_indices = np.array(sample(range(num_samples), k=int(num_samples * test_size)))
36+
test_indices.dump(idx_file_path)
37+
38+
test_indices = np.load(idx_file_path)
39+
train_indices = np.array([i for i in range(num_samples) if i not in test_indices])
40+
41+
return train_indices, test_indices
42+
43+
44+
def split_fixed(data, idx_file_path):
45+
data = data.reindex().drop(DROP_COLUMS, axis=1)
46+
47+
train_indices, test_indices = _train_test_indices(data.shape[0], idx_file_path)
48+
49+
test = data.iloc[test_indices]
50+
train = data.iloc[train_indices]
51+
52+
test_Y = test[Y_COLUMN]
53+
test_X = test.drop(Y_COLUMN, axis=1)
54+
55+
train_Y = train[Y_COLUMN]
56+
train_X = train.drop(Y_COLUMN, axis=1)
57+
58+
return train_X, train_Y, test_X, test_Y

wtrec/transformer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def _add_global_layout_features(self):
6565

6666
def _add_layout_features(self):
6767
for i in self.cols:
68-
total_rowspan = np.sum([int(bs(x, 'html.parser').find_all(['td', 'th'])[0].attrs.get('rowspan', 0)) for x in i[0]])
68+
total_rowspan = np.sum(
69+
[int(bs(x, 'html.parser').find_all(['td', 'th'])[0].attrs.get('rowspan', 0)) for x in i[0]]
70+
)
6971
num_rowspan = len([1 for x in i[0] if 'rowspan' in bs(x, 'html.parser').find_all(['td', 'th'])[0].attrs])
7072
features = DataFrame({
7173
'avg_length': [np.mean([len(str(elem)) for elem in i[1]])],
@@ -77,7 +79,9 @@ def _add_layout_features(self):
7779
self.obj = concat([self.obj, features])
7880

7981
for i in self.rows:
80-
total_colspan = np.sum([int(bs(x, 'html.parser').find_all(['td', 'th'])[0].attrs.get('colspan', 0)) for x in i[0]])
82+
total_colspan = np.sum(
83+
[int(bs(x, 'html.parser').find_all(['td', 'th'])[0].attrs.get('colspan', 0)) for x in i[0]]
84+
)
8185
num_colspan = len([1 for x in i[0] if 'colspan' in bs(x, 'html.parser').find_all(['td', 'th'])[0].attrs])
8286
features = DataFrame({
8387
'avg_length': [np.mean([len(str(elem)) for elem in i[1]])],
@@ -147,7 +151,7 @@ def transform_for_baseline(raw_dataframe):
147151
try:
148152
with_features = with_features.append(_BaselineSample(row).transform(), ignore_index=True)
149153
except IndexError: # FIXME: only tables with min shape 2x2 in dataset!
150-
continue
154+
print(row['path'])
151155

152156
return with_features
153157

@@ -176,6 +180,7 @@ def transform_for_approach(raw_dataframe):
176180
Args:
177181
Dataframe with columns raw and label
178182
Returns:
179-
Dataframe with columns raw, label and feature space
183+
Dataframe with columns raw, label and imagepath
184+
Generates image representations of web table
180185
"""
181186
pass

0 commit comments

Comments
 (0)