import os
import random
import re
import select
import shlex
import shutil
import socket
import subprocess
import tarfile
import tempfile
import threading
import time
import unittest
def unzst(data):
return subprocess.run(['zstd', '-d'], stdout=subprocess.PIPE, input=data).stdout
class Test(unittest.TestCase):
def setUp(self):
self._done = False
self._wal_lock = threading.Lock()
self._wal_k = -1
self._wal_i = 0
self._wal_req = threading.Condition(self._wal_lock)
self._wal_done = threading.Condition(self._wal_lock)
self._bak_relax = False
self._bak_list = []
self._prog = os.path.realpath('./pgbak')
self._sock,self._sock_prog = socket.socketpair(socket.AF_UNIX)
self._tmpdir = tempfile.mkdtemp()
self._bin_path = os.path.join(self._tmpdir, 'bin')
os.mkdir(self._bin_path)
self._bak_path = os.path.join(self._tmpdir, 'bak')
os.mkdir(self._bak_path)
os.mkdir(os.path.join(self._bak_path, 'scripts'))
self._db_path = os.path.join(self._tmpdir, 'db')
os.mkdir(self._db_path)
os.mkdir(os.path.join(self._db_path, 'pg_wal'))
with open(os.path.join(self._bin_path, 'pg_basebackup'), 'w') as f:
f.write('#!/bin/sh\necho -n r >&0; read t; cat base; sleep "$t"; [ "$((RANDOM%200))" != 0 ]\n')
os.fchmod(f.fileno(), 0o755)
with open(os.path.join(self._bak_path, 'scripts', 'backup'), 'w') as f:
shp = shlex.quote(self._tmpdir)
f.write(f'#!/bin/sh\nset -e; sleep 0.1; [ "$((RANDOM%200))" == 0 ] && exit 1; tar -c . > {shp}/bak-dir.tar; echo -n w >&0; read t\n')
os.fchmod(f.fileno(), 0o755)
self._wal_thr = threading.Thread(target=self._run_wal)
self._sock_thr = threading.Thread(target=self._run_sock)
self._wal_thr.start()
self._sock_thr.start()
self._wal_wait(-1)
def _run_wal(self):
i = 0
while True:
with open(os.path.join(self._db_path, 'pg_wal', f'{i}'), 'wb') as f:
f.write(f'wal {i}\n'.encode() + os.urandom(4096))
with self._wal_lock:
self._wal_k = i
self._wal_done.notify_all()
while self._wal_i == i:
self._wal_req.wait()
if self._done:
return
time.sleep(random.random() * 0.2)
self._run('wal', f'pg_wal/{i}')
i += 1
def _run_sock(self):
while True:
op = self._sock.recv(1)
if not op:
break
if op == b'w':
self._bak_read()
self._sock.send(b'.\n')
continue
if op == b'r':
print('>>> making snapshot')
i = self._wal_make()
with open(os.path.join(self._db_path, 'base'), 'wb') as f:
f.write(f'up to {i}\n'.encode() + os.urandom(20480))
t = random.random() * 0.2
self._wal_make()
self._sock.send(f'{t}\n'.encode())
continue
raise RuntimeError(f'bad command {op!r}')
def _wal_wait(self, i):
with self._wal_lock:
while self._wal_k <= i:
self._wal_done.wait()
def _wal_gen(self):
with self._wal_lock:
self._wal_i = (i := self._wal_i) + 1
self._wal_req.notify()
return i
def tearDown(self):
with self._wal_lock:
self._done = True
self._wal_req.notify()
self._sock_prog.close()
self._wal_thr.join()
self._sock_thr.join()
self._sock.close()
shutil.rmtree(self._tmpdir)
def _wal_make(self):
i = self._wal_gen()
self._wal_wait(i)
return i
def _bak_read(self):
with tarfile.TarFile(os.path.join(self._tmpdir, 'bak-dir.tar')) as tf:
data = unzst(tf.extractfile('./base.tzst').read()).split(b'\n')[0].decode()
m = re.fullmatch('up to ([0-9]+)', data)
if m is None:
raise ValueError(f'bad base data')
begin = int(m.group(1))
wals = set()
for i in tf.getmembers():
if i.name in {'.', './base.tzst', './pg_wal'}:
continue
m = re.fullmatch('\./pg_wal/(0|[1-9][0-9]*)\.zst', i.name)
if m is None:
raise ValueError(f'bad archive member {i.name!r}')
data = unzst(tf.extractfile(i).read()).split(b'\n')[0].decode()
m = re.fullmatch('wal ([0-9]+)', data)
if m is None:
raise ValueError(f'bad WAL data in {i.name!r}')
wals.add(int(m.group(1)))
print('>>> got backup', begin,sorted(wals))
end = begin
while end in wals:
end += 1
end -= 1
if end < begin + 1 and not self._bak_relax:
raise RuntimeError('missing first WAL')
self._bak_list.append(end)
def _run(self, *args):
env = {**os.environ}
env['PATH'] = f"{self._bin_path}:{env['PATH']}"
env['PGBAK'] = self._bak_path
subprocess.run([self._prog, *args], env=env, cwd=self._db_path, stdin=self._sock_prog, check=True)
def testArchiving(self):
for _ in range(1000):
i = self._wal_make()
if random.randrange(20) == 0:
self._run('wait')
if not self._bak_relax:
self.assertGreaterEqual(self._bak_list[-1], i)
self._bak_relax = False
if random.randrange(30) == 0:
# Relax the backup consistency requirement until the end of the next sync
self._bak_relax = True
try:
os.unlink(os.path.join(self._bak_path, 'current'))
except FileNotFoundError:
pass
else:
print('>>> forcing full backup')
if __name__ == '__main__':
unittest.main()