[udp] allow binding to multicast address (#11901)

This commit allows binding to multicast address, which means the socket
would only accept frames targeting to a particular multicast address.
This prevents other datagrams destined to the same port being delivered
to this socket.

Note that binding to a multicast address doesn't automatically subscribe
to the multicast group.
This commit is contained in:
Yakun Xu
2025-09-10 23:51:27 +08:00
committed by GitHub
parent 9491dccf42
commit e1d8a05c29
7 changed files with 249 additions and 14 deletions
+3 -2
View File
@@ -58,7 +58,7 @@ bool Udp::SocketHandle::Matches(const MessageInfo &aMessageInfo) const
VerifyOrExit(GetSockName().mPort == aMessageInfo.GetSockPort());
VerifyOrExit(aMessageInfo.GetSockAddr().IsMulticast() || GetSockName().GetAddress().IsUnspecified() ||
VerifyOrExit(GetSockName().GetAddress().IsUnspecified() ||
GetSockName().GetAddress() == aMessageInfo.GetSockAddr());
// Verify source if connected socket
@@ -254,7 +254,8 @@ Error Udp::Bind(SocketHandle &aSocket, const SockAddr &aSockAddr)
SuccessOrExit(error = Plat::BindToNetif(aSocket));
#endif
VerifyOrExit(aSockAddr.GetAddress().IsUnspecified() || Get<ThreadNetif>().HasUnicastAddress(aSockAddr.GetAddress()),
VerifyOrExit(aSockAddr.GetAddress().IsUnspecified() || aSockAddr.GetAddress().IsMulticast() ||
Get<ThreadNetif>().HasUnicastAddress(aSockAddr.GetAddress()),
error = kErrorInvalidArgs);
aSocket.mSockName = aSockAddr;
+1
View File
@@ -52,6 +52,7 @@ target_link_libraries(ot-fake-ftd INTERFACE
add_executable(ot-ftd-gtest
dataset_test.cpp
udp_test.cpp
)
target_link_libraries(ot-ftd-gtest
ot-fake-ftd
+1 -10
View File
@@ -39,22 +39,13 @@
#include "gmock/gmock.h"
#include "fake_platform.hpp"
#include "mock_callback.hpp"
using namespace ot;
using ::testing::AnyNumber;
using ::testing::AtLeast;
using ::testing::MockFunction;
using ::testing::Truly;
template <typename R, typename... A> class MockCallback : public testing::MockFunction<R(A...)>
{
public:
static R CallWithContext(A... aArgs, void *aContext)
{
return static_cast<MockCallback *>(aContext)->Call(aArgs...);
};
};
TEST(otDatasetSetActiveTlvs, shouldTriggerStateCallbackOnSuccess)
{
FakePlatform fakePlatform;
+1
View File
@@ -65,6 +65,7 @@ FakePlatform::FakePlatform()
assert(sPlatform == nullptr);
sPlatform = this;
fprintf(stderr, "fake platform start\r\n");
mTransmitFrame.mPsdu = mTransmitBuffer;
#if OPENTHREAD_CONFIG_MULTIPLE_INSTANCE_ENABLE
+2 -2
View File
@@ -150,8 +150,8 @@ protected:
template <uint64_t FakePlatform::*T> void HandleSchedule();
otRadioFrame mTransmitFrame;
uint8_t mTransmitBuffer[OT_RADIO_FRAME_MAX_SIZE];
otRadioFrame mTransmitFrame{};
uint8_t mTransmitBuffer[OT_RADIO_FRAME_MAX_SIZE]{};
uint8_t mChannel = 0;
uint8_t mReceiveAtChannel = 0;
+47
View File
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2025, The OpenThread Authors.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* 3. Neither the name of the copyright holder nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef OT_TESTS_GTEST_MOCK_CALLBACK
#define OT_TESTS_GTEST_MOCK_CALLBACK
#include "gmock/gmock.h"
template <typename R, typename... A> class MockCallback : public testing::MockFunction<R(A...)>
{
public:
static R CallWithContext(A... aArgs, void *aContext)
{
return static_cast<MockCallback *>(aContext)->Call(aArgs...);
};
static R CallWithContextAhead(void *aContext, A... aArgs)
{
return static_cast<MockCallback *>(aContext)->Call(aArgs...);
};
};
#endif // OT_TESTS_GTEST_MOCK_CALLBACK
+194
View File
@@ -0,0 +1,194 @@
/*
* Copyright (c) 2025, The OpenThread Authors.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* 3. Neither the name of the copyright holder nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <openthread/border_agent.h>
#include <openthread/dataset.h>
#include <openthread/dataset_ftd.h>
#include <openthread/instance.h>
#include <openthread/ip6.h>
#include <openthread/thread.h>
#include <openthread/platform/time.h>
#include "net/socket.hpp"
#include "openthread/error.h"
#include "openthread/message.h"
#include "openthread/udp.h"
#include "fake_platform.hpp"
#include "mock_callback.hpp"
#include "core/net/ip6_address.hpp"
using namespace ot;
using MockReceiveCallback = MockCallback<void, otMessage *, const otMessageInfo *>;
class UdpTest : public ::testing::Test
{
protected:
void SetUp() override
{
otOperationalDataset dataset;
otOperationalDatasetTlvs datasetTlvs;
ASSERT_EQ(OT_ERROR_NONE, otDatasetCreateNewNetwork(FakePlatform::CurrentInstance(), &dataset));
otDatasetConvertToTlvs(&dataset, &datasetTlvs);
ASSERT_EQ(OT_ERROR_NONE, otDatasetSetActiveTlvs(FakePlatform::CurrentInstance(), &datasetTlvs));
ASSERT_EQ(OT_ERROR_NONE, otIp6SetEnabled(FakePlatform::CurrentInstance(), true));
ASSERT_EQ(OT_ERROR_NONE, otThreadSetEnabled(FakePlatform::CurrentInstance(), true));
mFakePlatform.GoInMs(10000);
}
FakePlatform mFakePlatform;
};
TEST_F(UdpTest, shouldSuccessWhenBindingMulticastAddressAndReceiveFromIt)
{
otUdpSocket receiver;
MockReceiveCallback receiverCallback;
ASSERT_EQ(OT_ERROR_NONE, otUdpOpen(FakePlatform::CurrentInstance(), &receiver,
&MockReceiveCallback::CallWithContextAhead, &receiverCallback));
Ip6::SockAddr listenAddr;
ASSERT_EQ(OT_ERROR_NONE, listenAddr.GetAddress().FromString("ff02::21"));
listenAddr.SetPort(2121);
ASSERT_EQ(OT_ERROR_NONE, otUdpBind(FakePlatform::CurrentInstance(), &receiver, &listenAddr, OT_NETIF_UNSPECIFIED));
ASSERT_EQ(OT_ERROR_NONE, otIp6SubscribeMulticastAddress(FakePlatform::CurrentInstance(), &listenAddr.mAddress));
EXPECT_CALL(receiverCallback, Call).Times(1);
otUdpSocket sender{};
MockReceiveCallback senderCallback;
ASSERT_EQ(OT_ERROR_NONE, otUdpOpen(FakePlatform::CurrentInstance(), &sender,
&MockReceiveCallback::CallWithContextAhead, &senderCallback));
otMessageInfo messageInfo{};
messageInfo.mPeerAddr = listenAddr.mAddress;
messageInfo.mPeerPort = listenAddr.mPort;
messageInfo.mMulticastLoop = true;
otMessage *message = otUdpNewMessage(FakePlatform::CurrentInstance(), nullptr);
ASSERT_NE(message, nullptr);
ASSERT_EQ(otMessageAppend(message, "multicast", sizeof("multicast") - 1), OT_ERROR_NONE);
ASSERT_EQ(OT_ERROR_NONE, otUdpSend(FakePlatform::CurrentInstance(), &sender, message, &messageInfo));
mFakePlatform.GoInMs(1000);
ASSERT_EQ(OT_ERROR_NONE, otUdpClose(FakePlatform::CurrentInstance(), &sender));
ASSERT_EQ(OT_ERROR_NONE, otUdpClose(FakePlatform::CurrentInstance(), &receiver));
}
TEST_F(UdpTest, shouldSuccessWhenBindingMulticastAddressAndNoReceiveFromDifferentMulticast)
{
otUdpSocket receiver;
MockReceiveCallback receiverCallback;
ASSERT_EQ(OT_ERROR_NONE, otUdpOpen(FakePlatform::CurrentInstance(), &receiver,
&MockReceiveCallback::CallWithContextAhead, &receiverCallback));
Ip6::Address group1;
Ip6::Address group2;
Ip6::SockAddr listenAddr;
ASSERT_EQ(OT_ERROR_NONE, group1.FromString("ff02::21"));
ASSERT_EQ(OT_ERROR_NONE, group2.FromString("ff02::22"));
listenAddr.SetAddress(group1);
listenAddr.SetPort(2121);
ASSERT_EQ(OT_ERROR_NONE, otUdpBind(FakePlatform::CurrentInstance(), &receiver, &listenAddr, OT_NETIF_UNSPECIFIED));
ASSERT_EQ(OT_ERROR_NONE, otIp6SubscribeMulticastAddress(FakePlatform::CurrentInstance(), &group1));
ASSERT_EQ(OT_ERROR_NONE, otIp6SubscribeMulticastAddress(FakePlatform::CurrentInstance(), &group2));
EXPECT_CALL(receiverCallback, Call).Times(0);
otUdpSocket sender{};
MockReceiveCallback senderCallback;
ASSERT_EQ(OT_ERROR_NONE, otUdpOpen(FakePlatform::CurrentInstance(), &sender,
&MockReceiveCallback::CallWithContextAhead, &senderCallback));
otMessageInfo messageInfo{};
messageInfo.mPeerAddr = group2;
messageInfo.mPeerPort = listenAddr.mPort;
messageInfo.mMulticastLoop = true;
otMessage *message = otUdpNewMessage(FakePlatform::CurrentInstance(), nullptr);
ASSERT_NE(message, nullptr);
ASSERT_EQ(otMessageAppend(message, "multicast", sizeof("multicast") - 1), OT_ERROR_NONE);
ASSERT_EQ(OT_ERROR_NONE, otUdpSend(FakePlatform::CurrentInstance(), &sender, message, &messageInfo));
mFakePlatform.GoInMs(1000);
ASSERT_EQ(OT_ERROR_NONE, otUdpClose(FakePlatform::CurrentInstance(), &sender));
ASSERT_EQ(OT_ERROR_NONE, otUdpClose(FakePlatform::CurrentInstance(), &receiver));
}
TEST_F(UdpTest, shouldSuccessWhenBindingMulticastAddressAndNoReceiveIfNotSubscribed)
{
otUdpSocket receiver;
MockReceiveCallback receiverCallback;
ASSERT_EQ(OT_ERROR_NONE, otUdpOpen(FakePlatform::CurrentInstance(), &receiver,
&MockReceiveCallback::CallWithContextAhead, &receiverCallback));
Ip6::SockAddr listenAddr;
ASSERT_EQ(OT_ERROR_NONE, listenAddr.GetAddress().FromString("ff02::21"));
listenAddr.SetPort(2121);
ASSERT_EQ(OT_ERROR_NONE, otUdpBind(FakePlatform::CurrentInstance(), &receiver, &listenAddr, OT_NETIF_UNSPECIFIED));
EXPECT_CALL(receiverCallback, Call).Times(0);
otUdpSocket sender{};
MockReceiveCallback senderCallback;
ASSERT_EQ(OT_ERROR_NONE, otUdpOpen(FakePlatform::CurrentInstance(), &sender,
&MockReceiveCallback::CallWithContextAhead, &senderCallback));
otMessageInfo messageInfo{};
messageInfo.mPeerAddr = listenAddr.GetAddress();
messageInfo.mPeerPort = listenAddr.mPort;
messageInfo.mMulticastLoop = true;
otMessage *message = otUdpNewMessage(FakePlatform::CurrentInstance(), nullptr);
ASSERT_NE(message, nullptr);
ASSERT_EQ(otMessageAppend(message, "multicast", sizeof("multicast") - 1), OT_ERROR_NONE);
ASSERT_EQ(OT_ERROR_NONE, otUdpSend(FakePlatform::CurrentInstance(), &sender, message, &messageInfo));
mFakePlatform.GoInMs(1000);
ASSERT_EQ(OT_ERROR_NONE, otUdpClose(FakePlatform::CurrentInstance(), &sender));
ASSERT_EQ(OT_ERROR_NONE, otUdpClose(FakePlatform::CurrentInstance(), &receiver));
}