summaryrefslogblamecommitdiff
path: root/export_series.py
blob: a4ebc8fa38bae69b7912bc6a3bb8d0ad4faf1ca8 (plain) (tree)

























































































































































                                                                          
import patchstate as ps
import patchtheory as pt
import os, sys

class PatchSeq:
    __slots__ = ('upstreamed', 'diff', 'reverted', 'late')

    def __init__(self, *, rebase=pt.Diff()):
        self.upstreamed = -rebase
        self.diff = pt.Diff()
        self.reverted = []

    def flush(self):
        # We move all late patches back through the reverted
        try:
            late_kind,late = self.late
        except AttributeError:
            return
        del self.late

        print('Flushing', file=sys.stderr)

        late = sum(late, pt.Diff())

        for rev in reversed(self.reverted):
            (late, rev[1]) = pt.Diff.commute(rev[1], late)

        if late_kind == 'apply':
            self.diff += late
            return

        (late, self.diff) = pt.Diff.commute(self.diff, late)

        if late_kind == 'upstreamed':
            self.upstreamed += late
            return

        raise AssertionError

    def _push_late(self, kind, diff):
        try:
            lkind,late = self.late
        except AttributeError:
            self.late = (kind, (late := []))
        else:
            if lkind != kind:
                self.flush()
                self.late = (kind, (late := []))
        late.append(diff)

    def _push_reverted(self, by_pid, diff):
        self.flush()
        self.reverted.append([by_pid, diff])

    def _push_revert(self, pid, diff):
        self.flush()

        for i,rev in enumerate(reverted := self.reverted):
            if rev[0] == pid:
                break
        else:
            raise KeyError(f'patch {pid} not marked as reverted')

        if reverted.pop(i) is not rev:
            raise AssertionError

        rdiff = rev[1]
        for rev in reverted[i:]:
            rev[1],rdiff = pt.Diff.commute(rdiff, rev[1])

        if rdiff + diff:
            raise KeyError(f'revert of {pid} not exact')

    def push(self, patch_info, patch):
        diff = patch.diff
        mode = patch_info.mode.split(' ', 1)
        kind = mode[0]

        if kind in {'apply', 'upstreamed'}:
            self._push_late(kind, diff)
            return

        if kind == 'subst':
            with open(mode[1], 'rb') as diff2:
                diff2 = diff2.read()
            diff2 = pt.Diff.parse(diff2)
            self._push_late('apply', diff2)
            self._push_reverted(None, (-diff2) + diff)
            return

        if kind == 'reverted':
            self._push_reverted(patch_info.patch_id, diff)
            return

        if kind in {'skip', 'replaced'}:
            self._push_reverted(None, diff)
            return

        if kind == 'reverts':
            self._push_revert(mode[1], diff)
            return

        raise TypeError(f'bad patch kind: {kind!r}')

    def finish(self, prb):
        self.flush()

        for pid,_ in self.reverted:
            if pid is not None:
                raise RuntimeError(f'missing revert for {pid!r}')

        return prb + pt.Diff.commute(-prb + self.upstreamed, self.diff)[0]

def main(args):
    args = iter(args)
    arg0 = next(args)

    @ps.argparse_all(args)
    def path(arg):
        raise RuntimeError(f'Invalid argument: {arg!r}')

    repo_path = path.pop(0)

    if path:
        with open(path.pop(0), 'rb') as rebase_patch:
            rebase_patch = rebase_patch.read()
        rebase_patch = pt.Diff.parse(rebase_patch)
    else:
        rebase_patch = pt.Diff()

    if path:
        with open(path.pop(0), 'rb') as prb_patch:
            prb_patch = prb_patch.read()
        prb_patch = pt.Diff.parse(prb_patch)
    else:
        prb_patch = pt.Diff()

    assert not path

    repo = ps.Repository(repo_path)

    with open(os.path.join(repo.path, 'series'), 'r') as f:
        pcur = ps.Series.parse(f.read())

    seq = PatchSeq(rebase=rebase_patch)
    for pi in pcur.info:
        p = pi.get_patch(repo)
        print('Applying', p.title, file=sys.stderr)
        seq.push(pi, p)
    seq.finish(prb_patch).write(sys.stdout.buffer.write)
    return 0

if __name__ == '__main__':
    sys.exit(main(sys.argv))