aboutsummaryrefslogblamecommitdiff
path: root/test.py
blob: 163c589e27bdceea8d7c77de154f835174639f61 (plain) (tree)

























































































                                                                                                                                                 
                                            
































































































                                                                                                          
import os
import random
import re
import select
import shlex
import shutil
import socket
import subprocess
import tarfile
import tempfile
import threading
import time
import unittest

def unzst(data):
    return subprocess.run(['zstd', '-d'], stdout=subprocess.PIPE, input=data).stdout

class Test(unittest.TestCase):
    def setUp(self):
        self._done = False

        self._wal_lock = threading.Lock()
        self._wal_k = -1
        self._wal_i = 0
        self._wal_req = threading.Condition(self._wal_lock)
        self._wal_done = threading.Condition(self._wal_lock)

        self._bak_relax = False
        self._bak_list = []

        self._prog = os.path.realpath('./pgbak')
        self._sock,self._sock_prog = socket.socketpair(socket.AF_UNIX)

        self._tmpdir = tempfile.mkdtemp()
        self._bin_path = os.path.join(self._tmpdir, 'bin')
        os.mkdir(self._bin_path)
        self._bak_path = os.path.join(self._tmpdir, 'bak')
        os.mkdir(self._bak_path)
        os.mkdir(os.path.join(self._bak_path, 'scripts'))
        self._db_path = os.path.join(self._tmpdir, 'db')
        os.mkdir(self._db_path)
        os.mkdir(os.path.join(self._db_path, 'pg_wal'))

        with open(os.path.join(self._bin_path, 'pg_basebackup'), 'w') as f:
            f.write('#!/bin/sh\necho -n r >&0; read t; cat base; sleep "$t"; [ "$((RANDOM%200))" != 0 ]\n')
            os.fchmod(f.fileno(), 0o755)

        with open(os.path.join(self._bak_path, 'scripts', 'backup'), 'w') as f:
            shp = shlex.quote(self._tmpdir)
            f.write(f'#!/bin/sh\nset -e; sleep 0.1; [ "$((RANDOM%200))" == 0 ] && exit 1; tar -c . > {shp}/bak-dir.tar; echo -n w >&0; read t\n')
            os.fchmod(f.fileno(), 0o755)

        self._wal_thr = threading.Thread(target=self._run_wal)
        self._sock_thr = threading.Thread(target=self._run_sock)

        self._wal_thr.start()
        self._sock_thr.start()

        self._wal_wait(-1)

    def _run_wal(self):
        i = 0
        while True:
            with open(os.path.join(self._db_path, 'pg_wal', f'{i}'), 'wb') as f:
                f.write(f'wal {i}\n'.encode() + os.urandom(4096))

            with self._wal_lock:
                self._wal_k = i
                self._wal_done.notify_all()
                while self._wal_i == i:
                    self._wal_req.wait()
                    if self._done:
                        return

            time.sleep(random.random() * 0.2)
            self._run('wal', f'pg_wal/{i}')
            i += 1

    def _run_sock(self):
        while True:
            op = self._sock.recv(1)
            if not op:
                break

            if op == b'w':
                self._bak_read()
                self._sock.send(b'.\n')
                continue

            if op == b'r':
                print('>>> making snapshot')
                i = self._wal_make()
                with open(os.path.join(self._db_path, 'base'), 'wb') as f:
                    f.write(f'up to {i}\n'.encode() + os.urandom(20480))
                t = random.random() * 0.2
                self._wal_make()
                self._sock.send(f'{t}\n'.encode())
                continue

            raise RuntimeError(f'bad command {op!r}')

    def _wal_wait(self, i):
        with self._wal_lock:
            while self._wal_k <= i:
                self._wal_done.wait()

    def _wal_gen(self):
        with self._wal_lock:
            self._wal_i = (i := self._wal_i) + 1
            self._wal_req.notify()
        return i

    def tearDown(self):
        with self._wal_lock:
            self._done = True
            self._wal_req.notify()
        self._sock_prog.close()

        self._wal_thr.join()
        self._sock_thr.join()

        self._sock.close()
        shutil.rmtree(self._tmpdir)

    def _wal_make(self):
        i = self._wal_gen()
        self._wal_wait(i)
        return i

    def _bak_read(self):
        with tarfile.TarFile(os.path.join(self._tmpdir, 'bak-dir.tar')) as tf:
            data = unzst(tf.extractfile('./base.tzst').read()).split(b'\n')[0].decode()
            m = re.fullmatch('up to ([0-9]+)', data)
            if m is None:
                raise ValueError(f'bad base data')
            begin = int(m.group(1))

            wals = set()
            for i in tf.getmembers():
                if i.name in {'.', './base.tzst', './pg_wal'}:
                    continue
                m = re.fullmatch('\./pg_wal/(0|[1-9][0-9]*)\.zst', i.name)
                if m is None:
                    raise ValueError(f'bad archive member {i.name!r}')
                data = unzst(tf.extractfile(i).read()).split(b'\n')[0].decode()
                m = re.fullmatch('wal ([0-9]+)', data)
                if m is None:
                    raise ValueError(f'bad WAL data in {i.name!r}')
                wals.add(int(m.group(1)))

        print('>>> got backup', begin,sorted(wals))

        end = begin
        while end in wals:
            end += 1
        end -= 1
        if end < begin + 1 and not self._bak_relax:
            raise RuntimeError('missing first WAL')

        self._bak_list.append(end)

    def _run(self, *args):
        env = {**os.environ}
        env['PATH'] = f"{self._bin_path}:{env['PATH']}"
        env['PGBAK'] = self._bak_path
        subprocess.run([self._prog, *args], env=env, cwd=self._db_path, stdin=self._sock_prog, check=True)

    def testArchiving(self):
        for _ in range(1000):
            i = self._wal_make()
            if random.randrange(20) == 0:
                self._run('wait')
                if not self._bak_relax:
                    self.assertGreaterEqual(self._bak_list[-1], i)
                self._bak_relax = False

            if random.randrange(30) == 0:
                # Relax the backup consistency requirement until the end of the next sync
                self._bak_relax = True
                try:
                    os.unlink(os.path.join(self._bak_path, 'current'))
                except FileNotFoundError:
                    pass
                else:
                    print('>>> forcing full backup')

if __name__ == '__main__':
    unittest.main()