diff options
Diffstat (limited to 'src/wg.rs')
-rw-r--r-- | src/wg.rs | 135 |
1 files changed, 85 insertions, 50 deletions
@@ -1,8 +1,8 @@ -use ::std::{error, io, fmt}; -use ::std::collections::hash_map; +use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; +use crate::{config, proto}; use hash_map::HashMap; -use crate::ip::{Ipv4Net, Ipv6Net, Endpoint}; -use crate::{proto, config}; +use std::collections::hash_map; +use std::{error, fmt, io}; #[derive(Clone, PartialEq, Eq, Debug)] struct Peer { @@ -29,10 +29,19 @@ pub struct ConfigError { impl error::Error for ConfigError {} impl fmt::Display for ConfigError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Invalid peer [{}] from [{}]: {}", self.peer, self.url, self.err) + write!( + f, + "{} [{}] from [{}]: {}", + if self.important { + "Invalid peer" + } else { + "Misconfigured peer" + }, + self.peer, self.url, self.err + ) } } - + impl Config { pub fn new() -> Config { Config { @@ -40,7 +49,13 @@ impl Config { } } - pub fn add_peer(&mut self, errors: &mut Vec<ConfigError>, c: &config::PeerConfig, s: &config::Source, p: &proto::Peer) { + pub fn add_peer( + &mut self, + errors: &mut Vec<ConfigError>, + c: &config::PeerConfig, + s: &config::Source, + p: &proto::Peer, + ) { if !valid_key(&p.public_key) { errors.push(ConfigError { url: s.url.clone(), @@ -68,55 +83,64 @@ impl Config { } let ent = match self.peers.entry(p.public_key.clone()) { - hash_map::Entry::Occupied(_) => { + hash_map::Entry::Occupied(ent) => { errors.push(ConfigError { url: s.url.clone(), peer: p.public_key.clone(), important: true, err: "Duplicate public key", }); - return; + ent.into_mut() + } + hash_map::Entry::Vacant(ent) => { + let mut keepalive = p.keepalive; + if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) { + keepalive = c.max_keepalive; + } + if keepalive != 0 && keepalive < c.min_keepalive { + keepalive = c.min_keepalive; + } + + ent.insert(Peer { + endpoint: p.endpoint.clone(), + psk: s.psk.clone(), + keepalive, + ipv4: vec![], + ipv6: vec![], + }) }, - hash_map::Entry::Vacant(ent) => ent, }; - let mut keepalive = p.keepalive; - if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) { - keepalive = c.max_keepalive; - } - if keepalive != 0 && keepalive < c.min_keepalive { - keepalive = c.min_keepalive; - } - + let mut added = false; let mut removed = false; - let mut ipv4 = p.ipv4.clone(); - ipv4.retain(|i| { - let r = s.ipv4.contains(i); - if !r { removed = true; } - r - }); - - let mut ipv6 = p.ipv6.clone(); - ipv6.retain(|i| { - let r = s.ipv6.contains(i); - if !r { removed = true; } - r - }); - - let r = ent.insert(Peer { - endpoint: p.endpoint.clone(), - psk: s.psk.clone(), - keepalive, ipv4, ipv6, - }); + for i in p.ipv4.iter() { + if s.ipv4.contains(i) { + ent.ipv4.push(*i); + added = true; + } else { + removed = true; + } + } + for i in p.ipv6.iter() { + if s.ipv6.contains(i) { + ent.ipv6.push(*i); + added = true; + } else { + removed = true; + } + } if removed { - let all = r.ipv4.is_empty() && r.ipv6.is_empty(); errors.push(ConfigError { url: s.url.clone(), peer: p.public_key.clone(), - important: all, - err: if all { "All IPs removed" } else {"Some IPs removed"}, + important: !added, + err: if added { + "Some IPs removed" + } else { + "All IPs removed" + }, }); } } @@ -140,11 +164,10 @@ impl Device { } pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> { - use ::std::process::{Command, Stdio}; + use std::process::{Command, Stdio}; let mut proc = Command::new(&self.wg_command); proc.stdin(Stdio::piped()); - proc.stdout(Stdio::null()); proc.arg("set"); proc.arg(&self.ifname); @@ -158,7 +181,7 @@ impl Device { } proc.arg("peer"); proc.arg(pubkey); - + // TODO: maybe skip endpoint? proc.arg("endpoint"); proc.arg(format!("{}", conf.endpoint)); @@ -173,11 +196,15 @@ impl Device { { use std::fmt::Write; for ip in conf.ipv4.iter() { - if !ips.is_empty() { ips.push(','); } + if !ips.is_empty() { + ips.push(','); + } write!(ips, "{}", ip).unwrap(); } for ip in conf.ipv6.iter() { - if !ips.is_empty() { ips.push(','); } + if !ips.is_empty() { + ips.push(','); + } write!(ips, "{}", ip).unwrap(); } } @@ -200,7 +227,7 @@ impl Device { use std::io::Write; let stdin = proc.stdin.as_mut().unwrap(); for psk in psks { - write!(stdin, "{}\n", psk)?; + writeln!(stdin, "{}", psk)?; } } @@ -221,10 +248,18 @@ fn valid_key(s: &str) -> bool { return false; } for c in s[0..42].iter().cloned() { - if c >= b'0' && c <= b'9' { continue; } - if c >= b'A' && c <= b'Z' { continue; } - if c >= b'a' && c <= b'z' { continue; } - if c == b'+' || c <= b'/' { continue; } + if c >= b'0' && c <= b'9' { + continue; + } + if c >= b'A' && c <= b'Z' { + continue; + } + if c >= b'a' && c <= b'z' { + continue; + } + if c == b'+' || c <= b'/' { + continue; + } return false; } b"048AEIMQUYcgkosw".contains(&s[42]) |