Skip to content

Commit 1b29c1f

Browse files
committed
new traindata append to file. preparation for streamer
1 parent d847c7d commit 1b29c1f

4 files changed

Lines changed: 133 additions & 10 deletions

File tree

compiled/interface/trainData.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ class typeContainer{
3535
void push_back(simpleArrayBase& a);
3636
void move_back(simpleArrayBase& a);
3737

38+
bool operator==(const typeContainer& rhs)const;
39+
bool operator!=(const typeContainer& rhs)const{
40+
return !(*this==rhs);
41+
}
42+
43+
3844
simpleArrayBase& at(size_t idx);
3945
const simpleArrayBase& at(size_t idx)const;
4046

@@ -83,6 +89,11 @@ class trainData{
8389
public:
8490

8591

92+
93+
bool operator==(const trainData& rhs)const;
94+
bool operator!=(const trainData& rhs)const{
95+
return !(*this==rhs);
96+
}
8697
//takes ownership
8798
//these need to be separated by input type because python does not allow for overload
8899
//but then the py interface can be made generic to accept differnt types
@@ -204,6 +215,9 @@ class trainData{
204215
const std::vector<std::vector<int> > & weightShapes()const{return weight_shapes_;}
205216

206217
void writeToFile(std::string filename)const;
218+
void addToFile(std::string filename)const;
219+
220+
void addToFileP(FILE *& f)const;
207221

208222
void readFromFile(std::string filename){
209223
priv_readFromFile(filename,false);
@@ -276,6 +290,9 @@ class trainData{
276290

277291
void priv_readFromFile(std::string filename, bool memcp);
278292

293+
trainData priv_readFromFileP(FILE *& f, const std::string& filename)const;
294+
void priv_readSelfFromFileP(FILE *& f, const std::string& filename);
295+
279296
void checkFile(FILE *& f, const std::string& filename="")const;
280297

281298

compiled/src/c_trainData.C

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ using namespace djc;
1616
BOOST_PYTHON_MODULE(c_trainData) {
1717
Py_Initialize();
1818
np::initialize();
19+
using namespace p;
1920
p::class_<trainData >("trainData")
2021

22+
.def(self==self)
23+
.def(self!=self)
2124

2225
//excplicit overloading
2326
.def<int (trainData::*)(simpleArray_float32&)>("storeFeatureArray", &trainData::storeFeatureArray)
@@ -29,6 +32,7 @@ BOOST_PYTHON_MODULE(c_trainData) {
2932
.def<int (trainData::*)(simpleArray_float32&)>("storeWeightArray", &trainData::storeWeightArray)
3033
.def<int (trainData::*)(simpleArray_int32&)>("storeWeightArray", &trainData::storeWeightArray)
3134

35+
3236
// .def("featureList", &trainData::featureList)
3337
// .def("truthList", &trainData::truthList)
3438
// .def("weightList", &trainData::weightList)
@@ -46,6 +50,7 @@ BOOST_PYTHON_MODULE(c_trainData) {
4650
.def("readFromFile", &trainData::readFromFile)
4751
.def("readFromFileBuffered", &trainData::readFromFileBuffered)
4852
.def("writeToFile", &trainData::writeToFile)
53+
.def("addToFile", &trainData::addToFile)
4954

5055

5156
.def("copy", &trainData::copy)

compiled/src/trainData.cpp

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,25 @@ void typeContainer::move_back(simpleArrayBase& a){
3232
sorting_.push_back({isint,iarrs_.size()-1});
3333
}
3434
}
35+
bool typeContainer::operator==(const typeContainer& rhs)const{
36+
if(size() != rhs.size())
37+
return false;
38+
if(farrs_.size() != rhs.farrs_.size())
39+
return false;
40+
41+
if(sorting_ != rhs.sorting_)
42+
return false;
43+
44+
for(size_t i=0;i<farrs_.size();i++){
45+
if(farrs_.at(i) != rhs.farrs_.at(i))
46+
return false;
47+
}
48+
for(size_t i=0;i<iarrs_.size();i++){
49+
if(iarrs_.at(i) != rhs.iarrs_.at(i))
50+
return false;
51+
}
52+
return true;
53+
}
3554
simpleArrayBase& typeContainer::at(size_t idx){
3655
if(idx>=sorting_.size())
3756
throw std::out_of_range("typeContainer::at: requested "+std::to_string(idx)+" of "+std::to_string(sorting_.size()));
@@ -119,6 +138,22 @@ void typeContainer::readFromFile_priv(FILE *& ifile, bool justmetadata){
119138

120139
////////////////// trainData //////////////////////
121140

141+
bool trainData::operator==(const trainData& rhs)const{
142+
143+
if(feature_arrays_ != rhs.feature_arrays_)
144+
return false;
145+
if(truth_arrays_ != rhs.truth_arrays_)
146+
return false;
147+
if(weight_arrays_ != rhs.weight_arrays_)
148+
return false;
149+
if(feature_shapes_ != rhs.feature_shapes_)
150+
return false;
151+
if(truth_shapes_ != rhs.truth_shapes_)
152+
return false;
153+
if(weight_shapes_ != rhs. weight_shapes_)
154+
return false;
155+
return true;
156+
}
122157

123158

124159
int trainData::storeFeatureArray(simpleArrayBase & a){
@@ -267,6 +302,19 @@ bool trainData::validSlice(size_t splitindex_begin, size_t splitindex_end)const{
267302
void trainData::writeToFile(std::string filename)const{
268303

269304
FILE *ofile = fopen(filename.data(), "wb");
305+
addToFileP(ofile);
306+
fclose(ofile);
307+
308+
}
309+
310+
void trainData::addToFile(std::string filename)const{
311+
312+
FILE *ofile = fopen(filename.data(), "ab");
313+
addToFileP(ofile);
314+
fclose(ofile);
315+
}
316+
317+
void trainData::addToFileP(FILE *& ofile)const{
270318
float version = DJCDATAVERSION;
271319
io::writeToFile(&version, ofile);
272320

@@ -279,15 +327,13 @@ void trainData::writeToFile(std::string filename)const{
279327
feature_arrays_.writeToFile(ofile);
280328
truth_arrays_.writeToFile(ofile);
281329
weight_arrays_.writeToFile(ofile);
282-
fclose(ofile);
283-
284330
}
285331

286332
void trainData::priv_readFromFile(std::string filename, bool memcp){
287333
clear();
288334
FILE *ifile = fopen(filename.data(), "rb");
289335
char *buf = 0;
290-
if(memcp){
336+
if(false && memcp){
291337
FILE *diskfile = ifile;
292338
//check if exists before trying to memcp.
293339
checkFile(ifile, filename); //not set at start but won't be used
@@ -307,6 +353,37 @@ void trainData::priv_readFromFile(std::string filename, bool memcp){
307353
ifile = fmemopen(buf,fsize,"r");
308354
}
309355

356+
priv_readSelfFromFileP(ifile,filename);
357+
//check for eof and add until done. the append step can be heavily optimized! FIXME
358+
//read one more byte
359+
int ch = getc(ifile);
360+
while(! feof(ifile)){
361+
fseek(ifile,-1,SEEK_CUR);
362+
append(priv_readFromFileP(ifile,filename));
363+
ch = getc(ifile);
364+
}
365+
366+
fclose(ifile);
367+
if(buf){
368+
delete buf;
369+
}
370+
}
371+
372+
trainData trainData::priv_readFromFileP(FILE *& ifile, const std::string& filename)const{
373+
//include file version check
374+
trainData out;
375+
out.checkFile(ifile, filename);
376+
out.readNested(out.feature_shapes_, ifile);
377+
out.readNested(out.truth_shapes_, ifile);
378+
out.readNested(out.weight_shapes_, ifile);
379+
380+
out.feature_arrays_ .readFromFile(ifile);
381+
out.truth_arrays_.readFromFile(ifile);
382+
out.weight_arrays_.readFromFile(ifile);
383+
return out;
384+
}
385+
386+
void trainData::priv_readSelfFromFileP(FILE *& ifile, const std::string& filename){
310387
checkFile(ifile, filename);
311388
readNested(feature_shapes_, ifile);
312389
readNested(truth_shapes_, ifile);
@@ -315,12 +392,6 @@ void trainData::priv_readFromFile(std::string filename, bool memcp){
315392
feature_arrays_ .readFromFile(ifile);
316393
truth_arrays_.readFromFile(ifile);
317394
weight_arrays_.readFromFile(ifile);
318-
319-
fclose(ifile);
320-
if(buf){
321-
delete buf;
322-
}
323-
324395
}
325396

326397
void trainData::readMetaDataFromFile(const std::string& filename){

testing/unit/TestTrainData.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,37 @@ def test_store(self):
5353

5454
def test_readWrite(self):
5555
print('TestTrainData: readWrite')
56-
self.sub_test_store(True)
56+
self.sub_test_store(True)
57+
58+
def nestedEqual(self,l,l2):
59+
for a,b in zip(l,l2):
60+
if not np.all(a==b):
61+
return False
62+
return True
63+
64+
def test_AddToFile(self):
65+
print('TestTrainData: AddToFile')
66+
67+
td = TrainData()
68+
x,y,w = self.createSimpleArray('int32'), self.createSimpleArray('float32'), self.createSimpleArray('int32')
69+
xo,yo,wo = x.copy(),y.copy(),w.copy()
70+
x2,y2,_ = self.createSimpleArray('float32'), self.createSimpleArray('float32'), self.createSimpleArray('int32')
71+
x2o,y2o = x2.copy(),y2.copy()
72+
td._store([x,x2], [y,y2], [w])
73+
74+
td.writeToFile("testfile.tdjctd")
75+
td.addToFile("testfile.tdjctd")
76+
77+
78+
td2 = TrainData()
79+
td2._store([xo,x2o], [yo,y2o], [wo])
80+
td2.append(td)
81+
82+
td.readFromFile("testfile.tdjctd")
83+
os.system('rm -f testfile.tdjctd')
84+
85+
86+
self.assertEqual(td,td2)
5787

5888
def test_split(self):
5989
print('TestTrainData: split')

0 commit comments

Comments
 (0)