import patchstate as ps import patchtheory as pt import os, sys class PatchSeq: __slots__ = ('upstreamed', 'diff', 'reverted', 'late') def __init__(self, *, rebase=pt.Diff()): self.upstreamed = -rebase self.diff = pt.Diff() self.reverted = [] def flush(self): # We move all late patches back through the reverted try: late_kind,late = self.late except AttributeError: return del self.late print('Flushing', file=sys.stderr) late = sum(late, pt.Diff()) for rev in reversed(self.reverted): (late, rev[1]) = pt.Diff.commute(rev[1], late) if late_kind == 'apply': self.diff += late return (late, self.diff) = pt.Diff.commute(self.diff, late) if late_kind == 'upstreamed': self.upstreamed += late return raise AssertionError def _push_late(self, kind, diff): try: lkind,late = self.late except AttributeError: self.late = (kind, (late := [])) else: if lkind != kind: self.flush() self.late = (kind, (late := [])) late.append(diff) def _push_reverted(self, by_pid, diff): self.flush() self.reverted.append([by_pid, diff]) def _push_revert(self, pid, diff): self.flush() for i,rev in enumerate(reverted := self.reverted): if rev[0] == pid: break else: raise KeyError(f'patch {pid} not marked as reverted') if reverted.pop(i) is not rev: raise AssertionError rdiff = rev[1] for rev in reverted[i:]: rev[1],rdiff = pt.Diff.commute(rdiff, rev[1]) if rdiff + diff: raise KeyError(f'revert of {pid} not exact') def push(self, patch_info, patch): diff = patch.diff mode = patch_info.mode.split(' ', 1) kind = mode[0] if kind in {'apply', 'upstreamed'}: self._push_late(kind, diff) return if kind == 'subst': with open(mode[1], 'rb') as diff2: diff2 = diff2.read() diff2 = pt.Diff.parse(diff2) self._push_late('apply', diff2) self._push_reverted(None, (-diff2) + diff) return if kind == 'reverted': self._push_reverted(patch_info.patch_id, diff) return if kind in {'skip', 'replaced'}: self._push_reverted(None, diff) return if kind == 'reverts': self._push_revert(mode[1], diff) return raise TypeError(f'bad patch kind: {kind!r}') def finish(self, prb): self.flush() for pid,_ in self.reverted: if pid is not None: raise RuntimeError(f'missing revert for {pid!r}') return prb + pt.Diff.commute(-prb + self.upstreamed, self.diff)[0] def main(args): args = iter(args) arg0 = next(args) @ps.argparse_all(args) def path(arg): raise RuntimeError(f'Invalid argument: {arg!r}') repo_path = path.pop(0) if path: with open(path.pop(0), 'rb') as rebase_patch: rebase_patch = rebase_patch.read() rebase_patch = pt.Diff.parse(rebase_patch) else: rebase_patch = pt.Diff() if path: with open(path.pop(0), 'rb') as prb_patch: prb_patch = prb_patch.read() prb_patch = pt.Diff.parse(prb_patch) else: prb_patch = pt.Diff() assert not path repo = ps.Repository(repo_path) with open(os.path.join(repo.path, 'series'), 'r') as f: pcur = ps.Series.parse(f.read()) seq = PatchSeq(rebase=rebase_patch) for pi in pcur.info: p = pi.get_patch(repo) print('Applying', p.title, file=sys.stderr) seq.push(pi, p) seq.finish(prb_patch).write(sys.stdout.buffer.write) return 0 if __name__ == '__main__': sys.exit(main(sys.argv))