From 87aa7ff470e93e412502cd5341f110fa74de03a1 Mon Sep 17 00:00:00 2001 From: Hristo Venev Date: Mon, 18 Mar 2019 13:42:37 +0200 Subject: Implement road warrior support. --- src/config.rs | 16 ++++- src/main.rs | 30 ++++++--- src/proto.rs | 26 ++++++- src/wg.rs | 213 ++++++++++++++++++++++++++++++++++------------------------ 4 files changed, 182 insertions(+), 103 deletions(-) diff --git a/src/config.rs b/src/config.rs index 833b546..9124902 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,5 @@ use crate::ip::{Ipv4Set, Ipv6Set}; use serde_derive; -use std::collections::HashSet; #[serde(deny_unknown_fields)] #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] @@ -14,12 +13,11 @@ pub struct Source { #[serde(deny_unknown_fields)] #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct PeerConfig { + pub own_public_key: String, #[serde(default = "default_min_keepalive")] pub min_keepalive: u32, #[serde(default = "default_max_keepalive")] pub max_keepalive: u32, - - pub omit_peers: HashSet, } #[serde(deny_unknown_fields)] @@ -46,6 +44,18 @@ pub struct Config { pub sources: Vec, } +impl PeerConfig { + pub fn fix_keepalive(&self, mut k: u32) -> u32 { + if self.max_keepalive != 0 && (k == 0 || k > self.max_keepalive) { + k = self.max_keepalive; + } + if k != 0 && k < self.min_keepalive { + k = self.min_keepalive; + } + k + } +} + fn default_wg_command() -> String { "wg".to_owned() } diff --git a/src/main.rs b/src/main.rs index 4ad5f39..deea336 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,21 +38,22 @@ pub struct Device { impl Device { pub fn new(c: config::Config) -> Device { + let dev = wg::Device::new(c.ifname, c.wg_command); + let current = wg::ConfigBuilder::new(&c.peers).build(); Device { - dev: wg::Device::new(c.ifname, c.wg_command), + dev, peer_config: c.peers, update_config: c.update, sources: c.sources.into_iter().map(Source::new).collect(), - current: wg::Config::new(), + current, } } fn make_config(&self, ts: SystemTime) -> (wg::Config, Vec, SystemTime) { - let mut cfg = wg::Config::new(); let mut next_update = ts + Duration::from_secs(3600); - let mut errs = vec![]; + let mut sources: Vec<(&Source, &proto::SourceConfig)> = vec![]; for src in self.sources.iter() { - if let Some(data) = &src.data { + if let Some(ref data) = src.data { let sc = data .next .as_ref() @@ -65,11 +66,24 @@ impl Device { } }) .unwrap_or(&data.config); - for peer in sc.peers.iter() { - cfg.add_peer(&mut errs, &self.peer_config, &src.config, peer); - } + sources.push((src, sc)); } } + + let mut cfg = wg::ConfigBuilder::new(&self.peer_config); + let mut errs = vec![]; + for (src, sc) in sources.iter() { + for peer in sc.servers.iter() { + cfg.add_server(&mut errs, &src.config, peer); + } + } + for (src, sc) in sources.iter() { + for peer in sc.road_warriors.iter() { + cfg.add_road_warrior(&mut errs, &src.config, peer); + } + } + + let cfg = cfg.build(); (cfg, errs, next_update) } diff --git a/src/proto.rs b/src/proto.rs index a3ee6e6..609dbd9 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -7,11 +7,28 @@ use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct Peer { pub public_key: String, + #[serde(default = "Vec::new")] + pub ipv4: Vec, + #[serde(default = "Vec::new")] + pub ipv6: Vec, +} + +#[serde(deny_unknown_fields)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] +pub struct Server { + #[serde(flatten)] + pub peer: Peer, pub endpoint: Endpoint, #[serde(default = "default_peer_keepalive")] pub keepalive: u32, - pub ipv4: Vec, - pub ipv6: Vec, +} + +#[serde(deny_unknown_fields)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] +pub struct RoadWarrior { + #[serde(flatten)] + pub peer: Peer, + pub base: String, } fn default_peer_keepalive() -> u32 { @@ -20,7 +37,10 @@ fn default_peer_keepalive() -> u32 { #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct SourceConfig { - pub peers: Vec, + #[serde(default = "Vec::new")] + pub servers: Vec, + #[serde(default = "Vec::new")] + pub road_warriors: Vec, } #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] diff --git a/src/wg.rs b/src/wg.rs index 910e5e6..650aef9 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -4,20 +4,6 @@ use hash_map::HashMap; use std::collections::hash_map; use std::{error, fmt, io}; -#[derive(Clone, PartialEq, Eq, Debug)] -struct Peer { - endpoint: Endpoint, - psk: Option, - keepalive: u32, - ipv4: Vec, - ipv6: Vec, -} - -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct Config { - peers: HashMap, -} - #[derive(Debug)] pub struct ConfigError { pub url: String, @@ -26,6 +12,17 @@ pub struct ConfigError { err: &'static str, } +impl ConfigError { + fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self { + ConfigError { + url: s.url.clone(), + peer: p.public_key.clone(), + important, + err, + } + } +} + impl error::Error for ConfigError {} impl fmt::Display for ConfigError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -37,80 +34,71 @@ impl fmt::Display for ConfigError { } else { "Misconfigured peer" }, - self.peer, self.url, self.err + self.peer, + self.url, + self.err ) } } -impl Config { - pub fn new() -> Config { - Config { +#[derive(Clone, PartialEq, Eq, Debug)] +struct Peer { + endpoint: Option, + psk: Option, + keepalive: u32, + ipv4: Vec, + ipv6: Vec, +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct Config { + peers: HashMap, +} + +pub struct ConfigBuilder<'a> { + peers: HashMap, + pc: &'a config::PeerConfig, +} + +impl<'a> ConfigBuilder<'a> { + pub fn new(pc: &'a config::PeerConfig) -> Self { + ConfigBuilder { peers: HashMap::new(), + pc, } } - pub fn add_peer( - &mut self, - errors: &mut Vec, - c: &config::PeerConfig, + pub fn build(self) -> Config { + Config { peers: self.peers } + } + + fn insert_with<'b>( + &'b mut self, + err: &mut Vec, s: &config::Source, p: &proto::Peer, - ) { - if !valid_key(&p.public_key) { - errors.push(ConfigError { - url: s.url.clone(), - peer: p.public_key.clone(), - important: true, - err: "Invalid public key", - }); - return; - } - - if let Some(ref psk) = s.psk { - if !valid_key(psk) { - errors.push(ConfigError { - url: s.url.clone(), - peer: p.public_key.clone(), - important: true, - err: "Invalid preshared key", - }); - return; - } - } - - if c.omit_peers.contains(&p.public_key) { - return; - } - - let ent = match self.peers.entry(p.public_key.clone()) { + update: impl for<'c> FnOnce(&'c mut Peer) -> (), + ) -> &'b mut Peer { + match self.peers.entry(p.public_key.clone()) { hash_map::Entry::Occupied(ent) => { - errors.push(ConfigError { - url: s.url.clone(), - peer: p.public_key.clone(), - important: true, - err: "Duplicate public key", - }); + err.push(ConfigError::new("Duplicate public key", s, p, true)); ent.into_mut() } hash_map::Entry::Vacant(ent) => { - let mut keepalive = p.keepalive; - if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) { - keepalive = c.max_keepalive; - } - if keepalive != 0 && keepalive < c.min_keepalive { - keepalive = c.min_keepalive; - } - - ent.insert(Peer { - endpoint: p.endpoint.clone(), - psk: s.psk.clone(), - keepalive, + let ent = ent.insert(Peer { + endpoint: None, + psk: None, + keepalive: 0, ipv4: vec![], ipv6: vec![], - }) - }, - }; + }); + update(ent); + ent + } + } + } + fn add_peer(err: &mut Vec, ent: &mut Peer, s: &config::Source, p: &proto::Peer) { let mut added = false; let mut removed = false; @@ -132,24 +120,63 @@ impl Config { } if removed { - errors.push(ConfigError { - url: s.url.clone(), - peer: p.public_key.clone(), - important: !added, - err: if added { - "Some IPs removed" - } else { - "All IPs removed" - }, - }); + let msg = if added { + "Some IPs removed" + } else { + "All IPs removed" + }; + err.push(ConfigError::new(msg, s, p, !added)); } } -} -impl Default for Config { - #[inline] - fn default() -> Self { - Config::new() + pub fn add_server( + &mut self, + err: &mut Vec, + s: &config::Source, + p: &proto::Server, + ) { + if !valid_key(&p.peer.public_key) { + err.push(ConfigError::new("Invalid public key", s, &p.peer, true)); + return; + } + + if p.peer.public_key == self.pc.own_public_key { + return; + } + + let pc = self.pc; + let ent = self.insert_with(err, s, &p.peer, |ent| { + ent.psk = s.psk.clone(); + ent.endpoint = Some(p.endpoint.clone()); + ent.keepalive = pc.fix_keepalive(p.keepalive); + }); + + Self::add_peer(err, ent, s, &p.peer) + } + + pub fn add_road_warrior( + &mut self, + err: &mut Vec, + s: &config::Source, + p: &proto::RoadWarrior, + ) { + if !valid_key(&p.peer.public_key) { + err.push(ConfigError::new("Invalid public key", s, &p.peer, true)); + return; + } + + let ent = if p.base == self.pc.own_public_key { + self.insert_with(err, s, &p.peer, |_| {}) + } else { + match self.peers.get_mut(&p.base) { + Some(ent) => ent, + None => { + err.push(ConfigError::new("Unknown base peer", s, &p.peer, true)); + return; + } + } + }; + Self::add_peer(err, ent, s, &p.peer) } } @@ -174,17 +201,25 @@ impl Device { let mut psks = Vec::<&str>::new(); for (pubkey, conf) in new.peers.iter() { + let old_endpoint; if let Some(old_peer) = old.peers.get(pubkey) { if *old_peer == *conf { continue; } + old_endpoint = old_peer.endpoint.clone(); + } else { + old_endpoint = None; } + proc.arg("peer"); proc.arg(pubkey); - // TODO: maybe skip endpoint? - proc.arg("endpoint"); - proc.arg(format!("{}", conf.endpoint)); + if old_endpoint != conf.endpoint { + if let Some(ref endpoint) = conf.endpoint { + proc.arg("endpoint"); + proc.arg(format!("{}", endpoint)); + } + } if let Some(psk) = &conf.psk { proc.arg("preshared-key"); -- cgit