Skip to content

Commit fde2e02

Browse files
committed
Add go-tfdata integration (tests)
1 parent b6b3cd0 commit fde2e02

4 files changed

Lines changed: 149 additions & 10 deletions

File tree

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ bin/tarp-full: $(cmds) $(datapipes)
1515
bin/tarp -h
1616

1717
test:
18-
cd datapipes && go test -v
18+
cd dpipes && go test -v
19+
20+
test-tfdata:
21+
cd dpipes && go test -v --tags=gitlabnvidia
1922

2023
dtest:
2124
cd datapipes && debug=stdout go test -v | tee ../test.log

dpipes/go.mod

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module github.com/tmbdev/tarp/dpipes
33
go 1.14
44

55
require (
6+
github.com/NVIDIA/go-tfdata v0.3.1
67
github.com/shamaton/msgpack v1.1.1
7-
github.com/stretchr/testify v1.2.2
8+
github.com/stretchr/testify v1.3.0
89
)

dpipes/gotfdata_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package dpipes
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"io/ioutil"
8+
"os"
9+
"testing"
10+
11+
"github.com/NVIDIA/go-tfdata/tfdata/core"
12+
"github.com/NVIDIA/go-tfdata/tfdata/transform"
13+
"github.com/stretchr/testify/assert"
14+
)
15+
16+
type (
17+
SamplesReader struct {
18+
pipe Pipe
19+
}
20+
)
21+
22+
func (r *SamplesReader) Read() (sample *core.Sample, err error) {
23+
s, ok := <-r.pipe
24+
if !ok {
25+
return nil, io.EOF
26+
}
27+
28+
return tarpSampleToTfDataSample(s), nil
29+
}
30+
31+
func TFRecordSink(t *testing.T, writer io.Writer) Sink {
32+
return func(pipe Pipe) {
33+
w := core.NewTFRecordWriter(writer)
34+
samplesReader := &SamplesReader{pipe}
35+
tfExamplesReader := transform.SamplesToTFExample(samplesReader)
36+
err := w.WriteMessages(tfExamplesReader)
37+
38+
assert.NoError(t, err)
39+
}
40+
}
41+
42+
func TFRecordSource(t *testing.T, reader io.Reader) Source {
43+
return func(pipe Pipe) {
44+
defer close(pipe)
45+
var (
46+
ex *core.TFExample
47+
err error
48+
r core.TFExampleReader
49+
)
50+
r = core.NewTFRecordReader(reader)
51+
for ex, err = r.Read(); err == nil; ex, err = r.Read() {
52+
pipe <- tfExampleTarpSample(ex)
53+
}
54+
if err != io.EOF {
55+
assert.Fail(t, "expected to get io.EOF, got %v instead", err)
56+
}
57+
}
58+
}
59+
60+
func SamplesChecker(t *testing.T, target int) Process {
61+
return func(in, out Pipe) {
62+
total := 0
63+
for s := range in {
64+
assert.Equal(t, s["txt"], Bytes(fmt.Sprintf("%d", total)))
65+
assert.Equal(t, s["__key__"], Bytes(fmt.Sprintf("%06d", total)))
66+
total++
67+
out <- s
68+
}
69+
close(out)
70+
assert.Equal(t, target, total)
71+
}
72+
}
73+
74+
func tarpSampleToTfDataSample(sample Sample) *core.Sample {
75+
s := core.NewSample()
76+
for k, v := range sample {
77+
s.Entries[k] = v
78+
}
79+
return s
80+
}
81+
82+
func tfExampleTarpSample(example *core.TFExample) Sample {
83+
s := make(map[string]Bytes, len(example.GetFeatures().Feature))
84+
for k, v := range example.GetFeatures().Feature {
85+
var b Bytes
86+
err := json.Unmarshal(v.GetBytesList().Value[0], &b)
87+
if err != nil {
88+
panic(err)
89+
}
90+
s[k] = b // assume that all TFExample features are just a list of bytes
91+
}
92+
return s
93+
}
94+
95+
func PrepareTarSource() Source {
96+
return func(pipe Pipe) {
97+
for i := 0; i < 1; i++ {
98+
pipe <- Sample{
99+
"__key__": Bytes(fmt.Sprintf("%06d", i)),
100+
"txt": Bytes(fmt.Sprintf("%d", i)),
101+
}
102+
}
103+
close(pipe)
104+
}
105+
}
106+
107+
func prepareTar(t *testing.T) *os.File {
108+
var (
109+
sinkFd *os.File
110+
err error
111+
)
112+
sinkFd, err = ioutil.TempFile("", "go-tfdata-*.tar")
113+
assert.NoError(t, err)
114+
115+
sink := TarSink(sinkFd)
116+
Processing(PrepareTarSource(), nil, sink)
117+
return sinkFd
118+
}
119+
120+
func TestGoTfData(t *testing.T) {
121+
var (
122+
sourceFd = prepareTar(t)
123+
sinkFd *os.File
124+
err error
125+
)
126+
127+
defer os.RemoveAll(sourceFd.Name())
128+
sourceFd, err = os.Open(sourceFd.Name())
129+
assert.NoError(t, err)
130+
131+
sinkFd, err = ioutil.TempFile("", "go-tfdata-*.tfrecord")
132+
assert.NoError(t, err)
133+
defer os.RemoveAll(sinkFd.Name())
134+
135+
Processing(TarSource(sourceFd), nil, TFRecordSink(t, sinkFd))
136+
sinkFd.Close()
137+
sourceFd, err = os.Open(sinkFd.Name())
138+
assert.NoError(t, err)
139+
sinkFd, err = os.OpenFile(os.DevNull, os.O_RDWR, os.ModeAppend)
140+
assert.NoError(t, err)
141+
142+
Processing(TFRecordSource(t, sourceFd), SamplesChecker(t, 1), TFRecordSink(t, sinkFd))
143+
}

go.mod

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,3 @@ module github.com/tmbdev/tarp
33
replace github.com/tmbdev/tarp/dpipes => ./dpipes
44

55
go 1.14
6-
7-
require (
8-
github.com/bcicen/ctop v0.7.3 // indirect
9-
github.com/jessevdk/go-flags v1.4.0
10-
github.com/maruel/panicparse v1.3.0 // indirect
11-
github.com/stretchr/testify v1.2.2
12-
github.com/tmbdev/tarp/dpipes v0.0.0-20200330012711-53823ac810b9
13-
)

0 commit comments

Comments
 (0)