diff options
Diffstat (limited to 'src/model.rs')
-rw-r--r-- | src/model.rs | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..a7675f2 --- /dev/null +++ b/src/model.rs @@ -0,0 +1,197 @@ +// Copyright 2019 Hristo Venev +// +// See COPYING. + +use base64; +use crate::bin; +use crate::ip::{Ipv4Addr, Ipv6Addr, Ipv4Net, Ipv6Net, NetParseError}; +use std::{fmt}; +use std::collections::{HashMap}; +use std::str::FromStr; + +pub type KeyParseError = base64::DecodeError; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct Key([u8; 32]); + +impl Key { + pub fn from_bytes(s: &[u8]) -> Result<Key, KeyParseError> { + let mut v = Key([0u8; 32]); + let l = base64::decode_config_slice(s, base64::STANDARD, &mut v.0)?; + if l != v.0.len() { + return Err(base64::DecodeError::InvalidLength); + } + Ok(v) + } +} + +impl fmt::Display for Key { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", base64::display::Base64Display::with_config(&self.0, base64::STANDARD)) + } +} + +impl FromStr for Key { + type Err = KeyParseError; + #[inline] + fn from_str(s: &str) -> Result<Key, base64::DecodeError> { + Key::from_bytes(s.as_bytes()) + } +} + +impl serde::Serialize for Key { + fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> { + if ser.is_human_readable() { + ser.collect_str(self) + } else { + ser.serialize_bytes(&self.0) + } + } +} + +impl<'de> serde::Deserialize<'de> for Key { + fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> { + if de.is_human_readable() { + struct KeyVisitor; + impl<'de> serde::de::Visitor<'de> for KeyVisitor { + type Value = Key; + + #[inline] + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("WireGuard key") + } + + #[inline] + fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> { + s.parse().map_err(E::custom) + } + } + de.deserialize_str(KeyVisitor) + } else { + serde::Deserialize::deserialize(de).map(Key) + } + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct Endpoint { + address: Ipv6Addr, + port: u16, +} + +impl Endpoint { + #[inline] + pub fn ipv6_address(&self) -> Ipv6Addr { + self.address + } + + #[inline] + pub fn ipv4_address(&self) -> Option<Ipv4Addr> { + let seg = self.address.octets(); + let (first, second) = array_refs![&seg, 12, 4]; + if *first == [0,0,0,0,0,0,0,0,0,0,0xff,0xff] { + Some(Ipv4Addr::from(*second)) + } else { + None + } + } + + #[inline] + pub fn port(&self) -> u16 { + self.port + } +} + +impl fmt::Display for Endpoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(ipv4) = self.ipv4_address() { + write!(f, "{}:", ipv4)?; + } else { + write!(f, "[{}]:", self.ipv6_address())?; + } + write!(f, "{}", self.port()) + } +} + +impl FromStr for Endpoint { + type Err = NetParseError; + fn from_str(s: &str) -> Result<Endpoint, NetParseError> { + use std::net; + net::SocketAddr::from_str(s) + .map_err(|_| NetParseError) + .map(|v| Endpoint { + address: match v.ip() { + net::IpAddr::V4(a) => a.to_ipv6_mapped(), + net::IpAddr::V6(a) => a, + }, + port: v.port(), + }) + } +} + +impl serde::Serialize for Endpoint { + fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> { + if ser.is_human_readable() { + ser.collect_str(self) + } else { + let mut buf = [0u8; 16 + 2]; + let (buf_addr, buf_port) = mut_array_refs![&mut buf, 16, 2]; + *buf_addr = self.address.octets(); + *buf_port = crate::bin::u16_to_be(self.port); + ser.serialize_bytes(&buf) + } + } +} + +impl<'de> serde::Deserialize<'de> for Endpoint { + fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> { + if de.is_human_readable() { + struct EndpointVisitor; + impl<'de> serde::de::Visitor<'de> for EndpointVisitor { + type Value = Endpoint; + + #[inline] + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("IP:port") + } + + #[inline] + fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> { + s.parse().map_err(E::custom) + } + } + de.deserialize_str(EndpointVisitor) + } else { + let buf = <[u8; 16 + 2] as serde::Deserialize>::deserialize(de)?; + let (buf_addr, buf_port) = array_refs![&buf, 16, 2]; + Ok(Endpoint { + address: (*buf_addr).into(), + port: bin::u16_from_be(*buf_port), + }) + } + } +} + +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] +pub struct Peer { + pub endpoint: Option<Endpoint>, + pub psk: Option<Key>, + pub keepalive: u32, + pub ipv4: Vec<Ipv4Net>, + pub ipv6: Vec<Ipv6Net>, +} + +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] +pub struct Config { + pub peers: HashMap<Key, Peer>, +} + +impl Default for Config { + #[inline] + fn default() -> Config { + Config { + peers: HashMap::new(), + } + } +} |