aboutsummaryrefslogtreecommitdiff
path: root/test.py
blob: b67d26c734d4aeb35ac60e1593de8e7540017fab (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import random
import re
import select
import shlex
import shutil
import socket
import subprocess
import tarfile
import tempfile
import threading
import time
import unittest

def unzst(data):
    return subprocess.run(['zstd', '-d'], stdout=subprocess.PIPE, input=data).stdout

class Test(unittest.TestCase):
    def setUp(self):
        self._done = False

        self._wal_lock = threading.Lock()
        self._wal_k = -1
        self._wal_i = 0
        self._wal_req = threading.Condition(self._wal_lock)
        self._wal_done = threading.Condition(self._wal_lock)

        self._bak_relax = False
        self._bak_list = []

        self._prog = os.path.realpath('./pgbak')
        self._sock,self._sock_prog = socket.socketpair(socket.AF_UNIX)

        self._tmpdir = tempfile.mkdtemp()
        self._bin_path = os.path.join(self._tmpdir, 'bin')
        os.mkdir(self._bin_path)
        self._bak_path = os.path.join(self._tmpdir, 'bak')
        os.mkdir(self._bak_path)
        os.mkdir(os.path.join(self._bak_path, 'scripts'))
        self._db_path = os.path.join(self._tmpdir, 'db')
        os.mkdir(self._db_path)
        os.mkdir(os.path.join(self._db_path, 'pg_wal'))

        with open(os.path.join(self._bin_path, 'pg_basebackup'), 'w') as f:
            f.write('#!/bin/sh\necho -n r >&0; read t; cat base; sleep "$t"; [ "$((RANDOM%200))" != 0 ]\n')
            os.fchmod(f.fileno(), 0o755)

        with open(os.path.join(self._bak_path, 'scripts', 'backup'), 'w') as f:
            shp = shlex.quote(self._tmpdir)
            f.write(f'#!/bin/sh\nset -e; sleep 0.1; [ "$((RANDOM%200))" == 0 ] && exit 1; tar -c . > {shp}/bak-dir.tar; echo -n w >&0; read t\n')
            os.fchmod(f.fileno(), 0o755)

        self._wal_thr = threading.Thread(target=self._run_wal)
        self._sock_thr = threading.Thread(target=self._run_sock)

        self._wal_thr.start()
        self._sock_thr.start()

        self._wal_wait(-1)

    def _run_wal(self):
        i = 0
        while True:
            with open(os.path.join(self._db_path, 'pg_wal', f'{i}'), 'wb') as f:
                f.write(f'wal {i}\n'.encode() + os.urandom(4096))

            with self._wal_lock:
                self._wal_k = i
                self._wal_done.notify_all()
                while self._wal_i == i:
                    self._wal_req.wait()
                    if self._done:
                        return

            time.sleep(random.random() * 0.2)
            self._run('wal', f'pg_wal/{i}')
            i += 1

    def _run_sock(self):
        while True:
            op = self._sock.recv(1)
            if not op:
                break

            if op == b'w':
                self._bak_read()
                self._sock.send(b'.\n')
                continue

            if op == b'r':
                i = self._wal_make()
                with open(os.path.join(self._db_path, 'base'), 'wb') as f:
                    f.write(f'up to {i}\n'.encode() + os.urandom(20480))
                t = random.random() * 0.2
                self._wal_make()
                self._sock.send(f'{t}\n'.encode())
                continue

            raise RuntimeError(f'bad command {op!r}')

    def _wal_wait(self, i):
        with self._wal_lock:
            while self._wal_k <= i:
                self._wal_done.wait()

    def _wal_gen(self):
        with self._wal_lock:
            self._wal_i = (i := self._wal_i) + 1
            self._wal_req.notify()
        return i

    def tearDown(self):
        with self._wal_lock:
            self._done = True
            self._wal_req.notify()
        self._sock_prog.close()

        self._wal_thr.join()
        self._sock_thr.join()

        self._sock.close()
        shutil.rmtree(self._tmpdir)

    def _wal_make(self):
        i = self._wal_gen()
        self._wal_wait(i)
        return i

    def _bak_read(self):
        with tarfile.TarFile(os.path.join(self._tmpdir, 'bak-dir.tar')) as tf:
            data = unzst(tf.extractfile('./base.tzst').read()).split(b'\n')[0].decode()
            m = re.fullmatch('up to ([0-9]+)', data)
            if m is None:
                raise ValueError(f'bad base data')
            begin = int(m.group(1))

            wals = set()
            for i in tf.getmembers():
                if i.name in {'.', './base.tzst', './pg_wal'}:
                    continue
                m = re.fullmatch('\./pg_wal/(0|[1-9][0-9]*)\.zst', i.name)
                if m is None:
                    raise ValueError(f'bad archive member {i.name!r}')
                data = unzst(tf.extractfile(i).read()).split(b'\n')[0].decode()
                m = re.fullmatch('wal ([0-9]+)', data)
                if m is None:
                    raise ValueError(f'bad WAL data in {i.name!r}')
                wals.add(int(m.group(1)))

        print('>>> got backup', begin,sorted(wals))

        end = begin
        while end in wals:
            end += 1
        end -= 1
        if end < begin + 1 and not self._bak_relax:
            raise RuntimeError('missing first WAL')

        self._bak_list.append(end)

    def _run(self, *args):
        env = {**os.environ}
        env['PATH'] = f"{self._bin_path}:{env['PATH']}"
        env['PGBAK'] = self._bak_path
        subprocess.run([self._prog, *args], env=env, cwd=self._db_path, stdin=self._sock_prog, check=True)

    def testArchiving(self):
        for _ in range(1000):
            i = self._wal_make()
            if random.randrange(20) == 0:
                self._run('wait')
                if not self._bak_relax:
                    self.assertGreaterEqual(self._bak_list[-1], i)
                self._bak_relax = False

            if random.randrange(30) == 0:
                # Relax the backup consistency requirement until the end of the next sync
                self._bak_relax = True
                try:
                    os.unlink(os.path.join(self._bak_path, 'current'))
                except FileNotFoundError:
                    pass
                else:
                    print('>>> forcing full backup')

if __name__ == '__main__':
    unittest.main()