import re, operator, io, sys
rem_range = re.compile(b'@@ -([0-9]+)(?:,([0-9]+))? \\+([0-9]+)(?:,([0-9]+))? @@(?: .*)?').fullmatch
rem_diffgit = re.compile('diff --git a/([^ ]*) b/([^ ]*)').fullmatch
INFINITY = 2**64
class PatchError(Exception):
__slots__ = ('a', 'b')
def __init__(self, a, b, /, fn=None):
super().__init__(fn)
self.a = a
self.b = b
def __str__(self, /):
fn = self.args[0]
r = [f' in {fn!r}:' if fn is not None else ':']
def tabmangle(p):
o = io.BytesIO()
p.write(o.write)
p = o.getvalue().decode('utf-8', 'replace').splitlines()
return [' ' + l for l in p]
r.append('')
r += tabmangle(self.a)
r.append('')
r += tabmangle(self.b)
return '\n'.join(r)
class ConflictError(PatchError):
__slots__ = ()
def __str__(self, /):
return 'Conflict' + PatchError.__str__(self)
class MismatchError(PatchError):
__slots__ = ()
def __str__(self, /):
return 'Mismatch' + PatchError.__str__(self)
class Hunk:
__slots__ = ('lhs_line', 'rhs_line', 'pre_context', 'lhs', 'rhs', 'post_context')
def __init__(self, /, *, lhs_line, rhs_line, pre_context, lhs, rhs, post_context):
assert lhs_line > 0
assert rhs_line > 0
self.lhs_line = lhs_line
self.rhs_line = rhs_line
self.pre_context = pre_context
self.lhs = lhs
self.rhs = rhs
self.post_context = post_context
@property
def lhs_size(self, /):
return len(self.pre_context) + len(self.lhs) + len(self.post_context)
@property
def rhs_size(self, /):
return len(self.pre_context) + len(self.rhs) + len(self.post_context)
def write(self, w, /, munge=False, *, lhs_line=None, rhs_line=None, lhs_nl=True, rhs_nl=True):
if lhs_line is None:
lhs_line = self.lhs_line
if rhs_line is None:
rhs_line = self.rhs_line
pre_context = self.pre_context
lhs = self.lhs
rhs = self.rhs
post_context = self.post_context
clen = len(pre_context) + len(post_context)
llen = len(lhs) + clen
rlen = len(rhs) + clen
if munge:
w(f'@@ -xxx,{llen} +yyy,{rlen} @@'.encode('utf-8'))
else:
w(f'@@ -{lhs_line},{llen} +{rhs_line},{rlen} @@'.encode('utf-8'))
for l in pre_context:
w(b'\n ')
w(l)
for l in lhs:
w(b'\n-')
w(l)
for l in rhs:
w(b'\n+')
w(l)
for l in post_context:
w(b'\n ')
w(l)
w(b'\n')
@classmethod
def parse_l(tp, l, line, /):
m = rem_range(l)
if m is None:
raise ValueError(f'Invalid patch: bad range info {l!r}')
lhs_line,n2,rhs_line,n4 = m.groups()
lhs_line = 1 if lhs_line == b'0' else int(lhs_line)
rhs_line = 1 if rhs_line == b'0' else int(rhs_line)
n2 = int(n2) if n2 else 1
n4 = int(n4) if n4 else 1
hunk = []
prev = -1
lhs_nl = True
rhs_nl = True
post_nl = True
while True:
l = line()
if l.startswith(b'\\ '):
if prev == 0:
lhs_nl = False
elif prev == 1:
rhs_nl = False
elif prev == -1:
post_nl = False
else:
raise ValueError('Invalid patch: bad "No newline at end of file" flag')
prev = -2
continue
if not n2 and not n4:
break
k = l[0]
if k == b' '[0]:
if not post_nl or not lhs_nl or not rhs_nl:
raise ValueError('Invalid patch: bad "No newline at end of file" flag')
if not n2 or not n4:
raise ValueError('Invalid patch: bad line count')
n2 -= 1
n4 -= 1
prev = -1
elif k == b'-'[0]:
if not post_nl or not lhs_nl:
raise ValueError('Invalid patch: bad "No newline at end of file" flag')
if not n2:
raise ValueError('Invalid patch: bad line count')
n2 -= 1
prev = 0
elif k == b'+'[0]:
if not post_nl or not rhs_nl:
raise ValueError('Invalid patch: bad "No newline at end of file" flag')
if not n4:
raise ValueError('Invalid patch: bad line count')
n4 -= 1
prev = 1
else:
raise ValueError(f'Invalid patch: bad prefix character {chr(k)!r}')
assert False
hunk.append(l)
lret = l
strip1 = operator.itemgetter(slice(1,None))
for i,l in enumerate(hunk):
if l[0] != b' '[0]:
break
else:
raise ValueError('Invalid patch: no changed lines in hunk')
pre_context = (*map(strip1, hunk[:i]),)
del hunk[:i]
j = -1
while True:
l = hunk[j]
if l[0] != b' '[0]:
break
j -= 1
j += len(hunk) + 1
if not post_nl:
hunk[j] += b'\n\\ No newline at end of file'
post_context = (*map(strip1, hunk[j:]),)
del hunk[j:]
lhs = []
rhs = []
for l in hunk:
l0 = l[0]
if l0 != b'+'[0]:
lhs.append(l)
if l0 != b'-'[0]:
rhs.append(l)
if not lhs_nl:
lhs[-1] += b'\n\\ No newline at end of file'
if not rhs_nl:
rhs[-1] += b'\n\\ No newline at end of file'
return lret, tp(
lhs_line = lhs_line,
rhs_line = rhs_line,
pre_context = pre_context,
lhs = (*map(strip1, lhs),),
rhs = (*map(strip1, rhs),),
post_context = post_context,
)
def reduce_context(self, n=3):
pre = self.pre_context
lhs = self.lhs
rhs = self.rhs
post = self.post_context
i = 0
j = -1
m = min(len(lhs), len(rhs))
while i - j <= m:
if lhs[i] == rhs[i]:
i += 1
if i - j > m:
break
if lhs[j] == rhs[j]:
j -= 1
else:
if lhs[j] != rhs[j]:
break
j -= 1
j += 1
jl = len(lhs) + j
jr = len(rhs) + j
assert lhs[:i] == rhs[:i]
assert lhs[jl:] == rhs[jr:]
pre = pre + lhs[:i]
post = lhs[jl:] + post
lhs = lhs[i:jl]
rhs = rhs[i:jr]
if not lhs and not rhs:
return None
assert lhs != rhs
assert not lhs or not rhs or lhs[0] != rhs[0] and lhs[-1] != rhs[-1]
lhs_line = self.lhs_line
rhs_line = self.rhs_line
if n is None:
pass
elif not n:
d = len(pre)
lhs_line += d
rhs_line += d
pre = ()
post = ()
else:
if len(pre) > n:
d = len(pre) - n
lhs_line += d
rhs_line += d
pre = pre[d:]
if len(post) > n:
post = post[:n]
return Hunk(
lhs_line = lhs_line,
rhs_line = rhs_line,
pre_context = pre,
lhs = lhs,
rhs = rhs,
post_context = post,
)
def join_hunks(hs_a, hs_b):
hs_a = iter(hs_a)
hs_b = iter(hs_b)
need_a = True
need_b = True
have_cur = False
def make_cur(n):
if not have_cur:
return None
return Hunk(
lhs_line = cur_lhs_line,
rhs_line = cur_rhs_line,
pre_context = (),
lhs = tuple(cur_lhs),
rhs = tuple(cur_rhs),
post_context = (),
).reduce_context(n)
delta = 0
r = []
while True:
if need_a:
need_a = False
a = next(hs_a, None)
if need_b:
need_b = False
b = next(hs_b, None)
if not have_cur:
if a is not None and (b is None or a.rhs_line <= b.lhs_line):
cur_lhs_line = a.lhs_line
cur_rhs_line = a.lhs_line + delta
cur_lhs = [*a.pre_context, *a.lhs, *a.post_context]
cur_rhs = [*a.pre_context, *a.rhs, *a.post_context]
need_a = True
elif b is not None:
cur_lhs_line = b.rhs_line - delta
cur_rhs_line = b.rhs_line
cur_lhs = [*b.pre_context, *b.lhs, *b.post_context]
cur_rhs = [*b.pre_context, *b.rhs, *b.post_context]
need_b = True
else:
return (*r,)
have_cur = True
continue
cur_lhs_end = cur_lhs_line + len(cur_lhs)
cur_rhs_end = cur_rhs_line + len(cur_rhs)
if a is not None and a.lhs_line <= cur_lhs_end:
lhs = [*a.pre_context, *a.lhs, *a.post_context]
rhs = [*a.pre_context, *a.rhs, *a.post_context]
i = a.lhs_line - cur_lhs_line
k = i + len(rhs)
if len(cur_lhs) >= k:
if cur_lhs[i:k] != rhs:
raise MismatchError(make_cur(None), a)
cur_lhs[i:k] = lhs
else:
j = len(cur_lhs) - i
if cur_lhs[i:] != rhs[:j]:
raise MismatchError(make_cur(None), a)
cur_lhs[i:] = lhs
cur_rhs.extend(rhs[j:])
need_a = True
continue
if b is not None and b.rhs_line <= cur_rhs_end:
lhs = [*b.pre_context, *b.lhs, *b.post_context]
rhs = [*b.pre_context, *b.rhs, *b.post_context]
i = b.rhs_line - cur_rhs_line
k = i + len(lhs)
if len(cur_rhs) >= k:
if cur_rhs[i:k] != lhs:
raise MismatchError(make_cur(None), b)
cur_rhs[i:k] = rhs
else:
j = len(cur_rhs) - i
if cur_rhs[i:] != lhs[:j]:
raise MismatchError(make_cur(None), b)
cur_rhs[i:] = rhs
cur_lhs.extend(lhs[j:])
need_b = True
continue
cur = make_cur(3)
have_cur = False
if cur is None:
continue
r.append(cur)
delta += len(cur.rhs) - len(cur.lhs)
def commute_hunks(hs_a, hs_b):
hs_a = iter(hs_a)
hs_b = iter(hs_b)
need_a = True
need_b = True
delta_a = 0
delta_b = 0
ra = []
rb = []
wa = ra.append
wb = rb.append
while True:
if need_a:
need_a = False
try:
a = next(hs_a)
except StopIteration:
if b is None and not need_b:
break
a = None
a_begin = INFINITY
a_pre = ()
a_post = ()
a_end = INFINITY
else:
a_pre = a.pre_context
a_post = a.post_context
a_begin = a.rhs_line + len(a_pre)
a_end = a_begin + len(a.rhs)
if need_b:
need_b = False
try:
b = next(hs_b)
except StopIteration:
if a is None:
break
b = None
b_begin = INFINITY
b_pre = ()
b_post = ()
b_end = INFINITY
else:
b_pre = b.pre_context
b_post = b.post_context
b_begin = b.lhs_line + len(b_pre)
b_end = b_begin + len(b.lhs)
if a_end <= b_begin:
gap = b_begin - a_end
if (n := len(a_post) + len(b_pre) - gap) > 0:
assert (a_pre + a.rhs + a_post)[-n:] == (b_pre + b.lhs + b_post)[:n]
if (n := len(a_post) - gap) > 0:
a_post = a_post[:gap] + (b.rhs + b_post)[:n]
if (n := len(b_pre) - gap) > 0:
b_pre = (a_pre + a.lhs)[-n:] + b_pre[n:]
wa(Hunk(
lhs_line = a.lhs_line + delta_b,
rhs_line = a.rhs_line + delta_b,
pre_context = a_pre,
lhs = a.lhs,
rhs = a.rhs,
post_context = a_post,
))
delta_a += len(a.rhs) - len(a.lhs)
need_a = True
continue
if b_end <= a_begin:
gap = a_begin - b_end
if (n := len(b_post) + len(a_pre) - gap) > 0:
assert (b_pre + b.lhs + b_post)[-n:] == (a_pre + a.rhs + a_post)[:n]
if (n := len(b_post) - gap) > 0:
b_post = b_post[:gap] + (a.lhs + a_post)[:n]
if (n := len(a_pre) - gap) > 0:
a_pre = (b_pre + b.rhs)[-n:] + a_pre[n:]
wb(Hunk(
lhs_line = b.lhs_line - delta_a,
rhs_line = b.rhs_line - delta_a,
pre_context = b_pre,
lhs = b.lhs,
rhs = b.rhs,
post_context = b_post,
))
delta_b += len(b.rhs) - len(b.lhs)
need_b = True
continue
raise ConflictError(a, b)
return (*rb,), (*ra,)
class FileDiff:
__slots__ = ('src_name', 'dst_name', 'src_mode', 'dst_mode', 'hunks',)
def __init__(self, /, *, src_name, dst_name, src_mode, dst_mode, hunks):
assert (src_mode is ...) == (dst_mode is ...)
assert (src_name is None) == (src_mode is None)
assert (dst_name is None) == (dst_mode is None)
self.src_name = src_name
self.dst_name = dst_name
if src_mode == dst_mode:
self.src_mode = ...
self.dst_mode = ...
else:
self.src_mode = src_mode
self.dst_mode = dst_mode
self.hunks = hunks
@property
def op_kind_name(self):
if self.src_name is None:
return 'create'
if self.dst_name is None:
return 'delete'
is_chmod = self.dst_mode != self.dst_mode
if self.dst_name != self.dst_name:
return 'chmod+rename' if is_chmod else 'rename'
if self.hunks:
return 'chmod+modify' if is_chmod else 'modify'
return 'chmod' if is_chmod else 'noop'
def write(self, w):
src_name = self.src_name
dst_name = self.dst_name
if src_name is None:
assert dst_name is not None
w(f'diff --git a/{dst_name} b/{dst_name}\n'.encode('utf-8'))
elif dst_name is None:
w(f'diff --git a/{src_name} b/{src_name}\n'.encode('utf-8'))
else:
w(f'diff --git a/{src_name} b/{dst_name}\n'.encode('utf-8'))
if src_name is not None and dst_name is not None and src_name != dst_name:
w(f'rename from {src_name}\nrename to {dst_name}\n'.encode('utf-8'))
m = self.src_mode
if dst_name is None:
assert m is not None
w(f'deleted file mode {m}\n'.encode('utf-8'))
elif m not in (None, ...):
w(f'old mode {m}\n'.encode('utf-8'))
m = self.dst_mode
if src_name is None:
assert m is not None
w(f'new file mode {m}\n'.encode('utf-8'))
elif m not in (None, ...):
w(f'new mode {m}\n'.encode('utf-8'))
h = self.hunks
if h:
if src_name is None:
src_name = '/dev/null'
lhs_line = 0
else:
src_name = f'a/{src_name}'
lhs_line = None
if dst_name is None:
dst_name = '/dev/null'
rhs_line = 0
else:
dst_name = f'b/{dst_name}'
rhs_line = None
w(f'--- {src_name}\n+++ {dst_name}\n'.encode('utf-8'))
for h in h:
h.write(w, lhs_line=lhs_line, rhs_line=rhs_line)
def __neg__(self, /):
tp = type(self)
class Diff:
__slots__ = ('files',)
def __init__(self):
self.files = ()
def __bool__(self):
return not not self.files
@classmethod
def _sort_files(tp, files):
r = object.__new__(tp)
r.files = (*sorted(files, key=lambda f: (f.src_name or '', f.dst_name or '')),)
return r
@classmethod
def parse_l(tp, l, line):
files = []
while True:
l = l.decode('utf-8').rstrip()
if l == '--' or not l:
break
m = rem_diffgit(l)
if m is None:
raise ValueError(f'Invalid patch: expected \'diff --git ...\' header, got {l!r}')
header = l
src_name,dst_name = m.groups()
src_exists = ...
dst_exists = ...
if not src_exists and not dst_exists:
raise RuntimeError('wtf')
if (not src_exists or not dst_exists) and src_name != dst_name:
raise RuntimeError('wtf')
has_diff = False
has_chmod = False
has_rename = False
dst_mode = ...
src_mode = ...
while True:
l = line()
if l.startswith(b'--- a/'):
if src_exists is False: raise RuntimeError('wtf')
src_exists = True
if l[6:].decode('utf-8') != src_name:
raise RuntimeError('wtf')
has_diff = True
continue
if l == b'--- /dev/null':
if src_exists is True: raise RuntimeError('wtf')
src_exists = False
has_diff = True
continue
if l.startswith(b'+++ b/'):
if dst_exists is False: raise RuntimeError('wtf')
dst_exists = True
if l[6:].decode('utf-8') != dst_name:
raise RuntimeError('wtf')
has_diff = True
continue
if l == b'+++ /dev/null':
if dst_exists is True: raise RuntimeError('wtf')
dst_exists = False
has_diff = True
continue
if l.startswith(b'index '):
continue
if l.startswith(b'new file mode '):
if src_exists is True: raise RuntimeError('wtf')
if dst_exists is False: raise RuntimeError('wtf')
src_exists = False
dst_exists=True
dst_mode = l[14:].decode('ascii')
src_mode = None
continue
if l.startswith(b'new mode '):
dst_mode = l[9:].decode('ascii')
has_chmod = True
continue
if l.startswith(b'deleted file mode '):
if src_exists is False: raise RuntimeError('wtf')
if dst_exists is True: raise RuntimeError('wtf')
src_exists = True
dst_exists = False
src_mode = l[18:].decode('ascii')
dst_mode = None
continue
if l.startswith(b'old mode '):
if dst_exists is False:
raise RuntimeError('File deleted')
src_mode = l[9:].decode('ascii')
has_chmod = True
dst_exists = True
continue
if l.startswith(b'similarity index '):
has_rename = True
continue
if l.startswith(b'rename from '):
has_rename = True
rename_from = l[12:].decode('utf-8')
continue
if l.startswith(b'rename to '):
rename_to = l[10:].decode('utf-8')
has_rename = True
continue
break
if has_chmod or has_rename:
if src_exists is False: raise RuntimeError('wtf')
if dst_exists is False: raise RuntimeError('wtf')
src_exists = True
dst_exists = True
if has_chmod:
assert src_mode is not ...
assert dst_mode is not ...
assert src_exists is not ...
if not src_exists:
assert dst_mode is not ...
assert dst_exists is not ...
if not dst_exists:
assert src_mode is not ...
if has_rename:
if src_name == dst_name or not src_exists or not dst_exists:
raise RuntimeError(f'Bad rename: {src_name!r} -> {dst_name!r}')
assert rename_from == src_name
assert rename_to == dst_name
if not src_exists: src_name = None
if not dst_exists: dst_name = None
assert (src_mode is ...) == (dst_mode is ...)
assert has_diff or has_chmod or has_rename or not dst_exists
#want_header = fmt_diff_git(src_name, dst_name)
#if header != want_header:
# raise ValueError(f'Invalid patch: expected {want_header!r} header, got {header!r}')
if has_diff:
hunks = []
while True:
l,hunk = Hunk.parse_l(l, line)
hunk = hunk.reduce_context()
if hunk is not None:
hunks.append(hunk)
if not l.startswith(b'@@ '):
break
if src_name is None:
assert len(hunks) == 1
assert len(hunks[0].pre_context) == 0
assert len(hunks[0].post_context) == 0
assert len(hunks[0].lhs) == 0
assert hunks[0].lhs_line == 1
hunks[0].lhs_line = 1
if dst_name is None:
assert len(hunks) == 1
assert len(hunks[0].pre_context) == 0
assert len(hunks[0].post_context) == 0
assert len(hunks[0].rhs) == 0
assert hunks[0].rhs_line == 1
hunks = (*hunks,)
else:
hunks = ()
files.append(FileDiff(
src_name = src_name,
dst_name = dst_name,
src_mode = src_mode,
dst_mode = dst_mode,
hunks = hunks,
))
try:
l = line().rstrip()
except StopIteration:
pass
else:
if l[0] not in b'0123456789':
l = l.decode('utf-8')
raise ValueError('Invalid patch: expected git version, got {l!r}')
while True:
try:
l = line()
except StopIteration:
break
if l.strip():
l = l.decode('utf-8')
raise ValueError('Invalid patch: got trailing garbage {l!r}')
return tp._sort_files(files)
@classmethod
def parse(tp, body):
line = iter(body.split(b'\n')).__next__
return tp.parse_l(line(), line)
def write(self, w, /):
for f in self.files:
f.write(w)
def write_munged(self, w, /):
for f in self.files:
w((f.src_name or '').encode('utf-8'))
w(b'\n')
w((f.dst_name or '').encode('utf-8'))
w(b'\n')
for h in f.hunks:
h.write(w, True)
w(b'\n')
def __neg__(self):
return Diff._sort_files((*(FileDiff(
src_name = f.dst_name,
dst_name = f.src_name,
src_mode = f.dst_mode,
dst_mode = f.src_mode,
hunks = (*(Hunk(
lhs_line = h.rhs_line,
rhs_line = h.lhs_line,
pre_context = h.pre_context,
lhs = h.rhs,
rhs = h.lhs,
post_context = h.post_context,
) for h in f.hunks),),
) for f in self.files),))
def __add__(a, b):
tp = type(a)
if type(b) is not tp:
raise TypeError
r = []
w = r.append
fdiff = FileDiff
def join(fa, fb):
src_name = fa.src_name
dst_name = fb.dst_name
src_mode = fa.src_mode
dst_mode = fb.dst_mode
if src_mode is ...:
if dst_mode is not ...:
src_mode = fb.src_mode
elif dst_mode is ...:
dst_mode = fa.dst_mode
elif fa.dst_mode != fb.src_mode:
raise ValueError('mismatch')
hunks = (*join_hunks(fa.hunks, fb.hunks),)
if src_name is None and dst_name is None:
if hunks:
raise ValueError('mismatch')
return
if src_name == dst_name and src_mode == dst_mode and not hunks:
return
w(fdiff(
src_name = src_name,
dst_name = dst_name,
src_mode = src_mode,
dst_mode = dst_mode,
hunks = hunks,
))
_pair_diffs(a, b, on_lhs=w, on_rhs=w, on_pair=join, on_re=join)
return tp._sort_files(r)
@classmethod
def commute(tp, a, b):
if type(a) is not tp or type(b) is not tp:
raise TypeError
ra = []
rb = []
wa = ra.append
wb = rb.append
fdiff = FileDiff
def join(fa, fb):
name1 = fa.src_name
name2 = fa.dst_name
name3 = fb.dst_name
if name1 == name2:
name2 = name3
elif name2 == name3:
name2 = name1
else:
raise ValueError('cannot commute {name2!r}: {fa.op_kind_name} / {fb.op_kind_name}')
mode1 = fa.src_mode
mode3 = fb.dst_mode
if mode1 is ...:
if mode3 is ...:
mode2 = ...
else:
mode1 = fb.src_mode
mode2 = mode3
elif mode3 is ...:
mode3 = fa.dst_mode
mode2 = mode1
else:
mode2 = fb.dst_mode
if fa.src_mode != mode2:
raise ValueError('mismatch')
if mode1 == mode2:
mode2 = mode3
elif mode2 == mode3:
mode2 = mode1
else:
raise ValueError('cannot commute {name2!r}: chmod/chmod conflict')
try:
hb,ha = commute_hunks(fa.hunks, fb.hunks)
except PatchError as err:
err.args = (name1, name2, name3)
raise
wa(fdiff(
src_name = name1,
dst_name = name2,
src_mode = mode1,
dst_mode = mode2,
hunks = ha,
))
wb(fdiff(
src_name = name2,
dst_name = name3,
src_mode = mode2,
dst_mode = mode3,
hunks = hb,
))
_pair_diffs(a, b, on_lhs=wa, on_rhs=wb, on_re=join, on_pair=join)
return tp._sort_files(rb), tp._sort_files(ra)
def _pair_diffs(a, b, *, on_lhs, on_rhs, on_pair, on_re):
a_del = {}
b_new = {}
a_dst = {}
b_dst = set()
for f in a.files:
dst = f.dst_name
if dst is None:
a_del[f.src_name] = f
continue
a_dst[dst] = f
a = a_dst.pop
for f in b.files:
dst = f.dst_name
if dst is not None:
b_dst.add(dst)
src = f.src_name
if src is None:
b_new[dst] = f
continue
try:
f2 = a(src)
except KeyError:
pass
else:
on_pair(f2, f)
continue
on_rhs(f)
for dst in a_dst.keys() & b_dst:
raise ValueError(f'mismatch for {dst!r}')
b = b_new.pop
for src,f in a_del.items():
try:
f2 = b(src)
except KeyError:
pass
else:
on_re(f, f2)
continue
on_lhs(f)
for f in b_new.values():
on_rhs(f)
for f in a_dst.values():
on_lhs(f)