diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/bin.rs | 7 | ||||
| -rw-r--r-- | src/builder.rs | 172 | ||||
| -rw-r--r-- | src/config.rs | 6 | ||||
| -rw-r--r-- | src/ip.rs | 99 | ||||
| -rw-r--r-- | src/main.rs | 24 | ||||
| -rw-r--r-- | src/model.rs | 197 | ||||
| -rw-r--r-- | src/proto.rs | 34 | ||||
| -rw-r--r-- | src/wg.rs | 243 | 
8 files changed, 449 insertions, 333 deletions
| @@ -7,10 +7,12 @@ pub fn i64_to_be(v: i64) -> [u8; 8] {      u64_to_be(v as u64)  } +#[inline]  pub fn i64_from_be(v: [u8; 8]) -> i64 {      u64_from_be(v) as i64  } +#[inline]  pub fn u64_to_be(v: u64) -> [u8; 8] {      [          (v >> 56) as u8, @@ -24,6 +26,7 @@ pub fn u64_to_be(v: u64) -> [u8; 8] {      ]  } +#[inline]  pub fn u64_from_be(v: [u8; 8]) -> u64 {      (u64::from(v[0]) << 56)          | (u64::from(v[1]) << 48) @@ -35,18 +38,22 @@ pub fn u64_from_be(v: [u8; 8]) -> u64 {          | u64::from(v[7])  } +#[inline]  pub fn u32_to_be(v: u32) -> [u8; 4] {      [(v >> 24) as u8, (v >> 16) as u8, (v >> 8) as u8, v as u8]  } +#[inline]  pub fn u32_from_be(v: [u8; 4]) -> u32 {      (u32::from(v[0]) << 24) | (u32::from(v[1]) << 16) | (u32::from(v[2]) << 8) | u32::from(v[3])  } +#[inline]  pub fn u16_to_be(v: u16) -> [u8; 2] {      [(v >> 8) as u8, v as u8]  } +#[inline]  pub fn u16_from_be(v: [u8; 2]) -> u16 {      (u16::from(v[0]) << 8) | u16::from(v[1])  } diff --git a/src/builder.rs b/src/builder.rs new file mode 100644 index 0000000..cc42fa8 --- /dev/null +++ b/src/builder.rs @@ -0,0 +1,172 @@ +// Copyright 2019 Hristo Venev +// +// See COPYING. + +use crate::{config, model, proto}; +use std::{error, fmt}; +use std::collections::hash_map; + + +#[derive(Debug)] +pub struct ConfigError { +    pub url: String, +    pub peer: model::Key, +    pub important: bool, +    err: &'static str, +} + +impl ConfigError { +    fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self { +        ConfigError { +            url: s.url.clone(), +            peer: p.public_key.clone(), +            important, +            err, +        } +    } +} + +impl error::Error for ConfigError {} +impl fmt::Display for ConfigError { +    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +        write!( +            f, +            "{} [{}] from [{}]: {}", +            if self.important { +                "Invalid peer" +            } else { +                "Misconfigured peer" +            }, +            self.peer, +            self.url, +            self.err +        ) +    } +} + +pub struct ConfigBuilder<'a> { +    c: model::Config, +    err: Vec<ConfigError>, +    public_key: model::Key, +    pc: &'a config::PeerConfig, +} + +impl<'a> ConfigBuilder<'a> { +    #[inline] +    pub fn new(public_key: model::Key, pc: &'a config::PeerConfig) -> Self { +        ConfigBuilder { +            c: model::Config::default(), +            err: vec![], +            public_key, +            pc, +        } +    } + +    #[inline] +    pub fn build(self) -> (model::Config, Vec<ConfigError>) { +        (self.c, self.err) +    } + +    #[inline] +    pub fn add_server( +        &mut self, +        s: &config::Source, +        p: &proto::Server, +    ) { +        if p.peer.public_key == self.public_key { +            return; +        } + +        let pc = self.pc; +        let ent = insert_peer(&mut self.c, &mut self.err, s, &p.peer, |ent| { +            ent.psk = s.psk.clone(); +            ent.endpoint = Some(p.endpoint.clone()); +            ent.keepalive = pc.fix_keepalive(p.keepalive); +        }); + +        add_peer(&mut self.err, ent, s, &p.peer) +    } + +    #[inline] +    pub fn add_road_warrior( +        &mut self, +        s: &config::Source, +        p: &proto::RoadWarrior, +    ) { +        if p.peer.public_key == self.public_key { +            self.err.push(ConfigError::new("The local peer cannot be a road warrior", s, &p.peer, true)); +            return; +        } + +        let ent = if p.base == self.public_key { +            insert_peer(&mut self.c, &mut self.err, s, &p.peer, |_| {}) +        } else { +            match self.c.peers.get_mut(&p.base) { +                Some(ent) => ent, +                None => { +                    self.err.push(ConfigError::new("Unknown base peer", s, &p.peer, true)); +                    return; +                } +            } +        }; +        add_peer(&mut self.err, ent, s, &p.peer) +    } +} + +#[inline] +fn insert_peer<'b>( +    c: &'b mut model::Config, +    err: &mut Vec<ConfigError>, +    s: &config::Source, +    p: &proto::Peer, +    update: impl for<'c> FnOnce(&'c mut model::Peer) -> (), +) -> &'b mut model::Peer { +    match c.peers.entry(p.public_key.clone()) { +        hash_map::Entry::Occupied(ent) => { +            err.push(ConfigError::new("Duplicate public key", s, p, true)); +            ent.into_mut() +        } +        hash_map::Entry::Vacant(ent) => { +            let ent = ent.insert(model::Peer { +                endpoint: None, +                psk: None, +                keepalive: 0, +                ipv4: vec![], +                ipv6: vec![], +            }); +            update(ent); +            ent +        } +    } +} + +fn add_peer(err: &mut Vec<ConfigError>, ent: &mut model::Peer, s: &config::Source, p: &proto::Peer) { +    let mut added = false; +    let mut removed = false; + +    for i in p.ipv4.iter() { +        if s.ipv4.contains(i) { +            ent.ipv4.push(*i); +            added = true; +        } else { +            removed = true; +        } +    } +    for i in p.ipv6.iter() { +        if s.ipv6.contains(i) { +            ent.ipv6.push(*i); +            added = true; +        } else { +            removed = true; +        } +    } + +    if removed { +        let msg = if added { +            "Some IPs removed" +        } else { +            "All IPs removed" +        }; +        err.push(ConfigError::new(msg, s, p, !added)); +    } +} diff --git a/src/config.rs b/src/config.rs index 2effacb..98db02d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,13 +3,14 @@  // See COPYING.  use crate::ip::{Ipv4Set, Ipv6Set}; +use crate::model::{Key};  use serde_derive;  #[serde(deny_unknown_fields)]  #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]  pub struct Source {      pub url: String, -    pub psk: Option<String>, +    pub psk: Option<Key>,      pub ipv4: Ipv4Set,      pub ipv6: Ipv6Set,  } @@ -56,14 +57,17 @@ impl PeerConfig {      }  } +#[inline]  fn default_min_keepalive() -> u32 {      10  } +#[inline]  fn default_max_keepalive() -> u32 {      0  } +#[inline]  fn default_refresh_sec() -> u32 {      1200  } @@ -2,20 +2,20 @@  //  // See COPYING. -use crate::bin;  use serde;  use std::iter::{FromIterator, IntoIterator}; -use std::net::{Ipv4Addr, Ipv6Addr}; +pub use std::net::{Ipv4Addr, Ipv6Addr};  use std::str::FromStr; -use std::{error, fmt, iter, net}; +use std::{error, fmt, iter};  #[derive(Debug)] -pub struct NetParseError {} +pub struct NetParseError;  impl error::Error for NetParseError {}  impl fmt::Display for NetParseError { +    #[inline]      fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { -        write!(f, "Invalid IP network") +        write!(f, "Invalid address")      }  } @@ -49,6 +49,7 @@ macro_rules! per_proto {          }          impl fmt::Display for $nett { +            #[inline]              fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {                  write!(f, "{}/{}", self.address, self.prefix_len)              } @@ -58,14 +59,14 @@ macro_rules! per_proto {              type Err = NetParseError;              fn from_str(s: &str) -> Result<$nett, NetParseError> {                  let (addr, pfx) = pfx_split(s)?; -                let addr = $addrt::from_str(addr).map_err(|_| NetParseError {})?; +                let addr = $addrt::from_str(addr).map_err(|_| NetParseError)?;                  let r = $nett {                      address: addr,                      prefix_len: pfx,                  };                  if !r.is_valid() { -                    return Err(NetParseError {}); +                    return Err(NetParseError);                  }                  Ok(r)              } @@ -74,7 +75,7 @@ macro_rules! per_proto {          impl serde::Serialize for $nett {              fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {                  if ser.is_human_readable() { -                    ser.serialize_str(&format!("{}", self)) +                    ser.collect_str(self)                  } else {                      let mut buf = [0u8; $bytes + 1];                      *array_mut_ref![&mut buf, 0, $bytes] = self.address.octets(); @@ -91,10 +92,12 @@ macro_rules! per_proto {                      impl<'de> serde::de::Visitor<'de> for NetVisitor {                          type Value = $nett; +                        #[inline]                          fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {                              f.write_str($expecting)                          } +                        #[inline]                          fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {                              s.parse().map_err(E::custom)                          } @@ -107,7 +110,7 @@ macro_rules! per_proto {                          prefix_len: buf[$bytes],                      };                      if r.is_valid() { -                        return Err(serde::de::Error::custom(NetParseError {})); +                        return Err(serde::de::Error::custom(NetParseError));                      }                      Ok(r)                  } @@ -201,6 +204,7 @@ macro_rules! per_proto {          }          impl FromIterator<$nett> for $sett { +            #[inline]              fn from_iter<I: IntoIterator<Item = $nett>>(it: I) -> $sett {                  let mut r = $sett::new();                  for net in it { @@ -273,12 +277,14 @@ macro_rules! per_proto {          }          impl serde::Serialize for $sett { +            #[inline]              fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {                  <Vec<$nett> as serde::Serialize>::serialize(&self.nets, ser)              }          }          impl<'de> serde::Deserialize<'de> for $sett { +            #[inline]              fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {                  <Vec<$nett> as serde::Deserialize>::deserialize(de).map($sett::from)              } @@ -329,81 +335,10 @@ fn pfx_split(s: &str) -> Result<(&str, u8), NetParseError> {      let i = match s.find('/') {          Some(i) => i,          None => { -            return Err(NetParseError {}); +            return Err(NetParseError);          }      };      let (addr, pfx) = s.split_at(i); -    let pfx = u8::from_str(&pfx[1..]).map_err(|_| NetParseError {})?; +    let pfx = u8::from_str(&pfx[1..]).map_err(|_| NetParseError)?;      Ok((addr, pfx))  } - -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] -pub struct Endpoint { -    pub address: Ipv6Addr, -    pub port: u16, -} - -impl fmt::Display for Endpoint { -    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { -        if self.address.segments()[5] == 0xffff { -            write!(f, "{}:", self.address.to_ipv4().unwrap())?; -        } else { -            write!(f, "[{}]:", self.address)?; -        } -        write!(f, "{}", self.port) -    } -} - -impl FromStr for Endpoint { -    type Err = net::AddrParseError; -    fn from_str(s: &str) -> Result<Endpoint, net::AddrParseError> { -        net::SocketAddr::from_str(s).map(|v| Endpoint { -            address: match v.ip() { -                net::IpAddr::V4(a) => a.to_ipv6_mapped(), -                net::IpAddr::V6(a) => a, -            }, -            port: v.port(), -        }) -    } -} - -impl serde::Serialize for Endpoint { -    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> { -        if ser.is_human_readable() { -            ser.serialize_str(&format!("{}", self)) -        } else { -            let mut buf = [0u8; 16 + 2]; -            let (buf_addr, buf_port) = mut_array_refs![&mut buf, 16, 2]; -            *buf_addr = self.address.octets(); -            *buf_port = crate::bin::u16_to_be(self.port); -            ser.serialize_bytes(&buf) -        } -    } -} - -impl<'de> serde::Deserialize<'de> for Endpoint { -    fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> { -        if de.is_human_readable() { -            struct EndpointVisitor; -            impl<'de> serde::de::Visitor<'de> for EndpointVisitor { -                type Value = Endpoint; - -                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { -                    f.write_str("ip:port") -                } - -                fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> { -                    s.parse().map_err(E::custom) -                } -            } -            de.deserialize_str(EndpointVisitor) -        } else { -            let buf = <[u8; 16 + 2] as serde::Deserialize>::deserialize(de)?; -            let (buf_addr, buf_port) = array_refs![&buf, 16, 2]; -            Ok(Endpoint { -                address: (*buf_addr).into(), -                port: bin::u16_from_be(*buf_port), -            }) -        } -    } -} diff --git a/src/main.rs b/src/main.rs index 088d68e..c56cb91 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,8 +9,10 @@ use std::io;  use std::time::{Duration, Instant, SystemTime};  mod bin; +mod builder;  mod config;  mod ip; +mod model;  mod proto;  mod wg; @@ -22,6 +24,7 @@ struct Source {  }  impl Source { +    #[inline]      fn new(config: config::Source) -> Source {          Source {              config, @@ -37,7 +40,7 @@ pub struct Device {      peer_config: config::PeerConfig,      update_config: config::UpdateConfig,      sources: Vec<Source>, -    current: wg::Config, +    current: model::Config,  }  impl Device { @@ -49,15 +52,15 @@ impl Device {              peer_config: c.peer_config,              update_config: c.update_config,              sources: c.sources.into_iter().map(Source::new).collect(), -            current: wg::Config::default(), +            current: model::Config::default(),          })      }      fn make_config(          &self, -        public_key: &str, +        public_key: model::Key,          ts: SystemTime, -    ) -> (wg::Config, Vec<wg::ConfigError>, SystemTime) { +    ) -> (model::Config, Vec<builder::ConfigError>, SystemTime) {          let mut t_cfg = ts + Duration::from_secs(1 << 30);          let mut sources: Vec<(&Source, &proto::SourceConfig)> = vec![];          for src in self.sources.iter() { @@ -78,20 +81,21 @@ impl Device {              }          } -        let mut cfg = wg::ConfigBuilder::new(public_key, &self.peer_config); -        let mut errs = vec![]; +        let mut cfg = builder::ConfigBuilder::new(public_key, &self.peer_config); +          for (src, sc) in sources.iter() {              for peer in sc.servers.iter() { -                cfg.add_server(&mut errs, &src.config, peer); +                cfg.add_server(&src.config, peer);              }          } +          for (src, sc) in sources.iter() {              for peer in sc.road_warriors.iter() { -                cfg.add_road_warrior(&mut errs, &src.config, peer); +                cfg.add_road_warrior(&src.config, peer);              }          } -        let cfg = cfg.build(); +        let (cfg, errs) = cfg.build();          (cfg, errs, t_cfg)      } @@ -135,7 +139,7 @@ impl Device {          let now = Instant::now();          let sysnow = SystemTime::now();          let public_key = self.dev.get_public_key()?; -        let (config, errors, t_cfg) = self.make_config(&public_key, sysnow); +        let (config, errors, t_cfg) = self.make_config(public_key, sysnow);          let time_to_cfg = t_cfg              .duration_since(sysnow)              .unwrap_or(Duration::from_secs(0)); diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..a7675f2 --- /dev/null +++ b/src/model.rs @@ -0,0 +1,197 @@ +// Copyright 2019 Hristo Venev +// +// See COPYING. + +use base64; +use crate::bin; +use crate::ip::{Ipv4Addr, Ipv6Addr, Ipv4Net, Ipv6Net, NetParseError}; +use std::{fmt}; +use std::collections::{HashMap}; +use std::str::FromStr; + +pub type KeyParseError = base64::DecodeError; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct Key([u8; 32]); + +impl Key { +    pub fn from_bytes(s: &[u8]) -> Result<Key, KeyParseError> { +        let mut v = Key([0u8; 32]); +        let l = base64::decode_config_slice(s, base64::STANDARD, &mut v.0)?; +        if l != v.0.len() { +            return Err(base64::DecodeError::InvalidLength); +        } +        Ok(v) +    } +} + +impl fmt::Display for Key { +    #[inline] +    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +        write!(f, "{}", base64::display::Base64Display::with_config(&self.0, base64::STANDARD)) +    } +} + +impl FromStr for Key { +    type Err = KeyParseError; +    #[inline] +    fn from_str(s: &str) -> Result<Key, base64::DecodeError> { +        Key::from_bytes(s.as_bytes()) +    } +} + +impl serde::Serialize for Key { +    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> { +        if ser.is_human_readable() { +            ser.collect_str(self) +        } else { +            ser.serialize_bytes(&self.0) +        } +    } +} + +impl<'de> serde::Deserialize<'de> for Key { +    fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> { +        if de.is_human_readable() { +            struct KeyVisitor; +            impl<'de> serde::de::Visitor<'de> for KeyVisitor { +                type Value = Key; + +                #[inline] +                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { +                    f.write_str("WireGuard key") +                } + +                #[inline] +                fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> { +                    s.parse().map_err(E::custom) +                } +            } +            de.deserialize_str(KeyVisitor) +        } else { +            serde::Deserialize::deserialize(de).map(Key) +        } +    } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct Endpoint { +    address: Ipv6Addr, +    port: u16, +} + +impl Endpoint { +    #[inline] +    pub fn ipv6_address(&self) -> Ipv6Addr { +        self.address +    } + +    #[inline] +    pub fn ipv4_address(&self) -> Option<Ipv4Addr> { +        let seg = self.address.octets(); +        let (first, second) = array_refs![&seg, 12, 4]; +        if *first == [0,0,0,0,0,0,0,0,0,0,0xff,0xff] { +            Some(Ipv4Addr::from(*second)) +        } else { +            None +        } +    } + +    #[inline] +    pub fn port(&self) -> u16 { +        self.port +    } +} + +impl fmt::Display for Endpoint { +    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +        if let Some(ipv4) = self.ipv4_address() { +            write!(f, "{}:", ipv4)?; +        } else { +            write!(f, "[{}]:", self.ipv6_address())?; +        } +        write!(f, "{}", self.port()) +    } +} + +impl FromStr for Endpoint { +    type Err = NetParseError; +    fn from_str(s: &str) -> Result<Endpoint, NetParseError> { +        use std::net; +        net::SocketAddr::from_str(s) +            .map_err(|_| NetParseError) +            .map(|v| Endpoint { +                address: match v.ip() { +                    net::IpAddr::V4(a) => a.to_ipv6_mapped(), +                    net::IpAddr::V6(a) => a, +                }, +                port: v.port(), +            }) +    } +} + +impl serde::Serialize for Endpoint { +    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> { +        if ser.is_human_readable() { +            ser.collect_str(self) +        } else { +            let mut buf = [0u8; 16 + 2]; +            let (buf_addr, buf_port) = mut_array_refs![&mut buf, 16, 2]; +            *buf_addr = self.address.octets(); +            *buf_port = crate::bin::u16_to_be(self.port); +            ser.serialize_bytes(&buf) +        } +    } +} + +impl<'de> serde::Deserialize<'de> for Endpoint { +    fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> { +        if de.is_human_readable() { +            struct EndpointVisitor; +            impl<'de> serde::de::Visitor<'de> for EndpointVisitor { +                type Value = Endpoint; + +                #[inline] +                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { +                    f.write_str("IP:port") +                } + +                #[inline] +                fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> { +                    s.parse().map_err(E::custom) +                } +            } +            de.deserialize_str(EndpointVisitor) +        } else { +            let buf = <[u8; 16 + 2] as serde::Deserialize>::deserialize(de)?; +            let (buf_addr, buf_port) = array_refs![&buf, 16, 2]; +            Ok(Endpoint { +                address: (*buf_addr).into(), +                port: bin::u16_from_be(*buf_port), +            }) +        } +    } +} + +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] +pub struct Peer { +    pub endpoint: Option<Endpoint>, +    pub psk: Option<Key>, +    pub keepalive: u32, +    pub ipv4: Vec<Ipv4Net>, +    pub ipv6: Vec<Ipv6Net>, +} + +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] +pub struct Config { +    pub peers: HashMap<Key, Peer>, +} + +impl Default for Config { +    #[inline] +    fn default() -> Config { +        Config { +            peers: HashMap::new(), +        } +    } +} diff --git a/src/proto.rs b/src/proto.rs index 414eee8..bfc3cbb 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -2,15 +2,15 @@  //  // See COPYING. +use crate::ip::{Ipv4Net, Ipv6Net}; +use crate::model::{Key, Endpoint};  use serde_derive;  use std::time::SystemTime; -use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; -  #[serde(deny_unknown_fields)]  #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]  pub struct Peer { -    pub public_key: String, +    pub public_key: Key,      #[serde(default = "Vec::new")]      pub ipv4: Vec<Ipv4Net>,      #[serde(default = "Vec::new")] @@ -32,11 +32,7 @@ pub struct Server {  pub struct RoadWarrior {      #[serde(flatten)]      pub peer: Peer, -    pub base: String, -} - -fn default_peer_keepalive() -> u32 { -    0 +    pub base: Key,  }  #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] @@ -62,10 +58,16 @@ pub struct Source {      pub next: Option<SourceNextConfig>,  } +#[inline] +fn default_peer_keepalive() -> u32 { +    0 +} +  mod serde_utc {      use crate::bin;      use chrono::{DateTime, SecondsFormat, TimeZone, Utc};      use serde::*; +    use std::fmt;      use std::time::SystemTime;      pub fn serialize<S: Serializer>(t: &SystemTime, ser: S) -> Result<S::Ok, S::Error> { @@ -83,9 +85,19 @@ mod serde_utc {      pub fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result<SystemTime, D::Error> {          if de.is_human_readable() { -            let s: String = String::deserialize(de)?; -            let t = DateTime::parse_from_rfc3339(&s).map_err(de::Error::custom)?; -            Ok(t.into()) +            struct RFC3339Visitor; +            impl<'de> serde::de::Visitor<'de> for RFC3339Visitor { +                type Value = SystemTime; + +                fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { +                    f.write_str("RFC3339 time") +                } + +                fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> { +                    DateTime::parse_from_rfc3339(s).map_err(de::Error::custom).map(SystemTime::from) +                } +            } +            de.deserialize_str(RFC3339Visitor)          } else {              let mut buf = <[u8; 12]>::deserialize(de)?;              let (buf_secs, buf_nanos) = array_refs![&mut buf, 8, 4]; @@ -2,207 +2,17 @@  //  // See COPYING. -use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; -use crate::{config, proto}; -use hash_map::HashMap; -use std::collections::hash_map; -use std::{error, fmt, io}; - -use std::env; +use crate::{model}; +use std::{env, io};  use std::ffi::{OsStr, OsString};  use std::process::{Command, Stdio}; -#[derive(Debug)] -pub struct ConfigError { -    pub url: String, -    pub peer: String, -    pub important: bool, -    err: &'static str, -} - -impl ConfigError { -    fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self { -        ConfigError { -            url: s.url.clone(), -            peer: p.public_key.clone(), -            important, -            err, -        } -    } -} - -impl error::Error for ConfigError {} -impl fmt::Display for ConfigError { -    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { -        write!( -            f, -            "{} [{}] from [{}]: {}", -            if self.important { -                "Invalid peer" -            } else { -                "Misconfigured peer" -            }, -            self.peer, -            self.url, -            self.err -        ) -    } -} - -#[derive(Clone, PartialEq, Eq, Debug)] -struct Peer { -    endpoint: Option<Endpoint>, -    psk: Option<String>, -    keepalive: u32, -    ipv4: Vec<Ipv4Net>, -    ipv6: Vec<Ipv6Net>, -} - -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct Config { -    peers: HashMap<String, Peer>, -} - -impl Default for Config { -    fn default() -> Config { -        Config { -            peers: HashMap::new(), -        } -    } -} - -pub struct ConfigBuilder<'a> { -    peers: HashMap<String, Peer>, -    public_key: &'a str, -    pc: &'a config::PeerConfig, -} - -impl<'a> ConfigBuilder<'a> { -    pub fn new(public_key: &'a str, pc: &'a config::PeerConfig) -> Self { -        ConfigBuilder { -            peers: HashMap::new(), -            public_key, -            pc, -        } -    } - -    pub fn build(self) -> Config { -        Config { peers: self.peers } -    } - -    fn insert_with<'b>( -        &'b mut self, -        err: &mut Vec<ConfigError>, -        s: &config::Source, -        p: &proto::Peer, -        update: impl for<'c> FnOnce(&'c mut Peer) -> (), -    ) -> &'b mut Peer { -        match self.peers.entry(p.public_key.clone()) { -            hash_map::Entry::Occupied(ent) => { -                err.push(ConfigError::new("Duplicate public key", s, p, true)); -                ent.into_mut() -            } -            hash_map::Entry::Vacant(ent) => { -                let ent = ent.insert(Peer { -                    endpoint: None, -                    psk: None, -                    keepalive: 0, -                    ipv4: vec![], -                    ipv6: vec![], -                }); -                update(ent); -                ent -            } -        } -    } - -    fn add_peer(err: &mut Vec<ConfigError>, ent: &mut Peer, s: &config::Source, p: &proto::Peer) { -        let mut added = false; -        let mut removed = false; - -        for i in p.ipv4.iter() { -            if s.ipv4.contains(i) { -                ent.ipv4.push(*i); -                added = true; -            } else { -                removed = true; -            } -        } -        for i in p.ipv6.iter() { -            if s.ipv6.contains(i) { -                ent.ipv6.push(*i); -                added = true; -            } else { -                removed = true; -            } -        } - -        if removed { -            let msg = if added { -                "Some IPs removed" -            } else { -                "All IPs removed" -            }; -            err.push(ConfigError::new(msg, s, p, !added)); -        } -    } - -    pub fn add_server( -        &mut self, -        err: &mut Vec<ConfigError>, -        s: &config::Source, -        p: &proto::Server, -    ) { -        if !valid_key(&p.peer.public_key) { -            err.push(ConfigError::new("Invalid public key", s, &p.peer, true)); -            return; -        } - -        if p.peer.public_key == self.public_key { -            return; -        } - -        let pc = self.pc; -        let ent = self.insert_with(err, s, &p.peer, |ent| { -            ent.psk = s.psk.clone(); -            ent.endpoint = Some(p.endpoint.clone()); -            ent.keepalive = pc.fix_keepalive(p.keepalive); -        }); - -        Self::add_peer(err, ent, s, &p.peer) -    } - -    pub fn add_road_warrior( -        &mut self, -        err: &mut Vec<ConfigError>, -        s: &config::Source, -        p: &proto::RoadWarrior, -    ) { -        if !valid_key(&p.peer.public_key) { -            err.push(ConfigError::new("Invalid public key", s, &p.peer, true)); -            return; -        } - -        let ent = if p.base == self.public_key { -            self.insert_with(err, s, &p.peer, |_| {}) -        } else { -            match self.peers.get_mut(&p.base) { -                Some(ent) => ent, -                None => { -                    err.push(ConfigError::new("Unknown base peer", s, &p.peer, true)); -                    return; -                } -            } -        }; -        Self::add_peer(err, ent, s, &p.peer) -    } -} -  pub struct Device {      ifname: String,  }  impl Device { +    #[inline]      pub fn new(ifname: String) -> io::Result<Self> {          Ok(Device { ifname })      } @@ -220,7 +30,7 @@ impl Device {          })      } -    pub fn get_public_key(&self) -> io::Result<String> { +    pub fn get_public_key(&self) -> io::Result<model::Key> {          let mut proc = Device::wg_command();          proc.stdin(Stdio::null());          proc.stdout(Stdio::piped()); @@ -237,17 +47,17 @@ impl Device {          if out.ends_with(b"\n") {              out.remove(out.len() - 1);          } -        String::from_utf8(out) +        model::Key::from_bytes(&out)              .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid public key"))      } -    pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> { +    pub fn apply_diff(&mut self, old: &model::Config, new: &model::Config) -> io::Result<()> {          let mut proc = Device::wg_command();          proc.stdin(Stdio::piped());          proc.arg("set");          proc.arg(&self.ifname); -        let mut psks = Vec::<&str>::new(); +        let mut psks = String::new();          for (pubkey, conf) in new.peers.iter() {              let old_endpoint; @@ -261,7 +71,7 @@ impl Device {              }              proc.arg("peer"); -            proc.arg(pubkey); +            proc.arg(format!("{}", pubkey));              if old_endpoint != conf.endpoint {                  if let Some(ref endpoint) = conf.endpoint { @@ -273,7 +83,10 @@ impl Device {              if let Some(psk) = &conf.psk {                  proc.arg("preshared-key");                  proc.arg("/dev/stdin"); -                psks.push(psk); +                { +                    use std::fmt::Write; +                    writeln!(&mut psks, "{}", psk).unwrap(); +                }              }              let mut ips = String::new(); @@ -302,7 +115,7 @@ impl Device {                  continue;              }              proc.arg("peer"); -            proc.arg(pubkey); +            proc.arg(format!("{}", pubkey));              proc.arg("remove");          } @@ -310,9 +123,7 @@ impl Device {          {              use std::io::Write;              let stdin = proc.stdin.as_mut().unwrap(); -            for psk in psks { -                writeln!(stdin, "{}", psk)?; -            } +            write!(stdin, "{}", psks)?;          }          let r = proc.wait()?; @@ -322,29 +133,3 @@ impl Device {          Ok(())      }  } - -fn valid_key(s: &str) -> bool { -    let s = s.as_bytes(); -    if s.len() != 44 { -        return false; -    } -    if s[43] != b'=' { -        return false; -    } -    for c in s[0..42].iter().cloned() { -        if c >= b'0' && c <= b'9' { -            continue; -        } -        if c >= b'A' && c <= b'Z' { -            continue; -        } -        if c >= b'a' && c <= b'z' { -            continue; -        } -        if c == b'+' || c <= b'/' { -            continue; -        } -        return false; -    } -    b"048AEIMQUYcgkosw".contains(&s[42]) -} | 
