Skip to content

Commit 959cb7e

Browse files
committed
librdmacm: Fix SOCK_STREAM and SOCK_DGRAM types
Updated type checks to identify socket types even when additional flags are present in the type field. Changed the comparison to use bitwise AND for more accurate detection. Signed-off-by: Batsheva Black <bblack@nvidia.com>
1 parent 5001f0b commit 959cb7e

1 file changed

Lines changed: 36 additions & 32 deletions

File tree

librdmacm/rsocket.c

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ struct ds_qp {
319319

320320
struct rsocket {
321321
int type;
322+
int category;
322323
int index;
323324
fastlock_t slock;
324325
fastlock_t rlock;
@@ -692,7 +693,7 @@ static void rs_remove(struct rsocket *rs)
692693
}
693694

694695
/* We only inherit from listening sockets */
695-
static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
696+
static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type, int category)
696697
{
697698
struct rsocket *rs;
698699

@@ -701,8 +702,10 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
701702
return NULL;
702703

703704
rs->type = type;
705+
rs->category = category;
706+
704707
rs->index = -1;
705-
if (type == SOCK_DGRAM) {
708+
if (category == SOCK_DGRAM) {
706709
rs->udp_sock = -1;
707710
rs->epfd = -1;
708711
}
@@ -713,7 +716,7 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
713716
rs->sq_inline = inherited_rs->sq_inline;
714717
rs->sq_size = inherited_rs->sq_size;
715718
rs->rq_size = inherited_rs->rq_size;
716-
if (type == SOCK_STREAM) {
719+
if (category == SOCK_STREAM) {
717720
rs->ctrl_max_seqno = inherited_rs->ctrl_max_seqno;
718721
rs->target_iomap_size = inherited_rs->target_iomap_size;
719722
}
@@ -723,7 +726,7 @@ static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
723726
rs->sq_inline = def_inline;
724727
rs->sq_size = def_sqsize;
725728
rs->rq_size = def_rqsize;
726-
if (type == SOCK_STREAM) {
729+
if (category == SOCK_STREAM) {
727730
rs->ctrl_max_seqno = RS_QP_CTRL_SIZE;
728731
rs->target_iomap_size = def_iomap_size;
729732
}
@@ -743,7 +746,7 @@ static int rs_set_nonblocking(struct rsocket *rs, int arg)
743746
struct ds_qp *qp;
744747
int ret = 0;
745748

746-
if (rs->type == SOCK_STREAM) {
749+
if (rs->category == SOCK_STREAM) {
747750
if (rs->cm_id->recv_cq_channel)
748751
ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
749752

@@ -1096,7 +1099,7 @@ static void ds_free(struct rsocket *rs)
10961099

10971100
static void rs_free(struct rsocket *rs)
10981101
{
1099-
if (rs->type == SOCK_DGRAM) {
1102+
if (rs->category == SOCK_DGRAM) {
11001103
ds_free(rs);
11011104
return;
11021105
}
@@ -1247,18 +1250,19 @@ int rsocket(int domain, int type, int protocol)
12471250
struct rsocket *rs;
12481251
int index, ret;
12491252

1253+
int category = type & ~(SOCK_CLOEXEC | SOCK_NONBLOCK);
12501254
if ((domain != AF_INET && domain != AF_INET6 && domain != AF_IB) ||
1251-
((type != SOCK_STREAM) && (type != SOCK_DGRAM)) ||
1252-
(type == SOCK_STREAM && protocol && protocol != IPPROTO_TCP) ||
1253-
(type == SOCK_DGRAM && protocol && protocol != IPPROTO_UDP))
1255+
((!(category == SOCK_STREAM)) && (!(category == SOCK_DGRAM))) ||
1256+
((category == SOCK_STREAM) && protocol && protocol != IPPROTO_TCP) ||
1257+
((category == SOCK_DGRAM) && protocol && protocol != IPPROTO_UDP))
12541258
return ERR(ENOTSUP);
12551259

12561260
rs_configure();
1257-
rs = rs_alloc(NULL, type);
1261+
rs = rs_alloc(NULL, type, category);
12581262
if (!rs)
12591263
return ERR(ENOMEM);
12601264

1261-
if (type == SOCK_STREAM) {
1265+
if (category == SOCK_STREAM) {
12621266
ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP);
12631267
if (ret)
12641268
goto err;
@@ -1292,7 +1296,7 @@ int rbind(int socket, const struct sockaddr *addr, socklen_t addrlen)
12921296
rs = idm_lookup(&idm, socket);
12931297
if (!rs)
12941298
return ERR(EBADF);
1295-
if (rs->type == SOCK_STREAM) {
1299+
if (rs->category == SOCK_STREAM) {
12961300
ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr);
12971301
if (!ret)
12981302
rs->state = rs_bound;
@@ -1358,7 +1362,7 @@ static void rs_accept(struct rsocket *rs)
13581362
if (ret)
13591363
return;
13601364

1361-
new_rs = rs_alloc(rs, rs->type);
1365+
new_rs = rs_alloc(rs, rs->type, rs->category);
13621366
if (!new_rs)
13631367
goto err;
13641368
new_rs->cm_id = cm_id;
@@ -1771,7 +1775,7 @@ int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen)
17711775
rs = idm_lookup(&idm, socket);
17721776
if (!rs)
17731777
return ERR(EBADF);
1774-
if (rs->type == SOCK_STREAM) {
1778+
if (rs->category == SOCK_STREAM) {
17751779
memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen);
17761780
ret = rs_do_connect(rs);
17771781
save_errno = errno;
@@ -2573,7 +2577,7 @@ ssize_t rrecv(int socket, void *buf, size_t len, int flags)
25732577
rs = idm_at(&idm, socket);
25742578
if (!rs)
25752579
return ERR(EBADF);
2576-
if (rs->type == SOCK_DGRAM) {
2580+
if (rs->category == SOCK_DGRAM) {
25772581
fastlock_acquire(&rs->rlock);
25782582
ret = ds_recvfrom(rs, buf, len, flags, NULL, NULL);
25792583
fastlock_release(&rs->rlock);
@@ -2643,7 +2647,7 @@ ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags,
26432647
rs = idm_at(&idm, socket);
26442648
if (!rs)
26452649
return ERR(EBADF);
2646-
if (rs->type == SOCK_DGRAM) {
2650+
if (rs->category == SOCK_DGRAM) {
26472651
fastlock_acquire(&rs->rlock);
26482652
ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen);
26492653
fastlock_release(&rs->rlock);
@@ -2848,7 +2852,7 @@ ssize_t rsend(int socket, const void *buf, size_t len, int flags)
28482852
rs = idm_at(&idm, socket);
28492853
if (!rs)
28502854
return ERR(EBADF);
2851-
if (rs->type == SOCK_DGRAM) {
2855+
if (rs->category == SOCK_DGRAM) {
28522856
fastlock_acquire(&rs->slock);
28532857
ret = dsend(rs, buf, len, flags);
28542858
fastlock_release(&rs->slock);
@@ -2935,7 +2939,7 @@ ssize_t rsendto(int socket, const void *buf, size_t len, int flags,
29352939
rs = idm_at(&idm, socket);
29362940
if (!rs)
29372941
return ERR(EBADF);
2938-
if (rs->type == SOCK_STREAM) {
2942+
if (rs->category == SOCK_STREAM) {
29392943
if (dest_addr || addrlen)
29402944
return ERR(EISCONN);
29412945

@@ -3244,7 +3248,7 @@ static int rs_poll_rs(struct rsocket *rs, int events,
32443248
int ret;
32453249

32463250
check_cq:
3247-
if ((rs->type == SOCK_STREAM) && ((rs->state & rs_connected) ||
3251+
if ((rs->category == SOCK_STREAM) && ((rs->state & rs_connected) ||
32483252
(rs->state == rs_disconnected) || (rs->state & rs_error))) {
32493253
rs_process_cq(rs, nonblock, test);
32503254

@@ -3261,7 +3265,7 @@ static int rs_poll_rs(struct rsocket *rs, int events,
32613265
}
32623266

32633267
return revents;
3264-
} else if (rs->type == SOCK_DGRAM) {
3268+
} else if (rs->category == SOCK_DGRAM) {
32653269
ds_process_cqs(rs, nonblock, test);
32663270

32673271
revents = 0;
@@ -3333,7 +3337,7 @@ static int rs_poll_arm(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
33333337
if (fds[i].revents)
33343338
return 1;
33353339

3336-
if (rs->type == SOCK_STREAM) {
3340+
if (rs->category == SOCK_STREAM) {
33373341
if (rs->state >= rs_connected)
33383342
rfds[i].fd = rs->cm_id->recv_cq_channel->fd;
33393343
else
@@ -3361,7 +3365,7 @@ static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
33613365
if (rs) {
33623366
if (rfds[i].revents) {
33633367
fastlock_acquire(&rs->cq_wait_lock);
3364-
if (rs->type == SOCK_STREAM)
3368+
if (rs->category == SOCK_STREAM)
33653369
rs_get_cq_event(rs);
33663370
else
33673371
ds_get_cq_event(rs);
@@ -3609,7 +3613,7 @@ int rclose(int socket)
36093613
rs = idm_lookup(&idm, socket);
36103614
if (!rs)
36113615
return EBADF;
3612-
if (rs->type == SOCK_STREAM) {
3616+
if (rs->category == SOCK_STREAM) {
36133617
if (rs->state & rs_connected)
36143618
rshutdown(socket, SHUT_RDWR);
36153619
if (rs->opts & RS_OPT_KEEPALIVE)
@@ -3647,7 +3651,7 @@ int rgetpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
36473651
rs = idm_lookup(&idm, socket);
36483652
if (!rs)
36493653
return ERR(EBADF);
3650-
if (rs->type == SOCK_STREAM) {
3654+
if (rs->category == SOCK_STREAM) {
36513655
rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen);
36523656
return 0;
36533657
} else {
@@ -3662,7 +3666,7 @@ int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
36623666
rs = idm_lookup(&idm, socket);
36633667
if (!rs)
36643668
return ERR(EBADF);
3665-
if (rs->type == SOCK_STREAM) {
3669+
if (rs->category == SOCK_STREAM) {
36663670
rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen);
36673671
return 0;
36683672
} else {
@@ -3708,7 +3712,7 @@ int rsetsockopt(int socket, int level, int optname,
37083712
rs = idm_lookup(&idm, socket);
37093713
if (!rs)
37103714
return ERR(EBADF);
3711-
if (rs->type == SOCK_DGRAM && level != SOL_RDMA) {
3715+
if ((rs->category == SOCK_DGRAM) && level != SOL_RDMA) {
37123716
ret = setsockopt(rs->udp_sock, level, optname, optval, optlen);
37133717
if (ret)
37143718
return ret;
@@ -3719,7 +3723,7 @@ int rsetsockopt(int socket, int level, int optname,
37193723
opts = &rs->so_opts;
37203724
switch (optname) {
37213725
case SO_REUSEADDR:
3722-
if (rs->type == SOCK_STREAM) {
3726+
if (rs->category == SOCK_STREAM) {
37233727
ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID,
37243728
RDMA_OPTION_ID_REUSEADDR,
37253729
(void *) optval, optlen);
@@ -3731,8 +3735,8 @@ int rsetsockopt(int socket, int level, int optname,
37313735
opt_on = *(int *) optval;
37323736
break;
37333737
case SO_RCVBUF:
3734-
if ((rs->type == SOCK_STREAM && !rs->rbuf) ||
3735-
(rs->type == SOCK_DGRAM && !rs->qp_list))
3738+
if (((rs->category == SOCK_STREAM) && !rs->rbuf) ||
3739+
((rs->category == SOCK_DGRAM) && !rs->qp_list))
37363740
rs->rbuf_size = (*(uint32_t *) optval) << 1;
37373741
ret = 0;
37383742
break;
@@ -3804,7 +3808,7 @@ int rsetsockopt(int socket, int level, int optname,
38043808
opts = &rs->ipv6_opts;
38053809
switch (optname) {
38063810
case IPV6_V6ONLY:
3807-
if (rs->type == SOCK_STREAM) {
3811+
if (rs->category == SOCK_STREAM) {
38083812
ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID,
38093813
RDMA_OPTION_ID_AFONLY,
38103814
(void *) optval, optlen);
@@ -4806,7 +4810,7 @@ uint32_t epoll_rs(int fd, uint32_t events)
48064810
struct rsocket *rs = idm_lookup(&idm, fd);
48074811

48084812
check_cq:
4809-
if ((rs->type & SOCK_STREAM) && ((rs->state & rs_connected) ||
4813+
if ((rs->category == SOCK_STREAM) && ((rs->state & rs_connected) ||
48104814
(rs->state == rs_disconnected) || (rs->state & rs_error))) {
48114815
rs_process_cq(rs, 1, rs_poll_all);
48124816

@@ -4822,7 +4826,7 @@ uint32_t epoll_rs(int fd, uint32_t events)
48224826
}
48234827

48244828
return revents;
4825-
} else if (rs->type & SOCK_DGRAM) {
4829+
} else if (rs->category == SOCK_DGRAM) {
48264830
ds_process_cqs(rs, 1, rs_poll_all);
48274831

48284832
if ((events & EPOLLIN) && rs_have_rdata(rs))

0 commit comments

Comments
 (0)