Skip to content

Commit 7986adf

Browse files
committed
add ci
1 parent ffb0c98 commit 7986adf

5 files changed

Lines changed: 136 additions & 10 deletions

File tree

pom.xml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,31 @@
3838
<artifactId>commons-lang3</artifactId>
3939
<version>3.9</version>
4040
</dependency>
41+
<!-- coding tools -->
4142
<dependency>
4243
<groupId>org.projectlombok</groupId>
4344
<artifactId>lombok</artifactId>
4445
<version>1.18.10</version>
4546
</dependency>
47+
<!-- unit test tools -->
4648
<dependency>
4749
<groupId>junit</groupId>
4850
<artifactId>junit</artifactId>
4951
<version>${junit.version}</version>
5052
<scope>test</scope>
5153
</dependency>
54+
<dependency>
55+
<groupId>io.grpc</groupId>
56+
<artifactId>grpc-testing</artifactId>
57+
<version>${grpc.version}</version>
58+
<scope>test</scope>
59+
</dependency>
60+
<dependency>
61+
<groupId>org.mockito</groupId>
62+
<artifactId>mockito-core</artifactId>
63+
<version>3.1.0</version>
64+
<scope>test</scope>
65+
</dependency>
5266
</dependencies>
5367

5468
<build>

src/main/java/org/sqlflow/client/SQLFlow.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
package org.sqlflow.client;
1717

1818
import io.grpc.StatusRuntimeException;
19-
import org.sqlflow.client.models.RequestHeader;
19+
import org.sqlflow.client.model.RequestHeader;
2020
import proto.Sqlflow.JobStatus;
2121

2222
public interface SQLFlow {
@@ -58,5 +58,5 @@ String submit(RequestHeader header, String sql)
5858
*
5959
* @throws InterruptedException thrown by awaitTermination
6060
*/
61-
void release() throws InterruptedException;
61+
void shutdown() throws InterruptedException;
6262
}

src/main/java/org/sqlflow/client/impl/SQLFlowImpl.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import java.util.concurrent.TimeUnit;
2222
import org.apache.commons.lang3.StringUtils;
2323
import org.sqlflow.client.SQLFlow;
24-
import org.sqlflow.client.models.RequestHeader;
24+
import org.sqlflow.client.model.RequestHeader;
2525
import proto.SQLFlowGrpc;
2626
import proto.Sqlflow.Job;
2727
import proto.Sqlflow.JobStatus;
@@ -37,6 +37,11 @@ public void init(String serverUrl) {
3737
blockingStub = SQLFlowGrpc.newBlockingStub(channel);
3838
}
3939

40+
public SQLFlowImpl(ManagedChannel channel) {
41+
this.channel = channel;
42+
blockingStub = SQLFlowGrpc.newBlockingStub(channel);
43+
}
44+
4045
public String submit(RequestHeader header, String sql)
4146
throws IllegalArgumentException, StatusRuntimeException {
4247
if (header == null || StringUtils.isAnyBlank(header.getDataSource(), header.getUserId())) {
@@ -51,10 +56,10 @@ public String submit(RequestHeader header, String sql)
5156
.setDbConnStr(header.getDataSource())
5257
.setUserId(header.getUserId())
5358
.setExitOnSubmit(header.isExitOnSubmit())
54-
.setHiveLocation(header.getHiveLocation())
55-
.setHdfsNamenodeAddr(header.getHdfsNameNode())
56-
.setHdfsUser(header.getHdfsUser())
57-
.setHdfsPass(header.getHdfsPassword())
59+
.setHiveLocation(StringUtils.defaultString(header.getHiveLocation()))
60+
.setHdfsNamenodeAddr(StringUtils.defaultString(header.getHdfsNameNode()))
61+
.setHdfsUser(StringUtils.defaultString(header.getHdfsUser()))
62+
.setHdfsPass(StringUtils.defaultString(header.getHdfsPassword()))
5863
.build();
5964
Request req = Request.newBuilder().setSession(session).setSql(sql).build();
6065
try {
@@ -76,7 +81,7 @@ public JobStatus fetch(String jobId) throws StatusRuntimeException {
7681
}
7782
}
7883

79-
public void release() throws InterruptedException {
84+
public void shutdown() throws InterruptedException {
8085
try {
8186
channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
8287
} catch (InterruptedException e) {

src/main/java/org/sqlflow/client/models/RequestHeader.java renamed to src/main/java/org/sqlflow/client/model/RequestHeader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* limitations under the License.
1414
*/
1515

16-
package org.sqlflow.client.models;
16+
package org.sqlflow.client.model;
1717

1818
import lombok.Data;
1919

src/test/java/org/sqlflow/client/SQLFlowTest.java

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,116 @@
1515

1616
package org.sqlflow.client;
1717

18+
import static org.junit.Assert.assertEquals;
19+
import static org.mockito.AdditionalAnswers.delegatesTo;
20+
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.verify;
22+
23+
import io.grpc.ManagedChannel;
24+
import io.grpc.inprocess.InProcessChannelBuilder;
25+
import io.grpc.inprocess.InProcessServerBuilder;
26+
import io.grpc.stub.StreamObserver;
27+
import io.grpc.testing.GrpcCleanupRule;
28+
import org.junit.After;
29+
import org.junit.Before;
30+
import org.junit.Rule;
1831
import org.junit.Test;
32+
import org.junit.runner.RunWith;
33+
import org.junit.runners.JUnit4;
34+
import org.mockito.ArgumentCaptor;
35+
import org.mockito.ArgumentMatchers;
36+
import org.sqlflow.client.impl.SQLFlowImpl;
37+
import org.sqlflow.client.model.RequestHeader;
38+
import proto.SQLFlowGrpc;
39+
import proto.Sqlflow.Job;
40+
import proto.Sqlflow.JobStatus;
41+
import proto.Sqlflow.JobStatus.Code;
42+
import proto.Sqlflow.Request;
43+
import proto.Sqlflow.Session;
1944

45+
@RunWith(JUnit4.class)
2046
public class SQLFlowTest {
47+
private SQLFlow client;
48+
/**
49+
* This rule manages automatic graceful shutdown for the registered servers and channels at the
50+
* end of test.
51+
*/
52+
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
53+
54+
private final SQLFlowGrpc.SQLFlowImplBase grpcService =
55+
mock(
56+
SQLFlowGrpc.SQLFlowImplBase.class,
57+
delegatesTo(
58+
new SQLFlowGrpc.SQLFlowImplBase() {
59+
public void submit(Request request, StreamObserver<Job> response) {
60+
Session session = request.getSession();
61+
String userId = session.getUserId();
62+
response.onNext(
63+
Job.newBuilder().setId(mockJobId(userId, request.getSql())).build());
64+
response.onCompleted();
65+
}
66+
67+
public void fetch(Job request, StreamObserver<JobStatus> response) {
68+
String jobId = request.getId();
69+
response.onNext(
70+
JobStatus.newBuilder()
71+
.setCodeValue(Code.PENDING_VALUE)
72+
.setMessage(mockMessage(jobId))
73+
.build());
74+
response.onCompleted();
75+
}
76+
}));
77+
78+
@Before
79+
public void setUp() throws Exception {
80+
String serverName = InProcessServerBuilder.generateName();
81+
grpcCleanup.register(
82+
InProcessServerBuilder.forName(serverName)
83+
.directExecutor()
84+
.addService(grpcService)
85+
.build()
86+
.start());
87+
88+
ManagedChannel channel =
89+
grpcCleanup.register(InProcessChannelBuilder.forName(serverName).directExecutor().build());
90+
client = new SQLFlowImpl(channel);
91+
}
92+
2193
@Test
22-
public void testDummy() {}
94+
public void testSubmit() {
95+
String userId = "314159";
96+
String sql = "SELECT * TO TRAIN DNNClassify WITH ... COLUMN ... INTO ..";
97+
98+
ArgumentCaptor<Request> requestCaptor = ArgumentCaptor.forClass(Request.class);
99+
RequestHeader header = new RequestHeader();
100+
header.setUserId(userId);
101+
header.setDataSource("mysql://root@root@127.0.0.1:3306/iris");
102+
String jobId = client.submit(header, sql);
103+
assertEquals(mockJobId(userId, sql), jobId);
104+
verify(grpcService)
105+
.submit(requestCaptor.capture(), ArgumentMatchers.<StreamObserver<Job>>any());
106+
assertEquals(sql, requestCaptor.getValue().getSql());
107+
}
108+
109+
private String mockJobId(String userId, String sql) {
110+
return userId + "/" + sql;
111+
}
112+
113+
private String mockMessage(String jobId) {
114+
return "Hello " + jobId;
115+
}
116+
117+
@Test
118+
public void testFetch() {
119+
String jobId = "this is a job id";
120+
JobStatus jobStatus = client.fetch(jobId);
121+
122+
assertEquals(Code.PENDING_VALUE, jobStatus.getCode().getNumber());
123+
assertEquals(mockMessage(jobId), jobStatus.getMessage());
124+
}
125+
126+
@After
127+
public void tearDown() throws Exception {
128+
client.shutdown();
129+
}
23130
}

0 commit comments

Comments
 (0)