From 1168ac0458c7e92f22d1c77bf31daea4e23ac750 Mon Sep 17 00:00:00 2001 From: Hristo Venev Date: Fri, 7 Feb 2020 20:10:31 +0000 Subject: Load the preshared keys on startup. --- src/config.rs | 10 +++++++ src/fileutil.rs | 17 +++-------- src/main.rs | 10 ++----- src/manager/mod.rs | 9 +++--- src/manager/updater.rs | 9 +++--- src/model.rs | 76 ++++++++++++++++++++++++++++++++++---------------- src/wg.rs | 32 ++++++++++++++++----- 7 files changed, 102 insertions(+), 61 deletions(-) diff --git a/src/config.rs b/src/config.rs index a1dff3e..36e5572 100644 --- a/src/config.rs +++ b/src/config.rs @@ -12,6 +12,7 @@ use std::path::PathBuf; pub struct Source { pub name: String, pub url: String, + #[serde(default, deserialize_with = "deserialize_key_from_file")] pub psk: Option, pub ipv4: Ipv4Set, pub ipv6: Ipv6Set, @@ -26,6 +27,7 @@ pub struct Source { pub struct Peer { pub source: Option, pub endpoint: Option, + #[serde(default, deserialize_with = "deserialize_key_from_file")] pub psk: Option, pub keepalive: Option, } @@ -153,3 +155,11 @@ const fn default_max_keepalive() -> u32 { const fn default_refresh_sec() -> u32 { 1200 } + +fn deserialize_key_from_file<'de, D>(d: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + let path = >::deserialize(d)?; + Secret::from_file(&path).map_err(|e| ::custom(e.to_string())) +} diff --git a/src/fileutil.rs b/src/fileutil.rs index 2c322c4..190a7d0 100644 --- a/src/fileutil.rs +++ b/src/fileutil.rs @@ -82,22 +82,13 @@ pub fn update(path: &Path, data: &[u8]) -> io::Result<()> { } #[inline] -pub fn load(path: &impl AsRef) -> io::Result>> { +pub fn load(path: &impl AsRef) -> io::Result> { _load(path.as_ref()) } -fn _load(path: &Path) -> io::Result>> { - let mut file = match fs::File::open(&path) { - Ok(file) => file, - Err(e) => { - if e.kind() == io::ErrorKind::NotFound { - return Ok(None); - } - return Err(e); - } - }; - +fn _load(path: &Path) -> io::Result> { + let mut file = fs::File::open(&path)?; let mut data = Vec::new(); io::Read::read_to_end(&mut file, &mut data)?; - Ok(Some(data)) + Ok(data) } diff --git a/src/main.rs b/src/main.rs index cbc532b..11536fb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -37,7 +37,7 @@ fn cli_config(mut args: impl Iterator) -> Option { if key == "psk" { arg = args.next()?; - s.psk = Some(model::Secret::new(arg.into())); + s.psk = model::Secret::from_file(&arg).ok()?; continue; } if key == "ipv4" { @@ -82,7 +82,7 @@ fn cli_config(mut args: impl Iterator) -> Option) -> i32 { let data = fileutil::load(&path); mem::drop(path); let data = match data { - Ok(Some(v)) => v, - Ok(None) => { - eprintln!("<1>Configuration file not found"); - return 1; - } + Ok(v) => v, Err(e) => { eprintln!("<1>Failed to load config file: {}", e); return 1; diff --git a/src/manager/mod.rs b/src/manager/mod.rs index 39e486c..646349a 100644 --- a/src/manager/mod.rs +++ b/src/manager/mod.rs @@ -58,12 +58,11 @@ impl Manager { fn current_load(&mut self) -> bool { let data = match fileutil::load(&self.state_path) { - Ok(Some(data)) => data, - Ok(None) => { - return false; - } + Ok(data) => data, Err(e) => { - eprintln!("<3>Failed to read interface state: {}", e); + if e.kind() != io::ErrorKind::NotFound { + eprintln!("<3>Failed to read interface state: {}", e); + } return false; } }; diff --git a/src/manager/updater.rs b/src/manager/updater.rs index db24d6e..ca04ce9 100644 --- a/src/manager/updater.rs +++ b/src/manager/updater.rs @@ -46,12 +46,11 @@ impl Updater { }; let data = match fileutil::load(&path) { - Ok(Some(data)) => data, - Ok(None) => { - return false; - } + Ok(data) => data, Err(e) => { - eprintln!("<3>Failed to read [{}] from cache: {}", &src.config.name, e); + if e.kind() != io::ErrorKind::NotFound { + eprintln!("<3>Failed to read [{}] from cache: {}", &src.config.name, e); + } return false; } }; diff --git a/src/model.rs b/src/model.rs index 5797f0e..dfb35e0 100644 --- a/src/model.rs +++ b/src/model.rs @@ -2,11 +2,12 @@ // // Copyright 2019 Hristo Venev +use crate::fileutil; use base64; use std::collections::HashMap; -use std::fmt; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::str::FromStr; +use std::{fmt, io}; mod ip; pub use ip::*; @@ -17,7 +18,7 @@ pub type KeyParseError = base64::DecodeError; pub struct Key([u8; 32]); impl Key { - pub fn from_bytes(s: &[u8]) -> Result { + pub fn from_base64(s: &[u8]) -> Result { let mut v = Self([0; 32]); let l = base64::decode_config_slice(s, base64::STANDARD, &mut v.0)?; if l != v.0.len() { @@ -27,29 +28,10 @@ impl Key { } } -#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)] -pub struct Secret(PathBuf); - -impl Secret { - #[inline] - pub fn new(path: PathBuf) -> Self { - Self(path) - } - - #[inline] - pub fn path(&self) -> &Path { - &self.0 - } -} - impl fmt::Display for Key { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}", - base64::display::Base64Display::with_config(&self.0, base64::STANDARD) - ) + base64::display::Base64Display::with_config(&self.0, base64::STANDARD).fmt(f) } } @@ -57,7 +39,7 @@ impl FromStr for Key { type Err = KeyParseError; #[inline] fn from_str(s: &str) -> Result { - Self::from_bytes(s.as_bytes()) + Self::from_base64(s.as_bytes()) } } @@ -95,6 +77,52 @@ impl<'de> serde::Deserialize<'de> for Key { } } +#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq)] +pub struct Secret(Key); + +impl Secret { + #[inline] + pub fn from_file(path: &impl AsRef) -> io::Result> { + Self::_from_file(path.as_ref()) + } + + fn _from_file(path: &Path) -> io::Result> { + let mut data = fileutil::load(&path)?; + if data.last().copied() == Some(b'\n') { + data.pop(); + } + + if data.is_empty() { + return Ok(None); + } + + let k = match Key::from_base64(&data) { + Ok(v) => v, + Err(e) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to parse key: {}", e), + )) + } + }; + Ok(Some(Self(k))) + } +} + +impl fmt::Display for Secret { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Debug for Secret { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt("", f) + } +} + #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub struct Endpoint { address: Ipv6Addr, diff --git a/src/wg.rs b/src/wg.rs index 879251b..766f18f 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -5,7 +5,7 @@ use crate::model; use std::ffi::{OsStr, OsString}; use std::process::{Command, Stdio}; -use std::{env, io}; +use std::{env, fmt, io}; pub struct Device { ifname: OsString, @@ -46,27 +46,33 @@ impl Device { } let mut out = r.stdout; - if out.ends_with(b"\n") { - out.remove(out.len() - 1); + if out.last().copied() == Some(b'\n') { + out.pop(); } - model::Key::from_bytes(&out) + model::Key::from_base64(&out) .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid public key")) } pub fn apply_diff(&mut self, old: &model::Config, new: &model::Config) -> io::Result<()> { let mut proc = Self::wg_command(); + proc.stdin(Stdio::piped()); + proc.stdout(Stdio::null()); proc.arg("set"); proc.arg(&self.ifname); + let mut stdin = String::new(); for (pubkey, conf) in &new.peers { let old_endpoint; + let old_psk; if let Some(old_peer) = old.peers.get(pubkey) { if *old_peer == *conf { continue; } old_endpoint = old_peer.endpoint; + old_psk = old_peer.psk.as_ref(); } else { old_endpoint = None; + old_psk = None; } proc.arg("peer"); @@ -82,9 +88,15 @@ impl Device { } } - if let Some(psk) = &conf.psk { + if old_psk != conf.psk.as_ref() { proc.arg("preshared-key"); - proc.arg(psk.path()); + proc.arg("-"); + if let Some(psk) = conf.psk.as_ref() { + use fmt::Write; + writeln!(stdin, "{}", psk).unwrap(); + } else { + stdin.push('\n'); + } } let mut ips = String::new(); @@ -117,7 +129,13 @@ impl Device { proc.arg("remove"); } - let r = proc.status()?; + let mut proc = proc.spawn()?; + { + use io::Write; + proc.stdin.as_mut().unwrap().write_all(stdin.as_bytes())?; + } + + let r = proc.wait()?; if !r.success() { return Err(io::Error::new(io::ErrorKind::Other, "child process failed")); } -- cgit