aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/config.rs16
-rw-r--r--src/main.rs30
-rw-r--r--src/proto.rs26
-rw-r--r--src/wg.rs213
4 files changed, 182 insertions, 103 deletions
diff --git a/src/config.rs b/src/config.rs
index 833b546..9124902 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,6 +1,5 @@
use crate::ip::{Ipv4Set, Ipv6Set};
use serde_derive;
-use std::collections::HashSet;
#[serde(deny_unknown_fields)]
#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
@@ -14,12 +13,11 @@ pub struct Source {
#[serde(deny_unknown_fields)]
#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
pub struct PeerConfig {
+ pub own_public_key: String,
#[serde(default = "default_min_keepalive")]
pub min_keepalive: u32,
#[serde(default = "default_max_keepalive")]
pub max_keepalive: u32,
-
- pub omit_peers: HashSet<String>,
}
#[serde(deny_unknown_fields)]
@@ -46,6 +44,18 @@ pub struct Config {
pub sources: Vec<Source>,
}
+impl PeerConfig {
+ pub fn fix_keepalive(&self, mut k: u32) -> u32 {
+ if self.max_keepalive != 0 && (k == 0 || k > self.max_keepalive) {
+ k = self.max_keepalive;
+ }
+ if k != 0 && k < self.min_keepalive {
+ k = self.min_keepalive;
+ }
+ k
+ }
+}
+
fn default_wg_command() -> String {
"wg".to_owned()
}
diff --git a/src/main.rs b/src/main.rs
index 4ad5f39..deea336 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -38,21 +38,22 @@ pub struct Device {
impl Device {
pub fn new(c: config::Config) -> Device {
+ let dev = wg::Device::new(c.ifname, c.wg_command);
+ let current = wg::ConfigBuilder::new(&c.peers).build();
Device {
- dev: wg::Device::new(c.ifname, c.wg_command),
+ dev,
peer_config: c.peers,
update_config: c.update,
sources: c.sources.into_iter().map(Source::new).collect(),
- current: wg::Config::new(),
+ current,
}
}
fn make_config(&self, ts: SystemTime) -> (wg::Config, Vec<wg::ConfigError>, SystemTime) {
- let mut cfg = wg::Config::new();
let mut next_update = ts + Duration::from_secs(3600);
- let mut errs = vec![];
+ let mut sources: Vec<(&Source, &proto::SourceConfig)> = vec![];
for src in self.sources.iter() {
- if let Some(data) = &src.data {
+ if let Some(ref data) = src.data {
let sc = data
.next
.as_ref()
@@ -65,11 +66,24 @@ impl Device {
}
})
.unwrap_or(&data.config);
- for peer in sc.peers.iter() {
- cfg.add_peer(&mut errs, &self.peer_config, &src.config, peer);
- }
+ sources.push((src, sc));
}
}
+
+ let mut cfg = wg::ConfigBuilder::new(&self.peer_config);
+ let mut errs = vec![];
+ for (src, sc) in sources.iter() {
+ for peer in sc.servers.iter() {
+ cfg.add_server(&mut errs, &src.config, peer);
+ }
+ }
+ for (src, sc) in sources.iter() {
+ for peer in sc.road_warriors.iter() {
+ cfg.add_road_warrior(&mut errs, &src.config, peer);
+ }
+ }
+
+ let cfg = cfg.build();
(cfg, errs, next_update)
}
diff --git a/src/proto.rs b/src/proto.rs
index a3ee6e6..609dbd9 100644
--- a/src/proto.rs
+++ b/src/proto.rs
@@ -7,11 +7,28 @@ use crate::ip::{Endpoint, Ipv4Net, Ipv6Net};
#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
pub struct Peer {
pub public_key: String,
+ #[serde(default = "Vec::new")]
+ pub ipv4: Vec<Ipv4Net>,
+ #[serde(default = "Vec::new")]
+ pub ipv6: Vec<Ipv6Net>,
+}
+
+#[serde(deny_unknown_fields)]
+#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
+pub struct Server {
+ #[serde(flatten)]
+ pub peer: Peer,
pub endpoint: Endpoint,
#[serde(default = "default_peer_keepalive")]
pub keepalive: u32,
- pub ipv4: Vec<Ipv4Net>,
- pub ipv6: Vec<Ipv6Net>,
+}
+
+#[serde(deny_unknown_fields)]
+#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
+pub struct RoadWarrior {
+ #[serde(flatten)]
+ pub peer: Peer,
+ pub base: String,
}
fn default_peer_keepalive() -> u32 {
@@ -20,7 +37,10 @@ fn default_peer_keepalive() -> u32 {
#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
pub struct SourceConfig {
- pub peers: Vec<Peer>,
+ #[serde(default = "Vec::new")]
+ pub servers: Vec<Server>,
+ #[serde(default = "Vec::new")]
+ pub road_warriors: Vec<RoadWarrior>,
}
#[derive(serde_derive::Serialize, serde_derive::Deserialize, Clone, PartialEq, Eq, Debug)]
diff --git a/src/wg.rs b/src/wg.rs
index 910e5e6..650aef9 100644
--- a/src/wg.rs
+++ b/src/wg.rs
@@ -4,20 +4,6 @@ use hash_map::HashMap;
use std::collections::hash_map;
use std::{error, fmt, io};
-#[derive(Clone, PartialEq, Eq, Debug)]
-struct Peer {
- endpoint: Endpoint,
- psk: Option<String>,
- keepalive: u32,
- ipv4: Vec<Ipv4Net>,
- ipv6: Vec<Ipv6Net>,
-}
-
-#[derive(Clone, PartialEq, Eq, Debug)]
-pub struct Config {
- peers: HashMap<String, Peer>,
-}
-
#[derive(Debug)]
pub struct ConfigError {
pub url: String,
@@ -26,6 +12,17 @@ pub struct ConfigError {
err: &'static str,
}
+impl ConfigError {
+ fn new(err: &'static str, s: &config::Source, p: &proto::Peer, important: bool) -> Self {
+ ConfigError {
+ url: s.url.clone(),
+ peer: p.public_key.clone(),
+ important,
+ err,
+ }
+ }
+}
+
impl error::Error for ConfigError {}
impl fmt::Display for ConfigError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -37,80 +34,71 @@ impl fmt::Display for ConfigError {
} else {
"Misconfigured peer"
},
- self.peer, self.url, self.err
+ self.peer,
+ self.url,
+ self.err
)
}
}
-impl Config {
- pub fn new() -> Config {
- Config {
+#[derive(Clone, PartialEq, Eq, Debug)]
+struct Peer {
+ endpoint: Option<Endpoint>,
+ psk: Option<String>,
+ keepalive: u32,
+ ipv4: Vec<Ipv4Net>,
+ ipv6: Vec<Ipv6Net>,
+}
+
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub struct Config {
+ peers: HashMap<String, Peer>,
+}
+
+pub struct ConfigBuilder<'a> {
+ peers: HashMap<String, Peer>,
+ pc: &'a config::PeerConfig,
+}
+
+impl<'a> ConfigBuilder<'a> {
+ pub fn new(pc: &'a config::PeerConfig) -> Self {
+ ConfigBuilder {
peers: HashMap::new(),
+ pc,
}
}
- pub fn add_peer(
- &mut self,
- errors: &mut Vec<ConfigError>,
- c: &config::PeerConfig,
+ pub fn build(self) -> Config {
+ Config { peers: self.peers }
+ }
+
+ fn insert_with<'b>(
+ &'b mut self,
+ err: &mut Vec<ConfigError>,
s: &config::Source,
p: &proto::Peer,
- ) {
- if !valid_key(&p.public_key) {
- errors.push(ConfigError {
- url: s.url.clone(),
- peer: p.public_key.clone(),
- important: true,
- err: "Invalid public key",
- });
- return;
- }
-
- if let Some(ref psk) = s.psk {
- if !valid_key(psk) {
- errors.push(ConfigError {
- url: s.url.clone(),
- peer: p.public_key.clone(),
- important: true,
- err: "Invalid preshared key",
- });
- return;
- }
- }
-
- if c.omit_peers.contains(&p.public_key) {
- return;
- }
-
- let ent = match self.peers.entry(p.public_key.clone()) {
+ update: impl for<'c> FnOnce(&'c mut Peer) -> (),
+ ) -> &'b mut Peer {
+ match self.peers.entry(p.public_key.clone()) {
hash_map::Entry::Occupied(ent) => {
- errors.push(ConfigError {
- url: s.url.clone(),
- peer: p.public_key.clone(),
- important: true,
- err: "Duplicate public key",
- });
+ err.push(ConfigError::new("Duplicate public key", s, p, true));
ent.into_mut()
}
hash_map::Entry::Vacant(ent) => {
- let mut keepalive = p.keepalive;
- if c.max_keepalive != 0 && (keepalive == 0 || keepalive > c.max_keepalive) {
- keepalive = c.max_keepalive;
- }
- if keepalive != 0 && keepalive < c.min_keepalive {
- keepalive = c.min_keepalive;
- }
-
- ent.insert(Peer {
- endpoint: p.endpoint.clone(),
- psk: s.psk.clone(),
- keepalive,
+ let ent = ent.insert(Peer {
+ endpoint: None,
+ psk: None,
+ keepalive: 0,
ipv4: vec![],
ipv6: vec![],
- })
- },
- };
+ });
+ update(ent);
+ ent
+ }
+ }
+ }
+ fn add_peer(err: &mut Vec<ConfigError>, ent: &mut Peer, s: &config::Source, p: &proto::Peer) {
let mut added = false;
let mut removed = false;
@@ -132,24 +120,63 @@ impl Config {
}
if removed {
- errors.push(ConfigError {
- url: s.url.clone(),
- peer: p.public_key.clone(),
- important: !added,
- err: if added {
- "Some IPs removed"
- } else {
- "All IPs removed"
- },
- });
+ let msg = if added {
+ "Some IPs removed"
+ } else {
+ "All IPs removed"
+ };
+ err.push(ConfigError::new(msg, s, p, !added));
}
}
-}
-impl Default for Config {
- #[inline]
- fn default() -> Self {
- Config::new()
+ pub fn add_server(
+ &mut self,
+ err: &mut Vec<ConfigError>,
+ s: &config::Source,
+ p: &proto::Server,
+ ) {
+ if !valid_key(&p.peer.public_key) {
+ err.push(ConfigError::new("Invalid public key", s, &p.peer, true));
+ return;
+ }
+
+ if p.peer.public_key == self.pc.own_public_key {
+ return;
+ }
+
+ let pc = self.pc;
+ let ent = self.insert_with(err, s, &p.peer, |ent| {
+ ent.psk = s.psk.clone();
+ ent.endpoint = Some(p.endpoint.clone());
+ ent.keepalive = pc.fix_keepalive(p.keepalive);
+ });
+
+ Self::add_peer(err, ent, s, &p.peer)
+ }
+
+ pub fn add_road_warrior(
+ &mut self,
+ err: &mut Vec<ConfigError>,
+ s: &config::Source,
+ p: &proto::RoadWarrior,
+ ) {
+ if !valid_key(&p.peer.public_key) {
+ err.push(ConfigError::new("Invalid public key", s, &p.peer, true));
+ return;
+ }
+
+ let ent = if p.base == self.pc.own_public_key {
+ self.insert_with(err, s, &p.peer, |_| {})
+ } else {
+ match self.peers.get_mut(&p.base) {
+ Some(ent) => ent,
+ None => {
+ err.push(ConfigError::new("Unknown base peer", s, &p.peer, true));
+ return;
+ }
+ }
+ };
+ Self::add_peer(err, ent, s, &p.peer)
}
}
@@ -174,17 +201,25 @@ impl Device {
let mut psks = Vec::<&str>::new();
for (pubkey, conf) in new.peers.iter() {
+ let old_endpoint;
if let Some(old_peer) = old.peers.get(pubkey) {
if *old_peer == *conf {
continue;
}
+ old_endpoint = old_peer.endpoint.clone();
+ } else {
+ old_endpoint = None;
}
+
proc.arg("peer");
proc.arg(pubkey);
- // TODO: maybe skip endpoint?
- proc.arg("endpoint");
- proc.arg(format!("{}", conf.endpoint));
+ if old_endpoint != conf.endpoint {
+ if let Some(ref endpoint) = conf.endpoint {
+ proc.arg("endpoint");
+ proc.arg(format!("{}", endpoint));
+ }
+ }
if let Some(psk) = &conf.psk {
proc.arg("preshared-key");