Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(transport): Use socket2 to obtain orig_dst #3626

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 14 additions & 142 deletions linkerd/proxy/transport/src/orig_dst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use futures::prelude::*;
use linkerd_error::Result;
use linkerd_io as io;
use linkerd_stack::Param;
use std::{net::SocketAddr, pin::Pin};
use std::pin::Pin;
use tokio::net::TcpStream;

#[derive(Copy, Clone, Debug, Default)]
Expand Down Expand Up @@ -83,14 +83,7 @@ where

let incoming = incoming.map(|res| {
let (inner, tcp) = res?;
let orig_dst = match inner.param() {
// IPv4-mapped IPv6 addresses are unwrapped by BindTcp::bind() and received here as
// SocketAddr::V4. We must call getsockopt with IPv4 constants (via
// orig_dst_addr_v4) even if it originally was an IPv6
Remote(ClientAddr(SocketAddr::V4(_))) => orig_dst_addr_v4(&tcp)?,
Remote(ClientAddr(SocketAddr::V6(_))) => orig_dst_addr_v6(&tcp)?,
};
let orig_dst = OrigDstAddr(orig_dst);
let (orig_dst, tcp) = orig_dst(tcp)?;
let addrs = Addrs { inner, orig_dst };
Ok((addrs, tcp))
});
Expand All @@ -99,139 +92,18 @@ where
}
}

#[cfg(target_os = "linux")]
#[allow(unsafe_code)]
fn orig_dst_addr_v4(sock: &TcpStream) -> io::Result<SocketAddr> {
use std::os::unix::io::AsRawFd;
fn orig_dst(sock: TcpStream) -> io::Result<(OrigDstAddr, TcpStream)> {
let sock = {
let stream = tokio::net::TcpStream::into_std(sock)?;
socket2::Socket::from(stream)
};

let fd = sock.as_raw_fd();
unsafe { linux::so_original_dst_v4(fd) }
}

#[cfg(target_os = "linux")]
#[allow(unsafe_code)]
fn orig_dst_addr_v6(sock: &TcpStream) -> io::Result<SocketAddr> {
use std::os::unix::io::AsRawFd;

let fd = sock.as_raw_fd();
unsafe { linux::so_original_dst_v6(fd) }
}

#[cfg(not(target_os = "linux"))]
fn orig_dst_addr_v4(_: &TcpStream) -> io::Result<SocketAddr> {
Err(io::Error::new(
io::ErrorKind::Other,
"SO_ORIGINAL_DST not supported on this operating system",
))
}

#[cfg(not(target_os = "linux"))]
fn orig_dst_addr_v6(_: &TcpStream) -> io::Result<SocketAddr> {
Err(io::Error::new(
io::ErrorKind::Other,
"SO_ORIGINAL_DST not supported on this operating system",
))
}

#[cfg(target_os = "linux")]
#[allow(unsafe_code)]
mod linux {
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::unix::io::RawFd;
use std::{io, mem};

pub unsafe fn so_original_dst(fd: RawFd, level: i32, optname: i32) -> io::Result<SocketAddr> {
let mut sockaddr: libc::sockaddr_storage = mem::zeroed();
let mut sockaddr_len: libc::socklen_t = mem::size_of::<libc::sockaddr_storage>() as u32;

let ret = libc::getsockopt(
fd,
level,
optname,
&mut sockaddr as *mut _ as *mut _,
&mut sockaddr_len as *mut _ as *mut _,
);
if ret != 0 {
return Err(io::Error::last_os_error());
}

mk_addr(&sockaddr, sockaddr_len)
}
let orig_dst = sock.original_dst()?.as_socket().ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid address format",
))?;

pub unsafe fn so_original_dst_v4(fd: RawFd) -> io::Result<SocketAddr> {
so_original_dst(fd, libc::SOL_IP, libc::SO_ORIGINAL_DST)
}

pub unsafe fn so_original_dst_v6(fd: RawFd) -> io::Result<SocketAddr> {
so_original_dst(fd, libc::SOL_IPV6, libc::IP6T_SO_ORIGINAL_DST)
}

// Borrowed with love from net2-rs
// https://github.com/rust-lang-nursery/net2-rs/blob/1b4cb4fb05fbad750b271f38221eab583b666e5e/src/socket.rs#L103
//
// Copyright (c) 2014 The Rust Project Developers
fn mk_addr(storage: &libc::sockaddr_storage, len: libc::socklen_t) -> io::Result<SocketAddr> {
match storage.ss_family as libc::c_int {
libc::AF_INET => {
assert!(len as usize >= mem::size_of::<libc::sockaddr_in>());

let sa = {
let sa = storage as *const _ as *const libc::sockaddr_in;
unsafe { *sa }
};

let bits = ntoh32(sa.sin_addr.s_addr);
let ip = Ipv4Addr::new(
(bits >> 24) as u8,
(bits >> 16) as u8,
(bits >> 8) as u8,
bits as u8,
);
let port = sa.sin_port;
Ok(SocketAddr::V4(SocketAddrV4::new(ip, ntoh16(port))))
}
libc::AF_INET6 => {
assert!(len as usize >= mem::size_of::<libc::sockaddr_in6>());

let sa = {
let sa = storage as *const _ as *const libc::sockaddr_in6;
unsafe { *sa }
};

let arr = sa.sin6_addr.s6_addr;
let ip = Ipv6Addr::new(
(arr[0] as u16) << 8 | (arr[1] as u16),
(arr[2] as u16) << 8 | (arr[3] as u16),
(arr[4] as u16) << 8 | (arr[5] as u16),
(arr[6] as u16) << 8 | (arr[7] as u16),
(arr[8] as u16) << 8 | (arr[9] as u16),
(arr[10] as u16) << 8 | (arr[11] as u16),
(arr[12] as u16) << 8 | (arr[13] as u16),
(arr[14] as u16) << 8 | (arr[15] as u16),
);

let port = sa.sin6_port;
let flowinfo = sa.sin6_flowinfo;
let scope_id = sa.sin6_scope_id;
Ok(SocketAddr::V6(SocketAddrV6::new(
ip,
ntoh16(port),
flowinfo,
scope_id,
)))
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid argument",
)),
}
}

fn ntoh16(i: u16) -> u16 {
<u16>::from_be(i)
}

fn ntoh32(i: u32) -> u32 {
<u32>::from_be(i)
}
let stream: std::net::TcpStream = socket2::Socket::into(sock);
let stream = tokio::net::TcpStream::from_std(stream)?;
Ok((OrigDstAddr(orig_dst), stream))
}
Loading