44 "context"
55 "errors"
66 "net/url"
7+ "slices"
78 "testing"
89 "time"
910
@@ -49,9 +50,9 @@ func (m *mock_connection_authorizer) Authorize(_ context.Context, conn connectio
4950
5051type mock_connection struct {}
5152
52- func (m * mock_connection ) Read (_ context.Context ) ([]byte , error ) { return nil , nil }
53+ func (m * mock_connection ) Read (_ context.Context ) ([]byte , error ) { return nil , nil }
5354func (m * mock_connection ) Write (_ context.Context , _ []byte ) error { return nil }
54- func (m * mock_connection ) Close (_ context.Context ) error { return nil }
55+ func (m * mock_connection ) Close (_ context.Context ) error { return nil }
5556
5657func TestConnect (t * testing.T ) {
5758 gen_err := errors .New ("generate failed" )
@@ -70,36 +71,60 @@ func TestConnect(t *testing.T) {
7071 expect_err error
7172 expect_dialed bool
7273 expect_authorized bool
74+ verify func (* testing.T , * mock_dialer , * connection.CreateConnectionOutput )
7375 }{
7476 {
7577 name : "success" ,
7678 generator : & mock_subprotocol_generator {result : "header-xyz" },
7779 dialer : & mock_dialer {conn : conn },
7880 authorizer : & mock_connection_authorizer {timeout : timeout },
7981 input : connection.CreateConnectionInput {Url : endpoint },
80- expect_err : nil ,
8182 expect_dialed : true ,
8283 expect_authorized : true ,
84+ verify : func (t * testing.T , dialer * mock_dialer , output * connection.CreateConnectionOutput ) {
85+ if dialer .received .Url != endpoint {
86+ t .Errorf ("dialer.received.Url = %v, want %v" , dialer .received .Url , endpoint )
87+ }
88+ got := dialer .received .Subprotocols
89+ if len (got ) == 0 || got [len (got )- 1 ] != "header-xyz" {
90+ t .Errorf ("dialer subprotocols last element = %v, want %q" , got , "header-xyz" )
91+ }
92+ if output == nil {
93+ t .Fatal ("expected non-nil output" )
94+ }
95+ if output .Connection != conn {
96+ t .Errorf ("output.Connection = %v, want %v" , output .Connection , conn )
97+ }
98+ if output .Timeout != timeout {
99+ t .Errorf ("output.Timeout = %v, want %v" , output .Timeout , timeout )
100+ }
101+ },
83102 },
84103 {
85- name : "subprotocol error does not call dialer" ,
86- generator : & mock_subprotocol_generator {err : gen_err },
87- dialer : & mock_dialer {},
88- authorizer : & mock_connection_authorizer {},
89- input : connection.CreateConnectionInput {Url : endpoint },
90- expect_err : gen_err ,
91- expect_dialed : false ,
92- expect_authorized : false ,
104+ name : "subprotocol error does not call dialer" ,
105+ generator : & mock_subprotocol_generator {err : gen_err },
106+ dialer : & mock_dialer {},
107+ authorizer : & mock_connection_authorizer {},
108+ input : connection.CreateConnectionInput {Url : endpoint },
109+ expect_err : gen_err ,
93110 },
94111 {
95- name : "dialer error does not call authorizer" ,
96- generator : & mock_subprotocol_generator {result : "header-xyz" },
97- dialer : & mock_dialer {err : dial_err },
98- authorizer : & mock_connection_authorizer {},
99- input : connection.CreateConnectionInput {Url : endpoint },
100- expect_err : dial_err ,
101- expect_dialed : true ,
102- expect_authorized : false ,
112+ name : "dialer error does not call authorizer" ,
113+ generator : & mock_subprotocol_generator {result : "header-xyz" },
114+ dialer : & mock_dialer {err : dial_err },
115+ authorizer : & mock_connection_authorizer {},
116+ input : connection.CreateConnectionInput {Url : endpoint },
117+ expect_err : dial_err ,
118+ expect_dialed : true ,
119+ verify : func (t * testing.T , dialer * mock_dialer , _ * connection.CreateConnectionOutput ) {
120+ if dialer .received .Url != endpoint {
121+ t .Errorf ("dialer.received.Url = %v, want %v" , dialer .received .Url , endpoint )
122+ }
123+ got := dialer .received .Subprotocols
124+ if len (got ) == 0 || got [len (got )- 1 ] != "header-xyz" {
125+ t .Errorf ("dialer subprotocols last element = %v, want %q" , got , "header-xyz" )
126+ }
127+ },
103128 },
104129 {
105130 name : "authorizer error is returned" ,
@@ -110,6 +135,15 @@ func TestConnect(t *testing.T) {
110135 expect_err : auth_err ,
111136 expect_dialed : true ,
112137 expect_authorized : true ,
138+ verify : func (t * testing.T , dialer * mock_dialer , _ * connection.CreateConnectionOutput ) {
139+ if dialer .received .Url != endpoint {
140+ t .Errorf ("dialer.received.Url = %v, want %v" , dialer .received .Url , endpoint )
141+ }
142+ got := dialer .received .Subprotocols
143+ if len (got ) == 0 || got [len (got )- 1 ] != "header-xyz" {
144+ t .Errorf ("dialer subprotocols last element = %v, want %q" , got , "header-xyz" )
145+ }
146+ },
113147 },
114148 {
115149 name : "generated subprotocol appended to input subprotocols" ,
@@ -120,19 +154,34 @@ func TestConnect(t *testing.T) {
120154 Url : endpoint ,
121155 Subprotocols : []string {"header-abc" },
122156 },
123- expect_err : nil ,
124157 expect_dialed : true ,
125158 expect_authorized : true ,
159+ verify : func (t * testing.T , dialer * mock_dialer , output * connection.CreateConnectionOutput ) {
160+ got := dialer .received .Subprotocols
161+ if len (got ) == 0 || got [len (got )- 1 ] != "header-xyz" {
162+ t .Errorf ("dialer subprotocols last element = %v, want %q" , got , "header-xyz" )
163+ }
164+ if ! slices .Contains (got , "header-abc" ) {
165+ t .Errorf ("input subprotocol %q missing from dialer received %v" , "header-abc" , got )
166+ }
167+ if output == nil {
168+ t .Fatal ("expected non-nil output" )
169+ }
170+ },
126171 },
127172 {
128173 name : "input URL forwarded to dialer" ,
129174 generator : & mock_subprotocol_generator {result : "header-xyz" },
130175 dialer : & mock_dialer {conn : conn },
131176 authorizer : & mock_connection_authorizer {timeout : timeout },
132177 input : connection.CreateConnectionInput {Url : endpoint },
133- expect_err : nil ,
134178 expect_dialed : true ,
135179 expect_authorized : true ,
180+ verify : func (t * testing.T , dialer * mock_dialer , _ * connection.CreateConnectionOutput ) {
181+ if dialer .received .Url != endpoint {
182+ t .Errorf ("dialer.received.Url = %v, want %v" , dialer .received .Url , endpoint )
183+ }
184+ },
136185 },
137186 }
138187
@@ -145,50 +194,14 @@ func TestConnect(t *testing.T) {
145194 if ! errors .Is (err , tt .expect_err ) {
146195 t .Errorf ("got error %v, want %v" , err , tt .expect_err )
147196 }
148-
149197 if tt .dialer .called != tt .expect_dialed {
150198 t .Errorf ("dialer.called = %v, want %v" , tt .dialer .called , tt .expect_dialed )
151199 }
152-
153200 if tt .authorizer .called != tt .expect_authorized {
154201 t .Errorf ("authorizer.called = %v, want %v" , tt .authorizer .called , tt .expect_authorized )
155202 }
156-
157- if tt .expect_dialed {
158- if tt .dialer .received .Url != tt .input .Url {
159- t .Errorf ("dialer.received.Url = %v, want %v" , tt .dialer .received .Url , tt .input .Url )
160- }
161-
162- expected_last := tt .generator .result
163- got := tt .dialer .received .Subprotocols
164- if len (got ) == 0 || got [len (got )- 1 ] != expected_last {
165- t .Errorf ("dialer subprotocols last element = %v, want %q" , got , expected_last )
166- }
167-
168- for _ , p := range tt .input .Subprotocols {
169- found := false
170- for _ , g := range got {
171- if g == p {
172- found = true
173- break
174- }
175- }
176- if ! found {
177- t .Errorf ("input subprotocol %q missing from dialer received %v" , p , got )
178- }
179- }
180- }
181-
182- if err == nil {
183- if output == nil {
184- t .Fatal ("expected non-nil output" )
185- }
186- if output .Connection != conn {
187- t .Errorf ("output.Connection = %v, want %v" , output .Connection , conn )
188- }
189- if output .Timeout != tt .authorizer .timeout {
190- t .Errorf ("output.Timeout = %v, want %v" , output .Timeout , tt .authorizer .timeout )
191- }
203+ if tt .verify != nil {
204+ tt .verify (t , tt .dialer , output )
192205 }
193206 })
194207 }
0 commit comments