Skip to content

Commit 0b5b87a

Browse files
authored
WireGuard proxy: Refactor (#6287)
And #6303 (comment)
1 parent d27b3e4 commit 0b5b87a

20 files changed

Lines changed: 1550 additions & 1294 deletions

File tree

app/proxyman/inbound/always.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,23 @@ func NewAlwaysOnInboundHandler(ctx context.Context, tag string, receiverConfig *
5757
if err != nil {
5858
return nil, err
5959
}
60-
61-
// Set tag and sniffing config in context before creating proxy
62-
// This allows proxies like TUN to access these settings
63-
ctx = session.ContextWithInbound(ctx, &session.Inbound{Tag: tag})
64-
if receiverConfig.SniffingSettings != nil {
65-
ctx = session.ContextWithContent(ctx, &session.Content{
66-
SniffingRequest: sniffingRequest,
67-
})
60+
src := net.TCPDestination(net.AnyIP, 0)
61+
if receiverConfig.Listen != nil {
62+
src.Address = receiverConfig.Listen.AsAddress()
63+
}
64+
if receiverConfig.PortList != nil && len(receiverConfig.PortList.Range) > 0 {
65+
src.Port = net.Port(receiverConfig.PortList.Range[0].From)
6866
}
69-
rawProxy, err := common.CreateObject(ctx, proxyConfig)
67+
mss, err := internet.ToMemoryStreamConfig(receiverConfig.StreamSettings)
68+
if err != nil {
69+
return nil, errors.New("failed to parse stream config").Base(err).AtWarning()
70+
}
71+
72+
newCtx := session.ContextWithInbound(ctx, &session.Inbound{Tag: tag, Source: src})
73+
newCtx = session.ContextWithContent(newCtx, &session.Content{SniffingRequest: sniffingRequest})
74+
newCtx = session.ContextWithStreamSettings(newCtx, mss)
75+
76+
rawProxy, err := common.CreateObject(newCtx, proxyConfig)
7077
if err != nil {
7178
return nil, err
7279
}
@@ -92,11 +99,6 @@ func NewAlwaysOnInboundHandler(ctx context.Context, tag string, receiverConfig *
9299
address = net.AnyIP
93100
}
94101

95-
mss, err := internet.ToMemoryStreamConfig(receiverConfig.StreamSettings)
96-
if err != nil {
97-
return nil, errors.New("failed to parse stream config").Base(err).AtWarning()
98-
}
99-
100102
if receiverConfig.ReceiveOriginalDestination {
101103
if mss.SocketSettings == nil {
102104
mss.SocketSettings = &internet.SocketConfig{}

app/proxyman/outbound/handler.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
108108

109109
ctx = session.ContextWithFullHandler(ctx, h)
110110

111-
rawProxyHandler, err := common.CreateObject(ctx, proxyConfig)
111+
newCtx := session.ContextWithStreamSettings(ctx, h.streamSettings)
112+
113+
rawProxyHandler, err := common.CreateObject(newCtx, proxyConfig)
112114
if err != nil {
113115
return nil, err
114116
}

common/session/context.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ const (
2626
fullHandlerKey ctx.SessionKey = 10 // outbound gets full handler
2727
mitmAlpn11Key ctx.SessionKey = 11 // used by TLS dialer
2828
mitmServerNameKey ctx.SessionKey = 12 // used by TLS dialer
29+
30+
streamSettingsKey ctx.SessionKey = 13
2931
)
3032

3133
func ContextWithInbound(ctx context.Context, inbound *Inbound) context.Context {
@@ -192,3 +194,11 @@ func MitmServerNameFromContext(ctx context.Context) string {
192194
}
193195
return ""
194196
}
197+
198+
func ContextWithStreamSettings(ctx context.Context, streamSettings any) context.Context {
199+
return context.WithValue(ctx, streamSettingsKey, streamSettings)
200+
}
201+
202+
func StreamSettingsFromContext(ctx context.Context) any {
203+
return ctx.Value(streamSettingsKey)
204+
}

infra/conf/wireguard.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package conf
33
import (
44
"encoding/base64"
55
"encoding/hex"
6+
"strconv"
67
"strings"
78

89
"github.com/xtls/xray-core/common/errors"
@@ -37,8 +38,9 @@ func (c *WireGuardPeerConfig) Build() (proto.Message, error) {
3738
}
3839

3940
config.Endpoint = c.Endpoint
40-
// default 0
41-
config.KeepAlive = c.KeepAlive
41+
if c.KeepAlive != 0 {
42+
config.KeepAlive = strconv.FormatUint(uint64(c.KeepAlive), 10)
43+
}
4244
if c.AllowedIPs == nil {
4345
config.AllowedIps = []string{"0.0.0.0/0", "::0/0"}
4446
} else {
@@ -56,7 +58,6 @@ type WireGuardConfig struct {
5658
Address []string `json:"address"`
5759
Peers []*WireGuardPeerConfig `json:"peers"`
5860
MTU int32 `json:"mtu"`
59-
NumWorkers int32 `json:"workers"`
6061
Reserved []byte `json:"reserved"`
6162
DomainStrategy string `json:"domainStrategy"`
6263
}
@@ -93,9 +94,6 @@ func (c *WireGuardConfig) Build() (proto.Message, error) {
9394
} else {
9495
config.Mtu = c.MTU
9596
}
96-
// these a fallback code exists in wireguard-go code,
97-
// we don't need to process fallback manually
98-
config.NumWorkers = c.NumWorkers
9997

10098
if len(c.Reserved) != 0 && len(c.Reserved) != 3 {
10199
return nil, errors.New(`"reserved" should be empty or 3 bytes`)

infra/conf/wireguard_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,10 @@ func TestWireGuardConfig(t *testing.T) {
3838
// also can read from hex form directly
3939
PublicKey: "6e65ce0be17517110c17d77288ad87e7fd5252dcc7d09b95a39d61db03df832a",
4040
Endpoint: "127.0.0.1:1234",
41-
KeepAlive: 0,
4241
AllowedIps: []string{"0.0.0.0/0", "::0/0"},
4342
},
4443
},
4544
Mtu: 1300,
46-
NumWorkers: 2,
4745
DomainStrategy: wireguard.DeviceConfig_FORCE_IP64,
4846
NoKernelTun: false,
4947
},

proxy/hysteria/client.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ type Client struct {
2929
}
3030

3131
func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
32+
v := core.MustFromContext(ctx)
33+
p := v.GetFeature(policy.ManagerType()).(policy.Manager)
34+
35+
streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig)
36+
if _, ok := streamSettings.ProtocolSettings.(*hysteria.Config); !ok {
37+
return nil, errors.New("not hysteria transport")
38+
}
3239
if config.Server == nil {
3340
return nil, errors.New(`no target server found`)
3441
}
@@ -37,12 +44,10 @@ func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) {
3744
return nil, errors.New("failed to get server spec").Base(err)
3845
}
3946

40-
v := core.MustFromContext(ctx)
41-
client := &Client{
47+
return &Client{
4248
server: server,
43-
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
44-
}
45-
return client, nil
49+
policyManager: p,
50+
}, nil
4651
}
4752

4853
func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {

proxy/hysteria/server.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/xtls/xray-core/features/routing"
1717
"github.com/xtls/xray-core/proxy/hysteria/account"
1818
"github.com/xtls/xray-core/transport"
19+
"github.com/xtls/xray-core/transport/internet"
1920
"github.com/xtls/xray-core/transport/internet/hysteria"
2021
"github.com/xtls/xray-core/transport/internet/stat"
2122
)
@@ -27,6 +28,14 @@ type Server struct {
2728
}
2829

2930
func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
31+
v := core.MustFromContext(ctx)
32+
p := v.GetFeature(policy.ManagerType()).(policy.Manager)
33+
34+
streamSettings := session.StreamSettingsFromContext(ctx).(*internet.MemoryStreamConfig)
35+
if _, ok := streamSettings.ProtocolSettings.(*hysteria.Config); !ok {
36+
return nil, errors.New("not hysteria transport")
37+
}
38+
3039
validator := account.NewValidator()
3140
for _, user := range config.Users {
3241
u, err := user.ToMemoryUser()
@@ -39,14 +48,11 @@ func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
3948
}
4049
}
4150

42-
v := core.MustFromContext(ctx)
43-
s := &Server{
51+
return &Server{
4452
config: config,
4553
validator: validator,
46-
policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
47-
}
48-
49-
return s, nil
54+
policyManager: p,
55+
}, nil
5056
}
5157

5258
func (s *Server) HysteriaInboundValidator() *account.Validator {

0 commit comments

Comments
 (0)