Skip to content

Commit 377194e

Browse files
authored
feat: Implement transformations logic. (#1800)
This is a first draft of sdk changes to add transformations support, so that transformer plugins can be created. I don't want to merge it yet, because the channel logic is not well thought out. I'd rather want to do a test sync locally with some transformation.
1 parent 1b9be18 commit 377194e

8 files changed

Lines changed: 318 additions & 2 deletions

File tree

examples/simple_plugin/plugin/client.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ func (*Client) Read(context.Context, *schema.Table, chan<- arrow.Record) error {
6464
return nil
6565
}
6666

67+
func (*Client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arrow.Record) error {
68+
// Not implemented, just used for testing destination packaging
69+
return nil
70+
}
71+
6772
func Configure(_ context.Context, logger zerolog.Logger, spec []byte, opts plugin.NewClientOptions) (plugin.Client, error) {
6873
if opts.NoConnection {
6974
return &Client{

internal/memdb/memdb.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ func (c *client) deleteRecord(_ context.Context, msg *message.WriteDeleteRecord)
306306
c.memoryDB[tableName] = filteredTable
307307
}
308308

309+
func (*client) Transform(_ context.Context, _ <-chan arrow.Record, _ chan<- arrow.Record) error {
310+
return nil
311+
}
312+
309313
func evaluatePredicate(pred message.Predicate, record arrow.Record) bool {
310314
sc := record.Schema()
311315
indices := sc.FieldIndices(pred.Column)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package reversertransformer
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/apache/arrow/go/v16/arrow"
8+
"github.com/apache/arrow/go/v16/arrow/array"
9+
"github.com/apache/arrow/go/v16/arrow/memory"
10+
"github.com/cloudquery/plugin-sdk/v4/plugin"
11+
"github.com/rs/zerolog"
12+
)
13+
14+
// client is mostly used for testing the destination plugin.
15+
type client struct {
16+
plugin.UnimplementedDestination
17+
plugin.UnimplementedSource
18+
}
19+
20+
type Option func(*client)
21+
22+
type Spec struct {
23+
}
24+
25+
func GetNewClient(options ...Option) plugin.NewClientFunc {
26+
c := &client{}
27+
for _, opt := range options {
28+
opt(c)
29+
}
30+
return func(context.Context, zerolog.Logger, []byte, plugin.NewClientOptions) (plugin.Client, error) {
31+
return c, nil
32+
}
33+
}
34+
35+
func (*client) GetSpec() any {
36+
return &Spec{}
37+
}
38+
39+
func (*client) Close(context.Context) error {
40+
return nil
41+
}
42+
43+
func (c *client) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error {
44+
for {
45+
select {
46+
case record, ok := <-recvRecords:
47+
if !ok {
48+
return nil
49+
}
50+
reversedRecord, err := c.reverseStrings(record)
51+
if err != nil {
52+
return err
53+
}
54+
sendRecords <- reversedRecord
55+
case <-ctx.Done():
56+
return nil
57+
}
58+
}
59+
}
60+
61+
func (*client) reverseStrings(record arrow.Record) (arrow.Record, error) {
62+
for i, column := range record.Columns() {
63+
if column.DataType().ID() != arrow.STRING {
64+
continue
65+
}
66+
newColumnData := []string{}
67+
for i := 0; i < column.Len(); i++ {
68+
if !column.IsValid(i) {
69+
continue
70+
}
71+
s := column.ValueStr(i)
72+
runes := []rune(s)
73+
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
74+
runes[i], runes[j] = runes[j], runes[i]
75+
}
76+
newColumnData = append(newColumnData, string(runes))
77+
}
78+
fmt.Println("new column data is ", newColumnData)
79+
mem := memory.NewGoAllocator()
80+
bld := array.NewStringBuilder(mem)
81+
82+
// create an array with 4 values, no null
83+
bld.AppendValues(newColumnData, nil)
84+
var err error
85+
record, err = record.SetColumn(i, bld.NewStringArray())
86+
if err != nil {
87+
return nil, err
88+
}
89+
}
90+
return record, nil
91+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package reversertransformer
2+
3+
import (
4+
"context"
5+
"io"
6+
"testing"
7+
8+
"github.com/apache/arrow/go/v16/arrow"
9+
"github.com/apache/arrow/go/v16/arrow/array"
10+
"github.com/apache/arrow/go/v16/arrow/memory"
11+
pb "github.com/cloudquery/plugin-pb-go/pb/plugin/v3"
12+
internalPlugin "github.com/cloudquery/plugin-sdk/v4/internal/servers/plugin/v3"
13+
"github.com/cloudquery/plugin-sdk/v4/plugin"
14+
"github.com/stretchr/testify/require"
15+
"google.golang.org/grpc"
16+
"google.golang.org/grpc/metadata"
17+
)
18+
19+
var mem = memory.NewGoAllocator()
20+
21+
func TestReverserTransformer(t *testing.T) {
22+
p := plugin.NewPlugin("test", "development", GetNewClient())
23+
s := internalPlugin.Server{
24+
Plugin: p,
25+
}
26+
_, err := s.Init(context.Background(), &pb.Init_Request{
27+
Spec: []byte("{}"),
28+
NoConnection: true,
29+
InvocationId: "26b550f9-c6f8-4b4b-9ec4-773bab288ee6",
30+
})
31+
require.NoError(t, err)
32+
requests := makeRequestsFromStrings("hello", "world")
33+
stream := mockTransformServer{incomingMessages: requests}
34+
require.NoError(t, s.Transform(&stream))
35+
require.Equal(t, 2, len(stream.outgoingMessages))
36+
37+
record1, err := pb.NewRecordFromBytes(stream.outgoingMessages[0].Record)
38+
require.NoError(t, err)
39+
record2, err := pb.NewRecordFromBytes(stream.outgoingMessages[1].Record)
40+
require.NoError(t, err)
41+
42+
require.Equal(t, "olleh", record1.Column(0).ValueStr(0))
43+
require.Equal(t, "dlrow", record2.Column(0).ValueStr(0))
44+
}
45+
46+
func makeRequestsFromStrings(s ...string) []*pb.Transform_Request {
47+
requests := make([]*pb.Transform_Request, len(s))
48+
for i, str := range s {
49+
requests[i] = makeRequestFromString(str)
50+
}
51+
return requests
52+
}
53+
54+
func makeRequestFromString(s string) *pb.Transform_Request {
55+
record := makeRecordFromString(s)
56+
bs, _ := pb.RecordToBytes(record)
57+
return &pb.Transform_Request{Record: bs}
58+
}
59+
60+
func makeRecordFromString(s string) arrow.Record {
61+
str := array.NewStringBuilder(mem)
62+
str.AppendString(s)
63+
arr := str.NewStringArray()
64+
schema := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, nil)
65+
66+
return array.NewRecord(schema, []arrow.Array{arr}, 1)
67+
}
68+
69+
type mockTransformServer struct {
70+
grpc.ServerStream
71+
incomingMessages []*pb.Transform_Request
72+
outgoingMessages []*pb.Transform_Response
73+
}
74+
75+
func (*mockTransformServer) SendAndClose(*pb.Transform_Response) error {
76+
return nil
77+
}
78+
func (s *mockTransformServer) Recv() (*pb.Transform_Request, error) {
79+
if len(s.incomingMessages) > 0 {
80+
msg := s.incomingMessages[0]
81+
s.incomingMessages = s.incomingMessages[1:]
82+
return msg, nil
83+
}
84+
return nil, io.EOF
85+
}
86+
func (s *mockTransformServer) Send(resp *pb.Transform_Response) error {
87+
s.outgoingMessages = append(s.outgoingMessages, resp)
88+
return nil
89+
}
90+
func (*mockTransformServer) SetHeader(metadata.MD) error {
91+
return nil
92+
}
93+
func (*mockTransformServer) SendHeader(metadata.MD) error {
94+
return nil
95+
}
96+
func (*mockTransformServer) SetTrailer(metadata.MD) {
97+
}
98+
func (mockTransformServer) Context() context.Context {
99+
return context.Background()
100+
}
101+
func (mockTransformServer) SendMsg(any) error {
102+
return nil
103+
}
104+
func (mockTransformServer) RecvMsg(any) error {
105+
return nil
106+
}

internal/servers/plugin/v3/plugin.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,87 @@ func (s *Server) Write(stream pb.Plugin_WriteServer) error {
394394
}
395395
}
396396

397+
func (s *Server) Transform(stream pb.Plugin_TransformServer) error {
398+
var (
399+
recvRecords = make(chan arrow.Record)
400+
sendRecords = make(chan arrow.Record)
401+
pluginStopsWriter = make(chan struct{})
402+
doneReading = false
403+
ctx = stream.Context()
404+
eg, gctx = errgroup.WithContext(ctx)
405+
)
406+
407+
// Run the plugin's transform with both channels.
408+
//
409+
// When the plugin is done, it must return with either an error or nil.
410+
// The plugin must not close either channel.
411+
eg.Go(func() error {
412+
err := s.Plugin.Transform(gctx, recvRecords, sendRecords)
413+
close(pluginStopsWriter)
414+
doneReading = true
415+
return err
416+
})
417+
418+
// Write transformed records from transformer to destination.
419+
//
420+
// Currently the `sendRecords` channel is never closed. Instead, the plugin finishes this goroutine
421+
// when it returns, either with an error or null.
422+
//
423+
// The reading never closes the writer, because it's up to the Plugin to decide when to finish
424+
// writing, regardless of if the reading finished.
425+
eg.Go(func() error {
426+
for {
427+
select {
428+
case record := <-sendRecords:
429+
recordBytes, err := pb.RecordToBytes(record)
430+
if err != nil {
431+
return status.Errorf(codes.Internal, "failed to convert record to bytes: %v", err)
432+
}
433+
434+
if err := stream.Send(&pb.Transform_Response{Record: recordBytes}); err != nil {
435+
return fmt.Errorf("error sending response: %w", err)
436+
}
437+
case <-pluginStopsWriter:
438+
return nil
439+
}
440+
}
441+
})
442+
443+
// Read records from source to transformer
444+
//
445+
// If there's an error receiving or deserialising records, or if there are no more records,
446+
// the `recvRecords` channel will be closed. This will tell the plugin's transformer that
447+
// no more transforming can be done.
448+
//
449+
// The writer cannot stop the reader even on error, but the plugin will when it returns,
450+
// by setting `doneReading` to true.
451+
eg.Go(func() error {
452+
for {
453+
req, err := stream.Recv()
454+
if err == io.EOF {
455+
close(recvRecords)
456+
return nil
457+
}
458+
if err != nil {
459+
close(recvRecords)
460+
return fmt.Errorf("Error receiving request: %v", err)
461+
}
462+
if doneReading {
463+
return nil
464+
}
465+
record, err := pb.NewRecordFromBytes(req.Record)
466+
if err != nil {
467+
close(recvRecords)
468+
return status.Errorf(codes.InvalidArgument, "failed to create record: %v", err)
469+
}
470+
471+
recvRecords <- record
472+
}
473+
})
474+
475+
return eg.Wait()
476+
}
477+
397478
func (s *Server) Close(ctx context.Context, _ *pb.Close_Request) (*pb.Close_Response, error) {
398479
return &pb.Close_Response{}, s.Plugin.Close(ctx)
399480
}

plugin/plugin.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@ type NewClientFunc func(context.Context, zerolog.Logger, []byte, NewClientOption
2626
type Client interface {
2727
SourceClient
2828
DestinationClient
29+
TransformerClient
2930
}
3031

31-
type UnimplementedDestination struct{}
32+
type UnimplementedDestination struct {
33+
UnimplementedTransformer
34+
}
3235

3336
func (UnimplementedDestination) Write(context.Context, <-chan message.WriteMessage) error {
3437
return ErrNotImplemented
@@ -38,7 +41,9 @@ func (UnimplementedDestination) Read(context.Context, *schema.Table, chan<- arro
3841
return ErrNotImplemented
3942
}
4043

41-
type UnimplementedSource struct{}
44+
type UnimplementedSource struct {
45+
UnimplementedTransformer
46+
}
4247

4348
func (UnimplementedSource) Sync(context.Context, SyncOptions, chan<- message.SyncMessage) error {
4449
return ErrNotImplemented
@@ -48,6 +53,12 @@ func (UnimplementedSource) Tables(context.Context, TableOptions) (schema.Tables,
4853
return nil, ErrNotImplemented
4954
}
5055

56+
type UnimplementedTransformer struct{}
57+
58+
func (UnimplementedTransformer) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
59+
return ErrNotImplemented
60+
}
61+
5162
// Plugin is the base structure required to pass to sdk.serve
5263
// We take a declarative approach to API here similar to Cobra
5364
type Plugin struct {

plugin/plugin_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ func (c *testPluginClient) Write(_ context.Context, res <-chan message.WriteMess
5656
func (*testPluginClient) Close(context.Context) error {
5757
return nil
5858
}
59+
func (*testPluginClient) Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error {
60+
return nil
61+
}
5962

6063
func TestPluginSuccess(t *testing.T) {
6164
ctx := context.Background()

plugin/plugin_transformer.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package plugin
2+
3+
import (
4+
"context"
5+
6+
"github.com/apache/arrow/go/v16/arrow"
7+
)
8+
9+
type TransformerClient interface {
10+
Transform(context.Context, <-chan arrow.Record, chan<- arrow.Record) error
11+
}
12+
13+
func (p *Plugin) Transform(ctx context.Context, recvRecords <-chan arrow.Record, sendRecords chan<- arrow.Record) error {
14+
return p.client.Transform(ctx, recvRecords, sendRecords)
15+
}

0 commit comments

Comments
 (0)