aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/config.rs10
-rw-r--r--src/fileutil.rs17
-rw-r--r--src/main.rs10
-rw-r--r--src/manager/mod.rs9
-rw-r--r--src/manager/updater.rs9
-rw-r--r--src/model.rs76
-rw-r--r--src/wg.rs32
7 files changed, 102 insertions, 61 deletions
diff --git a/src/config.rs b/src/config.rs
index a1dff3e..36e5572 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -12,6 +12,7 @@ use std::path::PathBuf;
pub struct Source {
pub name: String,
pub url: String,
+ #[serde(default, deserialize_with = "deserialize_key_from_file")]
pub psk: Option<Secret>,
pub ipv4: Ipv4Set,
pub ipv6: Ipv6Set,
@@ -26,6 +27,7 @@ pub struct Source {
pub struct Peer {
pub source: Option<String>,
pub endpoint: Option<Endpoint>,
+ #[serde(default, deserialize_with = "deserialize_key_from_file")]
pub psk: Option<Secret>,
pub keepalive: Option<u32>,
}
@@ -153,3 +155,11 @@ const fn default_max_keepalive() -> u32 {
const fn default_refresh_sec() -> u32 {
1200
}
+
+fn deserialize_key_from_file<'de, D>(d: D) -> Result<Option<Secret>, D::Error>
+where
+ D: serde::Deserializer<'de>,
+{
+ let path = <PathBuf as serde::Deserialize<'de>>::deserialize(d)?;
+ Secret::from_file(&path).map_err(|e| <D::Error as serde::de::Error>::custom(e.to_string()))
+}
diff --git a/src/fileutil.rs b/src/fileutil.rs
index 2c322c4..190a7d0 100644
--- a/src/fileutil.rs
+++ b/src/fileutil.rs
@@ -82,22 +82,13 @@ pub fn update(path: &Path, data: &[u8]) -> io::Result<()> {
}
#[inline]
-pub fn load(path: &impl AsRef<Path>) -> io::Result<Option<Vec<u8>>> {
+pub fn load(path: &impl AsRef<Path>) -> io::Result<Vec<u8>> {
_load(path.as_ref())
}
-fn _load(path: &Path) -> io::Result<Option<Vec<u8>>> {
- let mut file = match fs::File::open(&path) {
- Ok(file) => file,
- Err(e) => {
- if e.kind() == io::ErrorKind::NotFound {
- return Ok(None);
- }
- return Err(e);
- }
- };
-
+fn _load(path: &Path) -> io::Result<Vec<u8>> {
+ let mut file = fs::File::open(&path)?;
let mut data = Vec::new();
io::Read::read_to_end(&mut file, &mut data)?;
- Ok(Some(data))
+ Ok(data)
}
diff --git a/src/main.rs b/src/main.rs
index cbc532b..11536fb 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -37,7 +37,7 @@ fn cli_config(mut args: impl Iterator<Item = OsString>) -> Option<config::Config
State::Source(ref mut s) => {
if key == "psk" {
arg = args.next()?;
- s.psk = Some(model::Secret::new(arg.into()));
+ s.psk = model::Secret::from_file(&arg).ok()?;
continue;
}
if key == "ipv4" {
@@ -82,7 +82,7 @@ fn cli_config(mut args: impl Iterator<Item = OsString>) -> Option<config::Config
}
if key == "psk" {
arg = args.next()?;
- p.psk = Some(model::Secret::new(arg.into()));
+ p.psk = model::Secret::from_file(&arg).ok()?;
continue;
}
if key == "keepalive" {
@@ -199,11 +199,7 @@ fn run_with_file(argv0: &str, args: Vec<OsString>) -> i32 {
let data = fileutil::load(&path);
mem::drop(path);
let data = match data {
- Ok(Some(v)) => v,
- Ok(None) => {
- eprintln!("<1>Configuration file not found");
- return 1;
- }
+ Ok(v) => v,
Err(e) => {
eprintln!("<1>Failed to load config file: {}", e);
return 1;
diff --git a/src/manager/mod.rs b/src/manager/mod.rs
index 39e486c..646349a 100644
--- a/src/manager/mod.rs
+++ b/src/manager/mod.rs
@@ -58,12 +58,11 @@ impl Manager {
fn current_load(&mut self) -> bool {
let data = match fileutil::load(&self.state_path) {
- Ok(Some(data)) => data,
- Ok(None) => {
- return false;
- }
+ Ok(data) => data,
Err(e) => {
- eprintln!("<3>Failed to read interface state: {}", e);
+ if e.kind() != io::ErrorKind::NotFound {
+ eprintln!("<3>Failed to read interface state: {}", e);
+ }
return false;
}
};
diff --git a/src/manager/updater.rs b/src/manager/updater.rs
index db24d6e..ca04ce9 100644
--- a/src/manager/updater.rs
+++ b/src/manager/updater.rs
@@ -46,12 +46,11 @@ impl Updater {
};
let data = match fileutil::load(&path) {
- Ok(Some(data)) => data,
- Ok(None) => {
- return false;
- }
+ Ok(data) => data,
Err(e) => {
- eprintln!("<3>Failed to read [{}] from cache: {}", &src.config.name, e);
+ if e.kind() != io::ErrorKind::NotFound {
+ eprintln!("<3>Failed to read [{}] from cache: {}", &src.config.name, e);
+ }
return false;
}
};
diff --git a/src/model.rs b/src/model.rs
index 5797f0e..dfb35e0 100644
--- a/src/model.rs
+++ b/src/model.rs
@@ -2,11 +2,12 @@
//
// Copyright 2019 Hristo Venev
+use crate::fileutil;
use base64;
use std::collections::HashMap;
-use std::fmt;
-use std::path::{Path, PathBuf};
+use std::path::Path;
use std::str::FromStr;
+use std::{fmt, io};
mod ip;
pub use ip::*;
@@ -17,7 +18,7 @@ pub type KeyParseError = base64::DecodeError;
pub struct Key([u8; 32]);
impl Key {
- pub fn from_bytes(s: &[u8]) -> Result<Self, KeyParseError> {
+ pub fn from_base64(s: &[u8]) -> Result<Self, KeyParseError> {
let mut v = Self([0; 32]);
let l = base64::decode_config_slice(s, base64::STANDARD, &mut v.0)?;
if l != v.0.len() {
@@ -27,29 +28,10 @@ impl Key {
}
}
-#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
-pub struct Secret(PathBuf);
-
-impl Secret {
- #[inline]
- pub fn new(path: PathBuf) -> Self {
- Self(path)
- }
-
- #[inline]
- pub fn path(&self) -> &Path {
- &self.0
- }
-}
-
impl fmt::Display for Key {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(
- f,
- "{}",
- base64::display::Base64Display::with_config(&self.0, base64::STANDARD)
- )
+ base64::display::Base64Display::with_config(&self.0, base64::STANDARD).fmt(f)
}
}
@@ -57,7 +39,7 @@ impl FromStr for Key {
type Err = KeyParseError;
#[inline]
fn from_str(s: &str) -> Result<Self, base64::DecodeError> {
- Self::from_bytes(s.as_bytes())
+ Self::from_base64(s.as_bytes())
}
}
@@ -95,6 +77,52 @@ impl<'de> serde::Deserialize<'de> for Key {
}
}
+#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq)]
+pub struct Secret(Key);
+
+impl Secret {
+ #[inline]
+ pub fn from_file(path: &impl AsRef<Path>) -> io::Result<Option<Self>> {
+ Self::_from_file(path.as_ref())
+ }
+
+ fn _from_file(path: &Path) -> io::Result<Option<Self>> {
+ let mut data = fileutil::load(&path)?;
+ if data.last().copied() == Some(b'\n') {
+ data.pop();
+ }
+
+ if data.is_empty() {
+ return Ok(None);
+ }
+
+ let k = match Key::from_base64(&data) {
+ Ok(v) => v,
+ Err(e) => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("failed to parse key: {}", e),
+ ))
+ }
+ };
+ Ok(Some(Self(k)))
+ }
+}
+
+impl fmt::Display for Secret {
+ #[inline]
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl fmt::Debug for Secret {
+ #[inline]
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ <str as fmt::Display>::fmt("<secret key>", f)
+ }
+}
+
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Endpoint {
address: Ipv6Addr,
diff --git a/src/wg.rs b/src/wg.rs
index 879251b..766f18f 100644
--- a/src/wg.rs
+++ b/src/wg.rs
@@ -5,7 +5,7 @@
use crate::model;
use std::ffi::{OsStr, OsString};
use std::process::{Command, Stdio};
-use std::{env, io};
+use std::{env, fmt, io};
pub struct Device {
ifname: OsString,
@@ -46,27 +46,33 @@ impl Device {
}
let mut out = r.stdout;
- if out.ends_with(b"\n") {
- out.remove(out.len() - 1);
+ if out.last().copied() == Some(b'\n') {
+ out.pop();
}
- model::Key::from_bytes(&out)
+ model::Key::from_base64(&out)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid public key"))
}
pub fn apply_diff(&mut self, old: &model::Config, new: &model::Config) -> io::Result<()> {
let mut proc = Self::wg_command();
+ proc.stdin(Stdio::piped());
+ proc.stdout(Stdio::null());
proc.arg("set");
proc.arg(&self.ifname);
+ let mut stdin = String::new();
for (pubkey, conf) in &new.peers {
let old_endpoint;
+ let old_psk;
if let Some(old_peer) = old.peers.get(pubkey) {
if *old_peer == *conf {
continue;
}
old_endpoint = old_peer.endpoint;
+ old_psk = old_peer.psk.as_ref();
} else {
old_endpoint = None;
+ old_psk = None;
}
proc.arg("peer");
@@ -82,9 +88,15 @@ impl Device {
}
}
- if let Some(psk) = &conf.psk {
+ if old_psk != conf.psk.as_ref() {
proc.arg("preshared-key");
- proc.arg(psk.path());
+ proc.arg("-");
+ if let Some(psk) = conf.psk.as_ref() {
+ use fmt::Write;
+ writeln!(stdin, "{}", psk).unwrap();
+ } else {
+ stdin.push('\n');
+ }
}
let mut ips = String::new();
@@ -117,7 +129,13 @@ impl Device {
proc.arg("remove");
}
- let r = proc.status()?;
+ let mut proc = proc.spawn()?;
+ {
+ use io::Write;
+ proc.stdin.as_mut().unwrap().write_all(stdin.as_bytes())?;
+ }
+
+ let r = proc.wait()?;
if !r.success() {
return Err(io::Error::new(io::ErrorKind::Other, "child process failed"));
}