aboutsummaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'test.py')
-rw-r--r--test.py187
1 files changed, 187 insertions, 0 deletions
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..b67d26c
--- /dev/null
+++ b/test.py
@@ -0,0 +1,187 @@
+import os
+import random
+import re
+import select
+import shlex
+import shutil
+import socket
+import subprocess
+import tarfile
+import tempfile
+import threading
+import time
+import unittest
+
+def unzst(data):
+ return subprocess.run(['zstd', '-d'], stdout=subprocess.PIPE, input=data).stdout
+
+class Test(unittest.TestCase):
+ def setUp(self):
+ self._done = False
+
+ self._wal_lock = threading.Lock()
+ self._wal_k = -1
+ self._wal_i = 0
+ self._wal_req = threading.Condition(self._wal_lock)
+ self._wal_done = threading.Condition(self._wal_lock)
+
+ self._bak_relax = False
+ self._bak_list = []
+
+ self._prog = os.path.realpath('./pgbak')
+ self._sock,self._sock_prog = socket.socketpair(socket.AF_UNIX)
+
+ self._tmpdir = tempfile.mkdtemp()
+ self._bin_path = os.path.join(self._tmpdir, 'bin')
+ os.mkdir(self._bin_path)
+ self._bak_path = os.path.join(self._tmpdir, 'bak')
+ os.mkdir(self._bak_path)
+ os.mkdir(os.path.join(self._bak_path, 'scripts'))
+ self._db_path = os.path.join(self._tmpdir, 'db')
+ os.mkdir(self._db_path)
+ os.mkdir(os.path.join(self._db_path, 'pg_wal'))
+
+ with open(os.path.join(self._bin_path, 'pg_basebackup'), 'w') as f:
+ f.write('#!/bin/sh\necho -n r >&0; read t; cat base; sleep "$t"; [ "$((RANDOM%200))" != 0 ]\n')
+ os.fchmod(f.fileno(), 0o755)
+
+ with open(os.path.join(self._bak_path, 'scripts', 'backup'), 'w') as f:
+ shp = shlex.quote(self._tmpdir)
+ f.write(f'#!/bin/sh\nset -e; sleep 0.1; [ "$((RANDOM%200))" == 0 ] && exit 1; tar -c . > {shp}/bak-dir.tar; echo -n w >&0; read t\n')
+ os.fchmod(f.fileno(), 0o755)
+
+ self._wal_thr = threading.Thread(target=self._run_wal)
+ self._sock_thr = threading.Thread(target=self._run_sock)
+
+ self._wal_thr.start()
+ self._sock_thr.start()
+
+ self._wal_wait(-1)
+
+ def _run_wal(self):
+ i = 0
+ while True:
+ with open(os.path.join(self._db_path, 'pg_wal', f'{i}'), 'wb') as f:
+ f.write(f'wal {i}\n'.encode() + os.urandom(4096))
+
+ with self._wal_lock:
+ self._wal_k = i
+ self._wal_done.notify_all()
+ while self._wal_i == i:
+ self._wal_req.wait()
+ if self._done:
+ return
+
+ time.sleep(random.random() * 0.2)
+ self._run('wal', f'pg_wal/{i}')
+ i += 1
+
+ def _run_sock(self):
+ while True:
+ op = self._sock.recv(1)
+ if not op:
+ break
+
+ if op == b'w':
+ self._bak_read()
+ self._sock.send(b'.\n')
+ continue
+
+ if op == b'r':
+ i = self._wal_make()
+ with open(os.path.join(self._db_path, 'base'), 'wb') as f:
+ f.write(f'up to {i}\n'.encode() + os.urandom(20480))
+ t = random.random() * 0.2
+ self._wal_make()
+ self._sock.send(f'{t}\n'.encode())
+ continue
+
+ raise RuntimeError(f'bad command {op!r}')
+
+ def _wal_wait(self, i):
+ with self._wal_lock:
+ while self._wal_k <= i:
+ self._wal_done.wait()
+
+ def _wal_gen(self):
+ with self._wal_lock:
+ self._wal_i = (i := self._wal_i) + 1
+ self._wal_req.notify()
+ return i
+
+ def tearDown(self):
+ with self._wal_lock:
+ self._done = True
+ self._wal_req.notify()
+ self._sock_prog.close()
+
+ self._wal_thr.join()
+ self._sock_thr.join()
+
+ self._sock.close()
+ shutil.rmtree(self._tmpdir)
+
+ def _wal_make(self):
+ i = self._wal_gen()
+ self._wal_wait(i)
+ return i
+
+ def _bak_read(self):
+ with tarfile.TarFile(os.path.join(self._tmpdir, 'bak-dir.tar')) as tf:
+ data = unzst(tf.extractfile('./base.tzst').read()).split(b'\n')[0].decode()
+ m = re.fullmatch('up to ([0-9]+)', data)
+ if m is None:
+ raise ValueError(f'bad base data')
+ begin = int(m.group(1))
+
+ wals = set()
+ for i in tf.getmembers():
+ if i.name in {'.', './base.tzst', './pg_wal'}:
+ continue
+ m = re.fullmatch('\./pg_wal/(0|[1-9][0-9]*)\.zst', i.name)
+ if m is None:
+ raise ValueError(f'bad archive member {i.name!r}')
+ data = unzst(tf.extractfile(i).read()).split(b'\n')[0].decode()
+ m = re.fullmatch('wal ([0-9]+)', data)
+ if m is None:
+ raise ValueError(f'bad WAL data in {i.name!r}')
+ wals.add(int(m.group(1)))
+
+ print('>>> got backup', begin,sorted(wals))
+
+ end = begin
+ while end in wals:
+ end += 1
+ end -= 1
+ if end < begin + 1 and not self._bak_relax:
+ raise RuntimeError('missing first WAL')
+
+ self._bak_list.append(end)
+
+ def _run(self, *args):
+ env = {**os.environ}
+ env['PATH'] = f"{self._bin_path}:{env['PATH']}"
+ env['PGBAK'] = self._bak_path
+ subprocess.run([self._prog, *args], env=env, cwd=self._db_path, stdin=self._sock_prog, check=True)
+
+ def testArchiving(self):
+ for _ in range(1000):
+ i = self._wal_make()
+ if random.randrange(20) == 0:
+ self._run('wait')
+ if not self._bak_relax:
+ self.assertGreaterEqual(self._bak_list[-1], i)
+ self._bak_relax = False
+
+ if random.randrange(30) == 0:
+ # Relax the backup consistency requirement until the end of the next sync
+ self._bak_relax = True
+ try:
+ os.unlink(os.path.join(self._bak_path, 'current'))
+ except FileNotFoundError:
+ pass
+ else:
+ print('>>> forcing full backup')
+
+if __name__ == '__main__':
+ unittest.main()