Skip to content

Commit 47cda22

Browse files
committed
librdmacm: Fix SOCK_STREAM and SOCK_DGRAM types
Updated type checks to identify socket types even when additional flags are present in the type field. Changed the comparison to use bitwise AND for more accurate detection. Signed-off-by: Batsheva Black <bblack@nvidia.com>
1 parent 41696fc commit 47cda22

File tree

1 file changed

+36
-32
lines changed

1 file changed

+36
-32
lines changed

librdmacm/rsocket.c

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ struct ds_qp {
319319

320320
struct rsocket {
321321
int type;
322+
int category;
322323
int index;
323324
fastlock_t slock;
324325
fastlock_t rlock;
@@ -693,7 +694,7 @@ static void rs_remove(struct rsocket *rs)
693694
}
694695

695696
/* We only inherit from listening sockets */
696-
static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
697+
static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type, int category)
697698
{
698699
struct rsocket *rs;
699700

@@ -702,8 +703,10 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
702703
return NULL;
703704

704705
rs->type = type;
706+
rs->category = category;
707+
705708
rs->index = -1;
706-
if (type == SOCK_DGRAM) {
709+
if (category == SOCK_DGRAM) {
707710
rs->udp_sock = -1;
708711
rs->epfd = -1;
709712
}
@@ -714,7 +717,7 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
714717
rs->sq_inline = inherited_rs->sq_inline;
715718
rs->sq_size = inherited_rs->sq_size;
716719
rs->rq_size = inherited_rs->rq_size;
717-
if (type == SOCK_STREAM) {
720+
if (category == SOCK_STREAM) {
718721
rs->ctrl_max_seqno = inherited_rs->ctrl_max_seqno;
719722
rs->target_iomap_size = inherited_rs->target_iomap_size;
720723
}
@@ -724,7 +727,7 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
724727
rs->sq_inline = def_inline;
725728
rs->sq_size = def_sqsize;
726729
rs->rq_size = def_rqsize;
727-
if (type == SOCK_STREAM) {
730+
if (category == SOCK_STREAM) {
728731
rs->ctrl_max_seqno = RS_QP_CTRL_SIZE;
729732
rs->target_iomap_size = def_iomap_size;
730733
}
@@ -744,7 +747,7 @@ static int rs_set_nonblocking(struct rsocket *rs, int arg)
744747
struct ds_qp *qp;
745748
int ret = 0;
746749

747-
if (rs->type == SOCK_STREAM) {
750+
if (rs->category == SOCK_STREAM) {
748751
if (rs->cm_id->recv_cq_channel)
749752
ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
750753

@@ -1097,7 +1100,7 @@ static void ds_free(struct rsocket *rs)
10971100

10981101
static void rs_free(struct rsocket *rs)
10991102
{
1100-
if (rs->type == SOCK_DGRAM) {
1103+
if (rs->category == SOCK_DGRAM) {
11011104
ds_free(rs);
11021105
return;
11031106
}
@@ -1248,18 +1251,19 @@ int rsocket(int domain, int type, int protocol)
12481251
struct rsocket *rs;
12491252
int index, ret;
12501253

1254+
int category = type & ~(SOCK_CLOEXEC | SOCK_NONBLOCK);
12511255
if ((domain != AF_INET && domain != AF_INET6 && domain != AF_IB) ||
1252-
((type != SOCK_STREAM) && (type != SOCK_DGRAM)) ||
1253-
(type == SOCK_STREAM && protocol && protocol != IPPROTO_TCP) ||
1254-
(type == SOCK_DGRAM && protocol && protocol != IPPROTO_UDP))
1256+
((!(category == SOCK_STREAM)) && (!(category == SOCK_DGRAM))) ||
1257+
((category == SOCK_STREAM) && protocol && protocol != IPPROTO_TCP) ||
1258+
((category == SOCK_DGRAM) && protocol && protocol != IPPROTO_UDP))
12551259
return ERR(ENOTSUP);
12561260

12571261
rs_configure();
1258-
rs = rs_alloc(NULL, type);
1262+
rs = rs_alloc(NULL, type, category);
12591263
if (!rs)
12601264
return ERR(ENOMEM);
12611265

1262-
if (type == SOCK_STREAM) {
1266+
if (category == SOCK_STREAM) {
12631267
ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP);
12641268
if (ret)
12651269
goto err;
@@ -1293,7 +1297,7 @@ int rbind(int socket, const struct sockaddr *addr, socklen_t addrlen)
12931297
rs = idm_lookup(&idm, socket);
12941298
if (!rs)
12951299
return ERR(EBADF);
1296-
if (rs->type == SOCK_STREAM) {
1300+
if (rs->category == SOCK_STREAM) {
12971301
ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr);
12981302
if (!ret)
12991303
rs->state = rs_bound;
@@ -1359,7 +1363,7 @@ static void rs_accept(struct rsocket *rs)
13591363
if (ret)
13601364
return;
13611365

1362-
new_rs = rs_alloc(rs, rs->type);
1366+
new_rs = rs_alloc(rs, rs->type, rs->category);
13631367
if (!new_rs)
13641368
goto err;
13651369
new_rs->cm_id = cm_id;
@@ -1772,7 +1776,7 @@ int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen)
17721776
rs = idm_lookup(&idm, socket);
17731777
if (!rs)
17741778
return ERR(EBADF);
1775-
if (rs->type == SOCK_STREAM) {
1779+
if (rs->category == SOCK_STREAM) {
17761780
memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen);
17771781
ret = rs_do_connect(rs);
17781782
save_errno = errno;
@@ -2575,7 +2579,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags)
25752579
rs = idm_at(&idm, socket);
25762580
if (!rs)
25772581
return ERR(EBADF);
2578-
if (rs->type == SOCK_DGRAM) {
2582+
if (rs->category == SOCK_DGRAM) {
25792583
fastlock_acquire(&rs->rlock);
25802584
ret = ds_recvfrom(rs, buf, len, flags, NULL, NULL);
25812585
fastlock_release(&rs->rlock);
@@ -2645,7 +2649,7 @@ ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags,
26452649
rs = idm_at(&idm, socket);
26462650
if (!rs)
26472651
return ERR(EBADF);
2648-
if (rs->type == SOCK_DGRAM) {
2652+
if (rs->category == SOCK_DGRAM) {
26492653
fastlock_acquire(&rs->rlock);
26502654
ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen);
26512655
fastlock_release(&rs->rlock);
@@ -2850,7 +2854,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags)
28502854
rs = idm_at(&idm, socket);
28512855
if (!rs)
28522856
return ERR(EBADF);
2853-
if (rs->type == SOCK_DGRAM) {
2857+
if (rs->category == SOCK_DGRAM) {
28542858
fastlock_acquire(&rs->slock);
28552859
ret = dsend(rs, buf, len, flags);
28562860
fastlock_release(&rs->slock);
@@ -2937,7 +2941,7 @@ ssize_t rsendto(int socket, const void *buf, size_t len, int flags,
29372941
rs = idm_at(&idm, socket);
29382942
if (!rs)
29392943
return ERR(EBADF);
2940-
if (rs->type == SOCK_STREAM) {
2944+
if (rs->category == SOCK_STREAM) {
29412945
if (dest_addr || addrlen)
29422946
return ERR(EISCONN);
29432947

@@ -3246,7 +3250,7 @@ static int rs_poll_rs(struct rsocket *rs, int events,
32463250
int ret;
32473251

32483252
check_cq:
3249-
if ((rs->type == SOCK_STREAM) && ((rs->state & rs_connected) ||
3253+
if ((rs->category == SOCK_STREAM) && ((rs->state & rs_connected) ||
32503254
(rs->state == rs_disconnected) || (rs->state & rs_error))) {
32513255
rs_process_cq(rs, nonblock, test);
32523256

@@ -3263,7 +3267,7 @@ static int rs_poll_rs(struct rsocket *rs, int events,
32633267
}
32643268

32653269
return revents;
3266-
} else if (rs->type == SOCK_DGRAM) {
3270+
} else if (rs->category == SOCK_DGRAM) {
32673271
ds_process_cqs(rs, nonblock, test);
32683272

32693273
revents = 0;
@@ -3335,7 +3339,7 @@ static int rs_poll_arm(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
33353339
if (fds[i].revents)
33363340
return 1;
33373341

3338-
if (rs->type == SOCK_STREAM) {
3342+
if (rs->category == SOCK_STREAM) {
33393343
if (rs->state >= rs_connected)
33403344
rfds[i].fd = rs->cm_id->recv_cq_channel->fd;
33413345
else
@@ -3363,7 +3367,7 @@ static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
33633367
if (rs) {
33643368
if (rfds[i].revents) {
33653369
fastlock_acquire(&rs->cq_wait_lock);
3366-
if (rs->type == SOCK_STREAM)
3370+
if (rs->category == SOCK_STREAM)
33673371
rs_get_cq_event(rs);
33683372
else
33693373
ds_get_cq_event(rs);
@@ -3611,7 +3615,7 @@ int rclose(int socket)
36113615
rs = idm_lookup(&idm, socket);
36123616
if (!rs)
36133617
return EBADF;
3614-
if (rs->type == SOCK_STREAM) {
3618+
if (rs->category == SOCK_STREAM) {
36153619
if (rs->state & rs_connected)
36163620
rshutdown(socket, SHUT_RDWR);
36173621
if (rs->opts & RS_OPT_KEEPALIVE)
@@ -3649,7 +3653,7 @@ int rgetpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
36493653
rs = idm_lookup(&idm, socket);
36503654
if (!rs)
36513655
return ERR(EBADF);
3652-
if (rs->type == SOCK_STREAM) {
3656+
if (rs->category == SOCK_STREAM) {
36533657
rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen);
36543658
return 0;
36553659
} else {
@@ -3664,7 +3668,7 @@ int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
36643668
rs = idm_lookup(&idm, socket);
36653669
if (!rs)
36663670
return ERR(EBADF);
3667-
if (rs->type == SOCK_STREAM) {
3671+
if (rs->category == SOCK_STREAM) {
36683672
rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen);
36693673
return 0;
36703674
} else {
@@ -3710,7 +3714,7 @@ int rsetsockopt(int socket, int level, int optname,
37103714
rs = idm_lookup(&idm, socket);
37113715
if (!rs)
37123716
return ERR(EBADF);
3713-
if (rs->type == SOCK_DGRAM && level != SOL_RDMA) {
3717+
if ((rs->category == SOCK_DGRAM) && level != SOL_RDMA) {
37143718
ret = setsockopt(rs->udp_sock, level, optname, optval, optlen);
37153719
if (ret)
37163720
return ret;
@@ -3721,7 +3725,7 @@ int rsetsockopt(int socket, int level, int optname,
37213725
opts = &rs->so_opts;
37223726
switch (optname) {
37233727
case SO_REUSEADDR:
3724-
if (rs->type == SOCK_STREAM) {
3728+
if (rs->category == SOCK_STREAM) {
37253729
ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID,
37263730
RDMA_OPTION_ID_REUSEADDR,
37273731
(void *) optval, optlen);
@@ -3733,8 +3737,8 @@ int rsetsockopt(int socket, int level, int optname,
37333737
opt_on = *(int *) optval;
37343738
break;
37353739
case SO_RCVBUF:
3736-
if ((rs->type == SOCK_STREAM && !rs->rbuf) ||
3737-
(rs->type == SOCK_DGRAM && !rs->qp_list))
3740+
if (((rs->category == SOCK_STREAM) && !rs->rbuf) ||
3741+
((rs->category == SOCK_DGRAM) && !rs->qp_list))
37383742
rs->rbuf_size = (*(uint32_t *) optval) << 1;
37393743
ret = 0;
37403744
break;
@@ -3806,7 +3810,7 @@ int rsetsockopt(int socket, int level, int optname,
38063810
opts = &rs->ipv6_opts;
38073811
switch (optname) {
38083812
case IPV6_V6ONLY:
3809-
if (rs->type == SOCK_STREAM) {
3813+
if (rs->category == SOCK_STREAM) {
38103814
ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID,
38113815
RDMA_OPTION_ID_AFONLY,
38123816
(void *) optval, optlen);
@@ -4808,7 +4812,7 @@ uint32_t epoll_rs(int fd, uint32_t events)
48084812
struct rsocket *rs = idm_lookup(&idm, fd);
48094813

48104814
check_cq:
4811-
if ((rs->type & SOCK_STREAM) && ((rs->state & rs_connected) ||
4815+
if ((rs->category == SOCK_STREAM) && ((rs->state & rs_connected) ||
48124816
(rs->state == rs_disconnected) || (rs->state & rs_error))) {
48134817
rs_process_cq(rs, 1, rs_poll_all);
48144818

@@ -4824,7 +4828,7 @@ uint32_t epoll_rs(int fd, uint32_t events)
48244828
}
48254829

48264830
return revents;
4827-
} else if (rs->type & SOCK_DGRAM) {
4831+
} else if (rs->category == SOCK_DGRAM) {
48284832
ds_process_cqs(rs, 1, rs_poll_all);
48294833

48304834
if ((events & EPOLLIN) && rs_have_rdata(rs))

0 commit comments

Comments
 (0)