import os import random import re import select import shlex import shutil import socket import subprocess import tarfile import tempfile import threading import time import unittest def unzst(data): return subprocess.run(['zstd', '-d'], stdout=subprocess.PIPE, input=data).stdout class Test(unittest.TestCase): def setUp(self): self._done = False self._wal_lock = threading.Lock() self._wal_k = -1 self._wal_i = 0 self._wal_req = threading.Condition(self._wal_lock) self._wal_done = threading.Condition(self._wal_lock) self._bak_relax = False self._bak_list = [] self._prog = os.path.realpath('./pgbak') self._sock,self._sock_prog = socket.socketpair(socket.AF_UNIX) self._tmpdir = tempfile.mkdtemp() self._bin_path = os.path.join(self._tmpdir, 'bin') os.mkdir(self._bin_path) self._bak_path = os.path.join(self._tmpdir, 'bak') os.mkdir(self._bak_path) os.mkdir(os.path.join(self._bak_path, 'scripts')) self._db_path = os.path.join(self._tmpdir, 'db') os.mkdir(self._db_path) os.mkdir(os.path.join(self._db_path, 'pg_wal')) with open(os.path.join(self._bin_path, 'pg_basebackup'), 'w') as f: f.write('#!/bin/sh\necho -n r >&0; read t; cat base; sleep "$t"; [ "$((RANDOM%200))" != 0 ]\n') os.fchmod(f.fileno(), 0o755) with open(os.path.join(self._bak_path, 'scripts', 'backup'), 'w') as f: shp = shlex.quote(self._tmpdir) f.write(f'#!/bin/sh\nset -e; sleep 0.1; [ "$((RANDOM%200))" == 0 ] && exit 1; tar -c . > {shp}/bak-dir.tar; echo -n w >&0; read t\n') os.fchmod(f.fileno(), 0o755) self._wal_thr = threading.Thread(target=self._run_wal) self._sock_thr = threading.Thread(target=self._run_sock) self._wal_thr.start() self._sock_thr.start() self._wal_wait(-1) def _run_wal(self): i = 0 while True: with open(os.path.join(self._db_path, 'pg_wal', f'{i}'), 'wb') as f: f.write(f'wal {i}\n'.encode() + os.urandom(4096)) with self._wal_lock: self._wal_k = i self._wal_done.notify_all() while self._wal_i == i: self._wal_req.wait() if self._done: return time.sleep(random.random() * 0.2) self._run('wal', f'pg_wal/{i}') i += 1 def _run_sock(self): while True: op = self._sock.recv(1) if not op: break if op == b'w': self._bak_read() self._sock.send(b'.\n') continue if op == b'r': print('>>> making snapshot') i = self._wal_make() with open(os.path.join(self._db_path, 'base'), 'wb') as f: f.write(f'up to {i}\n'.encode() + os.urandom(20480)) t = random.random() * 0.2 self._wal_make() self._sock.send(f'{t}\n'.encode()) continue raise RuntimeError(f'bad command {op!r}') def _wal_wait(self, i): with self._wal_lock: while self._wal_k <= i: self._wal_done.wait() def _wal_gen(self): with self._wal_lock: self._wal_i = (i := self._wal_i) + 1 self._wal_req.notify() return i def tearDown(self): with self._wal_lock: self._done = True self._wal_req.notify() self._sock_prog.close() self._wal_thr.join() self._sock_thr.join() self._sock.close() shutil.rmtree(self._tmpdir) def _wal_make(self): i = self._wal_gen() self._wal_wait(i) return i def _bak_read(self): with tarfile.TarFile(os.path.join(self._tmpdir, 'bak-dir.tar')) as tf: data = unzst(tf.extractfile('./base.tzst').read()).split(b'\n')[0].decode() m = re.fullmatch('up to ([0-9]+)', data) if m is None: raise ValueError(f'bad base data') begin = int(m.group(1)) wals = set() for i in tf.getmembers(): if i.name in {'.', './base.tzst', './pg_wal'}: continue m = re.fullmatch('\./pg_wal/(0|[1-9][0-9]*)\.zst', i.name) if m is None: raise ValueError(f'bad archive member {i.name!r}') data = unzst(tf.extractfile(i).read()).split(b'\n')[0].decode() m = re.fullmatch('wal ([0-9]+)', data) if m is None: raise ValueError(f'bad WAL data in {i.name!r}') wals.add(int(m.group(1))) print('>>> got backup', begin,sorted(wals)) end = begin while end in wals: end += 1 end -= 1 if end < begin + 1 and not self._bak_relax: raise RuntimeError('missing first WAL') self._bak_list.append(end) def _run(self, *args): env = {**os.environ} env['PATH'] = f"{self._bin_path}:{env['PATH']}" env['PGBAK'] = self._bak_path subprocess.run([self._prog, *args], env=env, cwd=self._db_path, stdin=self._sock_prog, check=True) def testArchiving(self): for _ in range(1000): i = self._wal_make() if random.randrange(20) == 0: self._run('wait') if not self._bak_relax: self.assertGreaterEqual(self._bak_list[-1], i) self._bak_relax = False if random.randrange(30) == 0: # Relax the backup consistency requirement until the end of the next sync self._bak_relax = True try: os.unlink(os.path.join(self._bak_path, 'current')) except FileNotFoundError: pass else: print('>>> forcing full backup') if __name__ == '__main__': unittest.main()