Skip to content

Commit 7fbf795

Browse files
authored
Merge pull request #5 from weiguoz/impl_interface
Implement SQLFlow client interface
2 parents bab290d + ec019ed commit 7fbf795

File tree

5 files changed

+259
-23
lines changed

5 files changed

+259
-23
lines changed

pom.xml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,36 @@
3535
<artifactId>grpc-stub</artifactId>
3636
<version>${grpc.version}</version>
3737
</dependency>
38+
<dependency>
39+
<groupId>org.apache.commons</groupId>
40+
<artifactId>commons-lang3</artifactId>
41+
<version>3.9</version>
42+
</dependency>
43+
<!-- coding tools -->
44+
<dependency>
45+
<groupId>org.projectlombok</groupId>
46+
<artifactId>lombok</artifactId>
47+
<version>1.18.10</version>
48+
</dependency>
49+
<!-- unit test tools -->
3850
<dependency>
3951
<groupId>junit</groupId>
4052
<artifactId>junit</artifactId>
4153
<version>${junit.version}</version>
4254
<scope>test</scope>
4355
</dependency>
56+
<dependency>
57+
<groupId>io.grpc</groupId>
58+
<artifactId>grpc-testing</artifactId>
59+
<version>${grpc.version}</version>
60+
<scope>test</scope>
61+
</dependency>
62+
<dependency>
63+
<groupId>org.mockito</groupId>
64+
<artifactId>mockito-core</artifactId>
65+
<version>3.1.0</version>
66+
<scope>test</scope>
67+
</dependency>
4468
</dependencies>
4569

4670
<build>

src/main/java/org/sqlflow/client/SQLFlow.java

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,45 +15,48 @@
1515

1616
package org.sqlflow.client;
1717

18-
import java.net.ConnectException;
18+
import io.grpc.StatusRuntimeException;
19+
import org.sqlflow.client.model.RequestHeader;
1920
import proto.Sqlflow.JobStatus;
2021

2122
public interface SQLFlow {
2223
/**
23-
* Open the connection(channel) to the SQLFlow server. The serverUrl argument always ends with a
24-
* port.
24+
* Open a channel to the SQLFlow server. The serverUrl argument always ends with a port.
2525
*
2626
* @param serverUrl an address the SQLFlow server exposed.
2727
* <p>Example: "localhost:50051"
28-
* @throws ConnectException when encountering the bad network.
2928
*/
30-
void open(String serverUrl) throws ConnectException;
29+
void init(String serverUrl);
3130

3231
/**
3332
* Submit a task to SQLFlow server. This method return immediately.
3433
*
35-
* @param sql: sql program. *
34+
* @param header: specify datasource, user ...
35+
* @param sql: sql program.
3636
* <p>Example: "SELECT * FROM iris.test; SELECT * FROM iris.iris TO TRAIN DNNClassifier
3737
* COLUMN..." *
3838
* @return return a job id for tracking.
39-
* @throws Exception TODO(weiguo): more precise
39+
* @throws IllegalArgumentException header or sql error
40+
* @throws StatusRuntimeException
4041
*/
41-
String submit(String sql) throws Exception;
42+
String submit(RequestHeader header, String sql)
43+
throws IllegalArgumentException, StatusRuntimeException;
4244

4345
/**
4446
* Fetch the job status by job id. The job id always returned by submit. By fetch(), we are able
4547
* to tracking the job status
4648
*
4749
* @param jobId specific the job we are going to track
4850
* @return see @code proto.JobStatus.Code
49-
* @throws Exception TODO(weiguo): more precise
51+
* @throws StatusRuntimeException
5052
*/
51-
JobStatus fetch(String jobId) throws Exception;
53+
JobStatus fetch(String jobId) throws StatusRuntimeException;
5254

5355
/**
54-
* Close the opened connection(channel) to SQLFlow server
56+
* Close the opened channel to SQLFlow server. Waits for the channel to become terminated, giving
57+
* up if the timeout is reached.
5558
*
56-
* @throws Exception TODO(weiguo): more precise
59+
* @throws InterruptedException thrown by awaitTermination
5760
*/
58-
void close() throws Exception;
61+
void shutdown() throws InterruptedException;
5962
}

src/main/java/org/sqlflow/client/impl/SQLFlowImpl.java

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,78 @@
1515

1616
package org.sqlflow.client.impl;
1717

18-
import java.net.ConnectException;
18+
import io.grpc.ManagedChannel;
19+
import io.grpc.ManagedChannelBuilder;
20+
import io.grpc.StatusRuntimeException;
21+
import java.util.concurrent.TimeUnit;
22+
import org.apache.commons.lang3.StringUtils;
1923
import org.sqlflow.client.SQLFlow;
24+
import org.sqlflow.client.model.RequestHeader;
25+
import proto.SQLFlowGrpc;
26+
import proto.Sqlflow.Job;
2027
import proto.Sqlflow.JobStatus;
28+
import proto.Sqlflow.Request;
29+
import proto.Sqlflow.Session;
2130

2231
public class SQLFlowImpl implements SQLFlow {
23-
// private final ManagedChannel channel;
24-
// private final SQLFlowGrpc.SQLFlowBlockingStub blockingStub;
32+
private ManagedChannel channel;
33+
private SQLFlowGrpc.SQLFlowBlockingStub blockingStub;
2534

26-
public void open(String serverUrl) throws ConnectException {}
35+
public void init(String serverUrl) {
36+
this.channel = ManagedChannelBuilder.forTarget(serverUrl).usePlaintext().build();
37+
blockingStub = SQLFlowGrpc.newBlockingStub(channel);
38+
}
2739

28-
public String submit(String sql) throws Exception {
29-
return null;
40+
public SQLFlowImpl(ManagedChannel channel) {
41+
this.channel = channel;
42+
blockingStub = SQLFlowGrpc.newBlockingStub(channel);
3043
}
3144

32-
public JobStatus fetch(String jobId) throws Exception {
33-
return null;
45+
public String submit(RequestHeader header, String sql)
46+
throws IllegalArgumentException, StatusRuntimeException {
47+
if (header == null || StringUtils.isAnyBlank(header.getDataSource(), header.getUserId())) {
48+
throw new IllegalArgumentException("data source and userId are not allowed to be empty");
49+
}
50+
if (StringUtils.isBlank(sql)) {
51+
throw new IllegalArgumentException("sql is empty");
52+
}
53+
54+
Session session =
55+
Session.newBuilder()
56+
.setDbConnStr(header.getDataSource())
57+
.setUserId(header.getUserId())
58+
.setExitOnSubmit(header.isExitOnSubmit())
59+
.setHiveLocation(StringUtils.defaultString(header.getHiveLocation()))
60+
.setHdfsNamenodeAddr(StringUtils.defaultString(header.getHdfsNameNode()))
61+
.setHdfsUser(StringUtils.defaultString(header.getHdfsUser()))
62+
.setHdfsPass(StringUtils.defaultString(header.getHdfsPassword()))
63+
.build();
64+
Request req = Request.newBuilder().setSession(session).setSql(sql).build();
65+
try {
66+
Job job = blockingStub.submit(req);
67+
return job.getId();
68+
} catch (StatusRuntimeException e) {
69+
// TODO(weiguo) logger.error
70+
throw e;
71+
}
3472
}
3573

36-
public void close() throws Exception {}
74+
public JobStatus fetch(String jobId) throws StatusRuntimeException {
75+
Job req = Job.newBuilder().setId(jobId).build();
76+
try {
77+
return blockingStub.fetch(req);
78+
} catch (StatusRuntimeException e) {
79+
// TODO(weiguo) logger.error
80+
throw e;
81+
}
82+
}
83+
84+
public void shutdown() throws InterruptedException {
85+
try {
86+
channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
87+
} catch (InterruptedException e) {
88+
// TODO(weiguo) logger.error
89+
throw e;
90+
}
91+
}
3792
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright 2019 The SQLFlow Authors. All rights reserved.
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
package org.sqlflow.client.model;
17+
18+
import lombok.Data;
19+
20+
@Data
21+
public class RequestHeader {
22+
/**
23+
* database source, for:
24+
*
25+
* <p>maxcomputer
26+
* maxcompute://{accesskey_id}:{accesskey_secret}@{endpoint}?curr_project={curr_project}&scheme={scheme}
27+
*
28+
* <p>mysql
29+
* mysql://{username}:{password}@tcp({address})/{dbname}[?param1=value1&...&paramN=valueN]
30+
*
31+
* <p>hive
32+
* hive://user:password@ip:port/dbname[?auth=<auth_mechanism>&session.<cfg_key1>=<cfg_value1>...&session<cfg_keyN>=valueN]
33+
*/
34+
private String dataSource;
35+
36+
/** user who submits the SQL task. */
37+
private String userId;
38+
39+
/* for alps */
40+
private boolean exitOnSubmit;
41+
42+
/* hive */
43+
private String hiveLocation;
44+
private String hdfsNameNode;
45+
private String hdfsUser;
46+
private String hdfsPassword;
47+
}

src/test/java/org/sqlflow/client/SQLFlowTest.java

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,116 @@
1515

1616
package org.sqlflow.client;
1717

18+
import static org.junit.Assert.assertEquals;
19+
import static org.mockito.AdditionalAnswers.delegatesTo;
20+
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.verify;
22+
23+
import io.grpc.ManagedChannel;
24+
import io.grpc.inprocess.InProcessChannelBuilder;
25+
import io.grpc.inprocess.InProcessServerBuilder;
26+
import io.grpc.stub.StreamObserver;
27+
import io.grpc.testing.GrpcCleanupRule;
28+
import org.junit.After;
29+
import org.junit.Before;
30+
import org.junit.Rule;
1831
import org.junit.Test;
32+
import org.junit.runner.RunWith;
33+
import org.junit.runners.JUnit4;
34+
import org.mockito.ArgumentCaptor;
35+
import org.mockito.ArgumentMatchers;
36+
import org.sqlflow.client.impl.SQLFlowImpl;
37+
import org.sqlflow.client.model.RequestHeader;
38+
import proto.SQLFlowGrpc;
39+
import proto.Sqlflow.Job;
40+
import proto.Sqlflow.JobStatus;
41+
import proto.Sqlflow.JobStatus.Code;
42+
import proto.Sqlflow.Request;
43+
import proto.Sqlflow.Session;
1944

45+
@RunWith(JUnit4.class)
2046
public class SQLFlowTest {
47+
private SQLFlow client;
48+
/**
49+
* This rule manages automatic graceful shutdown for the registered servers and channels at the
50+
* end of test.
51+
*/
52+
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
53+
54+
private final SQLFlowGrpc.SQLFlowImplBase grpcService =
55+
mock(
56+
SQLFlowGrpc.SQLFlowImplBase.class,
57+
delegatesTo(
58+
new SQLFlowGrpc.SQLFlowImplBase() {
59+
public void submit(Request request, StreamObserver<Job> response) {
60+
Session session = request.getSession();
61+
String userId = session.getUserId();
62+
response.onNext(
63+
Job.newBuilder().setId(mockJobId(userId, request.getSql())).build());
64+
response.onCompleted();
65+
}
66+
67+
public void fetch(Job request, StreamObserver<JobStatus> response) {
68+
String jobId = request.getId();
69+
response.onNext(
70+
JobStatus.newBuilder()
71+
.setCodeValue(Code.PENDING_VALUE)
72+
.setMessage(mockMessage(jobId))
73+
.build());
74+
response.onCompleted();
75+
}
76+
}));
77+
78+
@Before
79+
public void setUp() throws Exception {
80+
String serverName = InProcessServerBuilder.generateName();
81+
grpcCleanup.register(
82+
InProcessServerBuilder.forName(serverName)
83+
.directExecutor()
84+
.addService(grpcService)
85+
.build()
86+
.start());
87+
88+
ManagedChannel channel =
89+
grpcCleanup.register(InProcessChannelBuilder.forName(serverName).directExecutor().build());
90+
client = new SQLFlowImpl(channel);
91+
}
92+
2193
@Test
22-
public void testDummy() {}
94+
public void testSubmit() {
95+
String userId = "314159";
96+
String sql = "SELECT * TO TRAIN DNNClassify WITH ... COLUMN ... INTO ..";
97+
98+
ArgumentCaptor<Request> requestCaptor = ArgumentCaptor.forClass(Request.class);
99+
RequestHeader header = new RequestHeader();
100+
header.setUserId(userId);
101+
header.setDataSource("mysql://root@root@127.0.0.1:3306/iris");
102+
String jobId = client.submit(header, sql);
103+
assertEquals(mockJobId(userId, sql), jobId);
104+
verify(grpcService)
105+
.submit(requestCaptor.capture(), ArgumentMatchers.<StreamObserver<Job>>any());
106+
assertEquals(sql, requestCaptor.getValue().getSql());
107+
}
108+
109+
private String mockJobId(String userId, String sql) {
110+
return userId + "/" + sql;
111+
}
112+
113+
private String mockMessage(String jobId) {
114+
return "Hello " + jobId;
115+
}
116+
117+
@Test
118+
public void testFetch() {
119+
String jobId = "this is a job id";
120+
JobStatus jobStatus = client.fetch(jobId);
121+
122+
assertEquals(Code.PENDING_VALUE, jobStatus.getCode().getNumber());
123+
assertEquals(mockMessage(jobId), jobStatus.getMessage());
124+
}
125+
126+
@After
127+
public void tearDown() throws Exception {
128+
client.shutdown();
129+
}
23130
}

0 commit comments

Comments
 (0)