2424import org .springframework .stereotype .Service ;
2525import org .springframework .web .servlet .mvc .method .annotation .StreamingResponseBody ;
2626
27- import java .io .IOException ;
28- import java .io .InputStream ;
29- import java .net .URI ;
27+ import java .io .IOException ;
28+ import java .io .InputStream ;
29+ import java .net .Inet4Address ;
30+ import java .net .InetAddress ;
31+ import java .net .Inet6Address ;
32+ import java .net .URI ;
33+ import java .net .URISyntaxException ;
34+ import java .net .UnknownHostException ;
3035import java .net .http .HttpClient ;
3136import java .net .http .HttpRequest ;
3237import java .net .http .HttpResponse ;
3338import java .nio .charset .StandardCharsets ;
34- import java .time .Duration ;
35- import java .util .HashMap ;
36- import java .util .Map ;
39+ import java .time .Duration ;
40+ import java .util .Arrays ;
41+ import java .util .HashMap ;
42+ import java .util .List ;
43+ import java .util .Map ;
44+ import java .util .Set ;
3745
3846/**
3947 * The type AiChat v1 service.
4351@ Slf4j
4452@ Service
4553public class AiChatV1ServiceImpl implements AiChatV1Service {
46- private final OpenAIConfig config = new OpenAIConfig ();
47- private HttpClient httpClient = HttpClient .newBuilder ()
48- .connectTimeout (Duration .ofSeconds (config .getTimeoutSeconds ()))
49- .build ();
54+ private final OpenAIConfig config ;
55+ private final HttpClient httpClient ;
56+
57+ public AiChatV1ServiceImpl (OpenAIConfig config ) {
58+ this .config = config ;
59+ this .httpClient = HttpClient .newBuilder ()
60+ .connectTimeout (Duration .ofSeconds (config .getTimeoutSeconds ()))
61+ .followRedirects (HttpClient .Redirect .NEVER )
62+ .build ();
63+ }
5064
5165 /**
5266 * chatCompletion.
@@ -65,6 +79,9 @@ public Object chatCompletion(ChatRequest request) throws Exception {
6579 // 规范化URL处理
6680 String normalizedUrl = normalizeApiUrl (baseUrl );
6781
82+ // 对最终请求 URL 做安全校验(在 normalize 之后,确保校验的是真正发出的地址)
83+ validateFinalUrl (normalizedUrl );
84+
6885 HttpRequest .Builder requestBuilder = HttpRequest .newBuilder ()
6986 .uri (URI .create (normalizedUrl ))
7087 .header ("Content-Type" , "application/json" )
@@ -233,8 +250,137 @@ private StreamingResponseBody processStreamResponse(HttpRequest.Builder requestB
233250 };
234251 }
235252
236- private String getApiKey (String encryptApiKey ) throws Exception {
237- String sm4Key = System .getenv ("SM4KEY" );
253+ private static final Set <String > LOOPBACK_HOSTS = Set .of ("localhost" , "127.0.0.1" , "::1" , "[::1]" );
254+
255+ void validateFinalUrl (String finalUrl ) {
256+ URI uri ;
257+ try {
258+ uri = new URI (finalUrl );
259+ } catch (URISyntaxException e ) {
260+ throw new ServiceException ("400" , "Invalid baseUrl format" );
261+ }
262+
263+ String host = uri .getHost ();
264+ if (host == null || host .isEmpty ()) {
265+ throw new ServiceException ("400" , "Invalid baseUrl: missing host" );
266+ }
267+
268+ boolean isLoopback = LOOPBACK_HOSTS .contains (host .toLowerCase ());
269+
270+ List <String > allowedHosts = config .getAllowedHosts ();
271+
272+ if (allowedHosts == null || allowedHosts .isEmpty ()) {
273+ if (!config .isAllowAnyHost ()) {
274+ throw new ServiceException ("500" , "No AI allowed hosts configured" );
275+ }
276+
277+ enforceHttpsAndIpCheck (uri , host );
278+ return ;
279+ }
280+
281+ boolean matched = allowedHosts .stream ()
282+ .anyMatch (allowed -> allowed .equalsIgnoreCase (host ));
283+ if (!matched ) {
284+ throw new ServiceException ("400" ,
285+ "Host not allowed: " + host + ". Allowed hosts: " + allowedHosts );
286+ }
287+
288+ if (isLoopback ) {
289+ return ;
290+ }
291+
292+ enforceHttpsAndIpCheck (uri , host );
293+ }
294+
295+ void enforceHttpsAndIpCheck (URI uri , String host ) {
296+ String scheme = uri .getScheme ();
297+ if (scheme == null || !"https" .equalsIgnoreCase (scheme )) {
298+ throw new ServiceException ("400" , "Only HTTPS protocol is allowed for custom baseUrl" );
299+ }
300+
301+ try {
302+ InetAddress [] addresses = resolveHostAddresses (host );
303+ boolean hasBlockedAddress = Arrays .stream (addresses ).anyMatch (this ::isBlockedAddress );
304+ if (hasBlockedAddress ) {
305+ throw new ServiceException ("400" , "Internal network addresses are not allowed" );
306+ }
307+ } catch (UnknownHostException e ) {
308+ throw new ServiceException ("400" , "Unable to resolve host: " + host );
309+ }
310+ }
311+
312+ InetAddress [] resolveHostAddresses (String host ) throws UnknownHostException {
313+ return InetAddress .getAllByName (host );
314+ }
315+
316+ boolean isBlockedAddress (InetAddress address ) {
317+ if (address .isLoopbackAddress ()
318+ || address .isSiteLocalAddress ()
319+ || address .isLinkLocalAddress ()
320+ || address .isAnyLocalAddress ()
321+ || address .isMulticastAddress ()) {
322+ return true ;
323+ }
324+
325+ if (address instanceof Inet4Address ) {
326+ return isBlockedIpv4 ((Inet4Address ) address );
327+ }
328+ if (address instanceof Inet6Address ) {
329+ return isBlockedIpv6 ((Inet6Address ) address );
330+ }
331+ return false ;
332+ }
333+
334+ private boolean isBlockedIpv4 (Inet4Address address ) {
335+ byte [] octets = address .getAddress ();
336+ int first = octets [0 ] & 0xFF ;
337+ int second = octets [1 ] & 0xFF ;
338+ int third = octets [2 ] & 0xFF ;
339+
340+ if (first == 0 ) {
341+ return true ;
342+ }
343+ if (first == 100 && second >= 64 && second <= 127 ) {
344+ return true ;
345+ }
346+ if (first == 192 && second == 0 && third == 0 ) {
347+ return true ;
348+ }
349+ if (first == 192 && second == 0 && third == 2 ) {
350+ return true ;
351+ }
352+ if (first == 198 && (second == 18 || second == 19 )) {
353+ return true ;
354+ }
355+ if (first == 198 && second == 51 && third == 100 ) {
356+ return true ;
357+ }
358+ if (first == 203 && second == 0 && third == 113 ) {
359+ return true ;
360+ }
361+ return first >= 240 ;
362+ }
363+
364+ private boolean isBlockedIpv6 (Inet6Address address ) {
365+ byte [] octets = address .getAddress ();
366+ int first = octets [0 ] & 0xFF ;
367+ int second = octets [1 ] & 0xFF ;
368+
369+ if ((first & 0xFE ) == 0xFC ) {
370+ return true ;
371+ }
372+ if (first == 0x20 && second == 0x01 ) {
373+ int third = octets [2 ] & 0xFF ;
374+ int fourth = octets [3 ] & 0xFF ;
375+ if (third == 0x0D && fourth == 0xB8 ) {
376+ return true ;
377+ }
378+ }
379+ return first == 0xFF ;
380+ }
381+
382+ private String getApiKey (String encryptApiKey ) throws Exception {
383+ String sm4Key = System .getenv ("SM4KEY" );
238384
239385 if (encryptApiKey .startsWith ("EKEY_" )) {
240386 String encryptBase64ApiKey = encryptApiKey .substring (5 );
0 commit comments