aboutsummaryrefslogtreecommitdiff
path: root/src/wg.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/wg.rs')
-rw-r--r--src/wg.rs231
1 files changed, 231 insertions, 0 deletions
diff --git a/src/wg.rs b/src/wg.rs
new file mode 100644
index 0000000..d5a03ff
--- /dev/null
+++ b/src/wg.rs
@@ -0,0 +1,231 @@
+use ::std::{error, io, fmt};
+use ::std::collections::hash_map;
+use hash_map::HashMap;
+use crate::ip::{Ipv4Net, Ipv6Net, Endpoint};
+use crate::{proto, config};
+
+#[derive(Clone, PartialEq, Eq, Debug)]
+struct Peer {
+ endpoint: Endpoint,
+ psk: Option<String>,
+ keepalive: u32,
+ ipv4: Vec<Ipv4Net>,
+ ipv6: Vec<Ipv6Net>,
+}
+
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub struct Config {
+ peers: HashMap<String, Peer>,
+}
+
+#[derive(Debug)]
+pub struct ConfigError {
+ pub url: String,
+ pub peer: String,
+ pub important: bool,
+ err: &'static str,
+}
+
+impl error::Error for ConfigError {}
+impl fmt::Display for ConfigError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "Invalid peer [{}] from [{}]: {}", self.peer, self.url, self.err)
+ }
+}
+
+impl Config {
+ pub fn new() -> Config {
+ Config {
+ peers: HashMap::new(),
+ }
+ }
+
+ pub fn add_peer(&mut self, errors: &mut Vec<ConfigError>, c: &config::PeerConfig, s: &config::Source, p: &proto::Peer) {
+ if !valid_key(&p.public_key) {
+ errors.push(ConfigError {
+ url: s.url.clone(),
+ peer: p.public_key.clone(),
+ important: true,
+ err: "Invalid public key",
+ });
+ return;
+ }
+
+ if let Some(ref psk) = s.psk {
+ if !valid_key(psk) {
+ errors.push(ConfigError {
+ url: s.url.clone(),
+ peer: p.public_key.clone(),
+ important: true,
+ err: "Invalid preshared key",
+ });
+ return;
+ }
+ }
+
+ if c.omit_peers.contains(&p.public_key) {
+ return;
+ }
+
+ let ent = match self.peers.entry(p.public_key.clone()) {
+ hash_map::Entry::Occupied(_) => {
+ errors.push(ConfigError {
+ url: s.url.clone(),
+ peer: p.public_key.clone(),
+ important: true,
+ err: "Duplicate public key",
+ });
+ return;
+ },
+ hash_map::Entry::Vacant(ent) => ent,
+ };
+
+ let mut keepalive = p.keepalive;
+ if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) {
+ keepalive = c.max_keepalive;
+ }
+ if keepalive != 0 && keepalive < c.min_keepalive {
+ keepalive = c.min_keepalive;
+ }
+
+ let mut removed = false;
+
+ let mut ipv4 = p.ipv4.clone();
+ ipv4.retain(|i| {
+ let r = s.ipv4.contains(i);
+ if !r { removed = true; }
+ r
+ });
+
+ let mut ipv6 = p.ipv6.clone();
+ ipv6.retain(|i| {
+ let r = s.ipv6.contains(i);
+ if !r { removed = true; }
+ r
+ });
+
+ let r = ent.insert(Peer {
+ endpoint: p.endpoint.clone(),
+ psk: s.psk.clone(),
+ keepalive, ipv4, ipv6,
+ });
+
+ if removed {
+ let all = r.ipv4.is_empty() && r.ipv6.is_empty();
+ errors.push(ConfigError {
+ url: s.url.clone(),
+ peer: p.public_key.clone(),
+ important: all,
+ err: if all { "All IPs removed" } else {"Some IPs removed"},
+ });
+ }
+ }
+}
+
+impl Default for Config {
+ #[inline]
+ fn default() -> Self {
+ Config::new()
+ }
+}
+
+pub struct Device {
+ ifname: String,
+ wg_command: String,
+}
+
+impl Device {
+ pub fn new(ifname: String, wg_command: String) -> Self {
+ Device { ifname, wg_command }
+ }
+
+ pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> {
+ use ::std::process::{Command, Stdio};
+
+ let mut proc = Command::new(&self.wg_command);
+ proc.stdin(Stdio::piped());
+ proc.stdout(Stdio::null());
+ proc.arg("set");
+ proc.arg(&self.ifname);
+
+ let mut psks = Vec::<&str>::new();
+
+ for (pubkey, conf) in new.peers.iter() {
+ if let Some(old_peer) = old.peers.get(pubkey) {
+ if *old_peer == *conf {
+ continue;
+ }
+ }
+ proc.arg("peer");
+ proc.arg(pubkey);
+
+ // TODO: maybe skip endpoint?
+ proc.arg("endpoint");
+ proc.arg(format!("{}", conf.endpoint));
+
+ if let Some(psk) = &conf.psk {
+ proc.arg("preshared-key");
+ proc.arg("/dev/stdin");
+ psks.push(psk);
+ }
+
+ let mut ips = String::new();
+ {
+ use std::fmt::Write;
+ for ip in conf.ipv4.iter() {
+ if !ips.is_empty() { ips.push(','); }
+ write!(ips, "{}", ip).unwrap();
+ }
+ for ip in conf.ipv6.iter() {
+ if !ips.is_empty() { ips.push(','); }
+ write!(ips, "{}", ip).unwrap();
+ }
+ }
+
+ proc.arg("allowed-ips");
+ proc.arg(ips);
+ }
+
+ for pubkey in old.peers.keys() {
+ if new.peers.contains_key(pubkey) {
+ continue;
+ }
+ proc.arg("peer");
+ proc.arg(pubkey);
+ proc.arg("remove");
+ }
+
+ let mut proc = proc.spawn()?;
+ {
+ use std::io::Write;
+ let stdin = proc.stdin.as_mut().unwrap();
+ for psk in psks {
+ write!(stdin, "{}\n", psk)?;
+ }
+ }
+
+ let r = proc.wait()?;
+ if !r.success() {
+ return Err(io::Error::new(io::ErrorKind::Other, "Child process failed"));
+ }
+ Ok(())
+ }
+}
+
+fn valid_key(s: &str) -> bool {
+ let s = s.as_bytes();
+ if s.len() != 44 {
+ return false;
+ }
+ if s[43] != b'=' {
+ return false;
+ }
+ for c in s[0..42].iter().cloned() {
+ if c >= b'0' && c <= b'9' { continue; }
+ if c >= b'A' && c <= b'Z' { continue; }
+ if c >= b'a' && c <= b'z' { continue; }
+ if c == b'+' || c <= b'/' { continue; }
+ return false;
+ }
+ b"048AEIMQUYcgkosw".contains(&s[42])
+}