aboutsummaryrefslogtreecommitdiff
path: root/src/wg.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/wg.rs')
-rw-r--r--src/wg.rs135
1 files changed, 85 insertions, 50 deletions
diff --git a/src/wg.rs b/src/wg.rs
index d5a03ff..910e5e6 100644
--- a/src/wg.rs
+++ b/src/wg.rs
@@ -1,8 +1,8 @@
-use ::std::{error, io, fmt};
-use ::std::collections::hash_map;
+use crate::ip::{Endpoint, Ipv4Net, Ipv6Net};
+use crate::{config, proto};
use hash_map::HashMap;
-use crate::ip::{Ipv4Net, Ipv6Net, Endpoint};
-use crate::{proto, config};
+use std::collections::hash_map;
+use std::{error, fmt, io};
#[derive(Clone, PartialEq, Eq, Debug)]
struct Peer {
@@ -29,10 +29,19 @@ pub struct ConfigError {
impl error::Error for ConfigError {}
impl fmt::Display for ConfigError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(f, "Invalid peer [{}] from [{}]: {}", self.peer, self.url, self.err)
+ write!(
+ f,
+ "{} [{}] from [{}]: {}",
+ if self.important {
+ "Invalid peer"
+ } else {
+ "Misconfigured peer"
+ },
+ self.peer, self.url, self.err
+ )
}
}
-
+
impl Config {
pub fn new() -> Config {
Config {
@@ -40,7 +49,13 @@ impl Config {
}
}
- pub fn add_peer(&mut self, errors: &mut Vec<ConfigError>, c: &config::PeerConfig, s: &config::Source, p: &proto::Peer) {
+ pub fn add_peer(
+ &mut self,
+ errors: &mut Vec<ConfigError>,
+ c: &config::PeerConfig,
+ s: &config::Source,
+ p: &proto::Peer,
+ ) {
if !valid_key(&p.public_key) {
errors.push(ConfigError {
url: s.url.clone(),
@@ -68,55 +83,64 @@ impl Config {
}
let ent = match self.peers.entry(p.public_key.clone()) {
- hash_map::Entry::Occupied(_) => {
+ hash_map::Entry::Occupied(ent) => {
errors.push(ConfigError {
url: s.url.clone(),
peer: p.public_key.clone(),
important: true,
err: "Duplicate public key",
});
- return;
+ ent.into_mut()
+ }
+ hash_map::Entry::Vacant(ent) => {
+ let mut keepalive = p.keepalive;
+ if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) {
+ keepalive = c.max_keepalive;
+ }
+ if keepalive != 0 && keepalive < c.min_keepalive {
+ keepalive = c.min_keepalive;
+ }
+
+ ent.insert(Peer {
+ endpoint: p.endpoint.clone(),
+ psk: s.psk.clone(),
+ keepalive,
+ ipv4: vec![],
+ ipv6: vec![],
+ })
},
- hash_map::Entry::Vacant(ent) => ent,
};
- let mut keepalive = p.keepalive;
- if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) {
- keepalive = c.max_keepalive;
- }
- if keepalive != 0 && keepalive < c.min_keepalive {
- keepalive = c.min_keepalive;
- }
-
+ let mut added = false;
let mut removed = false;
- let mut ipv4 = p.ipv4.clone();
- ipv4.retain(|i| {
- let r = s.ipv4.contains(i);
- if !r { removed = true; }
- r
- });
-
- let mut ipv6 = p.ipv6.clone();
- ipv6.retain(|i| {
- let r = s.ipv6.contains(i);
- if !r { removed = true; }
- r
- });
-
- let r = ent.insert(Peer {
- endpoint: p.endpoint.clone(),
- psk: s.psk.clone(),
- keepalive, ipv4, ipv6,
- });
+ for i in p.ipv4.iter() {
+ if s.ipv4.contains(i) {
+ ent.ipv4.push(*i);
+ added = true;
+ } else {
+ removed = true;
+ }
+ }
+ for i in p.ipv6.iter() {
+ if s.ipv6.contains(i) {
+ ent.ipv6.push(*i);
+ added = true;
+ } else {
+ removed = true;
+ }
+ }
if removed {
- let all = r.ipv4.is_empty() && r.ipv6.is_empty();
errors.push(ConfigError {
url: s.url.clone(),
peer: p.public_key.clone(),
- important: all,
- err: if all { "All IPs removed" } else {"Some IPs removed"},
+ important: !added,
+ err: if added {
+ "Some IPs removed"
+ } else {
+ "All IPs removed"
+ },
});
}
}
@@ -140,11 +164,10 @@ impl Device {
}
pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> {
- use ::std::process::{Command, Stdio};
+ use std::process::{Command, Stdio};
let mut proc = Command::new(&self.wg_command);
proc.stdin(Stdio::piped());
- proc.stdout(Stdio::null());
proc.arg("set");
proc.arg(&self.ifname);
@@ -158,7 +181,7 @@ impl Device {
}
proc.arg("peer");
proc.arg(pubkey);
-
+
// TODO: maybe skip endpoint?
proc.arg("endpoint");
proc.arg(format!("{}", conf.endpoint));
@@ -173,11 +196,15 @@ impl Device {
{
use std::fmt::Write;
for ip in conf.ipv4.iter() {
- if !ips.is_empty() { ips.push(','); }
+ if !ips.is_empty() {
+ ips.push(',');
+ }
write!(ips, "{}", ip).unwrap();
}
for ip in conf.ipv6.iter() {
- if !ips.is_empty() { ips.push(','); }
+ if !ips.is_empty() {
+ ips.push(',');
+ }
write!(ips, "{}", ip).unwrap();
}
}
@@ -200,7 +227,7 @@ impl Device {
use std::io::Write;
let stdin = proc.stdin.as_mut().unwrap();
for psk in psks {
- write!(stdin, "{}\n", psk)?;
+ writeln!(stdin, "{}", psk)?;
}
}
@@ -221,10 +248,18 @@ fn valid_key(s: &str) -> bool {
return false;
}
for c in s[0..42].iter().cloned() {
- if c >= b'0' && c <= b'9' { continue; }
- if c >= b'A' && c <= b'Z' { continue; }
- if c >= b'a' && c <= b'z' { continue; }
- if c == b'+' || c <= b'/' { continue; }
+ if c >= b'0' && c <= b'9' {
+ continue;
+ }
+ if c >= b'A' && c <= b'Z' {
+ continue;
+ }
+ if c >= b'a' && c <= b'z' {
+ continue;
+ }
+ if c == b'+' || c <= b'/' {
+ continue;
+ }
return false;
}
b"048AEIMQUYcgkosw".contains(&s[42])