summaryrefslogblamecommitdiff
path: root/patchstate.py
blob: 79600e070911b37ab213cb54ff1a17e74b82a396 (plain) (tree)






























































































































































































































































































































































































                                                                                                                           
import io, sys, os, operator, re, base64, hashlib
import patchtheory

def strip_subject(s):
    if not s.startswith(b'[PATCH'):
        raise ValueError(f'Invalid subject: {s!r}')
    i = s.find(b'] ')
    if i == -1:
        raise ValueError(f'Invalid subject: {s!r}')
    return s[i+2:]

rem_patch = re.compile('patch ([0-9a-z]{16})').fullmatch
rem_author = re.compile('[^\0-\x1f]+ <[^\0-\x1f]+@[^\0-\x1f]+>').fullmatch
rem_date = re.compile('[^\0-\x1f]+').fullmatch
rem_title = re.compile('[^\0-\x1f]+').fullmatch

class Patch:
    __slots__ = ('author', 'date', 'title', 'body', 'diff', 'data', 'hash', 'id')

    def __init__(self, /, *, author, date, title, body, diff):
        assert type(author) is str and rem_author(author) is not None
        assert type(date) is str and rem_date(date) is not None
        assert type(title) is str and rem_title(title) is not None
        assert type(body) is str
        assert type(diff) is patchtheory.Diff

        self.author = author
        self.date = date
        self.title = title
        self.body = body
        self.diff = diff

        self._set_data()
        self._set_id()

    @classmethod
    def parse(tp, body):
        lines = iter(body.split(b'\n'))
        line = lines.__next__

        l = line()
        if l.startswith(b'From '):
            l = line()

        headers = {}
        hdr = None
        while True:
            if not l:
                break
            if hdr is not None and l.startswith(b' '):
                headers[hdr] += l
            else:
                hdr,sep,val = l.partition(b':')
                if not sep:
                    raise ValueError(f'Expected header line: {l!r}')
                headers[hdr] = val.lstrip()
            l = line()

        author = headers.pop(b'From').decode('utf-8')
        date = headers.pop(b'Date').decode('utf-8')
        title = strip_subject(headers.pop(b'Subject')).decode('utf-8')

        headers.pop(b'MIME-Version', None)
        headers.pop(b'Content-Type', None)
        headers.pop(b'Content-Transfer-Encoding', None)

        body = []
        for l in lines:
            if l.startswith(b'diff --git'):
                break
            body.append(l)
        else:
            raise RuntimeError('unexpected EOF')
        while not body[-1]:
            del body[-1]
        while body[-1].startswith(b' '):
            del body[-1]
        if body[-1] != b'---':
            raise RuntimeError
        del body[-1]
        while body and not body[-1]:
            del body[-1]

        body = b'\n'.join(body).decode('utf-8')

        return Patch(
            author = author,
            date = date,
            title = title,
            body = body,
            diff = patchtheory.Diff.parse_l(l, line),
        )

    def _set_id(self):
        out = io.BytesIO()
        w = out.write
        w(self.author.encode('utf-8'))
        w(b'\n')
        w(self.date.encode('utf-8'))
        w(b'\n')
        w(self.title.encode('utf-8'))
        w(b'\n')
        self.diff.write_munged(w)

        out = out.getvalue()
        dgst = hashlib.blake2b(out).digest()
        self.id = base64.b32encode(dgst[:10]).lower().decode('ascii')

    def as_filename(self):
        return '-'.join(re.findall('[0-9A-Za-z]+', self.title))[:72]

    def _set_data(self):
        out = io.BytesIO()
        w = out.write

        w(b'From: ')
        w(self.author.encode('utf-8'))
        w(b'\nDate: ')
        w(self.date.encode('utf-8'))
        w(b'\nSubject: [PATCH] ')
        w(self.title.encode('utf-8'))
        w(b'\n\n')
        w(self.body.encode('utf-8'))
        w(b'\n---\n\n')
        self.diff.write(w)
        w(b'-- \n0.0.0 patchstate\n\n')

        self.data = data = out.getvalue()

        self.hash = base64.b32encode(hashlib.blake2b(data).digest()[:20]).lower().decode('ascii')

    def revert(self):
        r = Patch(
            author = self.author,
            date = self.date,
            title = f'Revert "{self.title}"',
            body = self.body,
            diff = -self.diff,
        )

        r._set_data()
        r._set_id()
        return r

class PatchInfo:
    __slots__ = ('patch_id', '_body')

    def __init__(self, patch_id, info):
        self.patch_id = patch_id
        self._body = info

    def get_patch(self, repo):
        p = repo.patches[self.patch_hash]
        if p.id != self.patch_id:
            raise KeyError(f'patch id mismatch: {self.patch_id!r} -> {p.id!r}')
        if p.title != self.patch_title:
            raise KeyError(f'patch title mismatch: {self.patch_title!r} -> {p.title!r}')
        return p

    @property
    def patch_hash(self):
        return self._body[0]

    @property
    def patch_title(self):
        return self._body[1]

    @property
    def mode(self):
        return self._body[2]

    @property
    def info(self):
        return self._body[3:]

    def update(self, *, patch=None, patch_id=None, patch_hash=None, patch_title=None, mode=None, new_mode=None, info=None):
        if patch_title is None:
            patch_title = self.patch_title

        if patch is not None:
            assert patch_id is None
            assert patch_hash is None
            if patch_title != patch.title:
                raise ValueError('patch title change: {patch_title!r} -> {patch.title!r}')
            patch_id = patch.id
            patch_hash = patch.hash

        else:
            if patch_id is None:
                patch_id = self.patch_id

            if patch_hash is None:
                patch_hash = self.patch_hash

        if info is None:
            info = self.info

        if mode is None:
            mode = self.mode

        if new_mode is not None and new_mode != mode:
            info.insert(0, f'(was {mode})')
            mode = new_mode

        return PatchInfo(patch_id, [patch_hash, patch_title, mode, *info])

class Series:
    __slots__ = ('info',)

    def __init__(self, /, info=None):
        if info is None:
            info = []
        self.info = info

    @classmethod
    def parse(tp, data):
        inf = []

        lines = iter(data.split('\n'))
        l = next(lines)
        while True:
            if not l:
                try:
                    l = next(lines)
                except StopIteration:
                    break
                continue
            m = rem_patch(l)
            if m is None:
                raise ValueError(f'Invalid header: {l!r}')
            pid = m.group(1)
            ldata = []
            while True:
                l = next(lines)
                if not l.startswith('\t'):
                    break
                ldata.append(l[1:])
            if len(ldata) < 3:
                raise ValueError(f'Expected data after {f"patch {pid}"!r}')

            while ldata and not ldata[-1]:
                del ldata[-1]

            inf.append(PatchInfo(pid, ldata))

        return tp(info=inf)

    def fmt(self):
        out = io.StringIO()
        w = out.write

        for pi in self.info:
            w(f'patch {pi.patch_id}')
            for line in pi._body:
                w('\n\t')
                w(line)
            w('\n')

        return out.getvalue()

class Repository:
    __slots__ = ('path', 'patches', 'aliases', 'by_id')

    def __init__(self, path):
        self.path = path
        self.patches = {}
        self.by_id = {}

        aliases = {}
        for fn in os.listdir(path):
            if not fn.endswith('.patch'):
                continue
            fp = os.path.join(path, fn)
            with io.open(fp, 'rb') as f:
                data = f.read()

            try:
                p = Patch.parse(data)
            except ValueError as exc:
                print(f'Failed to parse {fn!r}: {exc}', file=sys.stderr)
                os.unlink(fp)
                continue

            h = p.hash
            if fn != f'{h}.patch':
                aliases[fn[:-6]] = (h,p)
            else:
                self.patches[h] = p
            self.by_id.setdefault(p.id, []).append(p)

        for h,(h2,p) in aliases.items():
            assert h != h2
            assert h not in self.patches
            if h2 in self.patches:
                continue
            print(f'Migrating patch: {h!r} -> {h2!r}', file=sys.stderr)
            self.patches[h2] = p
            with io.open(os.path.join(self.path, f'{h2}.patch'), 'xb') as f:
                f.write(p.data)

        self.aliases = {h: h2 for h,(h2,p) in aliases.items()}

    def hash_data(self, data):
        data = hashlib.blake2b(data).digest()
        return base64.b32encode(data[:20]).lower().decode()

    def insert_patch(self, patch):
        p = self.patches.setdefault(patch.hash, patch)
        if p is patch:
            pid = p.id
            try:
                pd = self.by_id[pid]
            except KeyError:
                self.by_id[pid] = [p]
            else:
                print(f'Duplicate patch ID: {pid!r} {p.title}', file=sys.stderr)
                pd.append(p)

            with io.open(os.path.join(self.path, f'{p.hash}.patch'), 'xb') as f:
                f.write(p.data)

        return p

    def import_dir(self, path):
        r = Series()
        r.info = inf = []

        for fn in sorted(os.listdir(path)):
            if not fn.endswith('.patch'):
                print(f'Skipping {fn!r}', file=sys.stderr)
                continue
            with io.open(os.path.join(path, fn), 'rb') as f:
                patch = Patch.parse(f.read())

            patch = self.insert_patch(patch)
            inf.append(PatchInfo(patch.id, [patch.hash, patch.title, 'new']))

        return r

    def gc(self, keep):
        for h in self.patches.keys() - keep:
            del self.patches[h]
            print(f'Removing patch: {h!r}', file=sys.stderr)
            os.unlink(os.path.join(self.path, f'{h}.patch'))

def argparse_next(args, func=...):
    def do_argparse(func):
        while True:
            try:
                arg = next(args)
            except StopIteration:
                return None
            if arg == '--':
                return next(args, None)
            if not arg.startswith('-'):
                return arg
            arg = func(arg)
            if arg is not None:
                return arg
    if func is not ...:
        return do_argparse(func)
    return do_argparse

def argparse_all(args, func=...):
    def do_argparse(func):
        r = []
        while True:
            try:
                arg = next(args)
            except StopIteration:
                return r
            if arg == '--':
                r.extend(args)
                return r
            if not arg.startswith('-'):
                r.append(arg)
                continue
            arg = func(arg)
            if arg is not None:
                r.extend(arg)
    if func is not ...:
        return do_argparse(func)
    return do_argparse