import patchstate as ps
import patchtheory as pt
import os, sys
class PatchSeq:
__slots__ = ('upstreamed', 'diff', 'reverted', 'late')
def __init__(self, *, rebase=pt.Diff()):
self.upstreamed = -rebase
self.diff = pt.Diff()
self.reverted = []
def flush(self):
# We move all late patches back through the reverted
try:
late_kind,late = self.late
except AttributeError:
return
del self.late
print('Flushing', file=sys.stderr)
late = sum(late, pt.Diff())
for rev in reversed(self.reverted):
(late, rev[1]) = pt.Diff.commute(rev[1], late)
if late_kind == 'apply':
self.diff += late
return
(late, self.diff) = pt.Diff.commute(self.diff, late)
if late_kind == 'upstreamed':
self.upstreamed += late
return
raise AssertionError
def _push_late(self, kind, diff):
try:
lkind,late = self.late
except AttributeError:
self.late = (kind, (late := []))
else:
if lkind != kind:
self.flush()
self.late = (kind, (late := []))
late.append(diff)
def _push_reverted(self, by_pid, diff):
self.flush()
self.reverted.append([by_pid, diff])
def _push_revert(self, pid, diff):
self.flush()
for i,rev in enumerate(reverted := self.reverted):
if rev[0] == pid:
break
else:
raise KeyError(f'patch {pid} not marked as reverted')
if reverted.pop(i) is not rev:
raise AssertionError
rdiff = rev[1]
for rev in reverted[i:]:
rev[1],rdiff = pt.Diff.commute(rdiff, rev[1])
if rdiff + diff:
raise KeyError(f'revert of {pid} not exact')
def push(self, patch_info, patch):
diff = patch.diff
mode = patch_info.mode.split(' ', 1)
kind = mode[0]
if kind in {'apply', 'upstreamed'}:
self._push_late(kind, diff)
return
if kind == 'subst':
with open(mode[1], 'rb') as diff2:
diff2 = diff2.read()
diff2 = pt.Diff.parse(diff2)
self._push_late('apply', diff2)
self._push_reverted(None, (-diff2) + diff)
return
if kind == 'reverted':
self._push_reverted(patch_info.patch_id, diff)
return
if kind in {'skip', 'replaced'}:
self._push_reverted(None, diff)
return
if kind == 'reverts':
self._push_revert(mode[1], diff)
return
raise TypeError(f'bad patch kind: {kind!r}')
def finish(self, prb):
self.flush()
for pid,_ in self.reverted:
if pid is not None:
raise RuntimeError(f'missing revert for {pid!r}')
return prb + pt.Diff.commute(-prb + self.upstreamed, self.diff)[0]
def main(args):
args = iter(args)
arg0 = next(args)
@ps.argparse_all(args)
def path(arg):
raise RuntimeError(f'Invalid argument: {arg!r}')
repo_path = path.pop(0)
if path:
with open(path.pop(0), 'rb') as rebase_patch:
rebase_patch = rebase_patch.read()
rebase_patch = pt.Diff.parse(rebase_patch)
else:
rebase_patch = pt.Diff()
if path:
with open(path.pop(0), 'rb') as prb_patch:
prb_patch = prb_patch.read()
prb_patch = pt.Diff.parse(prb_patch)
else:
prb_patch = pt.Diff()
assert not path
repo = ps.Repository(repo_path)
with open(os.path.join(repo.path, 'series'), 'r') as f:
pcur = ps.Series.parse(f.read())
seq = PatchSeq(rebase=rebase_patch)
for pi in pcur.info:
p = pi.get_patch(repo)
print('Applying', p.title, file=sys.stderr)
seq.push(pi, p)
seq.finish(prb_patch).write(sys.stdout.buffer.write)
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv))