use ::std::{error, io, fmt}; use ::std::collections::hash_map; use hash_map::HashMap; use crate::ip::{Ipv4Net, Ipv6Net, Endpoint}; use crate::{proto, config}; #[derive(Clone, PartialEq, Eq, Debug)] struct Peer { endpoint: Endpoint, psk: Option, keepalive: u32, ipv4: Vec, ipv6: Vec, } #[derive(Clone, PartialEq, Eq, Debug)] pub struct Config { peers: HashMap, } #[derive(Debug)] pub struct ConfigError { pub url: String, pub peer: String, pub important: bool, err: &'static str, } impl error::Error for ConfigError {} impl fmt::Display for ConfigError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Invalid peer [{}] from [{}]: {}", self.peer, self.url, self.err) } } impl Config { pub fn new() -> Config { Config { peers: HashMap::new(), } } pub fn add_peer(&mut self, errors: &mut Vec, c: &config::PeerConfig, s: &config::Source, p: &proto::Peer) { if !valid_key(&p.public_key) { errors.push(ConfigError { url: s.url.clone(), peer: p.public_key.clone(), important: true, err: "Invalid public key", }); return; } if let Some(ref psk) = s.psk { if !valid_key(psk) { errors.push(ConfigError { url: s.url.clone(), peer: p.public_key.clone(), important: true, err: "Invalid preshared key", }); return; } } if c.omit_peers.contains(&p.public_key) { return; } let ent = match self.peers.entry(p.public_key.clone()) { hash_map::Entry::Occupied(_) => { errors.push(ConfigError { url: s.url.clone(), peer: p.public_key.clone(), important: true, err: "Duplicate public key", }); return; }, hash_map::Entry::Vacant(ent) => ent, }; let mut keepalive = p.keepalive; if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) { keepalive = c.max_keepalive; } if keepalive != 0 && keepalive < c.min_keepalive { keepalive = c.min_keepalive; } let mut removed = false; let mut ipv4 = p.ipv4.clone(); ipv4.retain(|i| { let r = s.ipv4.contains(i); if !r { removed = true; } r }); let mut ipv6 = p.ipv6.clone(); ipv6.retain(|i| { let r = s.ipv6.contains(i); if !r { removed = true; } r }); let r = ent.insert(Peer { endpoint: p.endpoint.clone(), psk: s.psk.clone(), keepalive, ipv4, ipv6, }); if removed { let all = r.ipv4.is_empty() && r.ipv6.is_empty(); errors.push(ConfigError { url: s.url.clone(), peer: p.public_key.clone(), important: all, err: if all { "All IPs removed" } else {"Some IPs removed"}, }); } } } impl Default for Config { #[inline] fn default() -> Self { Config::new() } } pub struct Device { ifname: String, wg_command: String, } impl Device { pub fn new(ifname: String, wg_command: String) -> Self { Device { ifname, wg_command } } pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> { use ::std::process::{Command, Stdio}; let mut proc = Command::new(&self.wg_command); proc.stdin(Stdio::piped()); proc.stdout(Stdio::null()); proc.arg("set"); proc.arg(&self.ifname); let mut psks = Vec::<&str>::new(); for (pubkey, conf) in new.peers.iter() { if let Some(old_peer) = old.peers.get(pubkey) { if *old_peer == *conf { continue; } } proc.arg("peer"); proc.arg(pubkey); // TODO: maybe skip endpoint? proc.arg("endpoint"); proc.arg(format!("{}", conf.endpoint)); if let Some(psk) = &conf.psk { proc.arg("preshared-key"); proc.arg("/dev/stdin"); psks.push(psk); } let mut ips = String::new(); { use std::fmt::Write; for ip in conf.ipv4.iter() { if !ips.is_empty() { ips.push(','); } write!(ips, "{}", ip).unwrap(); } for ip in conf.ipv6.iter() { if !ips.is_empty() { ips.push(','); } write!(ips, "{}", ip).unwrap(); } } proc.arg("allowed-ips"); proc.arg(ips); } for pubkey in old.peers.keys() { if new.peers.contains_key(pubkey) { continue; } proc.arg("peer"); proc.arg(pubkey); proc.arg("remove"); } let mut proc = proc.spawn()?; { use std::io::Write; let stdin = proc.stdin.as_mut().unwrap(); for psk in psks { write!(stdin, "{}\n", psk)?; } } let r = proc.wait()?; if !r.success() { return Err(io::Error::new(io::ErrorKind::Other, "Child process failed")); } Ok(()) } } fn valid_key(s: &str) -> bool { let s = s.as_bytes(); if s.len() != 44 { return false; } if s[43] != b'=' { return false; } for c in s[0..42].iter().cloned() { if c >= b'0' && c <= b'9' { continue; } if c >= b'A' && c <= b'Z' { continue; } if c >= b'a' && c <= b'z' { continue; } if c == b'+' || c <= b'/' { continue; } return false; } b"048AEIMQUYcgkosw".contains(&s[42]) }