From 6ddb2ca6ec02b79e63ce7eb442904f4fc91e077e Mon Sep 17 00:00:00 2001 From: Hristo Venev Date: Tue, 19 Mar 2019 11:47:44 +0200 Subject: Move stuff around, base64 decode keys. --- src/bin.rs | 7 ++ src/builder.rs | 172 ++++++++++++++++++++++++++++++++++++++++ src/config.rs | 6 +- src/ip.rs | 99 ++++------------------- src/main.rs | 24 +++--- src/model.rs | 197 ++++++++++++++++++++++++++++++++++++++++++++++ src/proto.rs | 34 +++++--- src/wg.rs | 243 ++++----------------------------------------------------- 8 files changed, 449 insertions(+), 333 deletions(-) create mode 100644 src/builder.rs create mode 100644 src/model.rs (limited to 'src') diff --git a/src/bin.rs b/src/bin.rs index ecef739..bd1cf87 100644 --- a/src/bin.rs +++ b/src/bin.rs @@ -7,10 +7,12 @@ pub fn i64_to_be(v: i64) -> [u8; 8] { u64_to_be(v as u64) } +#[inline] pub fn i64_from_be(v: [u8; 8]) -> i64 { u64_from_be(v) as i64 } +#[inline] pub fn u64_to_be(v: u64) -> [u8; 8] { [ (v >> 56) as u8, @@ -24,6 +26,7 @@ pub fn u64_to_be(v: u64) -> [u8; 8] { ] } +#[inline] pub fn u64_from_be(v: [u8; 8]) -> u64 { (u64::from(v[0]) << 56) | (u64::from(v[1]) << 48) @@ -35,18 +38,22 @@ pub fn u64_from_be(v: [u8; 8]) -> u64 { | u64::from(v[7]) } +#[inline] pub fn u32_to_be(v: u32) -> [u8; 4] { [(v >> 24) as u8, (v >> 16) as u8, (v >> 8) as u8, v as u8] } +#[inline] pub fn u32_from_be(v: [u8; 4]) -> u32 { (u32::from(v[0]) << 24) | (u32::from(v[1]) << 16) | (u32::from(v[2]) << 8) | u32::from(v[3]) } +#[inline] pub fn u16_to_be(v: u16) -> [u8; 2] { [(v >> 8) as u8, v as u8] } +#[inline] pub fn u16_from_be(v: [u8; 2]) -> u16 { (u16::from(v[0]) << 8) | u16::from(v[1]) } diff --git a/src/builder.rs b/src/builder.rs new file mode 100644 index 0000000..cc42fa8 --- /dev/null +++ b/src/builder.rs @@ -0,0 +1,172 @@ +// Copyright 2019 Hristo Venev +// +// See COPYING. + +use crate::{config, model, proto}; +use std::{error, fmt}; +use std::collections::hash_map; + + +#[derive(Debug)] +pub struct ConfigError { + pub url: String, + pub peer: model::Key, + pub important: bool, + err: &'static str, +} + +impl ConfigError { + fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self { + ConfigError { + url: s.url.clone(), + peer: p.public_key.clone(), + important, + err, + } + } +} + +impl error::Error for ConfigError {} +impl fmt::Display for ConfigError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} [{}] from [{}]: {}", + if self.important { + "Invalid peer" + } else { + "Misconfigured peer" + }, + self.peer, + self.url, + self.err + ) + } +} + +pub struct ConfigBuilder<'a> { + c: model::Config, + err: Vec, + public_key: model::Key, + pc: &'a config::PeerConfig, +} + +impl<'a> ConfigBuilder<'a> { + #[inline] + pub fn new(public_key: model::Key, pc: &'a config::PeerConfig) -> Self { + ConfigBuilder { + c: model::Config::default(), + err: vec![], + public_key, + pc, + } + } + + #[inline] + pub fn build(self) -> (model::Config, Vec) { + (self.c, self.err) + } + + #[inline] + pub fn add_server( + &mut self, + s: &config::Source, + p: &proto::Server, + ) { + if p.peer.public_key == self.public_key { + return; + } + + let pc = self.pc; + let ent = insert_peer(&mut self.c, &mut self.err, s, &p.peer, |ent| { + ent.psk = s.psk.clone(); + ent.endpoint = Some(p.endpoint.clone()); + ent.keepalive = pc.fix_keepalive(p.keepalive); + }); + + add_peer(&mut self.err, ent, s, &p.peer) + } + + #[inline] + pub fn add_road_warrior( + &mut self, + s: &config::Source, + p: &proto::RoadWarrior, + ) { + if p.peer.public_key == self.public_key { + self.err.push(ConfigError::new("The local peer cannot be a road warrior", s, &p.peer, true)); + return; + } + + let ent = if p.base == self.public_key { + insert_peer(&mut self.c, &mut self.err, s, &p.peer, |_| {}) + } else { + match self.c.peers.get_mut(&p.base) { + Some(ent) => ent, + None => { + self.err.push(ConfigError::new("Unknown base peer", s, &p.peer, true)); + return; + } + } + }; + add_peer(&mut self.err, ent, s, &p.peer) + } +} + +#[inline] +fn insert_peer<'b>( + c: &'b mut model::Config, + err: &mut Vec, + s: &config::Source, + p: &proto::Peer, + update: impl for<'c> FnOnce(&'c mut model::Peer) -> (), +) -> &'b mut model::Peer { + match c.peers.entry(p.public_key.clone()) { + hash_map::Entry::Occupied(ent) => { + err.push(ConfigError::new("Duplicate public key", s, p, true)); + ent.into_mut() + } + hash_map::Entry::Vacant(ent) => { + let ent = ent.insert(model::Peer { + endpoint: None, + psk: None, + keepalive: 0, + ipv4: vec![], + ipv6: vec![], + }); + update(ent); + ent + } + } +} + +fn add_peer(err: &mut Vec, ent: &mut model::Peer, s: &config::Source, p: &proto::Peer) { + let mut added = false; + let mut removed = false; + + for i in p.ipv4.iter() { + if s.ipv4.contains(i) { + ent.ipv4.push(*i); + added = true; + } else { + removed = true; + } + } + for i in p.ipv6.iter() { + if s.ipv6.contains(i) { + ent.ipv6.push(*i); + added = true; + } else { + removed = true; + } + } + + if removed { + let msg = if added { + "Some IPs removed" + } else { + "All IPs removed" + }; + err.push(ConfigError::new(msg, s, p, !added)); + } +} diff --git a/src/config.rs b/src/config.rs index 2effacb..98db02d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,13 +3,14 @@ // See COPYING. use crate::ip::{Ipv4Set, Ipv6Set}; +use crate::model::{Key}; use serde_derive; #[serde(deny_unknown_fields)] #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct Source { pub url: String, - pub psk: Option, + pub psk: Option, pub ipv4: Ipv4Set, pub ipv6: Ipv6Set, } @@ -56,14 +57,17 @@ impl PeerConfig { } } +#[inline] fn default_min_keepalive() -> u32 { 10 } +#[inline] fn default_max_keepalive() -> u32 { 0 } +#[inline] fn default_refresh_sec() -> u32 { 1200 } diff --git a/src/ip.rs b/src/ip.rs index 46d635a..d2ad17e 100644 --- a/src/ip.rs +++ b/src/ip.rs @@ -2,20 +2,20 @@ // // See COPYING. -use crate::bin; use serde; use std::iter::{FromIterator, IntoIterator}; -use std::net::{Ipv4Addr, Ipv6Addr}; +pub use std::net::{Ipv4Addr, Ipv6Addr}; use std::str::FromStr; -use std::{error, fmt, iter, net}; +use std::{error, fmt, iter}; #[derive(Debug)] -pub struct NetParseError {} +pub struct NetParseError; impl error::Error for NetParseError {} impl fmt::Display for NetParseError { + #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Invalid IP network") + write!(f, "Invalid address") } } @@ -49,6 +49,7 @@ macro_rules! per_proto { } impl fmt::Display for $nett { + #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}/{}", self.address, self.prefix_len) } @@ -58,14 +59,14 @@ macro_rules! per_proto { type Err = NetParseError; fn from_str(s: &str) -> Result<$nett, NetParseError> { let (addr, pfx) = pfx_split(s)?; - let addr = $addrt::from_str(addr).map_err(|_| NetParseError {})?; + let addr = $addrt::from_str(addr).map_err(|_| NetParseError)?; let r = $nett { address: addr, prefix_len: pfx, }; if !r.is_valid() { - return Err(NetParseError {}); + return Err(NetParseError); } Ok(r) } @@ -74,7 +75,7 @@ macro_rules! per_proto { impl serde::Serialize for $nett { fn serialize(&self, ser: S) -> Result { if ser.is_human_readable() { - ser.serialize_str(&format!("{}", self)) + ser.collect_str(self) } else { let mut buf = [0u8; $bytes + 1]; *array_mut_ref![&mut buf, 0, $bytes] = self.address.octets(); @@ -91,10 +92,12 @@ macro_rules! per_proto { impl<'de> serde::de::Visitor<'de> for NetVisitor { type Value = $nett; + #[inline] fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str($expecting) } + #[inline] fn visit_str(self, s: &str) -> Result { s.parse().map_err(E::custom) } @@ -107,7 +110,7 @@ macro_rules! per_proto { prefix_len: buf[$bytes], }; if r.is_valid() { - return Err(serde::de::Error::custom(NetParseError {})); + return Err(serde::de::Error::custom(NetParseError)); } Ok(r) } @@ -201,6 +204,7 @@ macro_rules! per_proto { } impl FromIterator<$nett> for $sett { + #[inline] fn from_iter>(it: I) -> $sett { let mut r = $sett::new(); for net in it { @@ -273,12 +277,14 @@ macro_rules! per_proto { } impl serde::Serialize for $sett { + #[inline] fn serialize(&self, ser: S) -> Result { as serde::Serialize>::serialize(&self.nets, ser) } } impl<'de> serde::Deserialize<'de> for $sett { + #[inline] fn deserialize>(de: D) -> Result { as serde::Deserialize>::deserialize(de).map($sett::from) } @@ -329,81 +335,10 @@ fn pfx_split(s: &str) -> Result<(&str, u8), NetParseError> { let i = match s.find('/') { Some(i) => i, None => { - return Err(NetParseError {}); + return Err(NetParseError); } }; let (addr, pfx) = s.split_at(i); - let pfx = u8::from_str(&pfx[1..]).map_err(|_| NetParseError {})?; + let pfx = u8::from_str(&pfx[1..]).map_err(|_| NetParseError)?; Ok((addr, pfx)) } - -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] -pub struct Endpoint { - pub address: Ipv6Addr, - pub port: u16, -} - -impl fmt::Display for Endpoint { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.address.segments()[5] == 0xffff { - write!(f, "{}:", self.address.to_ipv4().unwrap())?; - } else { - write!(f, "[{}]:", self.address)?; - } - write!(f, "{}", self.port) - } -} - -impl FromStr for Endpoint { - type Err = net::AddrParseError; - fn from_str(s: &str) -> Result { - net::SocketAddr::from_str(s).map(|v| Endpoint { - address: match v.ip() { - net::IpAddr::V4(a) => a.to_ipv6_mapped(), - net::IpAddr::V6(a) => a, - }, - port: v.port(), - }) - } -} - -impl serde::Serialize for Endpoint { - fn serialize(&self, ser: S) -> Result { - if ser.is_human_readable() { - ser.serialize_str(&format!("{}", self)) - } else { - let mut buf = [0u8; 16 + 2]; - let (buf_addr, buf_port) = mut_array_refs![&mut buf, 16, 2]; - *buf_addr = self.address.octets(); - *buf_port = crate::bin::u16_to_be(self.port); - ser.serialize_bytes(&buf) - } - } -} - -impl<'de> serde::Deserialize<'de> for Endpoint { - fn deserialize>(de: D) -> Result { - if de.is_human_readable() { - struct EndpointVisitor; - impl<'de> serde::de::Visitor<'de> for EndpointVisitor { - type Value = Endpoint; - - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str("ip:port") - } - - fn visit_str(self, s: &str) -> Result { - s.parse().map_err(E::custom) - } - } - de.deserialize_str(EndpointVisitor) - } else { - let buf = <[u8; 16 + 2] as serde::Deserialize>::deserialize(de)?; - let (buf_addr, buf_port) = array_refs![&buf, 16, 2]; - Ok(Endpoint { - address: (*buf_addr).into(), - port: bin::u16_from_be(*buf_port), - }) - } - } -} diff --git a/src/main.rs b/src/main.rs index 088d68e..c56cb91 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,8 +9,10 @@ use std::io; use std::time::{Duration, Instant, SystemTime}; mod bin; +mod builder; mod config; mod ip; +mod model; mod proto; mod wg; @@ -22,6 +24,7 @@ struct Source { } impl Source { + #[inline] fn new(config: config::Source) -> Source { Source { config, @@ -37,7 +40,7 @@ pub struct Device { peer_config: config::PeerConfig, update_config: config::UpdateConfig, sources: Vec, - current: wg::Config, + current: model::Config, } impl Device { @@ -49,15 +52,15 @@ impl Device { peer_config: c.peer_config, update_config: c.update_config, sources: c.sources.into_iter().map(Source::new).collect(), - current: wg::Config::default(), + current: model::Config::default(), }) } fn make_config( &self, - public_key: &str, + public_key: model::Key, ts: SystemTime, - ) -> (wg::Config, Vec, SystemTime) { + ) -> (model::Config, Vec, SystemTime) { let mut t_cfg = ts + Duration::from_secs(1 << 30); let mut sources: Vec<(&Source, &proto::SourceConfig)> = vec![]; for src in self.sources.iter() { @@ -78,20 +81,21 @@ impl Device { } } - let mut cfg = wg::ConfigBuilder::new(public_key, &self.peer_config); - let mut errs = vec![]; + let mut cfg = builder::ConfigBuilder::new(public_key, &self.peer_config); + for (src, sc) in sources.iter() { for peer in sc.servers.iter() { - cfg.add_server(&mut errs, &src.config, peer); + cfg.add_server(&src.config, peer); } } + for (src, sc) in sources.iter() { for peer in sc.road_warriors.iter() { - cfg.add_road_warrior(&mut errs, &src.config, peer); + cfg.add_road_warrior(&src.config, peer); } } - let cfg = cfg.build(); + let (cfg, errs) = cfg.build(); (cfg, errs, t_cfg) } @@ -135,7 +139,7 @@ impl Device { let now = Instant::now(); let sysnow = SystemTime::now(); let public_key = self.dev.get_public_key()?; - let (config, errors, t_cfg) = self.make_config(&public_key, sysnow); + let (config, errors, t_cfg) = self.make_config(public_key, sysnow); let time_to_cfg = t_cfg .duration_since(sysnow) .unwrap_or(Duration::from_secs(0)); diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..a7675f2 --- /dev/null +++ b/src/model.rs @@ -0,0 +1,197 @@ +// Copyright 2019 Hristo Venev +// +// See COPYING. + +use base64; +use crate::bin; +use crate::ip::{Ipv4Addr, Ipv6Addr, Ipv4Net, Ipv6Net, NetParseError}; +use std::{fmt}; +use std::collections::{HashMap}; +use std::str::FromStr; + +pub type KeyParseError = base64::DecodeError; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct Key([u8; 32]); + +impl Key { + pub fn from_bytes(s: &[u8]) -> Result { + let mut v = Key([0u8; 32]); + let l = base64::decode_config_slice(s, base64::STANDARD, &mut v.0)?; + if l != v.0.len() { + return Err(base64::DecodeError::InvalidLength); + } + Ok(v) + } +} + +impl fmt::Display for Key { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", base64::display::Base64Display::with_config(&self.0, base64::STANDARD)) + } +} + +impl FromStr for Key { + type Err = KeyParseError; + #[inline] + fn from_str(s: &str) -> Result { + Key::from_bytes(s.as_bytes()) + } +} + +impl serde::Serialize for Key { + fn serialize(&self, ser: S) -> Result { + if ser.is_human_readable() { + ser.collect_str(self) + } else { + ser.serialize_bytes(&self.0) + } + } +} + +impl<'de> serde::Deserialize<'de> for Key { + fn deserialize>(de: D) -> Result { + if de.is_human_readable() { + struct KeyVisitor; + impl<'de> serde::de::Visitor<'de> for KeyVisitor { + type Value = Key; + + #[inline] + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("WireGuard key") + } + + #[inline] + fn visit_str(self, s: &str) -> Result { + s.parse().map_err(E::custom) + } + } + de.deserialize_str(KeyVisitor) + } else { + serde::Deserialize::deserialize(de).map(Key) + } + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct Endpoint { + address: Ipv6Addr, + port: u16, +} + +impl Endpoint { + #[inline] + pub fn ipv6_address(&self) -> Ipv6Addr { + self.address + } + + #[inline] + pub fn ipv4_address(&self) -> Option { + let seg = self.address.octets(); + let (first, second) = array_refs![&seg, 12, 4]; + if *first == [0,0,0,0,0,0,0,0,0,0,0xff,0xff] { + Some(Ipv4Addr::from(*second)) + } else { + None + } + } + + #[inline] + pub fn port(&self) -> u16 { + self.port + } +} + +impl fmt::Display for Endpoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(ipv4) = self.ipv4_address() { + write!(f, "{}:", ipv4)?; + } else { + write!(f, "[{}]:", self.ipv6_address())?; + } + write!(f, "{}", self.port()) + } +} + +impl FromStr for Endpoint { + type Err = NetParseError; + fn from_str(s: &str) -> Result { + use std::net; + net::SocketAddr::from_str(s) + .map_err(|_| NetParseError) + .map(|v| Endpoint { + address: match v.ip() { + net::IpAddr::V4(a) => a.to_ipv6_mapped(), + net::IpAddr::V6(a) => a, + }, + port: v.port(), + }) + } +} + +impl serde::Serialize for Endpoint { + fn serialize(&self, ser: S) -> Result { + if ser.is_human_readable() { + ser.collect_str(self) + } else { + let mut buf = [0u8; 16 + 2]; + let (buf_addr, buf_port) = mut_array_refs![&mut buf, 16, 2]; + *buf_addr = self.address.octets(); + *buf_port = crate::bin::u16_to_be(self.port); + ser.serialize_bytes(&buf) + } + } +} + +impl<'de> serde::Deserialize<'de> for Endpoint { + fn deserialize>(de: D) -> Result { + if de.is_human_readable() { + struct EndpointVisitor; + impl<'de> serde::de::Visitor<'de> for EndpointVisitor { + type Value = Endpoint; + + #[inline] + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("IP:port") + } + + #[inline] + fn visit_str(self, s: &str) -> Result { + s.parse().map_err(E::custom) + } + } + de.deserialize_str(EndpointVisitor) + } else { + let buf = <[u8; 16 + 2] as serde::Deserialize>::deserialize(de)?; + let (buf_addr, buf_port) = array_refs![&buf, 16, 2]; + Ok(Endpoint { + address: (*buf_addr).into(), + port: bin::u16_from_be(*buf_port), + }) + } + } +} + +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] +pub struct Peer { + pub endpoint: Option, + pub psk: Option, + pub keepalive: u32, + pub ipv4: Vec, + pub ipv6: Vec, +} + +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] +pub struct Config { + pub peers: HashMap, +} + +impl Default for Config { + #[inline] + fn default() -> Config { + Config { + peers: HashMap::new(), + } + } +} diff --git a/src/proto.rs b/src/proto.rs index 414eee8..bfc3cbb 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -2,15 +2,15 @@ // // See COPYING. +use crate::ip::{Ipv4Net, Ipv6Net}; +use crate::model::{Key, Endpoint}; use serde_derive; use std::time::SystemTime; -use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; - #[serde(deny_unknown_fields)] #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct Peer { - pub public_key: String, + pub public_key: Key, #[serde(default = "Vec::new")] pub ipv4: Vec, #[serde(default = "Vec::new")] @@ -32,11 +32,7 @@ pub struct Server { pub struct RoadWarrior { #[serde(flatten)] pub peer: Peer, - pub base: String, -} - -fn default_peer_keepalive() -> u32 { - 0 + pub base: Key, } #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] @@ -62,10 +58,16 @@ pub struct Source { pub next: Option, } +#[inline] +fn default_peer_keepalive() -> u32 { + 0 +} + mod serde_utc { use crate::bin; use chrono::{DateTime, SecondsFormat, TimeZone, Utc}; use serde::*; + use std::fmt; use std::time::SystemTime; pub fn serialize(t: &SystemTime, ser: S) -> Result { @@ -83,9 +85,19 @@ mod serde_utc { pub fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result { if de.is_human_readable() { - let s: String = String::deserialize(de)?; - let t = DateTime::parse_from_rfc3339(&s).map_err(de::Error::custom)?; - Ok(t.into()) + struct RFC3339Visitor; + impl<'de> serde::de::Visitor<'de> for RFC3339Visitor { + type Value = SystemTime; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("RFC3339 time") + } + + fn visit_str(self, s: &str) -> Result { + DateTime::parse_from_rfc3339(s).map_err(de::Error::custom).map(SystemTime::from) + } + } + de.deserialize_str(RFC3339Visitor) } else { let mut buf = <[u8; 12]>::deserialize(de)?; let (buf_secs, buf_nanos) = array_refs![&mut buf, 8, 4]; diff --git a/src/wg.rs b/src/wg.rs index 1de5574..0335c3d 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -2,207 +2,17 @@ // // See COPYING. -use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; -use crate::{config, proto}; -use hash_map::HashMap; -use std::collections::hash_map; -use std::{error, fmt, io}; - -use std::env; +use crate::{model}; +use std::{env, io}; use std::ffi::{OsStr, OsString}; use std::process::{Command, Stdio}; -#[derive(Debug)] -pub struct ConfigError { - pub url: String, - pub peer: String, - pub important: bool, - err: &'static str, -} - -impl ConfigError { - fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self { - ConfigError { - url: s.url.clone(), - peer: p.public_key.clone(), - important, - err, - } - } -} - -impl error::Error for ConfigError {} -impl fmt::Display for ConfigError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{} [{}] from [{}]: {}", - if self.important { - "Invalid peer" - } else { - "Misconfigured peer" - }, - self.peer, - self.url, - self.err - ) - } -} - -#[derive(Clone, PartialEq, Eq, Debug)] -struct Peer { - endpoint: Option, - psk: Option, - keepalive: u32, - ipv4: Vec, - ipv6: Vec, -} - -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct Config { - peers: HashMap, -} - -impl Default for Config { - fn default() -> Config { - Config { - peers: HashMap::new(), - } - } -} - -pub struct ConfigBuilder<'a> { - peers: HashMap, - public_key: &'a str, - pc: &'a config::PeerConfig, -} - -impl<'a> ConfigBuilder<'a> { - pub fn new(public_key: &'a str, pc: &'a config::PeerConfig) -> Self { - ConfigBuilder { - peers: HashMap::new(), - public_key, - pc, - } - } - - pub fn build(self) -> Config { - Config { peers: self.peers } - } - - fn insert_with<'b>( - &'b mut self, - err: &mut Vec, - s: &config::Source, - p: &proto::Peer, - update: impl for<'c> FnOnce(&'c mut Peer) -> (), - ) -> &'b mut Peer { - match self.peers.entry(p.public_key.clone()) { - hash_map::Entry::Occupied(ent) => { - err.push(ConfigError::new("Duplicate public key", s, p, true)); - ent.into_mut() - } - hash_map::Entry::Vacant(ent) => { - let ent = ent.insert(Peer { - endpoint: None, - psk: None, - keepalive: 0, - ipv4: vec![], - ipv6: vec![], - }); - update(ent); - ent - } - } - } - - fn add_peer(err: &mut Vec, ent: &mut Peer, s: &config::Source, p: &proto::Peer) { - let mut added = false; - let mut removed = false; - - for i in p.ipv4.iter() { - if s.ipv4.contains(i) { - ent.ipv4.push(*i); - added = true; - } else { - removed = true; - } - } - for i in p.ipv6.iter() { - if s.ipv6.contains(i) { - ent.ipv6.push(*i); - added = true; - } else { - removed = true; - } - } - - if removed { - let msg = if added { - "Some IPs removed" - } else { - "All IPs removed" - }; - err.push(ConfigError::new(msg, s, p, !added)); - } - } - - pub fn add_server( - &mut self, - err: &mut Vec, - s: &config::Source, - p: &proto::Server, - ) { - if !valid_key(&p.peer.public_key) { - err.push(ConfigError::new("Invalid public key", s, &p.peer, true)); - return; - } - - if p.peer.public_key == self.public_key { - return; - } - - let pc = self.pc; - let ent = self.insert_with(err, s, &p.peer, |ent| { - ent.psk = s.psk.clone(); - ent.endpoint = Some(p.endpoint.clone()); - ent.keepalive = pc.fix_keepalive(p.keepalive); - }); - - Self::add_peer(err, ent, s, &p.peer) - } - - pub fn add_road_warrior( - &mut self, - err: &mut Vec, - s: &config::Source, - p: &proto::RoadWarrior, - ) { - if !valid_key(&p.peer.public_key) { - err.push(ConfigError::new("Invalid public key", s, &p.peer, true)); - return; - } - - let ent = if p.base == self.public_key { - self.insert_with(err, s, &p.peer, |_| {}) - } else { - match self.peers.get_mut(&p.base) { - Some(ent) => ent, - None => { - err.push(ConfigError::new("Unknown base peer", s, &p.peer, true)); - return; - } - } - }; - Self::add_peer(err, ent, s, &p.peer) - } -} - pub struct Device { ifname: String, } impl Device { + #[inline] pub fn new(ifname: String) -> io::Result { Ok(Device { ifname }) } @@ -220,7 +30,7 @@ impl Device { }) } - pub fn get_public_key(&self) -> io::Result { + pub fn get_public_key(&self) -> io::Result { let mut proc = Device::wg_command(); proc.stdin(Stdio::null()); proc.stdout(Stdio::piped()); @@ -237,17 +47,17 @@ impl Device { if out.ends_with(b"\n") { out.remove(out.len() - 1); } - String::from_utf8(out) + model::Key::from_bytes(&out) .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid public key")) } - pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> { + pub fn apply_diff(&mut self, old: &model::Config, new: &model::Config) -> io::Result<()> { let mut proc = Device::wg_command(); proc.stdin(Stdio::piped()); proc.arg("set"); proc.arg(&self.ifname); - let mut psks = Vec::<&str>::new(); + let mut psks = String::new(); for (pubkey, conf) in new.peers.iter() { let old_endpoint; @@ -261,7 +71,7 @@ impl Device { } proc.arg("peer"); - proc.arg(pubkey); + proc.arg(format!("{}", pubkey)); if old_endpoint != conf.endpoint { if let Some(ref endpoint) = conf.endpoint { @@ -273,7 +83,10 @@ impl Device { if let Some(psk) = &conf.psk { proc.arg("preshared-key"); proc.arg("/dev/stdin"); - psks.push(psk); + { + use std::fmt::Write; + writeln!(&mut psks, "{}", psk).unwrap(); + } } let mut ips = String::new(); @@ -302,7 +115,7 @@ impl Device { continue; } proc.arg("peer"); - proc.arg(pubkey); + proc.arg(format!("{}", pubkey)); proc.arg("remove"); } @@ -310,9 +123,7 @@ impl Device { { use std::io::Write; let stdin = proc.stdin.as_mut().unwrap(); - for psk in psks { - writeln!(stdin, "{}", psk)?; - } + write!(stdin, "{}", psks)?; } let r = proc.wait()?; @@ -322,29 +133,3 @@ impl Device { Ok(()) } } - -fn valid_key(s: &str) -> bool { - let s = s.as_bytes(); - if s.len() != 44 { - return false; - } - if s[43] != b'=' { - return false; - } - for c in s[0..42].iter().cloned() { - if c >= b'0' && c <= b'9' { - continue; - } - if c >= b'A' && c <= b'Z' { - continue; - } - if c >= b'a' && c <= b'z' { - continue; - } - if c == b'+' || c <= b'/' { - continue; - } - return false; - } - b"048AEIMQUYcgkosw".contains(&s[42]) -} -- cgit