diff --git a/src/inet/UDPEndPointImplLwIP.cpp b/src/inet/UDPEndPointImplLwIP.cpp index 7ec594127b0f14..d3e87e1dd0ca42 100644 --- a/src/inet/UDPEndPointImplLwIP.cpp +++ b/src/inet/UDPEndPointImplLwIP.cpp @@ -62,13 +62,26 @@ static_assert(LWIP_VERSION_MAJOR > 1, "CHIP requires LwIP 2.0 or later"); namespace chip { namespace Inet { +namespace { +/** + * @brief + * RAII locking for LwIP core to simplify management of + * LOCK_TCPIP_CORE()/UNLOCK_TCPIP_CORE() calls. + */ +class ScopedLwIPLock +{ +public: + ScopedLwIPLock() { LOCK_TCPIP_CORE(); } + ~ScopedLwIPLock() { UNLOCK_TCPIP_CORE(); } +}; +} // anonymous namespace + EndpointQueueFilter * UDPEndPointImplLwIP::sQueueFilter = nullptr; CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddress & address, uint16_t port, InterfaceId interfaceId) { - // Lock LwIP stack - LOCK_TCPIP_CORE(); + ScopedLwIPLock lwipLock; // Make sure we have the appropriate type of PCB. CHIP_ERROR res = GetPCB(addressType); @@ -90,9 +103,6 @@ CHIP_ERROR UDPEndPointImplLwIP::BindImpl(IPAddressType addressType, const IPAddr res = LwIPBindInterface(mUDP, interfaceId); } - // Unlock LwIP stack - UNLOCK_TCPIP_CORE(); - return res; } @@ -101,7 +111,7 @@ CHIP_ERROR UDPEndPointImplLwIP::BindInterfaceImpl(IPAddressType addrType, Interf // A lock is required because the LwIP thread may be referring to intf_filter, // while this code running in the Inet application is potentially modifying it. // NOTE: this only supports LwIP interfaces whose number is no bigger than 9. - LOCK_TCPIP_CORE(); + ScopedLwIPLock lwipLock; // Make sure we have the appropriate type of PCB. CHIP_ERROR err = GetPCB(addrType); @@ -110,9 +120,6 @@ CHIP_ERROR UDPEndPointImplLwIP::BindInterfaceImpl(IPAddressType addrType, Interf { err = LwIPBindInterface(mUDP, intfId); } - - UNLOCK_TCPIP_CORE(); - return err; } @@ -134,6 +141,8 @@ CHIP_ERROR UDPEndPointImplLwIP::LwIPBindInterface(struct udp_pcb * aUDP, Interfa InterfaceId UDPEndPointImplLwIP::GetBoundInterface() const { + ScopedLwIPLock lwipLock; + #if HAVE_LWIP_UDP_BIND_NETIF return InterfaceId(netif_get_by_index(mUDP->netif_idx)); #else @@ -148,14 +157,9 @@ uint16_t UDPEndPointImplLwIP::GetBoundPort() const CHIP_ERROR UDPEndPointImplLwIP::ListenImpl() { - // Lock LwIP stack - LOCK_TCPIP_CORE(); + ScopedLwIPLock lwipLock; udp_recv(mUDP, LwIPReceiveUDPMessage, this); - - // Unlock LwIP stack - UNLOCK_TCPIP_CORE(); - return CHIP_NO_ERROR; } @@ -174,53 +178,53 @@ CHIP_ERROR UDPEndPointImplLwIP::SendMsgImpl(const IPPacketInfo * pktInfo, System VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_NO_MEMORY); } - // Lock LwIP stack - LOCK_TCPIP_CORE(); + CHIP_ERROR res = CHIP_NO_ERROR; + err_t lwipErr = ERR_VAL; - // Make sure we have the appropriate type of PCB based on the destination address. - CHIP_ERROR res = GetPCB(destAddr.Type()); - if (res != CHIP_NO_ERROR) + // Adding a scope here to unlock the LwIP core when the lock is no longer required. { - UNLOCK_TCPIP_CORE(); - return res; - } + ScopedLwIPLock lwipLock; - // Send the message to the specified address/port. - // If an outbound interface has been specified, call a specific version of the UDP sendto() - // function that accepts the target interface. - // If a source address has been specified, temporarily override the local_ip of the PCB. - // This results in LwIP using the given address being as the source address for the generated - // packet, as if the PCB had been bound to that address. - err_t lwipErr = ERR_VAL; - const IPAddress & srcAddr = pktInfo->SrcAddress; - const uint16_t & destPort = pktInfo->DestPort; - const InterfaceId & intfId = pktInfo->Interface; + // Make sure we have the appropriate type of PCB based on the destination address. + res = GetPCB(destAddr.Type()); + if (res != CHIP_NO_ERROR) + { + return res; + } - ip_addr_t lwipSrcAddr = srcAddr.ToLwIPAddr(); - ip_addr_t lwipDestAddr = destAddr.ToLwIPAddr(); + // Send the message to the specified address/port. + // If an outbound interface has been specified, call a specific version of the UDP sendto() + // function that accepts the target interface. + // If a source address has been specified, temporarily override the local_ip of the PCB. + // This results in LwIP using the given address being as the source address for the generated + // packet, as if the PCB had been bound to that address. + const IPAddress & srcAddr = pktInfo->SrcAddress; + const uint16_t & destPort = pktInfo->DestPort; + const InterfaceId & intfId = pktInfo->Interface; - ip_addr_t boundAddr; - ip_addr_copy(boundAddr, mUDP->local_ip); + ip_addr_t lwipSrcAddr = srcAddr.ToLwIPAddr(); + ip_addr_t lwipDestAddr = destAddr.ToLwIPAddr(); - if (!ip_addr_isany(&lwipSrcAddr)) - { - ip_addr_copy(mUDP->local_ip, lwipSrcAddr); - } + ip_addr_t boundAddr; + ip_addr_copy(boundAddr, mUDP->local_ip); - if (intfId.IsPresent()) - { - lwipErr = udp_sendto_if(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort, - intfId.GetPlatformInterface()); - } - else - { - lwipErr = udp_sendto(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort); - } + if (!ip_addr_isany(&lwipSrcAddr)) + { + ip_addr_copy(mUDP->local_ip, lwipSrcAddr); + } - ip_addr_copy(mUDP->local_ip, boundAddr); + if (intfId.IsPresent()) + { + lwipErr = udp_sendto_if(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort, + intfId.GetPlatformInterface()); + } + else + { + lwipErr = udp_sendto(mUDP, System::LwIPPacketBufferView::UnsafeGetLwIPpbuf(msg), &lwipDestAddr, destPort); + } - // Unlock LwIP stack - UNLOCK_TCPIP_CORE(); + ip_addr_copy(mUDP->local_ip, boundAddr); + } if (lwipErr != ERR_OK) { @@ -232,9 +236,7 @@ CHIP_ERROR UDPEndPointImplLwIP::SendMsgImpl(const IPPacketInfo * pktInfo, System void UDPEndPointImplLwIP::CloseImpl() { - - // Lock LwIP stack - LOCK_TCPIP_CORE(); + ScopedLwIPLock lwipLock; // Since UDP PCB is released synchronously here, but UDP endpoint itself might have to wait // for destruction asynchronously, there could be more allocated UDP endpoints than UDP PCBs. @@ -260,9 +262,6 @@ void UDPEndPointImplLwIP::CloseImpl() } } } - - // Unlock LwIP stack - UNLOCK_TCPIP_CORE(); } void UDPEndPointImplLwIP::Free() @@ -473,19 +472,23 @@ CHIP_ERROR UDPEndPointImplLwIP::IPv4JoinLeaveMulticastGroupImpl(InterfaceId aInt const ip4_addr_t lIPv4Address = aAddress.ToIPv4(); err_t lStatus; - if (aInterfaceId.IsPresent()) { + ScopedLwIPLock lwipLock; - struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId); - VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE); + if (aInterfaceId.IsPresent()) + { - lStatus = join ? igmp_joingroup_netif(lNetif, &lIPv4Address) // - : igmp_leavegroup_netif(lNetif, &lIPv4Address); - } - else - { - lStatus = join ? igmp_joingroup(IP4_ADDR_ANY4, &lIPv4Address) // - : igmp_leavegroup(IP4_ADDR_ANY4, &lIPv4Address); + struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId); + VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE); + + lStatus = join ? igmp_joingroup_netif(lNetif, &lIPv4Address) // + : igmp_leavegroup_netif(lNetif, &lIPv4Address); + } + else + { + lStatus = join ? igmp_joingroup(IP4_ADDR_ANY4, &lIPv4Address) // + : igmp_leavegroup(IP4_ADDR_ANY4, &lIPv4Address); + } } if (lStatus == ERR_MEM) @@ -504,17 +507,22 @@ CHIP_ERROR UDPEndPointImplLwIP::IPv6JoinLeaveMulticastGroupImpl(InterfaceId aInt #ifdef HAVE_IPV6_MULTICAST const ip6_addr_t lIPv6Address = aAddress.ToIPv6(); err_t lStatus; - if (aInterfaceId.IsPresent()) - { - struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId); - VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE); - lStatus = join ? mld6_joingroup_netif(lNetif, &lIPv6Address) // - : mld6_leavegroup_netif(lNetif, &lIPv6Address); - } - else + { - lStatus = join ? mld6_joingroup(IP6_ADDR_ANY6, &lIPv6Address) // - : mld6_leavegroup(IP6_ADDR_ANY6, &lIPv6Address); + ScopedLwIPLock lwipLock; + + if (aInterfaceId.IsPresent()) + { + struct netif * const lNetif = FindNetifFromInterfaceId(aInterfaceId); + VerifyOrReturnError(lNetif != nullptr, INET_ERROR_UNKNOWN_INTERFACE); + lStatus = join ? mld6_joingroup_netif(lNetif, &lIPv6Address) // + : mld6_leavegroup_netif(lNetif, &lIPv6Address); + } + else + { + lStatus = join ? mld6_joingroup(IP6_ADDR_ANY6, &lIPv6Address) // + : mld6_leavegroup(IP6_ADDR_ANY6, &lIPv6Address); + } } if (lStatus == ERR_MEM)