diff options
-rw-r--r-- | src/bin.rs | 28 | ||||
-rw-r--r-- | src/config.rs | 16 | ||||
-rw-r--r-- | src/ip.rs | 56 | ||||
-rw-r--r-- | src/main.rs | 49 | ||||
-rw-r--r-- | src/proto.rs | 24 | ||||
-rw-r--r-- | src/wg.rs | 135 |
6 files changed, 167 insertions, 141 deletions
@@ -21,30 +21,22 @@ pub fn u64_to_be(v: u64) -> [u8; 8] { } pub fn u64_from_be(v: [u8; 8]) -> u64 { - (u64::from(v[0]) << 56) | - (u64::from(v[1]) << 48) | - (u64::from(v[2]) << 40) | - (u64::from(v[3]) << 32) | - (u64::from(v[4]) << 24) | - (u64::from(v[5]) << 16) | - (u64::from(v[6]) << 8) | - u64::from(v[7]) + (u64::from(v[0]) << 56) + | (u64::from(v[1]) << 48) + | (u64::from(v[2]) << 40) + | (u64::from(v[3]) << 32) + | (u64::from(v[4]) << 24) + | (u64::from(v[5]) << 16) + | (u64::from(v[6]) << 8) + | u64::from(v[7]) } pub fn u32_to_be(v: u32) -> [u8; 4] { - [ - (v >> 24) as u8, - (v >> 16) as u8, - (v >> 8) as u8, - v as u8, - ] + [(v >> 24) as u8, (v >> 16) as u8, (v >> 8) as u8, v as u8] } pub fn u32_from_be(v: [u8; 4]) -> u32 { - (u32::from(v[0]) << 24) | - (u32::from(v[1]) << 16) | - (u32::from(v[2]) << 8) | - u32::from(v[3]) + (u32::from(v[0]) << 24) | (u32::from(v[1]) << 16) | (u32::from(v[2]) << 8) | u32::from(v[3]) } pub fn u16_to_be(v: u16) -> [u8; 2] { diff --git a/src/config.rs b/src/config.rs index 6411b3a..833b546 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,10 +1,9 @@ -use ::std::collections::HashSet; -use ::serde_derive; use crate::ip::{Ipv4Set, Ipv6Set}; +use serde_derive; +use std::collections::HashSet; #[serde(deny_unknown_fields)] -#[derive(serde_derive::Serialize, serde_derive::Deserialize)] -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct Source { pub url: String, pub psk: Option<String>, @@ -13,8 +12,7 @@ pub struct Source { } #[serde(deny_unknown_fields)] -#[derive(serde_derive::Serialize, serde_derive::Deserialize)] -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct PeerConfig { #[serde(default = "default_min_keepalive")] pub min_keepalive: u32, @@ -25,8 +23,7 @@ pub struct PeerConfig { } #[serde(deny_unknown_fields)] -#[derive(serde_derive::Serialize, serde_derive::Deserialize)] -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct UpdateConfig { // Number of seconds between regular updates. #[serde(default = "default_refresh")] @@ -34,8 +31,7 @@ pub struct UpdateConfig { } #[serde(deny_unknown_fields)] -#[derive(serde_derive::Serialize, serde_derive::Deserialize)] -#[derive(Clone, Debug)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, Debug)] pub struct Config { pub ifname: String, #[serde(default = "default_wg_command")] @@ -1,9 +1,9 @@ -use ::std::{error, fmt, iter, net}; -use ::std::net::{Ipv4Addr, Ipv6Addr}; -use ::std::iter::{IntoIterator, FromIterator}; -use ::std::str::{FromStr}; -use ::serde; use crate::bin; +use serde; +use std::iter::{FromIterator, IntoIterator}; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::str::FromStr; +use std::{error, fmt, iter, net}; #[derive(Debug)] pub struct NetParseError {} @@ -33,6 +33,9 @@ macro_rules! per_proto { if self.prefix_len == other.prefix_len { return self.address == other.address; } + if self.prefix_len == 0 { + return true; + } // self.prefix_len < other.prefix_len = BITS let shift = Self::BITS - self.prefix_len; let v1: $intt = self.address.into(); @@ -46,13 +49,13 @@ macro_rules! per_proto { write!(f, "{}/{}", self.address, self.prefix_len) } } - + impl FromStr for $nett { type Err = NetParseError; fn from_str(s: &str) -> Result<$nett, NetParseError> { let (addr, pfx) = pfx_split(s)?; let addr = $addrt::from_str(addr).map_err(|_| NetParseError {})?; - + let r = $nett { address: addr, prefix_len: pfx, @@ -94,7 +97,7 @@ macro_rules! per_proto { } de.deserialize_str(NetVisitor) } else { - let buf = <[u8; $bytes+1] as serde::Deserialize>::deserialize(de)?; + let buf = <[u8; $bytes + 1] as serde::Deserialize>::deserialize(de)?; let r = $nett { address: (*array_ref![&buf, 0, $bytes]).into(), prefix_len: buf[$bytes], @@ -144,8 +147,8 @@ macro_rules! per_proto { } }; let mut j = i; - if i != 0 && self.nets[i-1].contains(&net) { - net = self.nets[i-1]; + if i != 0 && self.nets[i - 1].contains(&net) { + net = self.nets[i - 1]; i -= 1; } while j < self.nets.len() && net.contains(&self.nets[j]) { @@ -154,8 +157,8 @@ macro_rules! per_proto { loop { if j < self.nets.len() && Self::siblings(&net, &self.nets[j]) { j += 1; - } else if i != 0 && Self::siblings(&self.nets[i-1], &net) { - net = self.nets[i-1]; + } else if i != 0 && Self::siblings(&self.nets[i - 1], &net) { + net = self.nets[i - 1]; i -= 1; } else { break; @@ -171,7 +174,7 @@ macro_rules! per_proto { if i == 0 { return false; } - self.nets[i-1].contains(&net) + self.nets[i - 1].contains(&net) } Ok(_) => true, } @@ -194,7 +197,7 @@ macro_rules! per_proto { } impl FromIterator<$nett> for $sett { - fn from_iter<I: IntoIterator<Item=$nett>>(it: I) -> $sett { + fn from_iter<I: IntoIterator<Item = $nett>>(it: I) -> $sett { let mut r = $sett::new(); for net in it { r.insert(net); @@ -234,12 +237,12 @@ macro_rules! per_proto { let mut i = 1; for j in 1..len { let mut net = s.nets[j]; - if i != 0 && s.nets[i-1].contains(&net) { - net = s.nets[i-1]; + if s.nets[i - 1].contains(&net) { + net = s.nets[i - 1]; i -= 1; } - while i != 0 && Self::siblings(&s.nets[i-1], &net) { - net = s.nets[i-1]; + while i != 0 && Self::siblings(&s.nets[i - 1], &net) { + net = s.nets[i - 1]; net.prefix_len -= 1; i -= 1; } @@ -319,7 +322,7 @@ impl Ipv6Net { } fn pfx_split(s: &str) -> Result<(&str, u8), NetParseError> { - let i = match s.find("/") { + let i = match s.find('/') { Some(i) => i, None => { return Err(NetParseError {}); @@ -350,14 +353,13 @@ impl fmt::Display for Endpoint { impl FromStr for Endpoint { type Err = net::AddrParseError; fn from_str(s: &str) -> Result<Endpoint, net::AddrParseError> { - net::SocketAddr::from_str(s) - .map(|v| Endpoint { - address: match v.ip() { - net::IpAddr::V4(a) => a.to_ipv6_mapped(), - net::IpAddr::V6(a) => a, - }, - port: v.port(), - }) + net::SocketAddr::from_str(s).map(|v| Endpoint { + address: match v.ip() { + net::IpAddr::V4(a) => a.to_ipv6_mapped(), + net::IpAddr::V6(a) => a, + }, + port: v.port(), + }) } } diff --git a/src/main.rs b/src/main.rs index ad57895..4ad5f39 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,13 @@ #[macro_use] extern crate arrayref; -use ::std::io; -use ::std::time::{Duration, SystemTime, Instant}; +use std::io; +use std::time::{Duration, Instant, SystemTime}; mod bin; +mod config; mod ip; mod proto; -mod config; mod wg; struct Source { @@ -53,14 +53,18 @@ impl Device { let mut errs = vec![]; for src in self.sources.iter() { if let Some(data) = &src.data { - let sc = data.next.as_ref().and_then(|next| { - if ts >= next.update_at { - Some(&next.config) - } else { - next_update = next_update.min(next.update_at); - None - } - }).unwrap_or(&data.config); + let sc = data + .next + .as_ref() + .and_then(|next| { + if ts >= next.update_at { + Some(&next.config) + } else { + next_update = next_update.min(next.update_at); + None + } + }) + .unwrap_or(&data.config); for peer in sc.peers.iter() { cfg.add_peer(&mut errs, &self.peer_config, &src.config, peer); } @@ -107,7 +111,9 @@ impl Device { let sysnow = SystemTime::now(); let (config, errors, upd_time) = self.make_config(sysnow); - let time_to_upd = upd_time.duration_since(sysnow).unwrap_or(Duration::from_secs(0)); + let time_to_upd = upd_time + .duration_since(sysnow) + .unwrap_or(Duration::from_secs(0)); next_update = next_update.min(now + time_to_upd); if config != self.current { @@ -125,7 +131,7 @@ impl Device { } fn fetch_source(url: &str) -> io::Result<proto::Source> { - use ::curl::easy::Easy; + use curl::easy::Easy; let mut res = Vec::<u8>::new(); @@ -144,7 +150,10 @@ fn fetch_source(url: &str) -> io::Result<proto::Source> { let code = req.response_code()?; if code != 0 && code != 200 { - return Err(io::Error::new(io::ErrorKind::Other, format!("HTTP error {}", code))); + return Err(io::Error::new( + io::ErrorKind::Other, + format!("HTTP error {}", code), + )); } } @@ -154,8 +163,8 @@ fn fetch_source(url: &str) -> io::Result<proto::Source> { } fn load_config(path: &str) -> io::Result<config::Config> { + use serde_json; use std::fs; - use ::serde_json; let config_file = fs::File::open(path)?; let rd = io::BufReader::new(config_file); @@ -164,15 +173,11 @@ fn load_config(path: &str) -> io::Result<config::Config> { } fn main() { - use ::std::{env, thread, process}; + use std::{env, process, thread}; - let args: Vec<String> = env::args().into_iter().collect(); + let args: Vec<String> = env::args().collect(); if args.len() != 2 { - let arg0 = if args.len() >= 1 { - &args[0] - } else { - "wgconf" - }; + let arg0 = if !args.is_empty() { &args[0] } else { "wgconf" }; eprintln!("<1>Usage:"); eprintln!("<1> {} CONFIG", arg0); process::exit(1); diff --git a/src/proto.rs b/src/proto.rs index e6759a1..a3ee6e6 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,11 +1,10 @@ -use ::std::time::SystemTime; -use ::serde_derive; +use serde_derive; +use std::time::SystemTime; -use crate::ip::{Ipv4Net, Ipv6Net, Endpoint}; +use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; #[serde(deny_unknown_fields)] -#[derive(serde_derive::Serialize, serde_derive::Deserialize)] -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct Peer { pub public_key: String, pub endpoint: Endpoint, @@ -19,14 +18,12 @@ fn default_peer_keepalive() -> u32 { 0 } -#[derive(serde_derive::Serialize, serde_derive::Deserialize)] -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct SourceConfig { pub peers: Vec<Peer>, } -#[derive(serde_derive::Serialize, serde_derive::Deserialize)] -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct SourceNextConfig { #[serde(with = "serde_utc")] pub update_at: SystemTime, @@ -34,8 +31,7 @@ pub struct SourceNextConfig { pub config: SourceConfig, } -#[derive(serde_derive::Serialize, serde_derive::Deserialize)] -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] pub struct Source { #[serde(flatten)] pub config: SourceConfig, @@ -43,10 +39,10 @@ pub struct Source { } mod serde_utc { - use ::std::time::SystemTime; - use ::chrono::{DateTime, TimeZone, Utc, SecondsFormat}; - use ::serde::*; use crate::bin; + use chrono::{DateTime, SecondsFormat, TimeZone, Utc}; + use serde::*; + use std::time::SystemTime; pub fn serialize<S: Serializer>(t: &SystemTime, ser: S) -> Result<S::Ok, S::Error> { let t = DateTime::<Utc>::from(*t); @@ -1,8 +1,8 @@ -use ::std::{error, io, fmt}; -use ::std::collections::hash_map; +use crate::ip::{Endpoint, Ipv4Net, Ipv6Net}; +use crate::{config, proto}; use hash_map::HashMap; -use crate::ip::{Ipv4Net, Ipv6Net, Endpoint}; -use crate::{proto, config}; +use std::collections::hash_map; +use std::{error, fmt, io}; #[derive(Clone, PartialEq, Eq, Debug)] struct Peer { @@ -29,10 +29,19 @@ pub struct ConfigError { impl error::Error for ConfigError {} impl fmt::Display for ConfigError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Invalid peer [{}] from [{}]: {}", self.peer, self.url, self.err) + write!( + f, + "{} [{}] from [{}]: {}", + if self.important { + "Invalid peer" + } else { + "Misconfigured peer" + }, + self.peer, self.url, self.err + ) } } - + impl Config { pub fn new() -> Config { Config { @@ -40,7 +49,13 @@ impl Config { } } - pub fn add_peer(&mut self, errors: &mut Vec<ConfigError>, c: &config::PeerConfig, s: &config::Source, p: &proto::Peer) { + pub fn add_peer( + &mut self, + errors: &mut Vec<ConfigError>, + c: &config::PeerConfig, + s: &config::Source, + p: &proto::Peer, + ) { if !valid_key(&p.public_key) { errors.push(ConfigError { url: s.url.clone(), @@ -68,55 +83,64 @@ impl Config { } let ent = match self.peers.entry(p.public_key.clone()) { - hash_map::Entry::Occupied(_) => { + hash_map::Entry::Occupied(ent) => { errors.push(ConfigError { url: s.url.clone(), peer: p.public_key.clone(), important: true, err: "Duplicate public key", }); - return; + ent.into_mut() + } + hash_map::Entry::Vacant(ent) => { + let mut keepalive = p.keepalive; + if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) { + keepalive = c.max_keepalive; + } + if keepalive != 0 && keepalive < c.min_keepalive { + keepalive = c.min_keepalive; + } + + ent.insert(Peer { + endpoint: p.endpoint.clone(), + psk: s.psk.clone(), + keepalive, + ipv4: vec![], + ipv6: vec![], + }) }, - hash_map::Entry::Vacant(ent) => ent, }; - let mut keepalive = p.keepalive; - if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) { - keepalive = c.max_keepalive; - } - if keepalive != 0 && keepalive < c.min_keepalive { - keepalive = c.min_keepalive; - } - + let mut added = false; let mut removed = false; - let mut ipv4 = p.ipv4.clone(); - ipv4.retain(|i| { - let r = s.ipv4.contains(i); - if !r { removed = true; } - r - }); - - let mut ipv6 = p.ipv6.clone(); - ipv6.retain(|i| { - let r = s.ipv6.contains(i); - if !r { removed = true; } - r - }); - - let r = ent.insert(Peer { - endpoint: p.endpoint.clone(), - psk: s.psk.clone(), - keepalive, ipv4, ipv6, - }); + for i in p.ipv4.iter() { + if s.ipv4.contains(i) { + ent.ipv4.push(*i); + added = true; + } else { + removed = true; + } + } + for i in p.ipv6.iter() { + if s.ipv6.contains(i) { + ent.ipv6.push(*i); + added = true; + } else { + removed = true; + } + } if removed { - let all = r.ipv4.is_empty() && r.ipv6.is_empty(); errors.push(ConfigError { url: s.url.clone(), peer: p.public_key.clone(), - important: all, - err: if all { "All IPs removed" } else {"Some IPs removed"}, + important: !added, + err: if added { + "Some IPs removed" + } else { + "All IPs removed" + }, }); } } @@ -140,11 +164,10 @@ impl Device { } pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> { - use ::std::process::{Command, Stdio}; + use std::process::{Command, Stdio}; let mut proc = Command::new(&self.wg_command); proc.stdin(Stdio::piped()); - proc.stdout(Stdio::null()); proc.arg("set"); proc.arg(&self.ifname); @@ -158,7 +181,7 @@ impl Device { } proc.arg("peer"); proc.arg(pubkey); - + // TODO: maybe skip endpoint? proc.arg("endpoint"); proc.arg(format!("{}", conf.endpoint)); @@ -173,11 +196,15 @@ impl Device { { use std::fmt::Write; for ip in conf.ipv4.iter() { - if !ips.is_empty() { ips.push(','); } + if !ips.is_empty() { + ips.push(','); + } write!(ips, "{}", ip).unwrap(); } for ip in conf.ipv6.iter() { - if !ips.is_empty() { ips.push(','); } + if !ips.is_empty() { + ips.push(','); + } write!(ips, "{}", ip).unwrap(); } } @@ -200,7 +227,7 @@ impl Device { use std::io::Write; let stdin = proc.stdin.as_mut().unwrap(); for psk in psks { - write!(stdin, "{}\n", psk)?; + writeln!(stdin, "{}", psk)?; } } @@ -221,10 +248,18 @@ fn valid_key(s: &str) -> bool { return false; } for c in s[0..42].iter().cloned() { - if c >= b'0' && c <= b'9' { continue; } - if c >= b'A' && c <= b'Z' { continue; } - if c >= b'a' && c <= b'z' { continue; } - if c == b'+' || c <= b'/' { continue; } + if c >= b'0' && c <= b'9' { + continue; + } + if c >= b'A' && c <= b'Z' { + continue; + } + if c >= b'a' && c <= b'z' { + continue; + } + if c == b'+' || c <= b'/' { + continue; + } return false; } b"048AEIMQUYcgkosw".contains(&s[42]) |