diff --git a/src/backend/libc/net/read_sockaddr.rs b/src/backend/libc/net/read_sockaddr.rs index 08939d4f8..c54c3dc47 100644 --- a/src/backend/libc/net/read_sockaddr.rs +++ b/src/backend/libc/net/read_sockaddr.rs @@ -9,7 +9,10 @@ use crate::backend::c; use crate::ffi::CStr; use crate::io; #[cfg(target_os = "linux")] -use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp}; +use crate::net::{ + netlink::SocketAddrNetlink, + xdp::{SockaddrXdpFlags, SocketAddrXdp}, +}; use crate::net::{Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrV4, SocketAddrV6}; use core::mem::size_of; @@ -208,6 +211,17 @@ pub(crate) unsafe fn read_sockaddr( u32::from_be(decode.sxdp_shared_umem_fd), ))) } + #[cfg(target_os = "linux")] + c::AF_NETLINK => { + if len < size_of::() { + return Err(io::Errno::INVAL); + } + let decode = &*storage.cast::(); + Ok(SocketAddrAny::Netlink(SocketAddrNetlink::new( + decode.nl_pid, + decode.nl_groups, + ))) + } _ => Err(io::Errno::INVAL), } } @@ -327,6 +341,12 @@ unsafe fn inner_read_sockaddr_os( u32::from_be(decode.sxdp_shared_umem_fd), )) } + #[cfg(target_os = "linux")] + c::AF_NETLINK => { + assert!(len >= size_of::()); + let decode = &*storage.cast::(); + SocketAddrAny::Netlink(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups)) + } other => unimplemented!("{:?}", other), } } diff --git a/src/backend/linux_raw/net/read_sockaddr.rs b/src/backend/linux_raw/net/read_sockaddr.rs index 23e1d641d..ab556e5d6 100644 --- a/src/backend/linux_raw/net/read_sockaddr.rs +++ b/src/backend/linux_raw/net/read_sockaddr.rs @@ -5,7 +5,10 @@ use crate::backend::c; use crate::io; #[cfg(target_os = "linux")] -use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp}; +use crate::net::{ + netlink::SocketAddrNetlink, + xdp::{SockaddrXdpFlags, SocketAddrXdp}, +}; use crate::net::{Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6}; use core::mem::size_of; use core::slice; @@ -127,6 +130,17 @@ pub(crate) unsafe fn read_sockaddr( u32::from_be(decode.sxdp_shared_umem_fd), ))) } + #[cfg(target_os = "linux")] + c::AF_NETLINK => { + if len < size_of::() { + return Err(io::Errno::INVAL); + } + let decode = &*storage.cast::(); + Ok(SocketAddrAny::Netlink(SocketAddrNetlink::new( + decode.nl_pid, + decode.nl_groups, + ))) + } _ => Err(io::Errno::NOTSUP), } } @@ -216,6 +230,12 @@ pub(crate) unsafe fn read_sockaddr_os(storage: *const c::sockaddr, len: usize) - u32::from_be(decode.sxdp_shared_umem_fd), )) } + #[cfg(target_os = "linux")] + c::AF_NETLINK => { + assert!(len >= size_of::()); + let decode = &*storage.cast::(); + SocketAddrAny::Netlink(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups)) + } other => unimplemented!("{:?}", other), } } diff --git a/src/net/socket_addr_any.rs b/src/net/socket_addr_any.rs index fd72a7b3c..cb9ba539e 100644 --- a/src/net/socket_addr_any.rs +++ b/src/net/socket_addr_any.rs @@ -10,10 +10,10 @@ #![allow(unsafe_code)] use crate::backend::c; -#[cfg(target_os = "linux")] -use crate::net::xdp::SocketAddrXdp; #[cfg(unix)] use crate::net::SocketAddrUnix; +#[cfg(target_os = "linux")] +use crate::net::{netlink::SocketAddrNetlink, xdp::SocketAddrXdp}; use crate::net::{AddressFamily, SocketAddr, SocketAddrV4, SocketAddrV6}; use crate::{backend, io}; #[cfg(feature = "std")] @@ -39,6 +39,9 @@ pub enum SocketAddrAny { /// `struct sockaddr_xdp` #[cfg(target_os = "linux")] Xdp(SocketAddrXdp), + /// `struct sockaddr_nl` + #[cfg(target_os = "linux")] + Netlink(SocketAddrNetlink), } impl From for SocketAddrAny { @@ -84,6 +87,8 @@ impl SocketAddrAny { Self::Unix(_) => AddressFamily::UNIX, #[cfg(target_os = "linux")] Self::Xdp(_) => AddressFamily::XDP, + #[cfg(target_os = "linux")] + Self::Netlink(_) => AddressFamily::NETLINK, } } @@ -103,6 +108,8 @@ impl SocketAddrAny { SocketAddrAny::Unix(a) => a.write_sockaddr(storage), #[cfg(target_os = "linux")] SocketAddrAny::Xdp(a) => a.write_sockaddr(storage), + #[cfg(target_os = "linux")] + SocketAddrAny::Netlink(a) => a.write_sockaddr(storage), } } @@ -128,6 +135,8 @@ impl fmt::Debug for SocketAddrAny { Self::Unix(unix) => unix.fmt(fmt), #[cfg(target_os = "linux")] Self::Xdp(xdp) => xdp.fmt(fmt), + #[cfg(target_os = "linux")] + Self::Netlink(nl) => nl.fmt(fmt), } } } @@ -158,6 +167,8 @@ unsafe impl SocketAddress for SocketAddrAny { Self::Unix(a) => a.with_sockaddr(f), #[cfg(target_os = "linux")] Self::Xdp(a) => a.with_sockaddr(f), + #[cfg(target_os = "linux")] + Self::Netlink(a) => a.with_sockaddr(f), } } } diff --git a/src/net/types.rs b/src/net/types.rs index 2373cdc77..fd9aeefdf 100644 --- a/src/net/types.rs +++ b/src/net/types.rs @@ -915,6 +915,8 @@ pub mod netlink { use { super::{new_raw_protocol, Protocol}, crate::backend::c, + crate::net::SocketAddress, + core::mem, }; /// `NETLINK_UNUSED` @@ -1024,6 +1026,68 @@ pub mod netlink { /// `NETLINK_GET_STRICT_CHK` #[cfg(linux_kernel)] pub const GET_STRICT_CHK: Protocol = Protocol(new_raw_protocol(c::NETLINK_GET_STRICT_CHK as _)); + + /// A Netlink socket address. + /// + /// Used to bind to a Netlink socket. + /// + /// Not ABI compatible with `struct sockaddr_nl` + #[derive(Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)] + #[cfg(linux_kernel)] + pub struct SocketAddrNetlink { + /// Port ID + pid: u32, + + /// Multicast groups mask + groups: u32, + } + + #[cfg(linux_kernel)] + impl SocketAddrNetlink { + /// Construct a netlink address + #[inline] + pub fn new(pid: u32, groups: u32) -> Self { + Self { pid, groups } + } + + /// Return port id. + #[inline] + pub fn pid(&self) -> u32 { + self.pid + } + + /// Set port id. + #[inline] + pub fn set_pid(&mut self, pid: u32) { + self.pid = pid; + } + + /// Return multicast groups mask. + #[inline] + pub fn groups(&self) -> u32 { + self.groups + } + + /// Set multicast groups mask. + #[inline] + pub fn set_groups(&mut self, groups: u32) { + self.groups = groups; + } + } + + #[cfg(linux_kernel)] + #[allow(unsafe_code)] + unsafe impl SocketAddress for SocketAddrNetlink { + type CSockAddr = c::sockaddr_nl; + + fn encode(&self) -> Self::CSockAddr { + let mut addr: c::sockaddr_nl = unsafe { mem::zeroed() }; + addr.nl_family = c::AF_NETLINK as _; + addr.nl_pid = self.pid; + addr.nl_groups = self.groups; + addr + } + } } /// `ETH_P_*` constants.