aboutsummaryrefslogblamecommitdiff
path: root/test.py
blob: ec0c5b85745defacf1aa21cbdc141a601afd116f (plain) (tree)


























                                                                                    
                                 














                                                                           
                                                                                                                                                  



                                                                               
                                                                                                                                                                                   



































                                                                                




                                               


                          
                                            


                                                                          






                                                      


























































                                                                                       
                           

                                                   

                                                                         











                                                                                                          





                                                                 


                          
import os
import random
import re
import select
import shlex
import shutil
import socket
import subprocess
import tarfile
import tempfile
import threading
import time
import unittest

def unzst(data):
    return subprocess.run(['zstd', '-d'], stdout=subprocess.PIPE, input=data).stdout

class Test(unittest.TestCase):
    def setUp(self):
        self._done = False

        self._wal_lock = threading.Lock()
        self._wal_k = -1
        self._wal_i = 0
        self._wal_req = threading.Condition(self._wal_lock)
        self._wal_done = threading.Condition(self._wal_lock)

        self._bak_list = [(0, 0)]

        self._prog = os.path.realpath('./pgbak')
        self._sock,self._sock_prog = socket.socketpair(socket.AF_UNIX)

        self._tmpdir = tempfile.mkdtemp()
        self._bin_path = os.path.join(self._tmpdir, 'bin')
        os.mkdir(self._bin_path)
        self._bak_path = os.path.join(self._tmpdir, 'bak')
        os.mkdir(self._bak_path)
        os.mkdir(os.path.join(self._bak_path, 'scripts'))
        self._db_path = os.path.join(self._tmpdir, 'db')
        os.mkdir(self._db_path)
        os.mkdir(os.path.join(self._db_path, 'pg_wal'))

        with open(os.path.join(self._bin_path, 'pg_basebackup'), 'w') as f:
            f.write('#!/bin/sh\necho -n r >&0; read t; [ "$t" != k ] || exec kill -9 "$PPID"; cat base; sleep "$t"; [ "$((RANDOM%200))" != 0 ]\n')
            os.fchmod(f.fileno(), 0o755)

        with open(os.path.join(self._bak_path, 'scripts', 'backup'), 'w') as f:
            shp = shlex.quote(self._tmpdir)
            f.write(f'#!/bin/sh\nset -e; sleep 0.1; [ "$((RANDOM%200))" == 0 ] && exit 1; tar -c . > {shp}/bak-dir.tar; echo -n w >&0; read t; [ "$t" != k ] || kill -9 "$PPID"\n')
            os.fchmod(f.fileno(), 0o755)

        self._wal_thr = threading.Thread(target=self._run_wal)
        self._sock_thr = threading.Thread(target=self._run_sock)

        self._wal_thr.start()
        self._sock_thr.start()

        self._wal_wait(-1)

    def _run_wal(self):
        i = 0
        while True:
            with open(os.path.join(self._db_path, 'pg_wal', f'{i}'), 'wb') as f:
                f.write(f'wal {i}\n'.encode() + os.urandom(4096))

            with self._wal_lock:
                self._wal_k = i
                self._wal_done.notify_all()
                while self._wal_i == i:
                    self._wal_req.wait()
                    if self._done:
                        return

            time.sleep(random.random() * 0.2)
            self._run('wal', f'pg_wal/{i}')
            i += 1

    def _run_sock(self):
        while True:
            op = self._sock.recv(1)
            if not op:
                break

            if op == b'w':
                self._bak_read()
                if random.randrange(10) == 0:
                    print('>>> killing backup')
                    self._sock.send(b'k\n')
                else:
                    self._sock.send(b'.\n')
                continue

            if op == b'r':
                print('>>> making snapshot')
                i = self._wal_make()
                with open(os.path.join(self._db_path, 'base'), 'wb') as f:
                    f.write(f'up to {i}\n'.encode() + os.urandom(20480))
                if random.randrange(10) == 0:
                    print('>>> killing snapshot')
                    self._sock.send(b'k\n')
                else:
                    t = random.random() * 0.2
                    self._wal_make()
                    self._sock.send(f'{t}\n'.encode())
                continue

            raise RuntimeError(f'bad command {op!r}')

    def _wal_wait(self, i):
        with self._wal_lock:
            while self._wal_k <= i:
                self._wal_done.wait()

    def _wal_gen(self):
        with self._wal_lock:
            self._wal_i = (i := self._wal_i) + 1
            self._wal_req.notify()
        return i

    def tearDown(self):
        with self._wal_lock:
            self._done = True
            self._wal_req.notify()
        self._sock_prog.close()

        self._wal_thr.join()
        self._sock_thr.join()

        self._sock.close()
        shutil.rmtree(self._tmpdir)

    def _wal_make(self):
        i = self._wal_gen()
        self._wal_wait(i)
        return i

    def _bak_read(self):
        with tarfile.TarFile(os.path.join(self._tmpdir, 'bak-dir.tar')) as tf:
            data = unzst(tf.extractfile('./base.tzst').read()).split(b'\n')[0].decode()
            m = re.fullmatch('up to ([0-9]+)', data)
            if m is None:
                raise ValueError(f'bad base data')
            begin = int(m.group(1))

            wals = set()
            for i in tf.getmembers():
                if i.name in {'.', './base.tzst', './pg_wal'}:
                    continue
                m = re.fullmatch('\./pg_wal/(0|[1-9][0-9]*)\.zst', i.name)
                if m is None:
                    raise ValueError(f'bad archive member {i.name!r}')
                data = unzst(tf.extractfile(i).read()).split(b'\n')[0].decode()
                m = re.fullmatch('wal ([0-9]+)', data)
                if m is None:
                    raise ValueError(f'bad WAL data in {i.name!r}')
                wals.add(int(m.group(1)))

        print('>>> got backup', begin,sorted(wals))

        end = begin
        while end in wals:
            end += 1
        end -= 1
        if end < begin + 1:
            raise RuntimeError('missing first WAL')

        assert len(self._bak_list) == 1 or begin <= self._bak_list[-1][1]
        self._bak_list.append((begin, end))

    def _run(self, *args):
        env = {**os.environ}
        env['PATH'] = f"{self._bin_path}:{env['PATH']}"
        env['PGBAK'] = self._bak_path
        subprocess.run([self._prog, *args], env=env, cwd=self._db_path, stdin=self._sock_prog, check=True)

    def testArchiving(self):
        for _ in range(1000):
            i = self._wal_make()
            if random.randrange(20) == 0:
                self._run('wait')
                self.assertGreaterEqual(self._bak_list[-1][1], i)

            elif random.randrange(30) == 0:
                print('>>> forcing full backup')
                self._run('full-sync')
                self.assertGreaterEqual(self._bak_list[-1][1], i)

if __name__ == '__main__':
    unittest.main()