Browse code

dco_linux: fix async message reception

Currently whenever we send a PEER_GET request to ovpn, we also
set the CB that is supposed to parse the reply.

However, due to the async nature of netlink messages, we could
get an unrelated notification, sent by ovpn (kernel) upon some event,
after userland has set the CB, but before parsing the awaited reply.

When this happens, the notification is then parsed with the
configured CB instead of the notification parser, thus effectively
rejecting the notification and losing the event.

To fix this inconsistency, make ovpn_handle_msg() the default and
only netlink parser CB. It is configured upon DCO initialization
and is never removed.

ovpn_handle_msg() will check the message type and will call the
corresponding handler. This way, no matter what message we get at
what time, we'll always parse it correctly.

As a bonus we can also simplify the nl_sendmsg() API as we
don't need to pass the cb and its argument anymore.

The ID of the NLCTRL family is now also stored in the DCO
context as we need it to check when we receive a mcast ID
lookup message.

Change-Id: I23ad79e14844aefde9ece34dadef0b75ff267201
Github: closes OpenVPN/openvpn#793
Signed-off-by: Antonio Quartulli <antonio@mandelbit.com>
Acked-by: Gert Doering <gert@greenie.muc.de>
Message-Id: <20250725172708.19456-1-gert@greenie.muc.de>
URL: https://www.mail-archive.com/openvpn-devel@lists.sourceforge.net/msg32339.html
Signed-off-by: Gert Doering <gert@greenie.muc.de>

Antonio Quartulli authored on 2025/07/26 02:27:02
Showing 2 changed files
... ...
@@ -167,23 +167,19 @@ ovpn_nl_recvmsgs(dco_context_t *dco, const char *prefix)
167 167
 }
168 168
 
169 169
 /**
170
- * Send a prepared netlink message and registers cb as callback if non-null.
170
+ * Send a prepared netlink message.
171 171
  *
172 172
  * The method will also free nl_msg
173 173
  * @param dco       The dco context to use
174 174
  * @param nl_msg    the message to use
175
- * @param cb        An optional callback if the caller expects an answer
176
- * @param cb_arg    An optional param to pass to the callback
177 175
  * @param prefix    A prefix to report in the error message to give the user context
178 176
  * @return          status of sending the message
179 177
  */
180 178
 static int
181
-ovpn_nl_msg_send(dco_context_t *dco, struct nl_msg *nl_msg, ovpn_nl_cb cb,
182
-                 void *cb_arg, const char *prefix)
179
+ovpn_nl_msg_send(dco_context_t *dco, struct nl_msg *nl_msg, const char *prefix)
183 180
 {
184 181
     dco->status = 1;
185 182
 
186
-    nl_cb_set(dco->nl_cb, NL_CB_VALID, NL_CB_CUSTOM, cb, cb_arg);
187 183
     nl_send_auto(dco->nl_sock, nl_msg);
188 184
 
189 185
     while (dco->status == 1)
... ...
@@ -285,7 +281,7 @@ dco_new_peer(dco_context_t *dco, unsigned int peerid, int sd,
285 285
     }
286 286
     nla_nest_end(nl_msg, attr);
287 287
 
288
-    ret = ovpn_nl_msg_send(dco, nl_msg, NULL, NULL, __func__);
288
+    ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
289 289
 
290 290
 nla_put_failure:
291 291
     nlmsg_free(nl_msg);
... ...
@@ -385,6 +381,29 @@ ovpn_nl_cb_error(struct sockaddr_nl (*nla) __attribute__ ((unused)),
385 385
 }
386 386
 
387 387
 static void
388
+ovpn_dco_register(dco_context_t *dco)
389
+{
390
+    msg(D_DCO_DEBUG, __func__);
391
+    ovpn_get_mcast_id(dco);
392
+
393
+    if (dco->ovpn_dco_mcast_id < 0)
394
+    {
395
+        msg(M_FATAL, "cannot get mcast group: %s",  nl_geterror(dco->ovpn_dco_mcast_id));
396
+    }
397
+
398
+    /* Register for ovpn-dco specific multicast messages that the kernel may
399
+     * send
400
+     */
401
+    int ret = nl_socket_add_membership(dco->nl_sock, dco->ovpn_dco_mcast_id);
402
+    if (ret)
403
+    {
404
+        msg(M_FATAL, "%s: failed to join groups: %d", __func__, ret);
405
+    }
406
+}
407
+
408
+static int ovpn_handle_msg(struct nl_msg *msg, void *arg);
409
+
410
+static void
388 411
 ovpn_dco_init_netlink(dco_context_t *dco)
389 412
 {
390 413
     dco->ovpn_dco_id = resolve_ovpn_netlink_id(M_FATAL);
... ...
@@ -420,11 +439,15 @@ ovpn_dco_init_netlink(dco_context_t *dco)
420 420
 
421 421
     nl_socket_set_cb(dco->nl_sock, dco->nl_cb);
422 422
 
423
+    dco->dco_message_peer_id = -1;
423 424
     nl_cb_err(dco->nl_cb, NL_CB_CUSTOM, ovpn_nl_cb_error, &dco->status);
424 425
     nl_cb_set(dco->nl_cb, NL_CB_FINISH, NL_CB_CUSTOM, ovpn_nl_cb_finish,
425 426
               &dco->status);
426 427
     nl_cb_set(dco->nl_cb, NL_CB_ACK, NL_CB_CUSTOM, ovpn_nl_cb_finish,
427 428
               &dco->status);
429
+    nl_cb_set(dco->nl_cb, NL_CB_VALID, NL_CB_CUSTOM, ovpn_handle_msg, dco);
430
+
431
+    ovpn_dco_register(dco);
428 432
 
429 433
     /* The async PACKET messages confuse libnl and it will drop them with
430 434
      * wrong sequence numbers (NLE_SEQ_MISMATCH), so disable libnl's sequence
... ...
@@ -476,27 +499,6 @@ ovpn_dco_uninit_netlink(dco_context_t *dco)
476 476
     CLEAR(dco);
477 477
 }
478 478
 
479
-static void
480
-ovpn_dco_register(dco_context_t *dco)
481
-{
482
-    msg(D_DCO_DEBUG, __func__);
483
-    ovpn_get_mcast_id(dco);
484
-
485
-    if (dco->ovpn_dco_mcast_id < 0)
486
-    {
487
-        msg(M_FATAL, "cannot get mcast group: %s",  nl_geterror(dco->ovpn_dco_mcast_id));
488
-    }
489
-
490
-    /* Register for ovpn-dco specific multicast messages that the kernel may
491
-     * send
492
-     */
493
-    int ret = nl_socket_add_membership(dco->nl_sock, dco->ovpn_dco_mcast_id);
494
-    if (ret)
495
-    {
496
-        msg(M_FATAL, "%s: failed to join groups: %d", __func__, ret);
497
-    }
498
-}
499
-
500 479
 int
501 480
 open_tun_dco(struct tuntap *tt, openvpn_net_ctx_t *ctx, const char *dev)
502 481
 {
... ...
@@ -516,10 +518,6 @@ open_tun_dco(struct tuntap *tt, openvpn_net_ctx_t *ctx, const char *dev)
516 516
         msg(M_FATAL, "DCO: cannot retrieve ifindex for interface %s", dev);
517 517
     }
518 518
 
519
-    tt->dco.dco_message_peer_id = -1;
520
-
521
-    ovpn_dco_register(&tt->dco);
522
-
523 519
     return 0;
524 520
 }
525 521
 
... ...
@@ -548,7 +546,7 @@ dco_swap_keys(dco_context_t *dco, unsigned int peerid)
548 548
     NLA_PUT_U32(nl_msg, OVPN_A_KEYCONF_PEER_ID, peerid);
549 549
     nla_nest_end(nl_msg, attr);
550 550
 
551
-    ret = ovpn_nl_msg_send(dco, nl_msg, NULL, NULL, __func__);
551
+    ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
552 552
 
553 553
 nla_put_failure:
554 554
     nlmsg_free(nl_msg);
... ...
@@ -572,7 +570,7 @@ dco_del_peer(dco_context_t *dco, unsigned int peerid)
572 572
     NLA_PUT_U32(nl_msg, OVPN_A_PEER_ID, peerid);
573 573
     nla_nest_end(nl_msg, attr);
574 574
 
575
-    ret = ovpn_nl_msg_send(dco, nl_msg, NULL, NULL, __func__);
575
+    ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
576 576
 
577 577
 nla_put_failure:
578 578
     nlmsg_free(nl_msg);
... ...
@@ -598,7 +596,7 @@ dco_del_key(dco_context_t *dco, unsigned int peerid,
598 598
     NLA_PUT_U32(nl_msg, OVPN_A_KEYCONF_SLOT, slot);
599 599
     nla_nest_end(nl_msg, keyconf);
600 600
 
601
-    ret = ovpn_nl_msg_send(dco, nl_msg, NULL, NULL, __func__);
601
+    ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
602 602
 
603 603
 nla_put_failure:
604 604
     nlmsg_free(nl_msg);
... ...
@@ -657,7 +655,7 @@ dco_new_key(dco_context_t *dco, unsigned int peerid, int keyid,
657 657
     nla_nest_end(nl_msg, key_conf);
658 658
 
659 659
 
660
-    ret = ovpn_nl_msg_send(dco, nl_msg, NULL, NULL, __func__);
660
+    ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
661 661
 
662 662
 nla_put_failure:
663 663
     nlmsg_free(nl_msg);
... ...
@@ -686,7 +684,7 @@ dco_set_peer(dco_context_t *dco, unsigned int peerid,
686 686
                 keepalive_timeout);
687 687
     nla_nest_end(nl_msg, attr);
688 688
 
689
-    ret = ovpn_nl_msg_send(dco, nl_msg, NULL, NULL, __func__);
689
+    ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
690 690
 
691 691
 nla_put_failure:
692 692
     nlmsg_free(nl_msg);
... ...
@@ -754,7 +752,7 @@ ovpn_get_mcast_id(dco_context_t *dco)
754 754
 
755 755
     /* Even though 'nlctrl' is a constant, there seem to be no library
756 756
      * provided define for it */
757
-    int ctrlid = genl_ctrl_resolve(dco->nl_sock, "nlctrl");
757
+    dco->ctrlid = genl_ctrl_resolve(dco->nl_sock, "nlctrl");
758 758
 
759 759
     struct nl_msg *nl_msg = nlmsg_alloc();
760 760
     if (!nl_msg)
... ...
@@ -762,12 +760,12 @@ ovpn_get_mcast_id(dco_context_t *dco)
762 762
         return -ENOMEM;
763 763
     }
764 764
 
765
-    genlmsg_put(nl_msg, 0, 0, ctrlid, 0, 0, CTRL_CMD_GETFAMILY, 0);
765
+    genlmsg_put(nl_msg, 0, 0, dco->ctrlid, 0, 0, CTRL_CMD_GETFAMILY, 0);
766 766
 
767 767
     int ret = -EMSGSIZE;
768 768
     NLA_PUT_STRING(nl_msg, CTRL_ATTR_FAMILY_NAME, OVPN_FAMILY_NAME);
769 769
 
770
-    ret = ovpn_nl_msg_send(dco, nl_msg, mcast_family_handler, dco, __func__);
770
+    ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
771 771
 
772 772
 nla_put_failure:
773 773
     nlmsg_free(nl_msg);
... ...
@@ -879,31 +877,34 @@ dco_update_peer_stat(struct context_2 *c2, struct nlattr *tb[], uint32_t id)
879 879
 }
880 880
 
881 881
 static int
882
-dco_parse_peer_multi(struct nl_msg *msg, void *arg)
882
+ovpn_handle_peer_multi(dco_context_t *dco, struct nlattr *attrs[])
883 883
 {
884
-    struct nlattr *tb[OVPN_A_MAX + 1];
885
-    struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg));
886
-
887 884
     msg(D_DCO_DEBUG, "%s: parsing message...", __func__);
888 885
 
889
-    nla_parse(tb, OVPN_A_MAX, genlmsg_attrdata(gnlh, 0),
890
-              genlmsg_attrlen(gnlh, 0), NULL);
886
+    /* this function assumes openvpn is running in multipeer mode as
887
+     * it accesses c->multi
888
+     */
889
+    if (dco->ifmode != OVPN_MODE_MP)
890
+    {
891
+        msg(M_WARN, "%s: can't parse 'multi-peer' message on P2P instance", __func__);
892
+        return NL_SKIP;
893
+    }
891 894
 
892
-    if (!tb[OVPN_A_PEER])
895
+    if (!attrs[OVPN_A_PEER])
893 896
     {
894 897
         return NL_SKIP;
895 898
     }
896 899
 
897 900
     struct nlattr *tb_peer[OVPN_A_PEER_MAX + 1];
898
-    nla_parse_nested(tb_peer, OVPN_A_PEER_MAX, tb[OVPN_A_PEER], NULL);
901
+    nla_parse_nested(tb_peer, OVPN_A_PEER_MAX, attrs[OVPN_A_PEER], NULL);
899 902
 
900 903
     if (!tb_peer[OVPN_A_PEER_ID])
901 904
     {
902
-        msg(M_WARN, "%s: no peer-id provided in reply", __func__);
905
+        msg(M_WARN, "ovpn-dco: no peer-id provided in (MULTI) PEER_GET reply");
903 906
         return NL_SKIP;
904 907
     }
905 908
 
906
-    struct multi_context *m = arg;
909
+    struct multi_context *m = dco->c->multi;
907 910
     uint32_t peer_id = nla_get_u32(tb_peer[OVPN_A_PEER_ID]);
908 911
 
909 912
     if (peer_id >= m->max_clients || !m->instances[peer_id])
... ...
@@ -919,39 +920,53 @@ dco_parse_peer_multi(struct nl_msg *msg, void *arg)
919 919
 }
920 920
 
921 921
 static int
922
-dco_parse_peer(struct nl_msg *msg, void *arg)
922
+ovpn_handle_peer(dco_context_t *dco, struct nlattr *attrs[])
923 923
 {
924
-    struct context *c = arg;
925
-    struct nlattr *tb[OVPN_A_MAX + 1];
926
-    struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg));
927
-
928 924
     msg(D_DCO_DEBUG, "%s: parsing message...", __func__);
929 925
 
930
-    nla_parse(tb, OVPN_A_MAX, genlmsg_attrdata(gnlh, 0),
931
-              genlmsg_attrlen(gnlh, 0), NULL);
932
-
933
-    if (!tb[OVPN_A_PEER])
926
+    if (!attrs[OVPN_A_PEER])
934 927
     {
935 928
         msg(D_DCO_DEBUG, "%s: malformed reply", __func__);
936 929
         return NL_SKIP;
937 930
     }
938 931
 
939 932
     struct nlattr *tb_peer[OVPN_A_PEER_MAX + 1];
940
-    nla_parse_nested(tb_peer, OVPN_A_PEER_MAX, tb[OVPN_A_PEER], NULL);
933
+    nla_parse_nested(tb_peer, OVPN_A_PEER_MAX, attrs[OVPN_A_PEER], NULL);
941 934
 
942 935
     if (!tb_peer[OVPN_A_PEER_ID])
943 936
     {
944
-        msg(M_WARN, "%s: no peer-id provided in reply", __func__);
937
+        msg(M_WARN, "ovpn-dco: no peer-id provided in PEER_GET reply");
945 938
         return NL_SKIP;
946 939
     }
947 940
 
948 941
     uint32_t peer_id = nla_get_u32(tb_peer[OVPN_A_PEER_ID]);
949
-    if (c->c2.tls_multi->dco_peer_id != peer_id)
942
+    struct context_2 *c2;
943
+
944
+    if (dco->ifmode == OVPN_MODE_P2P)
945
+    {
946
+        c2 = &dco->c->c2;
947
+    }
948
+    else
949
+    {
950
+        struct multi_instance *mi = dco->c->multi->instances[peer_id];
951
+        if (!mi)
952
+        {
953
+            msg(M_WARN, "%s: received data for a non-existing peer %u", __func__, peer_id);
954
+            return NL_SKIP;
955
+        }
956
+
957
+        c2 = &mi->context.c2;
958
+    }
959
+
960
+    /* at this point this check should never fail for MP mode,
961
+     * but it's still fully valid for P2P mode
962
+     */
963
+    if (c2->tls_multi->dco_peer_id != peer_id)
950 964
     {
951 965
         return NL_SKIP;
952 966
     }
953 967
 
954
-    dco_update_peer_stat(&c->c2, tb_peer, peer_id);
968
+    dco_update_peer_stat(c2, tb_peer, peer_id);
955 969
 
956 970
     return NL_OK;
957 971
 }
... ...
@@ -1120,9 +1135,22 @@ ovpn_handle_msg(struct nl_msg *msg, void *arg)
1120 1120
 {
1121 1121
     dco_context_t *dco = arg;
1122 1122
 
1123
-    struct genlmsghdr *gnlh = nlmsg_data(nlmsg_hdr(msg));
1124 1123
     struct nlattr *attrs[OVPN_A_MAX + 1];
1125 1124
     struct nlmsghdr *nlh = nlmsg_hdr(msg);
1125
+    struct genlmsghdr *gnlh = genlmsg_hdr(nlh);
1126
+
1127
+    msg(D_DCO_DEBUG, "ovpn-dco: received netlink message type=%u cmd=%u flags=%#.4x",
1128
+        nlh->nlmsg_type, gnlh->cmd, nlh->nlmsg_flags);
1129
+
1130
+    /* if we get a message from the NLCTRL family, it means
1131
+     * this is the reply to the mcast ID resolution request
1132
+     * and we parse it accordingly.
1133
+     */
1134
+    if (nlh->nlmsg_type == dco->ctrlid)
1135
+    {
1136
+        msg(D_DCO_DEBUG, "ovpn-dco: received CTRLID message");
1137
+        return mcast_family_handler(msg, dco);
1138
+    }
1126 1139
 
1127 1140
     if (!genlmsg_valid_hdr(nlh, 0))
1128 1141
     {
... ...
@@ -1146,6 +1174,21 @@ ovpn_handle_msg(struct nl_msg *msg, void *arg)
1146 1146
      */
1147 1147
     switch (gnlh->cmd)
1148 1148
     {
1149
+        case OVPN_CMD_PEER_GET:
1150
+        {
1151
+            /* this message is part of a peer list dump, hence triggered
1152
+             * by a MP/server instance
1153
+             */
1154
+            if (nlh->nlmsg_flags & NLM_F_MULTI)
1155
+            {
1156
+                return ovpn_handle_peer_multi(dco, attrs);
1157
+            }
1158
+            else
1159
+            {
1160
+                return ovpn_handle_peer(dco, attrs);
1161
+            }
1162
+        }
1163
+
1149 1164
         case OVPN_CMD_PEER_DEL_NTF:
1150 1165
         {
1151 1166
             return ovpn_handle_peer_del_ntf(dco, attrs);
... ...
@@ -1174,7 +1217,6 @@ int
1174 1174
 dco_do_read(dco_context_t *dco)
1175 1175
 {
1176 1176
     msg(D_DCO_DEBUG, __func__);
1177
-    nl_cb_set(dco->nl_cb, NL_CB_VALID, NL_CB_CUSTOM, ovpn_handle_msg, dco);
1178 1177
 
1179 1178
     return ovpn_nl_recvmsgs(dco, __func__);
1180 1179
 }
... ...
@@ -1189,7 +1231,7 @@ dco_get_peer_stats_multi(dco_context_t *dco, struct multi_context *m,
1189 1189
 
1190 1190
     nlmsg_hdr(nl_msg)->nlmsg_flags |= NLM_F_DUMP;
1191 1191
 
1192
-    int ret = ovpn_nl_msg_send(dco, nl_msg, dco_parse_peer_multi, m, __func__);
1192
+    int ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
1193 1193
 
1194 1194
     nlmsg_free(nl_msg);
1195 1195
 
... ...
@@ -1227,7 +1269,7 @@ dco_get_peer_stats(struct context *c, const bool raise_sigusr1_on_err)
1227 1227
     NLA_PUT_U32(nl_msg, OVPN_A_PEER_ID, peer_id);
1228 1228
     nla_nest_end(nl_msg, attr);
1229 1229
 
1230
-    ret = ovpn_nl_msg_send(dco, nl_msg, dco_parse_peer, c, __func__);
1230
+    ret = ovpn_nl_msg_send(dco, nl_msg, __func__);
1231 1231
 
1232 1232
 nla_put_failure:
1233 1233
     nlmsg_free(nl_msg);
... ...
@@ -66,6 +66,7 @@ typedef struct
66 66
     int status;
67 67
 
68 68
     struct context *c;
69
+    int ctrlid;
69 70
 
70 71
     enum ovpn_mode ifmode;
71 72