Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inet: ScopedLwIPLock for better safety and added locks at necessary places #28655

Merged
merged 4 commits into from
Aug 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 87 additions & 79 deletions src/inet/UDPEndPointImplLwIP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,26 @@ static_assert(LWIP_VERSION_MAJOR > 1, "CHIP requires LwIP 2.0 or later");
namespace chip {
namespace Inet {

namespace {
/**
* @brief
* RAII locking for LwIP core to simplify management of
* LOCK_TCPIP_CORE()/UNLOCK_TCPIP_CORE() calls.
*/
class ScopedLwIPLock
{
public:
ScopedLwIPLock() { LOCK_TCPIP_CORE(); }
~ScopedLwIPLock() { UNLOCK_TCPIP_CORE(); }
};
} // anonymous namespace

EndpointQueueFilter * UDPEndPointImplLwIP::sQueueFilter = nullptr;

CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddress & address, uint16_t port,
InterfaceId interfaceId)
{
// Lock LwIP stack
LOCK_TCPIP_CORE();
ScopedLwIPLock lwipLock;

// Make sure we have the appropriate type of PCB.
CHIP_ERROR res = GetPCB(addressType);
Expand All @@ -90,9 +103,6 @@ CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddr
res = LwIPBindInterface(mUDP, interfaceId);
}

// Unlock LwIP stack
UNLOCK_TCPIP_CORE();

return res;
}

Expand All @@ -101,7 +111,7 @@ CHIP_ERROR UDPEndPointImplLwIP::BindInterfaceImpl(IPAddressType addrType, Interf
// A lock is required because the LwIP thread may be referring to intf_filter,
// while this code running in the Inet application is potentially modifying it.
// NOTE: this only supports LwIP interfaces whose number is no bigger than 9.
LOCK_TCPIP_CORE();
ScopedLwIPLock lwipLock;

// Make sure we have the appropriate type of PCB.
CHIP_ERROR err = GetPCB(addrType);
Expand All @@ -110,9 +120,6 @@ CHIP_ERROR UDPEndPointImplLwIP::BindInterfaceImpl(IPAddressType addrType, Interf
{
err = LwIPBindInterface(mUDP, intfId);
}

UNLOCK_TCPIP_CORE();

return err;
}

Expand All @@ -134,6 +141,8 @@ CHIP_ERROR UDPEndPointImplLwIP::LwIPBindInterface(struct udp_pcb * aUDP, Interfa

InterfaceId UDPEndPointImplLwIP::GetBoundInterface() const
{
ScopedLwIPLock lwipLock;

#if HAVE_LWIP_UDP_BIND_NETIF
return InterfaceId(netif_get_by_index(mUDP->netif_idx));
#else
Expand All @@ -148,14 +157,9 @@ uint16_t UDPEndPointImplLwIP::GetBoundPort() const

CHIP_ERROR UDPEndPointImplLwIP::ListenImpl()
{
// Lock LwIP stack
LOCK_TCPIP_CORE();
ScopedLwIPLock lwipLock;

udp_recv(mUDP, LwIPReceiveUDPMessage, this);

// Unlock LwIP stack
UNLOCK_TCPIP_CORE();

return CHIP_NO_ERROR;
}

Expand All @@ -174,53 +178,53 @@ CHIP_ERROR UDPEndPointImplLwIP::SendMsgImpl(const IPPacketInfo * pktInfo, System
VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_NO_MEMORY);
}

// Lock LwIP stack
LOCK_TCPIP_CORE();
CHIP_ERROR res = CHIP_NO_ERROR;
err_t lwipErr = ERR_VAL;

// Make sure we have the appropriate type of PCB based on the destination address.
CHIP_ERROR res = GetPCB(destAddr.Type());
if (res != CHIP_NO_ERROR)
// Adding a scope here to unlock the LwIP core when the lock is no longer required.
{
UNLOCK_TCPIP_CORE();
return res;
}
ScopedLwIPLock lwipLock;

// Send the message to the specified address/port.
// If an outbound interface has been specified, call a specific version of the UDP sendto()
// function that accepts the target interface.
// If a source address has been specified, temporarily override the local_ip of the PCB.
// This results in LwIP using the given address being as the source address for the generated
// packet, as if the PCB had been bound to that address.
err_t lwipErr = ERR_VAL;
const IPAddress & srcAddr = pktInfo->SrcAddress;
const uint16_t & destPort = pktInfo->DestPort;
const InterfaceId & intfId = pktInfo->Interface;
// Make sure we have the appropriate type of PCB based on the destination address.
res = GetPCB(destAddr.Type());
if (res != CHIP_NO_ERROR)
{
return res;
}

ip_addr_t lwipSrcAddr = srcAddr.ToLwIPAddr();
ip_addr_t lwipDestAddr = destAddr.ToLwIPAddr();
// Send the message to the specified address/port.
// If an outbound interface has been specified, call a specific version of the UDP sendto()
// function that accepts the target interface.
// If a source address has been specified, temporarily override the local_ip of the PCB.
// This results in LwIP using the given address being as the source address for the generated
// packet, as if the PCB had been bound to that address.
const IPAddress & srcAddr = pktInfo->SrcAddress;
const uint16_t & destPort = pktInfo->DestPort;
const InterfaceId & intfId = pktInfo->Interface;

ip_addr_t boundAddr;
ip_addr_copy(boundAddr, mUDP->local_ip);
ip_addr_t lwipSrcAddr = srcAddr.ToLwIPAddr();
ip_addr_t lwipDestAddr = destAddr.ToLwIPAddr();

if (!ip_addr_isany(&lwipSrcAddr))
{
ip_addr_copy(mUDP->local_ip, lwipSrcAddr);
}
ip_addr_t boundAddr;
ip_addr_copy(boundAddr, mUDP->local_ip);

if (intfId.IsPresent())
{
lwipErr = udp_sendto_if(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort,
intfId.GetPlatformInterface());
}
else
{
lwipErr = udp_sendto(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort);
}
if (!ip_addr_isany(&lwipSrcAddr))
{
ip_addr_copy(mUDP->local_ip, lwipSrcAddr);
}

ip_addr_copy(mUDP->local_ip, boundAddr);
if (intfId.IsPresent())
{
lwipErr = udp_sendto_if(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort,
intfId.GetPlatformInterface());
}
else
{
lwipErr = udp_sendto(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort);
}

// Unlock LwIP stack
UNLOCK_TCPIP_CORE();
ip_addr_copy(mUDP->local_ip, boundAddr);
}

if (lwipErr != ERR_OK)
{
Expand All @@ -232,9 +236,7 @@ CHIP_ERROR UDPEndPointImplLwIP::SendMsgImpl(const IPPacketInfo * pktInfo, System

void UDPEndPointImplLwIP::CloseImpl()
{

// Lock LwIP stack
LOCK_TCPIP_CORE();
ScopedLwIPLock lwipLock;

// Since UDP PCB is released synchronously here, but UDP endpoint itself might have to wait
// for destruction asynchronously, there could be more allocated UDP endpoints than UDP PCBs.
Expand All @@ -260,9 +262,6 @@ void UDPEndPointImplLwIP::CloseImpl()
}
}
}

// Unlock LwIP stack
UNLOCK_TCPIP_CORE();
}

void UDPEndPointImplLwIP::Free()
Expand Down Expand Up @@ -473,19 +472,23 @@ CHIP_ERROR UDPEndPointImplLwIP::IPv4JoinLeaveMulticastGroupImpl(InterfaceId aInt
const ip4_addr_t lIPv4Address = aAddress.ToIPv4();
err_t lStatus;

if (aInterfaceId.IsPresent())
{
ScopedLwIPLock lwipLock;

struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId);
VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE);
if (aInterfaceId.IsPresent())
{

lStatus = join ? igmp_joingroup_netif(lNetif, &lIPv4Address) //
: igmp_leavegroup_netif(lNetif, &lIPv4Address);
}
else
{
lStatus = join ? igmp_joingroup(IP4_ADDR_ANY4, &lIPv4Address) //
: igmp_leavegroup(IP4_ADDR_ANY4, &lIPv4Address);
struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId);
VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE);

lStatus = join ? igmp_joingroup_netif(lNetif, &lIPv4Address) //
: igmp_leavegroup_netif(lNetif, &lIPv4Address);
}
else
{
lStatus = join ? igmp_joingroup(IP4_ADDR_ANY4, &lIPv4Address) //
: igmp_leavegroup(IP4_ADDR_ANY4, &lIPv4Address);
}
}

if (lStatus == ERR_MEM)
Expand All @@ -504,17 +507,22 @@ CHIP_ERROR UDPEndPointImplLwIP::IPv6JoinLeaveMulticastGroupImpl(InterfaceId aInt
#ifdef HAVE_IPV6_MULTICAST
const ip6_addr_t lIPv6Address = aAddress.ToIPv6();
err_t lStatus;
if (aInterfaceId.IsPresent())
{
struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId);
VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE);
lStatus = join ? mld6_joingroup_netif(lNetif, &lIPv6Address) //
: mld6_leavegroup_netif(lNetif, &lIPv6Address);
}
else

{
lStatus = join ? mld6_joingroup(IP6_ADDR_ANY6, &lIPv6Address) //
: mld6_leavegroup(IP6_ADDR_ANY6, &lIPv6Address);
ScopedLwIPLock lwipLock;

if (aInterfaceId.IsPresent())
{
struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId);
VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE);
lStatus = join ? mld6_joingroup_netif(lNetif, &lIPv6Address) //
: mld6_leavegroup_netif(lNetif, &lIPv6Address);
}
else
{
lStatus = join ? mld6_joingroup(IP6_ADDR_ANY6, &lIPv6Address) //
: mld6_leavegroup(IP6_ADDR_ANY6, &lIPv6Address);
}
}

if (lStatus == ERR_MEM)
Expand Down