diff options
Diffstat (limited to 'src/wg.rs')
-rw-r--r-- | src/wg.rs | 243 |
1 files changed, 14 insertions, 229 deletions
@@ -2,207 +2,17 @@ // // See COPYING. -use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; -use crate::{config, proto}; -use hash_map::HashMap; -use std::collections::hash_map; -use std::{error, fmt, io}; - -use std::env; +use crate::{model}; +use std::{env, io}; use std::ffi::{OsStr, OsString}; use std::process::{Command, Stdio}; -#[derive(Debug)] -pub struct ConfigError { - pub url: String, - pub peer: String, - pub important: bool, - err: &'static str, -} - -impl ConfigError { - fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self { - ConfigError { - url: s.url.clone(), - peer: p.public_key.clone(), - important, - err, - } - } -} - -impl error::Error for ConfigError {} -impl fmt::Display for ConfigError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{} [{}] from [{}]: {}", - if self.important { - "Invalid peer" - } else { - "Misconfigured peer" - }, - self.peer, - self.url, - self.err - ) - } -} - -#[derive(Clone, PartialEq, Eq, Debug)] -struct Peer { - endpoint: Option<Endpoint>, - psk: Option<String>, - keepalive: u32, - ipv4: Vec<Ipv4Net>, - ipv6: Vec<Ipv6Net>, -} - -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct Config { - peers: HashMap<String, Peer>, -} - -impl Default for Config { - fn default() -> Config { - Config { - peers: HashMap::new(), - } - } -} - -pub struct ConfigBuilder<'a> { - peers: HashMap<String, Peer>, - public_key: &'a str, - pc: &'a config::PeerConfig, -} - -impl<'a> ConfigBuilder<'a> { - pub fn new(public_key: &'a str, pc: &'a config::PeerConfig) -> Self { - ConfigBuilder { - peers: HashMap::new(), - public_key, - pc, - } - } - - pub fn build(self) -> Config { - Config { peers: self.peers } - } - - fn insert_with<'b>( - &'b mut self, - err: &mut Vec<ConfigError>, - s: &config::Source, - p: &proto::Peer, - update: impl for<'c> FnOnce(&'c mut Peer) -> (), - ) -> &'b mut Peer { - match self.peers.entry(p.public_key.clone()) { - hash_map::Entry::Occupied(ent) => { - err.push(ConfigError::new("Duplicate public key", s, p, true)); - ent.into_mut() - } - hash_map::Entry::Vacant(ent) => { - let ent = ent.insert(Peer { - endpoint: None, - psk: None, - keepalive: 0, - ipv4: vec![], - ipv6: vec![], - }); - update(ent); - ent - } - } - } - - fn add_peer(err: &mut Vec<ConfigError>, ent: &mut Peer, s: &config::Source, p: &proto::Peer) { - let mut added = false; - let mut removed = false; - - for i in p.ipv4.iter() { - if s.ipv4.contains(i) { - ent.ipv4.push(*i); - added = true; - } else { - removed = true; - } - } - for i in p.ipv6.iter() { - if s.ipv6.contains(i) { - ent.ipv6.push(*i); - added = true; - } else { - removed = true; - } - } - - if removed { - let msg = if added { - "Some IPs removed" - } else { - "All IPs removed" - }; - err.push(ConfigError::new(msg, s, p, !added)); - } - } - - pub fn add_server( - &mut self, - err: &mut Vec<ConfigError>, - s: &config::Source, - p: &proto::Server, - ) { - if !valid_key(&p.peer.public_key) { - err.push(ConfigError::new("Invalid public key", s, &p.peer, true)); - return; - } - - if p.peer.public_key == self.public_key { - return; - } - - let pc = self.pc; - let ent = self.insert_with(err, s, &p.peer, |ent| { - ent.psk = s.psk.clone(); - ent.endpoint = Some(p.endpoint.clone()); - ent.keepalive = pc.fix_keepalive(p.keepalive); - }); - - Self::add_peer(err, ent, s, &p.peer) - } - - pub fn add_road_warrior( - &mut self, - err: &mut Vec<ConfigError>, - s: &config::Source, - p: &proto::RoadWarrior, - ) { - if !valid_key(&p.peer.public_key) { - err.push(ConfigError::new("Invalid public key", s, &p.peer, true)); - return; - } - - let ent = if p.base == self.public_key { - self.insert_with(err, s, &p.peer, |_| {}) - } else { - match self.peers.get_mut(&p.base) { - Some(ent) => ent, - None => { - err.push(ConfigError::new("Unknown base peer", s, &p.peer, true)); - return; - } - } - }; - Self::add_peer(err, ent, s, &p.peer) - } -} - pub struct Device { ifname: String, } impl Device { + #[inline] pub fn new(ifname: String) -> io::Result<Self> { Ok(Device { ifname }) } @@ -220,7 +30,7 @@ impl Device { }) } - pub fn get_public_key(&self) -> io::Result<String> { + pub fn get_public_key(&self) -> io::Result<model::Key> { let mut proc = Device::wg_command(); proc.stdin(Stdio::null()); proc.stdout(Stdio::piped()); @@ -237,17 +47,17 @@ impl Device { if out.ends_with(b"\n") { out.remove(out.len() - 1); } - String::from_utf8(out) + model::Key::from_bytes(&out) .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid public key")) } - pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> { + pub fn apply_diff(&mut self, old: &model::Config, new: &model::Config) -> io::Result<()> { let mut proc = Device::wg_command(); proc.stdin(Stdio::piped()); proc.arg("set"); proc.arg(&self.ifname); - let mut psks = Vec::<&str>::new(); + let mut psks = String::new(); for (pubkey, conf) in new.peers.iter() { let old_endpoint; @@ -261,7 +71,7 @@ impl Device { } proc.arg("peer"); - proc.arg(pubkey); + proc.arg(format!("{}", pubkey)); if old_endpoint != conf.endpoint { if let Some(ref endpoint) = conf.endpoint { @@ -273,7 +83,10 @@ impl Device { if let Some(psk) = &conf.psk { proc.arg("preshared-key"); proc.arg("/dev/stdin"); - psks.push(psk); + { + use std::fmt::Write; + writeln!(&mut psks, "{}", psk).unwrap(); + } } let mut ips = String::new(); @@ -302,7 +115,7 @@ impl Device { continue; } proc.arg("peer"); - proc.arg(pubkey); + proc.arg(format!("{}", pubkey)); proc.arg("remove"); } @@ -310,9 +123,7 @@ impl Device { { use std::io::Write; let stdin = proc.stdin.as_mut().unwrap(); - for psk in psks { - writeln!(stdin, "{}", psk)?; - } + write!(stdin, "{}", psks)?; } let r = proc.wait()?; @@ -322,29 +133,3 @@ impl Device { Ok(()) } } - -fn valid_key(s: &str) -> bool { - let s = s.as_bytes(); - if s.len() != 44 { - return false; - } - if s[43] != b'=' { - return false; - } - for c in s[0..42].iter().cloned() { - if c >= b'0' && c <= b'9' { - continue; - } - if c >= b'A' && c <= b'Z' { - continue; - } - if c >= b'a' && c <= b'z' { - continue; - } - if c == b'+' || c <= b'/' { - continue; - } - return false; - } - b"048AEIMQUYcgkosw".contains(&s[42]) -} |