aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--src/bin.rs7
-rw-r--r--src/builder.rs172
-rw-r--r--src/config.rs6
-rw-r--r--src/ip.rs99
-rw-r--r--src/main.rs24
-rw-r--r--src/model.rs197
-rw-r--r--src/proto.rs34
-rw-r--r--src/wg.rs243
9 files changed, 450 insertions, 333 deletions
diff --git a/Cargo.toml b/Cargo.toml
index e7dbaf1..3d7c0dc 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -6,6 +6,7 @@ edition = "2018"
[dependencies]
arrayref = { version = "0.3.5" }
+base64 = { version = "0.10.1" }
serde = { version = "1.0.89" }
serde_derive = { version = "1.0.89" }
serde_json = { version = "1.0.39" }
diff --git a/src/bin.rs b/src/bin.rs
index ecef739..bd1cf87 100644
--- a/src/bin.rs
+++ b/src/bin.rs
@@ -7,10 +7,12 @@ pub fn i64_to_be(v: i64) -> [u8; 8] {
u64_to_be(v as u64)
}
+#[inline]
pub fn i64_from_be(v: [u8; 8]) -> i64 {
u64_from_be(v) as i64
}
+#[inline]
pub fn u64_to_be(v: u64) -> [u8; 8] {
[
(v >> 56) as u8,
@@ -24,6 +26,7 @@ pub fn u64_to_be(v: u64) -> [u8; 8] {
]
}
+#[inline]
pub fn u64_from_be(v: [u8; 8]) -> u64 {
(u64::from(v[0]) << 56)
| (u64::from(v[1]) << 48)
@@ -35,18 +38,22 @@ pub fn u64_from_be(v: [u8; 8]) -> u64 {
| u64::from(v[7])
}
+#[inline]
pub fn u32_to_be(v: u32) -> [u8; 4] {
[(v >> 24) as u8, (v >> 16) as u8, (v >> 8) as u8, v as u8]
}
+#[inline]
pub fn u32_from_be(v: [u8; 4]) -> u32 {
(u32::from(v[0]) << 24) | (u32::from(v[1]) << 16) | (u32::from(v[2]) << 8) | u32::from(v[3])
}
+#[inline]
pub fn u16_to_be(v: u16) -> [u8; 2] {
[(v >> 8) as u8, v as u8]
}
+#[inline]
pub fn u16_from_be(v: [u8; 2]) -> u16 {
(u16::from(v[0]) << 8) | u16::from(v[1])
}
diff --git a/src/builder.rs b/src/builder.rs
new file mode 100644
index 0000000..cc42fa8
--- /dev/null
+++ b/src/builder.rs
@@ -0,0 +1,172 @@
+// Copyright 2019 Hristo Venev
+//
+// See COPYING.
+
+use crate::{config, model, proto};
+use std::{error, fmt};
+use std::collections::hash_map;
+
+
+#[derive(Debug)]
+pub struct ConfigError {
+ pub url: String,
+ pub peer: model::Key,
+ pub important: bool,
+ err: &'static str,
+}
+
+impl ConfigError {
+ fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self {
+ ConfigError {
+ url: s.url.clone(),
+ peer: p.public_key.clone(),
+ important,
+ err,
+ }
+ }
+}
+
+impl error::Error for ConfigError {}
+impl fmt::Display for ConfigError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{} [{}] from [{}]: {}",
+ if self.important {
+ "Invalid peer"
+ } else {
+ "Misconfigured peer"
+ },
+ self.peer,
+ self.url,
+ self.err
+ )
+ }
+}
+
+pub struct ConfigBuilder<'a> {
+ c: model::Config,
+ err: Vec<ConfigError>,
+ public_key: model::Key,
+ pc: &'a config::PeerConfig,
+}
+
+impl<'a> ConfigBuilder<'a> {
+ #[inline]
+ pub fn new(public_key: model::Key, pc: &'a config::PeerConfig) -> Self {
+ ConfigBuilder {
+ c: model::Config::default(),
+ err: vec![],
+ public_key,
+ pc,
+ }
+ }
+
+ #[inline]
+ pub fn build(self) -> (model::Config, Vec<ConfigError>) {
+ (self.c, self.err)
+ }
+
+ #[inline]
+ pub fn add_server(
+ &mut self,
+ s: &config::Source,
+ p: &proto::Server,
+ ) {
+ if p.peer.public_key == self.public_key {
+ return;
+ }
+
+ let pc = self.pc;
+ let ent = insert_peer(&mut self.c, &mut self.err, s, &p.peer, |ent| {
+ ent.psk = s.psk.clone();
+ ent.endpoint = Some(p.endpoint.clone());
+ ent.keepalive = pc.fix_keepalive(p.keepalive);
+ });
+
+ add_peer(&mut self.err, ent, s, &p.peer)
+ }
+
+ #[inline]
+ pub fn add_road_warrior(
+ &mut self,
+ s: &config::Source,
+ p: &proto::RoadWarrior,
+ ) {
+ if p.peer.public_key == self.public_key {
+ self.err.push(ConfigError::new("The local peer cannot be a road warrior", s, &p.peer, true));
+ return;
+ }
+
+ let ent = if p.base == self.public_key {
+ insert_peer(&mut self.c, &mut self.err, s, &p.peer, |_| {})
+ } else {
+ match self.c.peers.get_mut(&p.base) {
+ Some(ent) => ent,
+ None => {
+ self.err.push(ConfigError::new("Unknown base peer", s, &p.peer, true));
+ return;
+ }
+ }
+ };
+ add_peer(&mut self.err, ent, s, &p.peer)
+ }
+}
+
+#[inline]
+fn insert_peer<'b>(
+ c: &'b mut model::Config,
+ err: &mut Vec<ConfigError>,
+ s: &config::Source,
+ p: &proto::Peer,
+ update: impl for<'c> FnOnce(&'c mut model::Peer) -> (),
+) -> &'b mut model::Peer {
+ match c.peers.entry(p.public_key.clone()) {
+ hash_map::Entry::Occupied(ent) => {
+ err.push(ConfigError::new("Duplicate public key", s, p, true));
+ ent.into_mut()
+ }
+ hash_map::Entry::Vacant(ent) => {
+ let ent = ent.insert(model::Peer {
+ endpoint: None,
+ psk: None,
+ keepalive: 0,
+ ipv4: vec![],
+ ipv6: vec![],
+ });
+ update(ent);
+ ent
+ }
+ }
+}
+
+fn add_peer(err: &mut Vec<ConfigError>, ent: &mut model::Peer, s: &config::Source, p: &proto::Peer) {
+ let mut added = false;
+ let mut removed = false;
+
+ for i in p.ipv4.iter() {
+ if s.ipv4.contains(i) {
+ ent.ipv4.push(*i);
+ added = true;
+ } else {
+ removed = true;
+ }
+ }
+ for i in p.ipv6.iter() {
+ if s.ipv6.contains(i) {
+ ent.ipv6.push(*i);
+ added = true;
+ } else {
+ removed = true;
+ }
+ }
+
+ if removed {
+ let msg = if added {
+ "Some IPs removed"
+ } else {
+ "All IPs removed"
+ };
+ err.push(ConfigError::new(msg, s, p, !added));
+ }
+}
diff --git a/src/config.rs b/src/config.rs
index 2effacb..98db02d 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -3,13 +3,14 @@
// See COPYING.
use crate::ip::{Ipv4Set, Ipv6Set};
+use crate::model::{Key};
use serde_derive;
#[serde(deny_unknown_fields)]
#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
pub struct Source {
pub url: String,
- pub psk: Option<String>,
+ pub psk: Option<Key>,
pub ipv4: Ipv4Set,
pub ipv6: Ipv6Set,
}
@@ -56,14 +57,17 @@ impl PeerConfig {
}
}
+#[inline]
fn default_min_keepalive() -> u32 {
10
}
+#[inline]
fn default_max_keepalive() -> u32 {
0
}
+#[inline]
fn default_refresh_sec() -> u32 {
1200
}
diff --git a/src/ip.rs b/src/ip.rs
index 46d635a..d2ad17e 100644
--- a/src/ip.rs
+++ b/src/ip.rs
@@ -2,20 +2,20 @@
//
// See COPYING.
-use crate::bin;
use serde;
use std::iter::{FromIterator, IntoIterator};
-use std::net::{Ipv4Addr, Ipv6Addr};
+pub use std::net::{Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
-use std::{error, fmt, iter, net};
+use std::{error, fmt, iter};
#[derive(Debug)]
-pub struct NetParseError {}
+pub struct NetParseError;
impl error::Error for NetParseError {}
impl fmt::Display for NetParseError {
+ #[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(f, "Invalid IP network")
+ write!(f, "Invalid address")
}
}
@@ -49,6 +49,7 @@ macro_rules! per_proto {
}
impl fmt::Display for $nett {
+ #[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}/{}", self.address, self.prefix_len)
}
@@ -58,14 +59,14 @@ macro_rules! per_proto {
type Err = NetParseError;
fn from_str(s: &str) -> Result<$nett, NetParseError> {
let (addr, pfx) = pfx_split(s)?;
- let addr = $addrt::from_str(addr).map_err(|_| NetParseError {})?;
+ let addr = $addrt::from_str(addr).map_err(|_| NetParseError)?;
let r = $nett {
address: addr,
prefix_len: pfx,
};
if !r.is_valid() {
- return Err(NetParseError {});
+ return Err(NetParseError);
}
Ok(r)
}
@@ -74,7 +75,7 @@ macro_rules! per_proto {
impl serde::Serialize for $nett {
fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
if ser.is_human_readable() {
- ser.serialize_str(&format!("{}", self))
+ ser.collect_str(self)
} else {
let mut buf = [0u8; $bytes + 1];
*array_mut_ref![&mut buf, 0, $bytes] = self.address.octets();
@@ -91,10 +92,12 @@ macro_rules! per_proto {
impl<'de> serde::de::Visitor<'de> for NetVisitor {
type Value = $nett;
+ #[inline]
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str($expecting)
}
+ #[inline]
fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {
s.parse().map_err(E::custom)
}
@@ -107,7 +110,7 @@ macro_rules! per_proto {
prefix_len: buf[$bytes],
};
if r.is_valid() {
- return Err(serde::de::Error::custom(NetParseError {}));
+ return Err(serde::de::Error::custom(NetParseError));
}
Ok(r)
}
@@ -201,6 +204,7 @@ macro_rules! per_proto {
}
impl FromIterator<$nett> for $sett {
+ #[inline]
fn from_iter<I: IntoIterator<Item = $nett>>(it: I) -> $sett {
let mut r = $sett::new();
for net in it {
@@ -273,12 +277,14 @@ macro_rules! per_proto {
}
impl serde::Serialize for $sett {
+ #[inline]
fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
<Vec<$nett> as serde::Serialize>::serialize(&self.nets, ser)
}
}
impl<'de> serde::Deserialize<'de> for $sett {
+ #[inline]
fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
<Vec<$nett> as serde::Deserialize>::deserialize(de).map($sett::from)
}
@@ -329,81 +335,10 @@ fn pfx_split(s: &str) -> Result<(&str, u8), NetParseError> {
let i = match s.find('/') {
Some(i) => i,
None => {
- return Err(NetParseError {});
+ return Err(NetParseError);
}
};
let (addr, pfx) = s.split_at(i);
- let pfx = u8::from_str(&pfx[1..]).map_err(|_| NetParseError {})?;
+ let pfx = u8::from_str(&pfx[1..]).map_err(|_| NetParseError)?;
Ok((addr, pfx))
}
-
-#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
-pub struct Endpoint {
- pub address: Ipv6Addr,
- pub port: u16,
-}
-
-impl fmt::Display for Endpoint {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- if self.address.segments()[5] == 0xffff {
- write!(f, "{}:", self.address.to_ipv4().unwrap())?;
- } else {
- write!(f, "[{}]:", self.address)?;
- }
- write!(f, "{}", self.port)
- }
-}
-
-impl FromStr for Endpoint {
- type Err = net::AddrParseError;
- fn from_str(s: &str) -> Result<Endpoint, net::AddrParseError> {
- net::SocketAddr::from_str(s).map(|v| Endpoint {
- address: match v.ip() {
- net::IpAddr::V4(a) => a.to_ipv6_mapped(),
- net::IpAddr::V6(a) => a,
- },
- port: v.port(),
- })
- }
-}
-
-impl serde::Serialize for Endpoint {
- fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
- if ser.is_human_readable() {
- ser.serialize_str(&format!("{}", self))
- } else {
- let mut buf = [0u8; 16 + 2];
- let (buf_addr, buf_port) = mut_array_refs![&mut buf, 16, 2];
- *buf_addr = self.address.octets();
- *buf_port = crate::bin::u16_to_be(self.port);
- ser.serialize_bytes(&buf)
- }
- }
-}
-
-impl<'de> serde::Deserialize<'de> for Endpoint {
- fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
- if de.is_human_readable() {
- struct EndpointVisitor;
- impl<'de> serde::de::Visitor<'de> for EndpointVisitor {
- type Value = Endpoint;
-
- fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
- f.write_str("ip:port")
- }
-
- fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {
- s.parse().map_err(E::custom)
- }
- }
- de.deserialize_str(EndpointVisitor)
- } else {
- let buf = <[u8; 16 + 2] as serde::Deserialize>::deserialize(de)?;
- let (buf_addr, buf_port) = array_refs![&buf, 16, 2];
- Ok(Endpoint {
- address: (*buf_addr).into(),
- port: bin::u16_from_be(*buf_port),
- })
- }
- }
-}
diff --git a/src/main.rs b/src/main.rs
index 088d68e..c56cb91 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -9,8 +9,10 @@ use std::io;
use std::time::{Duration, Instant, SystemTime};
mod bin;
+mod builder;
mod config;
mod ip;
+mod model;
mod proto;
mod wg;
@@ -22,6 +24,7 @@ struct Source {
}
impl Source {
+ #[inline]
fn new(config: config::Source) -> Source {
Source {
config,
@@ -37,7 +40,7 @@ pub struct Device {
peer_config: config::PeerConfig,
update_config: config::UpdateConfig,
sources: Vec<Source>,
- current: wg::Config,
+ current: model::Config,
}
impl Device {
@@ -49,15 +52,15 @@ impl Device {
peer_config: c.peer_config,
update_config: c.update_config,
sources: c.sources.into_iter().map(Source::new).collect(),
- current: wg::Config::default(),
+ current: model::Config::default(),
})
}
fn make_config(
&self,
- public_key: &str,
+ public_key: model::Key,
ts: SystemTime,
- ) -> (wg::Config, Vec<wg::ConfigError>, SystemTime) {
+ ) -> (model::Config, Vec<builder::ConfigError>, SystemTime) {
let mut t_cfg = ts + Duration::from_secs(1 << 30);
let mut sources: Vec<(&Source, &proto::SourceConfig)> = vec![];
for src in self.sources.iter() {
@@ -78,20 +81,21 @@ impl Device {
}
}
- let mut cfg = wg::ConfigBuilder::new(public_key, &self.peer_config);
- let mut errs = vec![];
+ let mut cfg = builder::ConfigBuilder::new(public_key, &self.peer_config);
+
for (src, sc) in sources.iter() {
for peer in sc.servers.iter() {
- cfg.add_server(&mut errs, &src.config, peer);
+ cfg.add_server(&src.config, peer);
}
}
+
for (src, sc) in sources.iter() {
for peer in sc.road_warriors.iter() {
- cfg.add_road_warrior(&mut errs, &src.config, peer);
+ cfg.add_road_warrior(&src.config, peer);
}
}
- let cfg = cfg.build();
+ let (cfg, errs) = cfg.build();
(cfg, errs, t_cfg)
}
@@ -135,7 +139,7 @@ impl Device {
let now = Instant::now();
let sysnow = SystemTime::now();
let public_key = self.dev.get_public_key()?;
- let (config, errors, t_cfg) = self.make_config(&public_key, sysnow);
+ let (config, errors, t_cfg) = self.make_config(public_key, sysnow);
let time_to_cfg = t_cfg
.duration_since(sysnow)
.unwrap_or(Duration::from_secs(0));
diff --git a/src/model.rs b/src/model.rs
new file mode 100644
index 0000000..a7675f2
--- /dev/null
+++ b/src/model.rs
@@ -0,0 +1,197 @@
+// Copyright 2019 Hristo Venev
+//
+// See COPYING.
+
+use base64;
+use crate::bin;
+use crate::ip::{Ipv4Addr, Ipv6Addr, Ipv4Net, Ipv6Net, NetParseError};
+use std::{fmt};
+use std::collections::{HashMap};
+use std::str::FromStr;
+
+pub type KeyParseError = base64::DecodeError;
+
+#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
+pub struct Key([u8; 32]);
+
+impl Key {
+ pub fn from_bytes(s: &[u8]) -> Result<Key, KeyParseError> {
+ let mut v = Key([0u8; 32]);
+ let l = base64::decode_config_slice(s, base64::STANDARD, &mut v.0)?;
+ if l != v.0.len() {
+ return Err(base64::DecodeError::InvalidLength);
+ }
+ Ok(v)
+ }
+}
+
+impl fmt::Display for Key {
+ #[inline]
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", base64::display::Base64Display::with_config(&self.0, base64::STANDARD))
+ }
+}
+
+impl FromStr for Key {
+ type Err = KeyParseError;
+ #[inline]
+ fn from_str(s: &str) -> Result<Key, base64::DecodeError> {
+ Key::from_bytes(s.as_bytes())
+ }
+}
+
+impl serde::Serialize for Key {
+ fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
+ if ser.is_human_readable() {
+ ser.collect_str(self)
+ } else {
+ ser.serialize_bytes(&self.0)
+ }
+ }
+}
+
+impl<'de> serde::Deserialize<'de> for Key {
+ fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
+ if de.is_human_readable() {
+ struct KeyVisitor;
+ impl<'de> serde::de::Visitor<'de> for KeyVisitor {
+ type Value = Key;
+
+ #[inline]
+ fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str("WireGuard key")
+ }
+
+ #[inline]
+ fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {
+ s.parse().map_err(E::custom)
+ }
+ }
+ de.deserialize_str(KeyVisitor)
+ } else {
+ serde::Deserialize::deserialize(de).map(Key)
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
+pub struct Endpoint {
+ address: Ipv6Addr,
+ port: u16,
+}
+
+impl Endpoint {
+ #[inline]
+ pub fn ipv6_address(&self) -> Ipv6Addr {
+ self.address
+ }
+
+ #[inline]
+ pub fn ipv4_address(&self) -> Option<Ipv4Addr> {
+ let seg = self.address.octets();
+ let (first, second) = array_refs![&seg, 12, 4];
+ if *first == [0,0,0,0,0,0,0,0,0,0,0xff,0xff] {
+ Some(Ipv4Addr::from(*second))
+ } else {
+ None
+ }
+ }
+
+ #[inline]
+ pub fn port(&self) -> u16 {
+ self.port
+ }
+}
+
+impl fmt::Display for Endpoint {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ if let Some(ipv4) = self.ipv4_address() {
+ write!(f, "{}:", ipv4)?;
+ } else {
+ write!(f, "[{}]:", self.ipv6_address())?;
+ }
+ write!(f, "{}", self.port())
+ }
+}
+
+impl FromStr for Endpoint {
+ type Err = NetParseError;
+ fn from_str(s: &str) -> Result<Endpoint, NetParseError> {
+ use std::net;
+ net::SocketAddr::from_str(s)
+ .map_err(|_| NetParseError)
+ .map(|v| Endpoint {
+ address: match v.ip() {
+ net::IpAddr::V4(a) => a.to_ipv6_mapped(),
+ net::IpAddr::V6(a) => a,
+ },
+ port: v.port(),
+ })
+ }
+}
+
+impl serde::Serialize for Endpoint {
+ fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
+ if ser.is_human_readable() {
+ ser.collect_str(self)
+ } else {
+ let mut buf = [0u8; 16 + 2];
+ let (buf_addr, buf_port) = mut_array_refs![&mut buf, 16, 2];
+ *buf_addr = self.address.octets();
+ *buf_port = crate::bin::u16_to_be(self.port);
+ ser.serialize_bytes(&buf)
+ }
+ }
+}
+
+impl<'de> serde::Deserialize<'de> for Endpoint {
+ fn deserialize<D: serde::Deserializer<'de>>(de: D) -> Result<Self, D::Error> {
+ if de.is_human_readable() {
+ struct EndpointVisitor;
+ impl<'de> serde::de::Visitor<'de> for EndpointVisitor {
+ type Value = Endpoint;
+
+ #[inline]
+ fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str("IP:port")
+ }
+
+ #[inline]
+ fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {
+ s.parse().map_err(E::custom)
+ }
+ }
+ de.deserialize_str(EndpointVisitor)
+ } else {
+ let buf = <[u8; 16 + 2] as serde::Deserialize>::deserialize(de)?;
+ let (buf_addr, buf_port) = array_refs![&buf, 16, 2];
+ Ok(Endpoint {
+ address: (*buf_addr).into(),
+ port: bin::u16_from_be(*buf_port),
+ })
+ }
+ }
+}
+
+#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
+pub struct Peer {
+ pub endpoint: Option<Endpoint>,
+ pub psk: Option<Key>,
+ pub keepalive: u32,
+ pub ipv4: Vec<Ipv4Net>,
+ pub ipv6: Vec<Ipv6Net>,
+}
+
+#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
+pub struct Config {
+ pub peers: HashMap<Key, Peer>,
+}
+
+impl Default for Config {
+ #[inline]
+ fn default() -> Config {
+ Config {
+ peers: HashMap::new(),
+ }
+ }
+}
diff --git a/src/proto.rs b/src/proto.rs
index 414eee8..bfc3cbb 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -2,15 +2,15 @@
//
// See COPYING.
+use crate::ip::{Ipv4Net, Ipv6Net};
+use crate::model::{Key, Endpoint};
use serde_derive;
use std::time::SystemTime;
-use crate::ip::{Endpoint, Ipv4Net, Ipv6Net};
-
#[serde(deny_unknown_fields)]
#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
pub struct Peer {
- pub public_key: String,
+ pub public_key: Key,
#[serde(default = "Vec::new")]
pub ipv4: Vec<Ipv4Net>,
#[serde(default = "Vec::new")]
@@ -32,11 +32,7 @@ pub struct Server {
pub struct RoadWarrior {
#[serde(flatten)]
pub peer: Peer,
- pub base: String,
-}
-
-fn default_peer_keepalive() -> u32 {
- 0
+ pub base: Key,
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
@@ -62,10 +58,16 @@ pub struct Source {
pub next: Option<SourceNextConfig>,
}
+#[inline]
+fn default_peer_keepalive() -> u32 {
+ 0
+}
+
mod serde_utc {
use crate::bin;
use chrono::{DateTime, SecondsFormat, TimeZone, Utc};
use serde::*;
+ use std::fmt;
use std::time::SystemTime;
pub fn serialize<S: Serializer>(t: &SystemTime, ser: S) -> Result<S::Ok, S::Error> {
@@ -83,9 +85,19 @@ mod serde_utc {
pub fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result<SystemTime, D::Error> {
if de.is_human_readable() {
- let s: String = String::deserialize(de)?;
- let t = DateTime::parse_from_rfc3339(&s).map_err(de::Error::custom)?;
- Ok(t.into())
+ struct RFC3339Visitor;
+ impl<'de> serde::de::Visitor<'de> for RFC3339Visitor {
+ type Value = SystemTime;
+
+ fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str("RFC3339 time")
+ }
+
+ fn visit_str<E: serde::de::Error>(self, s: &str) -> Result<Self::Value, E> {
+ DateTime::parse_from_rfc3339(s).map_err(de::Error::custom).map(SystemTime::from)
+ }
+ }
+ de.deserialize_str(RFC3339Visitor)
} else {
let mut buf = <[u8; 12]>::deserialize(de)?;
let (buf_secs, buf_nanos) = array_refs![&mut buf, 8, 4];
diff --git a/src/wg.rs b/src/wg.rs
index 1de5574..0335c3d 100644
--- a/src/wg.rs
+++ b/src/wg.rs
@@ -2,207 +2,17 @@
//
// See COPYING.
-use crate::ip::{Endpoint, Ipv4Net, Ipv6Net};
-use crate::{config, proto};
-use hash_map::HashMap;
-use std::collections::hash_map;
-use std::{error, fmt, io};
-
-use std::env;
+use crate::{model};
+use std::{env, io};
use std::ffi::{OsStr, OsString};
use std::process::{Command, Stdio};
-#[derive(Debug)]
-pub struct ConfigError {
- pub url: String,
- pub peer: String,
- pub important: bool,
- err: &'static str,
-}
-
-impl ConfigError {
- fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self {
- ConfigError {
- url: s.url.clone(),
- peer: p.public_key.clone(),
- important,
- err,
- }
- }
-}
-
-impl error::Error for ConfigError {}
-impl fmt::Display for ConfigError {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(
- f,
- "{} [{}] from [{}]: {}",
- if self.important {
- "Invalid peer"
- } else {
- "Misconfigured peer"
- },
- self.peer,
- self.url,
- self.err
- )
- }
-}
-
-#[derive(Clone, PartialEq, Eq, Debug)]
-struct Peer {
- endpoint: Option<Endpoint>,
- psk: Option<String>,
- keepalive: u32,
- ipv4: Vec<Ipv4Net>,
- ipv6: Vec<Ipv6Net>,
-}
-
-#[derive(Clone, PartialEq, Eq, Debug)]
-pub struct Config {
- peers: HashMap<String, Peer>,
-}
-
-impl Default for Config {
- fn default() -> Config {
- Config {
- peers: HashMap::new(),
- }
- }
-}
-
-pub struct ConfigBuilder<'a> {
- peers: HashMap<String, Peer>,
- public_key: &'a str,
- pc: &'a config::PeerConfig,
-}
-
-impl<'a> ConfigBuilder<'a> {
- pub fn new(public_key: &'a str, pc: &'a config::PeerConfig) -> Self {
- ConfigBuilder {
- peers: HashMap::new(),
- public_key,
- pc,
- }
- }
-
- pub fn build(self) -> Config {
- Config { peers: self.peers }
- }
-
- fn insert_with<'b>(
- &'b mut self,
- err: &mut Vec<ConfigError>,
- s: &config::Source,
- p: &proto::Peer,
- update: impl for<'c> FnOnce(&'c mut Peer) -> (),
- ) -> &'b mut Peer {
- match self.peers.entry(p.public_key.clone()) {
- hash_map::Entry::Occupied(ent) => {
- err.push(ConfigError::new("Duplicate public key", s, p, true));
- ent.into_mut()
- }
- hash_map::Entry::Vacant(ent) => {
- let ent = ent.insert(Peer {
- endpoint: None,
- psk: None,
- keepalive: 0,
- ipv4: vec![],
- ipv6: vec![],
- });
- update(ent);
- ent
- }
- }
- }
-
- fn add_peer(err: &mut Vec<ConfigError>, ent: &mut Peer, s: &config::Source, p: &proto::Peer) {
- let mut added = false;
- let mut removed = false;
-
- for i in p.ipv4.iter() {
- if s.ipv4.contains(i) {
- ent.ipv4.push(*i);
- added = true;
- } else {
- removed = true;
- }
- }
- for i in p.ipv6.iter() {
- if s.ipv6.contains(i) {
- ent.ipv6.push(*i);
- added = true;
- } else {
- removed = true;
- }
- }
-
- if removed {
- let msg = if added {
- "Some IPs removed"
- } else {
- "All IPs removed"
- };
- err.push(ConfigError::new(msg, s, p, !added));
- }
- }
-
- pub fn add_server(
- &mut self,
- err: &mut Vec<ConfigError>,
- s: &config::Source,
- p: &proto::Server,
- ) {
- if !valid_key(&p.peer.public_key) {
- err.push(ConfigError::new("Invalid public key", s, &p.peer, true));
- return;
- }
-
- if p.peer.public_key == self.public_key {
- return;
- }
-
- let pc = self.pc;
- let ent = self.insert_with(err, s, &p.peer, |ent| {
- ent.psk = s.psk.clone();
- ent.endpoint = Some(p.endpoint.clone());
- ent.keepalive = pc.fix_keepalive(p.keepalive);
- });
-
- Self::add_peer(err, ent, s, &p.peer)
- }
-
- pub fn add_road_warrior(
- &mut self,
- err: &mut Vec<ConfigError>,
- s: &config::Source,
- p: &proto::RoadWarrior,
- ) {
- if !valid_key(&p.peer.public_key) {
- err.push(ConfigError::new("Invalid public key", s, &p.peer, true));
- return;
- }
-
- let ent = if p.base == self.public_key {
- self.insert_with(err, s, &p.peer, |_| {})
- } else {
- match self.peers.get_mut(&p.base) {
- Some(ent) => ent,
- None => {
- err.push(ConfigError::new("Unknown base peer", s, &p.peer, true));
- return;
- }
- }
- };
- Self::add_peer(err, ent, s, &p.peer)
- }
-}
-
pub struct Device {
ifname: String,
}
impl Device {
+ #[inline]
pub fn new(ifname: String) -> io::Result<Self> {
Ok(Device { ifname })
}
@@ -220,7 +30,7 @@ impl Device {
})
}
- pub fn get_public_key(&self) -> io::Result<String> {
+ pub fn get_public_key(&self) -> io::Result<model::Key> {
let mut proc = Device::wg_command();
proc.stdin(Stdio::null());
proc.stdout(Stdio::piped());
@@ -237,17 +47,17 @@ impl Device {
if out.ends_with(b"\n") {
out.remove(out.len() - 1);
}
- String::from_utf8(out)
+ model::Key::from_bytes(&out)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid public key"))
}
- pub fn apply_diff(&mut self, old: &Config, new: &Config) -> io::Result<()> {
+ pub fn apply_diff(&mut self, old: &model::Config, new: &model::Config) -> io::Result<()> {
let mut proc = Device::wg_command();
proc.stdin(Stdio::piped());
proc.arg("set");
proc.arg(&self.ifname);
- let mut psks = Vec::<&str>::new();
+ let mut psks = String::new();
for (pubkey, conf) in new.peers.iter() {
let old_endpoint;
@@ -261,7 +71,7 @@ impl Device {
}
proc.arg("peer");
- proc.arg(pubkey);
+ proc.arg(format!("{}", pubkey));
if old_endpoint != conf.endpoint {
if let Some(ref endpoint) = conf.endpoint {
@@ -273,7 +83,10 @@ impl Device {
if let Some(psk) = &conf.psk {
proc.arg("preshared-key");
proc.arg("/dev/stdin");
- psks.push(psk);
+ {
+ use std::fmt::Write;
+ writeln!(&mut psks, "{}", psk).unwrap();
+ }
}
let mut ips = String::new();
@@ -302,7 +115,7 @@ impl Device {
continue;
}
proc.arg("peer");
- proc.arg(pubkey);
+ proc.arg(format!("{}", pubkey));
proc.arg("remove");
}
@@ -310,9 +123,7 @@ impl Device {
{
use std::io::Write;
let stdin = proc.stdin.as_mut().unwrap();
- for psk in psks {
- writeln!(stdin, "{}", psk)?;
- }
+ write!(stdin, "{}", psks)?;
}
let r = proc.wait()?;
@@ -322,29 +133,3 @@ impl Device {
Ok(())
}
}
-
-fn valid_key(s: &str) -> bool {
- let s = s.as_bytes();
- if s.len() != 44 {
- return false;
- }
- if s[43] != b'=' {
- return false;
- }
- for c in s[0..42].iter().cloned() {
- if c >= b'0' && c <= b'9' {
- continue;
- }
- if c >= b'A' && c <= b'Z' {
- continue;
- }
- if c >= b'a' && c <= b'z' {
- continue;
- }
- if c == b'+' || c <= b'/' {
- continue;
- }
- return false;
- }
- b"048AEIMQUYcgkosw".contains(&s[42])
-}