From b06338ec1d282a762440ad72c935717e404badca Mon Sep 17 00:00:00 2001 From: Hristo Venev Date: Tue, 2 Apr 2019 15:56:06 +0300 Subject: Reorg, sources have names. --- src/bin.rs | 59 ---------- src/config.rs | 19 ++-- src/ip.rs | 344 -------------------------------------------------------- src/main.rs | 293 +++++++++++++---------------------------------- src/manager.rs | 332 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/model.rs | 9 +- src/model/ip.rs | 344 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/proto.rs | 35 +++--- src/wg.rs | 8 +- 9 files changed, 796 insertions(+), 647 deletions(-) delete mode 100644 src/bin.rs delete mode 100644 src/ip.rs create mode 100644 src/manager.rs create mode 100644 src/model/ip.rs (limited to 'src') diff --git a/src/bin.rs b/src/bin.rs deleted file mode 100644 index bd1cf87..0000000 --- a/src/bin.rs +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 Hristo Venev -// -// See COPYING. - -#[inline] -pub fn i64_to_be(v: i64) -> [u8; 8] { - u64_to_be(v as u64) -} - -#[inline] -pub fn i64_from_be(v: [u8; 8]) -> i64 { - u64_from_be(v) as i64 -} - -#[inline] -pub fn u64_to_be(v: u64) -> [u8; 8] { - [ - (v >> 56) as u8, - (v >> 48) as u8, - (v >> 40) as u8, - (v >> 32) as u8, - (v >> 24) as u8, - (v >> 16) as u8, - (v >> 8) as u8, - v as u8, - ] -} - -#[inline] -pub fn u64_from_be(v: [u8; 8]) -> u64 { - (u64::from(v[0]) << 56) - | (u64::from(v[1]) << 48) - | (u64::from(v[2]) << 40) - | (u64::from(v[3]) << 32) - | (u64::from(v[4]) << 24) - | (u64::from(v[5]) << 16) - | (u64::from(v[6]) << 8) - | u64::from(v[7]) -} - -#[inline] -pub fn u32_to_be(v: u32) -> [u8; 4] { - [(v >> 24) as u8, (v >> 16) as u8, (v >> 8) as u8, v as u8] -} - -#[inline] -pub fn u32_from_be(v: [u8; 4]) -> u32 { - (u32::from(v[0]) << 24) | (u32::from(v[1]) << 16) | (u32::from(v[2]) << 8) | u32::from(v[3]) -} - -#[inline] -pub fn u16_to_be(v: u16) -> [u8; 2] { - [(v >> 8) as u8, v as u8] -} - -#[inline] -pub fn u16_from_be(v: [u8; 2]) -> u16 { - (u16::from(v[0]) << 8) | u16::from(v[1]) -} diff --git a/src/config.rs b/src/config.rs index e8dcb05..6eb3761 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,8 +2,9 @@ // // See COPYING. -use crate::ip::{Ipv4Set, Ipv6Set}; -use crate::model::Key; +use crate::model::{Ipv4Set, Ipv6Set, Key}; +use std::path::PathBuf; +use std::collections::HashMap; use serde_derive; #[serde(deny_unknown_fields)] @@ -13,6 +14,8 @@ pub struct Source { pub psk: Option, pub ipv4: Ipv4Set, pub ipv6: Ipv6Set, + #[serde(default)] + pub required: bool, } #[serde(deny_unknown_fields)] @@ -20,7 +23,7 @@ pub struct Source { pub struct PeerConfig { #[serde(default = "default_min_keepalive")] pub min_keepalive: u32, - #[serde(default = "default_max_keepalive")] + #[serde(default)] pub max_keepalive: u32, } @@ -35,6 +38,9 @@ pub struct UpdateConfig { #[serde(deny_unknown_fields)] #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, Debug)] pub struct Config { + pub cache_directory: Option, + pub runtime_directory: Option, + #[serde(flatten)] pub peer_config: PeerConfig, @@ -42,7 +48,7 @@ pub struct Config { pub update_config: UpdateConfig, #[serde(rename = "source")] - pub sources: Vec, + pub sources: HashMap, } impl PeerConfig { @@ -62,11 +68,6 @@ fn default_min_keepalive() -> u32 { 10 } -#[inline] -fn default_max_keepalive() -> u32 { - 0 -} - #[inline] fn default_refresh_sec() -> u32 { 1200 diff --git a/src/ip.rs b/src/ip.rs deleted file mode 100644 index d2ad17e..0000000 --- a/src/ip.rs +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright 2019 Hristo Venev -// -// See COPYING. - -use serde; -use std::iter::{FromIterator, IntoIterator}; -pub use std::net::{Ipv4Addr, Ipv6Addr}; -use std::str::FromStr; -use std::{error, fmt, iter}; - -#[derive(Debug)] -pub struct NetParseError; - -impl error::Error for NetParseError {} -impl fmt::Display for NetParseError { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Invalid address") - } -} - -macro_rules! per_proto { - ($nett:ident ($addrt:ident; $expecting:expr); $intt:ident($bytes:expr); $sett:ident) => { - #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] - pub struct $nett { - pub address: $addrt, - pub prefix_len: u8, - } - - impl $nett { - const BITS: u8 = $bytes * 8; - - pub fn contains(&self, other: &$nett) -> bool { - if self.prefix_len > other.prefix_len { - return false; - } - if self.prefix_len == other.prefix_len { - return self.address == other.address; - } - if self.prefix_len == 0 { - return true; - } - // self.prefix_len < other.prefix_len = BITS - let shift = Self::BITS - self.prefix_len; - let v1: $intt = self.address.into(); - let v2: $intt = other.address.into(); - v1 >> shift == v2 >> shift - } - } - - impl fmt::Display for $nett { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}/{}", self.address, self.prefix_len) - } - } - - impl FromStr for $nett { - type Err = NetParseError; - fn from_str(s: &str) -> Result<$nett, NetParseError> { - let (addr, pfx) = pfx_split(s)?; - let addr = $addrt::from_str(addr).map_err(|_| NetParseError)?; - - let r = $nett { - address: addr, - prefix_len: pfx, - }; - if !r.is_valid() { - return Err(NetParseError); - } - Ok(r) - } - } - - impl serde::Serialize for $nett { - fn serialize(&self, ser: S) -> Result { - if ser.is_human_readable() { - ser.collect_str(self) - } else { - let mut buf = [0u8; $bytes + 1]; - *array_mut_ref![&mut buf, 0, $bytes] = self.address.octets(); - buf[$bytes] = self.prefix_len; - ser.serialize_bytes(&buf) - } - } - } - - impl<'de> serde::Deserialize<'de> for $nett { - fn deserialize>(de: D) -> Result { - if de.is_human_readable() { - struct NetVisitor; - impl<'de> serde::de::Visitor<'de> for NetVisitor { - type Value = $nett; - - #[inline] - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str($expecting) - } - - #[inline] - fn visit_str(self, s: &str) -> Result { - s.parse().map_err(E::custom) - } - } - de.deserialize_str(NetVisitor) - } else { - let buf = <[u8; $bytes + 1] as serde::Deserialize>::deserialize(de)?; - let r = $nett { - address: (*array_ref![&buf, 0, $bytes]).into(), - prefix_len: buf[$bytes], - }; - if r.is_valid() { - return Err(serde::de::Error::custom(NetParseError)); - } - Ok(r) - } - } - } - - #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] - pub struct $sett { - nets: Vec<$nett>, - } - - impl Default for $sett { - #[inline] - fn default() -> Self { - $sett::new() - } - } - - impl $sett { - #[inline] - pub fn new() -> Self { - $sett { nets: vec![] } - } - - #[inline] - fn siblings(a: &$nett, b: &$nett) -> bool { - let pfx = a.prefix_len; - if b.prefix_len != pfx || pfx == 0 { - return false; - } - let a: $intt = a.address.into(); - let b: $intt = b.address.into(); - a ^ b == 1 << ($nett::BITS - pfx) - } - - pub fn insert(&mut self, mut net: $nett) { - let mut i = match self.nets.binary_search(&net) { - Err(i) => i, - Ok(_) => { - return; - } - }; - let mut j = i; - if i != 0 && self.nets[i - 1].contains(&net) { - net = self.nets[i - 1]; - i -= 1; - } - while j < self.nets.len() && net.contains(&self.nets[j]) { - j += 1; - } - loop { - if j < self.nets.len() && Self::siblings(&net, &self.nets[j]) { - j += 1; - } else if i != 0 && Self::siblings(&self.nets[i - 1], &net) { - net = self.nets[i - 1]; - i -= 1; - } else { - break; - } - net.prefix_len -= 1; - } - self.nets.splice(i..j, iter::once(net)); - } - - pub fn contains(&self, net: &$nett) -> bool { - match self.nets.binary_search(&net) { - Err(i) => { - if i == 0 { - return false; - } - self.nets[i - 1].contains(&net) - } - Ok(_) => true, - } - } - - #[inline] - pub fn iter(&self) -> std::slice::Iter<$nett> { - self.nets.iter() - } - } - - impl IntoIterator for $sett { - type Item = $nett; - type IntoIter = std::vec::IntoIter<$nett>; - - #[inline] - fn into_iter(self) -> Self::IntoIter { - self.nets.into_iter() - } - } - - impl FromIterator<$nett> for $sett { - #[inline] - fn from_iter>(it: I) -> $sett { - let mut r = $sett::new(); - for net in it { - r.insert(net); - } - r - } - } - - impl<'a> From<$nett> for $sett { - #[inline] - fn from(v: $nett) -> $sett { - $sett { nets: vec![v] } - } - } - - impl<'a> From<[$nett; 1]> for $sett { - #[inline] - fn from(v: [$nett; 1]) -> $sett { - $sett { nets: vec![v[0]] } - } - } - - impl From<$sett> for Vec<$nett> { - fn from(v: $sett) -> Vec<$nett> { - v.nets - } - } - - impl From> for $sett { - fn from(nets: Vec<$nett>) -> $sett { - let mut s = $sett { nets }; - let len = s.nets.len(); - if len == 0 { - return s; - } - s.nets.sort(); - let mut i = 1; - for j in 1..len { - let mut net = s.nets[j]; - if s.nets[i - 1].contains(&net) { - net = s.nets[i - 1]; - i -= 1; - } - while i != 0 && Self::siblings(&s.nets[i - 1], &net) { - net = s.nets[i - 1]; - net.prefix_len -= 1; - i -= 1; - } - s.nets[i] = net; - i += 1; - } - s.nets.splice(i.., iter::empty()); - s - } - } - - impl<'a> From<&'a [$nett]> for $sett { - #[inline] - fn from(nets: &'a [$nett]) -> $sett { - Vec::from(nets).into() - } - } - - impl<'a> From<&'a mut [$nett]> for $sett { - #[inline] - fn from(nets: &'a mut [$nett]) -> $sett { - Vec::from(nets).into() - } - } - - impl serde::Serialize for $sett { - #[inline] - fn serialize(&self, ser: S) -> Result { - as serde::Serialize>::serialize(&self.nets, ser) - } - } - - impl<'de> serde::Deserialize<'de> for $sett { - #[inline] - fn deserialize>(de: D) -> Result { - as serde::Deserialize>::deserialize(de).map($sett::from) - } - } - }; -} - -per_proto!(Ipv4Net(Ipv4Addr; "IPv4 network"); u32(4); Ipv4Set); -per_proto!(Ipv6Net(Ipv6Addr; "IPv6 network"); u128(16); Ipv6Set); - -impl Ipv4Net { - pub fn is_valid(&self) -> bool { - let pfx = self.prefix_len; - if pfx > 32 { - return false; - } - if pfx == 32 { - return true; - } - let val: u32 = self.address.into(); - val & (u32::max_value() >> pfx) == 0 - } -} - -impl Ipv6Net { - pub fn is_valid(&self) -> bool { - let pfx = self.prefix_len; - if pfx > 128 { - return false; - } - if pfx == 128 { - return true; - } - - let val: u128 = self.address.into(); - let val: [u64; 2] = [(val >> 64) as u64, val as u64]; - if pfx >= 64 { - return val[1] & (u64::max_value() >> (pfx - 64)) == 0; - } - if val[1] != 0 { - return false; - } - val[0] & (u64::max_value() >> pfx) == 0 - } -} - -fn pfx_split(s: &str) -> Result<(&str, u8), NetParseError> { - let i = match s.find('/') { - Some(i) => i, - None => { - return Err(NetParseError); - } - }; - let (addr, pfx) = s.split_at(i); - let pfx = u8::from_str(&pfx[1..]).map_err(|_| NetParseError)?; - Ok((addr, pfx)) -} diff --git a/src/main.rs b/src/main.rs index 9ad0f9b..dfadaf4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,212 +5,19 @@ #[macro_use] extern crate arrayref; -use std::io; -use std::time::{Duration, Instant, SystemTime}; +use std::{env, fs, io, process, thread}; +use std::time::Instant; +use std::ffi::{OsStr, OsString}; +use toml; -mod bin; mod builder; -mod config; -mod ip; mod model; +mod config; mod proto; mod wg; +mod manager; -struct Source { - config: config::Source, - data: Option, - next_update: Instant, - backoff: Option, -} - -impl Source { - #[inline] - fn new(config: config::Source) -> Source { - Source { - config, - data: None, - next_update: Instant::now(), - backoff: None, - } - } -} - -pub struct Device { - dev: wg::Device, - peer_config: config::PeerConfig, - update_config: config::UpdateConfig, - sources: Vec, - current: model::Config, -} - -impl Device { - pub fn new(ifname: String, c: config::Config) -> io::Result { - let dev = wg::Device::new(ifname)?; - let _ = dev.get_public_key()?; - - Ok(Device { - dev, - peer_config: c.peer_config, - update_config: c.update_config, - sources: c.sources.into_iter().map(Source::new).collect(), - current: model::Config::default(), - }) - } - - fn make_config( - &self, - public_key: model::Key, - ts: SystemTime, - ) -> (model::Config, Vec, SystemTime) { - let mut t_cfg = ts + Duration::from_secs(1 << 30); - let mut sources: Vec<(&Source, &proto::SourceConfig)> = vec![]; - for src in self.sources.iter() { - if let Some(ref data) = src.data { - let sc = data - .next - .as_ref() - .and_then(|next| { - if ts >= next.update_at { - Some(&next.config) - } else { - t_cfg = t_cfg.min(next.update_at); - None - } - }) - .unwrap_or(&data.config); - sources.push((src, sc)); - } - } - - let mut cfg = builder::ConfigBuilder::new(public_key, &self.peer_config); - - for (src, sc) in sources.iter() { - for peer in sc.servers.iter() { - cfg.add_server(&src.config, peer); - } - } - - for (src, sc) in sources.iter() { - for peer in sc.road_warriors.iter() { - cfg.add_road_warrior(&src.config, peer); - } - } - - let (cfg, errs) = cfg.build(); - (cfg, errs, t_cfg) - } - - pub fn update(&mut self) -> io::Result { - let refresh = Duration::from_secs(u64::from(self.update_config.refresh_sec)); - let mut now = Instant::now(); - let mut t_refresh = now + refresh; - - for src in self.sources.iter_mut() { - if now < src.next_update { - t_refresh = t_refresh.min(src.next_update); - continue; - } - - let r = fetch_source(&src.config.url); - now = Instant::now(); - let r = match r { - Ok(r) => { - eprintln!("<6>Updated [{}]", &src.config.url); - src.data = Some(r); - src.backoff = None; - src.next_update = now + refresh; - continue; - } - Err(r) => r, - }; - - let b = src.backoff.unwrap_or(if src.data.is_some() { - refresh / 3 - } else { - Duration::from_secs(10).min(refresh / 10) - }); - src.next_update = now + b; - t_refresh = t_refresh.min(src.next_update); - - eprintln!("<3>Failed to update [{}], retrying after {:.1?}: {}", &src.config.url, b, &r); - - src.backoff = Some((b + b / 3).min(refresh)); - } - - let now = Instant::now(); - let sysnow = SystemTime::now(); - let public_key = self.dev.get_public_key()?; - let (config, errors, t_cfg) = self.make_config(public_key, sysnow); - let time_to_cfg = t_cfg - .duration_since(sysnow) - .unwrap_or(Duration::from_secs(0)); - let t_cfg = now + time_to_cfg; - - if config != self.current { - eprintln!("<5>Applying configuration update"); - for err in errors.iter() { - eprintln!("<{}>{}", if err.important { '4' } else { '5' }, err); - } - self.dev.apply_diff(&self.current, &config)?; - self.current = config; - } - - Ok(if t_cfg < t_refresh { - eprintln!("<6>Next configuration update after {:.1?}", time_to_cfg); - t_cfg - } else if t_refresh > now { - t_refresh - } else { - eprintln!("<4>Next refresh immediately?"); - now - }) - } -} - -fn fetch_source(url: &str) -> io::Result { - use std::env; - use std::ffi::{OsStr, OsString}; - use std::process::{Command, Stdio}; - - let curl = match env::var_os("CURL") { - None => OsString::new(), - Some(v) => v, - }; - let mut proc = Command::new(if curl.is_empty() { - OsStr::new("curl") - } else { - curl.as_os_str() - }); - - proc.stdin(Stdio::null()); - proc.stdout(Stdio::piped()); - proc.stderr(Stdio::piped()); - proc.arg("-gsSfL"); - proc.arg("--fail-early"); - proc.arg("--max-time"); - proc.arg("10"); - proc.arg("--max-filesize"); - proc.arg("1M"); - proc.arg("--"); - proc.arg(url); - - let out = proc.output()?; - - if !out.status.success() { - let msg = String::from_utf8_lossy(&out.stderr); - let msg = msg.replace('\n', "; "); - return Err(io::Error::new(io::ErrorKind::Other, msg)); - } - - let mut de = serde_json::Deserializer::from_slice(&out.stdout); - let r = serde::Deserialize::deserialize(&mut de)?; - Ok(r) -} - -fn load_config(path: &str) -> io::Result { - use std::fs; - use toml; - +fn load_config(path: &OsStr) -> io::Result { let mut data = String::new(); { use io::Read; @@ -222,24 +29,36 @@ fn load_config(path: &str) -> io::Result { .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) } -fn main() { - use std::{env, process, thread}; +fn usage(argv0: &str) -> i32 { + eprintln!("<1>Invalid arguments. See `{} --help` for more information", argv0); + 1 +} + +fn help(argv0: &str) -> i32 { + println!("Usage:"); + println!(" {} IFNAME CONFIG - run daemon on iterface", argv0); + println!(" {} --check-source PATH - validate source JSON", argv0); + 1 +} - let args: Vec = env::args().collect(); - if args.len() != 3 { - let arg0 = if !args.is_empty() { &args[0] } else { "wgconf" }; - eprintln!("<1>Usage:"); - eprintln!("<1> {} IFNAME CONFIG", arg0); - process::exit(1); +fn maybe_get_var(out: &mut Option>, var: impl AsRef) { + let var = var.as_ref(); + if let Some(s) = env::var_os(var) { + env::remove_var(var); + *out = Some(s.into()); } +} +fn run_daemon(argv0: String, args: Vec) -> i32 { + if args.len() != 2 { + return usage(&argv0); + } let mut args = args.into_iter(); - let _ = args.next().unwrap(); let ifname = args.next().unwrap(); let config_path = args.next().unwrap(); assert!(args.next().is_none()); - let config = match load_config(&config_path) { + let mut config = match load_config(&config_path) { Ok(c) => c, Err(e) => { eprintln!("<1>Failed to load config: {}", e); @@ -247,8 +66,11 @@ fn main() { } }; - let mut dev = match Device::new(ifname, config) { - Ok(dev) => dev, + maybe_get_var(&mut config.cache_directory, "CACHE_DIRECTORY"); + maybe_get_var(&mut config.runtime_directory, "RUNTIME_DIRECTORY"); + + let mut m = match manager::Manager::new(ifname, config) { + Ok(m) => m, Err(e) => { eprintln!("<1>Failed to open device: {}", e); process::exit(1); @@ -256,7 +78,7 @@ fn main() { }; loop { - let tm = match dev.update() { + let tm = match m.update() { Ok(t) => t, Err(e) => { eprintln!("<1>{}", e); @@ -270,3 +92,48 @@ fn main() { } } } + +fn run_check_source(argv0: String, args: Vec) -> i32 { + if args.len() != 1 { + usage(&argv0); + } + let mut args = args.into_iter(); + let path = args.next().unwrap(); + assert!(args.next().is_none()); + + match manager::load_source(&path) { + Ok(_) => { + println!("OK"); + 0 + } + Err(e) => { + println!("{}", e); + 1 + } + } +} + +fn main() -> () { + let mut iter_args = env::args_os(); + let argv0 = iter_args.next().unwrap().to_string_lossy().into_owned(); + + let mut args = Vec::new(); + let mut run: for<'a> fn(String, Vec) -> i32 = run_daemon; + let mut parse_args = true; + for arg in iter_args { + if !parse_args || !arg.to_string_lossy().starts_with('-') { + args.push(arg); + } else if arg == "--" { + parse_args = false; + } else if arg == "-h" || arg == "--help" { + process::exit(help(&argv0)); + } else if arg == "--check-source" { + run = run_check_source; + parse_args = false; + } else { + usage(&argv0); + } + } + + process::exit(run(argv0, args)); +} diff --git a/src/manager.rs b/src/manager.rs new file mode 100644 index 0000000..20d7f50 --- /dev/null +++ b/src/manager.rs @@ -0,0 +1,332 @@ +// Copyright 2019 Hristo Venev +// +// See COPYING. + +use std::{io, fs}; +use crate::{builder, config, model, proto, wg}; +use std::time::{Duration, Instant, SystemTime}; +use std::ffi::{OsStr, OsString}; +use std::path::PathBuf; + +struct Source { + name: String, + config: config::Source, + data: proto::Source, + next_update: Instant, + backoff: Option, +} + +struct Updater { + config: config::UpdateConfig, + cache_directory: Option, +} + +impl Updater { + fn cache_path(&self, s: &Source) -> Option { + if let Some(ref dir) = self.cache_directory { + let mut p = dir.clone(); + p.push(&s.name); + Some(p) + } else { + None + } + } + + fn cache_update(&self, src: &Source) -> io::Result { + let path = match self.cache_path(src) { + Some(path) => path, + None => { + return Ok(false); + } + }; + + let mut tmp_path = OsString::from(path.clone()); + tmp_path.push(".tmp"); + let tmp_path = PathBuf::from(tmp_path); + + let data = serde_json::to_vec(&src.data).unwrap(); + + let mut file = fs::File::create(&tmp_path)?; + match io::Write::write_all(&mut file, &data) + .and_then(|_| file.sync_data()) + .and_then(|_| fs::rename(&tmp_path, &path)) + { + Ok(()) => {} + Err(e) => { + fs::remove_file(&tmp_path).unwrap_or_else(|e2| { + eprintln!("<3>Failed to clean up [{}]: {}", tmp_path.display(), e2); + }); + return Err(e); + } + } + + Ok(true) + } + + fn cache_load(&self, src: &mut Source) -> bool { + let path = match self.cache_path(src) { + Some(path) => path, + None => { + return false; + } + }; + + let mut file = match fs::File::open(&path) { + Ok(file) => file, + Err(_) => { + return false; + } + }; + + let mut data = Vec::new(); + match io::Read::read_to_end(&mut file, &mut data) { + Ok(_) => {} + Err(e) => { + eprintln!("<3>Failed to read [{}] from cache: {}", src.config.url, e); + return false; + } + }; + + let mut de = serde_json::Deserializer::from_slice(&data); + src.data = match serde::Deserialize::deserialize(&mut de) { + Ok(r) => r, + Err(e) => { + eprintln!("<3>Failed to load [{}] from cache: {}", src.config.url, e); + return false; + } + }; + + true + } + + fn update(&self, src: &mut Source) -> (bool, Instant) { + let refresh = Duration::from_secs(u64::from(self.config.refresh_sec)); + + let r = fetch_source(&src.config.url); + let now = Instant::now(); + let r = match r { + Ok(r) => { + eprintln!("<6>Updated [{}]", &src.config.url); + src.data = r; + src.backoff = None; + src.next_update = now + refresh; + match self.cache_update(src) { + Ok(_) => {} + Err(e) => { + eprintln!("<4>Failed to cache [{}]: {}", &src.config.url, e); + } + } + return (true, now); + } + Err(r) => r, + }; + + let b = src.backoff.unwrap_or_else(|| Duration::from_secs(10).min(refresh / 10)); + src.next_update = now + b; + src.backoff = Some((b + b / 3).min(refresh / 3)); + eprintln!("<3>Failed to update [{}], retrying after {:.1?}: {}", &src.config.url, b, &r); + (false, now) + } + +} + +pub struct Manager { + dev: wg::Device, + peer_config: config::PeerConfig, + sources: Vec, + current: model::Config, + updater: Updater, +} + +impl Manager { + pub fn new(ifname: OsString, c: config::Config) -> io::Result { + let mut m = Manager { + dev: wg::Device::new(ifname)?, + peer_config: c.peer_config, + sources: vec![], + current: model::Config::default(), + updater: Updater { + config: c.update_config, + cache_directory: c.cache_directory, + }, + }; + + for (name, cfg) in c.sources.into_iter() { + m.add_source(name, cfg)?; + } + + Ok(m) + } + + fn add_source(&mut self, name: String, config: config::Source) -> io::Result<()> { + let mut s = Source { + name, + config, + data: proto::Source::empty(), + next_update: Instant::now(), + backoff: None, + }; + + self.init_source(&mut s)?; + self.sources.push(s); + Ok(()) + } + + fn init_source(&mut self, s: &mut Source) -> io::Result<()> { + if self.updater.update(s).0 { + return Ok(()); + } + if self.updater.cache_load(s) { + return Ok(()); + } + if !s.config.required { + return Ok(()); + } + if self.updater.update(s).0 { + return Ok(()); + } + if self.updater.update(s).0 { + return Ok(()); + } + Err(io::Error::new(io::ErrorKind::Other, format!("Failed to update required source [{}]", &s.config.url))) + } + + fn make_config( + &self, + public_key: model::Key, + ts: SystemTime, + ) -> (model::Config, Vec, SystemTime) { + let mut t_cfg = ts + Duration::from_secs(1 << 30); + let mut sources: Vec<(&Source, &proto::SourceConfig)> = vec![]; + for src in self.sources.iter() { + let sc = src.data.next + .as_ref() + .and_then(|next| { + if ts >= next.update_at { + Some(&next.config) + } else { + t_cfg = t_cfg.min(next.update_at); + None + } + }) + .unwrap_or(&src.data.config); + sources.push((src, sc)); + } + + let mut cfg = builder::ConfigBuilder::new(public_key, &self.peer_config); + + for (src, sc) in sources.iter() { + for peer in sc.servers.iter() { + cfg.add_server(&src.config, peer); + } + } + + for (src, sc) in sources.iter() { + for peer in sc.road_warriors.iter() { + cfg.add_road_warrior(&src.config, peer); + } + } + + let (cfg, errs) = cfg.build(); + (cfg, errs, t_cfg) + } + + fn refresh(&mut self) -> io::Result { + let refresh = Duration::from_secs(u64::from(self.updater.config.refresh_sec)); + let mut now = Instant::now(); + let mut t_refresh = now + refresh; + + for src in self.sources.iter_mut() { + if now >= src.next_update { + now = self.updater.update(src).1; + } + t_refresh = t_refresh.min(src.next_update); + } + + Ok(t_refresh) + } + + pub fn update(&mut self) -> io::Result { + let t_refresh = self.refresh()?; + + let public_key = self.dev.get_public_key()?; + let now = Instant::now(); + let sysnow = SystemTime::now(); + let (config, errors, t_cfg) = self.make_config(public_key, sysnow); + let time_to_cfg = t_cfg + .duration_since(sysnow) + .unwrap_or(Duration::from_secs(0)); + let t_cfg = now + time_to_cfg; + + if config != self.current { + eprintln!("<5>Applying configuration update"); + for err in errors.iter() { + eprintln!("<{}>{}", if err.important { '4' } else { '5' }, err); + } + self.dev.apply_diff(&self.current, &config)?; + self.current = config; + } + + Ok(if t_cfg < t_refresh { + eprintln!("<6>Next configuration update after {:.1?}", time_to_cfg); + t_cfg + } else if t_refresh > now { + t_refresh + } else { + eprintln!("<4>Next refresh immediately?"); + now + }) + } +} + +pub fn fetch_source(url: &str) -> io::Result { + use std::env; + use std::process::{Command, Stdio}; + + let curl = match env::var_os("CURL") { + None => OsString::new(), + Some(v) => v, + }; + let mut proc = Command::new(if curl.is_empty() { + OsStr::new("curl") + } else { + curl.as_os_str() + }); + + proc.stdin(Stdio::null()); + proc.stdout(Stdio::piped()); + proc.stderr(Stdio::piped()); + proc.arg("-gsSfL"); + proc.arg("--fail-early"); + proc.arg("--max-time"); + proc.arg("10"); + proc.arg("--max-filesize"); + proc.arg("1M"); + proc.arg("--"); + proc.arg(url); + + let out = proc.output()?; + + if !out.status.success() { + let msg = String::from_utf8_lossy(&out.stderr); + let msg = msg.replace('\n', "; "); + return Err(io::Error::new(io::ErrorKind::Other, msg)); + } + + let mut de = serde_json::Deserializer::from_slice(&out.stdout); + let r = serde::Deserialize::deserialize(&mut de)?; + Ok(r) +} + +pub fn load_source(path: &OsStr) -> io::Result { + let mut data = Vec::new(); + { + use std::io::Read; + let mut f = fs::File::open(&path)?; + f.read_to_end(&mut data)?; + } + + let mut de = serde_json::Deserializer::from_slice(&data); + let r = serde::Deserialize::deserialize(&mut de)?; + Ok(r) +} diff --git a/src/model.rs b/src/model.rs index f8b76b5..bc800b2 100644 --- a/src/model.rs +++ b/src/model.rs @@ -2,13 +2,14 @@ // // See COPYING. -use crate::bin; -use crate::ip::{Ipv4Addr, Ipv4Net, Ipv6Addr, Ipv6Net, NetParseError}; use base64; use std::collections::HashMap; use std::fmt; use std::str::FromStr; +mod ip; +pub use ip::*; + pub type KeyParseError = base64::DecodeError; #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] @@ -142,7 +143,7 @@ impl serde::Serialize for Endpoint { let mut buf = [0u8; 16 + 2]; let (buf_addr, buf_port) = mut_array_refs![&mut buf, 16, 2]; *buf_addr = self.address.octets(); - *buf_port = crate::bin::u16_to_be(self.port); + *buf_port = self.port.to_be_bytes(); ser.serialize_bytes(&buf) } } @@ -171,7 +172,7 @@ impl<'de> serde::Deserialize<'de> for Endpoint { let (buf_addr, buf_port) = array_refs![&buf, 16, 2]; Ok(Endpoint { address: (*buf_addr).into(), - port: bin::u16_from_be(*buf_port), + port: u16::from_be_bytes(*buf_port), }) } } diff --git a/src/model/ip.rs b/src/model/ip.rs new file mode 100644 index 0000000..0ada314 --- /dev/null +++ b/src/model/ip.rs @@ -0,0 +1,344 @@ +// Copyright 2019 Hristo Venev +// +// See COPYING. + +use serde; +use std::iter::{FromIterator, IntoIterator}; +pub use std::net::{Ipv4Addr, Ipv6Addr}; +use std::str::FromStr; +use std::{error, fmt, iter}; + +#[derive(Debug)] +pub struct NetParseError; + +impl error::Error for NetParseError {} +impl fmt::Display for NetParseError { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Invalid address") + } +} + +macro_rules! per_proto { + ($nett:ident ($addrt:ident; $expecting:expr); $intt:ident($bytes:expr); $sett:ident) => { + #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] + pub struct $nett { + pub address: $addrt, + pub prefix_len: u8, + } + + impl $nett { + const BITS: u8 = $bytes * 8; + + pub fn contains(&self, other: &$nett) -> bool { + if self.prefix_len > other.prefix_len { + return false; + } + if self.prefix_len == other.prefix_len { + return self.address == other.address; + } + if self.prefix_len == 0 { + return true; + } + // self.prefix_len < other.prefix_len = BITS + let shift = Self::BITS - self.prefix_len; + let v1: $intt = self.address.into(); + let v2: $intt = other.address.into(); + v1 >> shift == v2 >> shift + } + } + + impl fmt::Display for $nett { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}/{}", self.address, self.prefix_len) + } + } + + impl FromStr for $nett { + type Err = NetParseError; + fn from_str(s: &str) -> Result<$nett, NetParseError> { + let (addr, pfx) = pfx_split(s)?; + let addr = $addrt::from_str(addr).map_err(|_| NetParseError)?; + + let r = $nett { + address: addr, + prefix_len: pfx, + }; + if !r.is_valid() { + return Err(NetParseError); + } + Ok(r) + } + } + + impl serde::Serialize for $nett { + fn serialize(&self, ser: S) -> Result { + if ser.is_human_readable() { + ser.collect_str(self) + } else { + let mut buf = [0u8; $bytes + 1]; + *array_mut_ref![&mut buf, 0, $bytes] = self.address.octets(); + buf[$bytes] = self.prefix_len; + ser.serialize_bytes(&buf) + } + } + } + + impl<'de> serde::Deserialize<'de> for $nett { + fn deserialize>(de: D) -> Result { + if de.is_human_readable() { + struct NetVisitor; + impl<'de> serde::de::Visitor<'de> for NetVisitor { + type Value = $nett; + + #[inline] + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str($expecting) + } + + #[inline] + fn visit_str(self, s: &str) -> Result { + s.parse().map_err(E::custom) + } + } + de.deserialize_str(NetVisitor) + } else { + let buf = <[u8; $bytes + 1] as serde::Deserialize>::deserialize(de)?; + let r = $nett { + address: (*array_ref![&buf, 0, $bytes]).into(), + prefix_len: buf[$bytes], + }; + if r.is_valid() { + return Err(serde::de::Error::custom(NetParseError)); + } + Ok(r) + } + } + } + + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] + pub struct $sett { + nets: Vec<$nett>, + } + + impl Default for $sett { + #[inline] + fn default() -> Self { + $sett::new() + } + } + + impl $sett { + #[inline] + pub fn new() -> Self { + $sett { nets: vec![] } + } + + #[inline] + fn siblings(a: $nett, b: $nett) -> bool { + let pfx = a.prefix_len; + if b.prefix_len != pfx || pfx == 0 { + return false; + } + let a: $intt = a.address.into(); + let b: $intt = b.address.into(); + a ^ b == 1 << ($nett::BITS - pfx) + } + + pub fn insert(&mut self, mut net: $nett) { + let mut i = match self.nets.binary_search(&net) { + Err(i) => i, + Ok(_) => { + return; + } + }; + let mut j = i; + if i != 0 && self.nets[i - 1].contains(&net) { + net = self.nets[i - 1]; + i -= 1; + } + while j < self.nets.len() && net.contains(&self.nets[j]) { + j += 1; + } + loop { + if j < self.nets.len() && Self::siblings(net, self.nets[j]) { + j += 1; + } else if i != 0 && Self::siblings(self.nets[i - 1], net) { + net = self.nets[i - 1]; + i -= 1; + } else { + break; + } + net.prefix_len -= 1; + } + self.nets.splice(i..j, iter::once(net)); + } + + pub fn contains(&self, net: &$nett) -> bool { + match self.nets.binary_search(&net) { + Err(i) => { + if i == 0 { + return false; + } + self.nets[i - 1].contains(&net) + } + Ok(_) => true, + } + } + + #[inline] + pub fn iter(&self) -> std::slice::Iter<$nett> { + self.nets.iter() + } + } + + impl IntoIterator for $sett { + type Item = $nett; + type IntoIter = std::vec::IntoIter<$nett>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.nets.into_iter() + } + } + + impl FromIterator<$nett> for $sett { + #[inline] + fn from_iter>(it: I) -> $sett { + let mut r = $sett::new(); + for net in it { + r.insert(net); + } + r + } + } + + impl<'a> From<$nett> for $sett { + #[inline] + fn from(v: $nett) -> $sett { + $sett { nets: vec![v] } + } + } + + impl<'a> From<[$nett; 1]> for $sett { + #[inline] + fn from(v: [$nett; 1]) -> $sett { + $sett { nets: vec![v[0]] } + } + } + + impl From<$sett> for Vec<$nett> { + fn from(v: $sett) -> Vec<$nett> { + v.nets + } + } + + impl From> for $sett { + fn from(nets: Vec<$nett>) -> $sett { + let mut s = $sett { nets }; + let len = s.nets.len(); + if len == 0 { + return s; + } + s.nets.sort(); + let mut i = 1; + for j in 1..len { + let mut net = s.nets[j]; + if s.nets[i - 1].contains(&net) { + net = s.nets[i - 1]; + i -= 1; + } + while i != 0 && Self::siblings(s.nets[i - 1], net) { + net = s.nets[i - 1]; + net.prefix_len -= 1; + i -= 1; + } + s.nets[i] = net; + i += 1; + } + s.nets.splice(i.., iter::empty()); + s + } + } + + impl<'a> From<&'a [$nett]> for $sett { + #[inline] + fn from(nets: &'a [$nett]) -> $sett { + Vec::from(nets).into() + } + } + + impl<'a> From<&'a mut [$nett]> for $sett { + #[inline] + fn from(nets: &'a mut [$nett]) -> $sett { + Vec::from(nets).into() + } + } + + impl serde::Serialize for $sett { + #[inline] + fn serialize(&self, ser: S) -> Result { + as serde::Serialize>::serialize(&self.nets, ser) + } + } + + impl<'de> serde::Deserialize<'de> for $sett { + #[inline] + fn deserialize>(de: D) -> Result { + as serde::Deserialize>::deserialize(de).map($sett::from) + } + } + }; +} + +per_proto!(Ipv4Net(Ipv4Addr; "IPv4 network"); u32(4); Ipv4Set); +per_proto!(Ipv6Net(Ipv6Addr; "IPv6 network"); u128(16); Ipv6Set); + +impl Ipv4Net { + pub fn is_valid(self) -> bool { + let pfx = self.prefix_len; + if pfx > 32 { + return false; + } + if pfx == 32 { + return true; + } + let val: u32 = self.address.into(); + val & (u32::max_value() >> pfx) == 0 + } +} + +impl Ipv6Net { + pub fn is_valid(self) -> bool { + let pfx = self.prefix_len; + if pfx > 128 { + return false; + } + if pfx == 128 { + return true; + } + + let val: u128 = self.address.into(); + let val: [u64; 2] = [(val >> 64) as u64, val as u64]; + if pfx >= 64 { + return val[1] & (u64::max_value() >> (pfx - 64)) == 0; + } + if val[1] != 0 { + return false; + } + val[0] & (u64::max_value() >> pfx) == 0 + } +} + +fn pfx_split(s: &str) -> Result<(&str, u8), NetParseError> { + let i = match s.find('/') { + Some(i) => i, + None => { + return Err(NetParseError); + } + }; + let (addr, pfx) = s.split_at(i); + let pfx = u8::from_str(&pfx[1..]).map_err(|_| NetParseError)?; + Ok((addr, pfx)) +} diff --git a/src/proto.rs b/src/proto.rs index 86c7eee..9c98fff 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -2,8 +2,7 @@ // // See COPYING. -use crate::ip::{Ipv4Net, Ipv6Net}; -use crate::model::{Endpoint, Key}; +use crate::model::{Endpoint, Ipv4Net, Ipv6Net, Key}; use serde_derive; use std::time::SystemTime; @@ -11,9 +10,9 @@ use std::time::SystemTime; #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct Peer { pub public_key: Key, - #[serde(default = "Vec::new")] + #[serde(default)] pub ipv4: Vec, - #[serde(default = "Vec::new")] + #[serde(default)] pub ipv6: Vec, } @@ -23,7 +22,7 @@ pub struct Server { #[serde(flatten)] pub peer: Peer, pub endpoint: Endpoint, - #[serde(default = "default_peer_keepalive")] + #[serde(default)] pub keepalive: u32, } @@ -37,9 +36,9 @@ pub struct RoadWarrior { #[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct SourceConfig { - #[serde(default = "Vec::new")] + #[serde(default)] pub servers: Vec, - #[serde(default = "Vec::new")] + #[serde(default)] pub road_warriors: Vec, } @@ -58,13 +57,19 @@ pub struct Source { pub next: Option, } -#[inline] -fn default_peer_keepalive() -> u32 { - 0 +impl Source { + pub fn empty() -> Source { + Source { + config: SourceConfig { + servers: vec![], + road_warriors: vec![], + }, + next: None, + } + } } mod serde_utc { - use crate::bin; use chrono::{DateTime, SecondsFormat, TimeZone, Utc}; use serde::*; use std::fmt; @@ -77,8 +82,8 @@ mod serde_utc { } else { let mut buf = [0u8; 12]; let (buf_secs, buf_nanos) = mut_array_refs![&mut buf, 8, 4]; - *buf_secs = bin::i64_to_be(t.timestamp()); - *buf_nanos = bin::u32_to_be(t.timestamp_subsec_nanos()); + *buf_secs = t.timestamp().to_be_bytes(); + *buf_nanos = t.timestamp_subsec_nanos().to_be_bytes(); ser.serialize_bytes(&buf) } } @@ -103,8 +108,8 @@ mod serde_utc { } else { let mut buf = <[u8; 12]>::deserialize(de)?; let (buf_secs, buf_nanos) = array_refs![&mut buf, 8, 4]; - let secs = bin::i64_from_be(*buf_secs); - let nanos = bin::u32_from_be(*buf_nanos); + let secs = i64::from_be_bytes(*buf_secs); + let nanos = u32::from_be_bytes(*buf_nanos); Ok(Utc.timestamp(secs, nanos).into()) } } diff --git a/src/wg.rs b/src/wg.rs index b3ec451..745c7b4 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -8,13 +8,15 @@ use std::process::{Command, Stdio}; use std::{env, io}; pub struct Device { - ifname: String, + ifname: OsString, } impl Device { #[inline] - pub fn new(ifname: String) -> io::Result { - Ok(Device { ifname }) + pub fn new(ifname: OsString) -> io::Result { + let dev = Device { ifname }; + let _ = dev.get_public_key()?; + Ok(dev) } pub fn wg_command() -> Command { -- cgit