summaryrefslogtreecommitdiff
path: root/patchtheory.py
diff options
context:
space:
mode:
Diffstat (limited to 'patchtheory.py')
-rw-r--r--patchtheory.py960
1 files changed, 960 insertions, 0 deletions
diff --git a/patchtheory.py b/patchtheory.py
new file mode 100644
index 0000000..21b3e29
--- /dev/null
+++ b/patchtheory.py
@@ -0,0 +1,960 @@
+import re, operator, io, sys
+
+rem_range = re.compile(b'@@ -([0-9]+)(?:,([0-9]+))? \\+([0-9]+)(?:,([0-9]+))? @@(?: .*)?').fullmatch
+rem_diffgit = re.compile('diff --git a/([^ ]*) b/([^ ]*)').fullmatch
+
+INFINITY = 2**64
+
+class PatchError(Exception):
+ __slots__ = ('a', 'b')
+ def __init__(self, a, b, /, fn=None):
+ super().__init__(fn)
+ self.a = a
+ self.b = b
+
+ def __str__(self, /):
+ fn = self.args[0]
+ r = [f' in {fn!r}:' if fn is not None else ':']
+
+ def tabmangle(p):
+ o = io.BytesIO()
+ p.write(o.write)
+ p = o.getvalue().decode('utf-8', 'replace').splitlines()
+ return [' ' + l for l in p]
+
+ r.append('')
+ r += tabmangle(self.a)
+
+ r.append('')
+ r += tabmangle(self.b)
+
+ return '\n'.join(r)
+
+class ConflictError(PatchError):
+ __slots__ = ()
+
+ def __str__(self, /):
+ return 'Conflict' + PatchError.__str__(self)
+
+class MismatchError(PatchError):
+ __slots__ = ()
+
+ def __str__(self, /):
+ return 'Mismatch' + PatchError.__str__(self)
+
+class Hunk:
+ __slots__ = ('lhs_line', 'rhs_line', 'pre_context', 'lhs', 'rhs', 'post_context')
+
+ def __init__(self, /, *, lhs_line, rhs_line, pre_context, lhs, rhs, post_context):
+ assert lhs_line > 0
+ assert rhs_line > 0
+
+ self.lhs_line = lhs_line
+ self.rhs_line = rhs_line
+ self.pre_context = pre_context
+ self.lhs = lhs
+ self.rhs = rhs
+ self.post_context = post_context
+
+ @property
+ def lhs_size(self, /):
+ return len(self.pre_context) + len(self.lhs) + len(self.post_context)
+
+ @property
+ def rhs_size(self, /):
+ return len(self.pre_context) + len(self.rhs) + len(self.post_context)
+
+ def write(self, w, /, munge=False, *, lhs_line=None, rhs_line=None, lhs_nl=True, rhs_nl=True):
+ if lhs_line is None:
+ lhs_line = self.lhs_line
+ if rhs_line is None:
+ rhs_line = self.rhs_line
+ pre_context = self.pre_context
+ lhs = self.lhs
+ rhs = self.rhs
+ post_context = self.post_context
+ clen = len(pre_context) + len(post_context)
+ llen = len(lhs) + clen
+ rlen = len(rhs) + clen
+
+ if munge:
+ w(f'@@ -xxx,{llen} +yyy,{rlen} @@'.encode('utf-8'))
+ else:
+ w(f'@@ -{lhs_line},{llen} +{rhs_line},{rlen} @@'.encode('utf-8'))
+
+ for l in pre_context:
+ w(b'\n ')
+ w(l)
+
+ for l in lhs:
+ w(b'\n-')
+ w(l)
+
+ for l in rhs:
+ w(b'\n+')
+ w(l)
+
+ for l in post_context:
+ w(b'\n ')
+ w(l)
+
+ w(b'\n')
+
+ @classmethod
+ def parse_l(tp, l, line, /):
+ m = rem_range(l)
+ if m is None:
+ raise ValueError(f'Invalid patch: bad range info {l!r}')
+ lhs_line,n2,rhs_line,n4 = m.groups()
+ lhs_line = 1 if lhs_line == b'0' else int(lhs_line)
+ rhs_line = 1 if rhs_line == b'0' else int(rhs_line)
+ n2 = int(n2) if n2 else 1
+ n4 = int(n4) if n4 else 1
+
+ hunk = []
+ prev = -1
+ lhs_nl = True
+ rhs_nl = True
+ post_nl = True
+ while True:
+ l = line()
+ if l.startswith(b'\\ '):
+ if prev == 0:
+ lhs_nl = False
+ elif prev == 1:
+ rhs_nl = False
+ elif prev == -1:
+ post_nl = False
+ else:
+ raise ValueError('Invalid patch: bad "No newline at end of file" flag')
+ prev = -2
+ continue
+ if not n2 and not n4:
+ break
+ k = l[0]
+ if k == b' '[0]:
+ if not post_nl or not lhs_nl or not rhs_nl:
+ raise ValueError('Invalid patch: bad "No newline at end of file" flag')
+ if not n2 or not n4:
+ raise ValueError('Invalid patch: bad line count')
+ n2 -= 1
+ n4 -= 1
+ prev = -1
+ elif k == b'-'[0]:
+ if not post_nl or not lhs_nl:
+ raise ValueError('Invalid patch: bad "No newline at end of file" flag')
+ if not n2:
+ raise ValueError('Invalid patch: bad line count')
+ n2 -= 1
+ prev = 0
+ elif k == b'+'[0]:
+ if not post_nl or not rhs_nl:
+ raise ValueError('Invalid patch: bad "No newline at end of file" flag')
+ if not n4:
+ raise ValueError('Invalid patch: bad line count')
+ n4 -= 1
+ prev = 1
+ else:
+ raise ValueError(f'Invalid patch: bad prefix character {chr(k)!r}')
+ assert False
+ hunk.append(l)
+ lret = l
+
+ strip1 = operator.itemgetter(slice(1,None))
+
+ for i,l in enumerate(hunk):
+ if l[0] != b' '[0]:
+ break
+ else:
+ raise ValueError('Invalid patch: no changed lines in hunk')
+ pre_context = (*map(strip1, hunk[:i]),)
+ del hunk[:i]
+
+ j = -1
+ while True:
+ l = hunk[j]
+ if l[0] != b' '[0]:
+ break
+ j -= 1
+ j += len(hunk) + 1
+ if not post_nl:
+ hunk[j] += b'\n\\ No newline at end of file'
+ post_context = (*map(strip1, hunk[j:]),)
+ del hunk[j:]
+
+ lhs = []
+ rhs = []
+ for l in hunk:
+ l0 = l[0]
+ if l0 != b'+'[0]:
+ lhs.append(l)
+ if l0 != b'-'[0]:
+ rhs.append(l)
+ if not lhs_nl:
+ lhs[-1] += b'\n\\ No newline at end of file'
+ if not rhs_nl:
+ rhs[-1] += b'\n\\ No newline at end of file'
+
+ return lret, tp(
+ lhs_line = lhs_line,
+ rhs_line = rhs_line,
+ pre_context = pre_context,
+ lhs = (*map(strip1, lhs),),
+ rhs = (*map(strip1, rhs),),
+ post_context = post_context,
+ )
+
+ def reduce_context(self, n=3):
+ pre = self.pre_context
+ lhs = self.lhs
+ rhs = self.rhs
+ post = self.post_context
+
+ i = 0
+ j = -1
+ m = min(len(lhs), len(rhs))
+ while i - j <= m:
+ if lhs[i] == rhs[i]:
+ i += 1
+ if i - j > m:
+ break
+ if lhs[j] == rhs[j]:
+ j -= 1
+ else:
+ if lhs[j] != rhs[j]:
+ break
+ j -= 1
+ j += 1
+ jl = len(lhs) + j
+ jr = len(rhs) + j
+
+ assert lhs[:i] == rhs[:i]
+ assert lhs[jl:] == rhs[jr:]
+ pre = pre + lhs[:i]
+ post = lhs[jl:] + post
+ lhs = lhs[i:jl]
+ rhs = rhs[i:jr]
+
+ if not lhs and not rhs:
+ return None
+
+ assert lhs != rhs
+ assert not lhs or not rhs or lhs[0] != rhs[0] and lhs[-1] != rhs[-1]
+
+ lhs_line = self.lhs_line
+ rhs_line = self.rhs_line
+
+ if n is None:
+ pass
+ elif not n:
+ d = len(pre)
+ lhs_line += d
+ rhs_line += d
+ pre = ()
+ post = ()
+ else:
+ if len(pre) > n:
+ d = len(pre) - n
+ lhs_line += d
+ rhs_line += d
+ pre = pre[d:]
+
+ if len(post) > n:
+ post = post[:n]
+
+ return Hunk(
+ lhs_line = lhs_line,
+ rhs_line = rhs_line,
+ pre_context = pre,
+ lhs = lhs,
+ rhs = rhs,
+ post_context = post,
+ )
+
+def join_hunks(hs_a, hs_b):
+ hs_a = iter(hs_a)
+ hs_b = iter(hs_b)
+
+ need_a = True
+ need_b = True
+ have_cur = False
+
+ def make_cur(n):
+ if not have_cur:
+ return None
+ return Hunk(
+ lhs_line = cur_lhs_line,
+ rhs_line = cur_rhs_line,
+ pre_context = (),
+ lhs = tuple(cur_lhs),
+ rhs = tuple(cur_rhs),
+ post_context = (),
+ ).reduce_context(n)
+
+ delta = 0
+ r = []
+
+ while True:
+ if need_a:
+ need_a = False
+ a = next(hs_a, None)
+
+ if need_b:
+ need_b = False
+ b = next(hs_b, None)
+
+ if not have_cur:
+ if a is not None and (b is None or a.rhs_line <= b.lhs_line):
+ cur_lhs_line = a.lhs_line
+ cur_rhs_line = a.lhs_line + delta
+ cur_lhs = [*a.pre_context, *a.lhs, *a.post_context]
+ cur_rhs = [*a.pre_context, *a.rhs, *a.post_context]
+ need_a = True
+ elif b is not None:
+ cur_lhs_line = b.rhs_line - delta
+ cur_rhs_line = b.rhs_line
+ cur_lhs = [*b.pre_context, *b.lhs, *b.post_context]
+ cur_rhs = [*b.pre_context, *b.rhs, *b.post_context]
+ need_b = True
+ else:
+ return (*r,)
+ have_cur = True
+ continue
+
+ cur_lhs_end = cur_lhs_line + len(cur_lhs)
+ cur_rhs_end = cur_rhs_line + len(cur_rhs)
+
+ if a is not None and a.lhs_line <= cur_lhs_end:
+ lhs = [*a.pre_context, *a.lhs, *a.post_context]
+ rhs = [*a.pre_context, *a.rhs, *a.post_context]
+
+ i = a.lhs_line - cur_lhs_line
+ k = i + len(rhs)
+
+ if len(cur_lhs) >= k:
+ if cur_lhs[i:k] != rhs:
+ raise MismatchError(make_cur(None), a)
+ cur_lhs[i:k] = lhs
+ else:
+ j = len(cur_lhs) - i
+ if cur_lhs[i:] != rhs[:j]:
+ raise MismatchError(make_cur(None), a)
+ cur_lhs[i:] = lhs
+ cur_rhs.extend(rhs[j:])
+ need_a = True
+ continue
+
+ if b is not None and b.rhs_line <= cur_rhs_end:
+ lhs = [*b.pre_context, *b.lhs, *b.post_context]
+ rhs = [*b.pre_context, *b.rhs, *b.post_context]
+ i = b.rhs_line - cur_rhs_line
+ k = i + len(lhs)
+
+ if len(cur_rhs) >= k:
+ if cur_rhs[i:k] != lhs:
+ raise MismatchError(make_cur(None), b)
+ cur_rhs[i:k] = rhs
+ else:
+ j = len(cur_rhs) - i
+ if cur_rhs[i:] != lhs[:j]:
+ raise MismatchError(make_cur(None), b)
+ cur_rhs[i:] = rhs
+ cur_lhs.extend(lhs[j:])
+ need_b = True
+ continue
+
+ cur = make_cur(3)
+ have_cur = False
+ if cur is None:
+ continue
+ r.append(cur)
+ delta += len(cur.rhs) - len(cur.lhs)
+
+def commute_hunks(hs_a, hs_b):
+ hs_a = iter(hs_a)
+ hs_b = iter(hs_b)
+
+ need_a = True
+ need_b = True
+
+ delta_a = 0
+ delta_b = 0
+ ra = []
+ rb = []
+ wa = ra.append
+ wb = rb.append
+
+ while True:
+ if need_a:
+ need_a = False
+ try:
+ a = next(hs_a)
+ except StopIteration:
+ if b is None and not need_b:
+ break
+ a = None
+ a_begin = INFINITY
+ a_pre = ()
+ a_post = ()
+ a_end = INFINITY
+ else:
+ a_pre = a.pre_context
+ a_post = a.post_context
+ a_begin = a.rhs_line + len(a_pre)
+ a_end = a_begin + len(a.rhs)
+
+ if need_b:
+ need_b = False
+ try:
+ b = next(hs_b)
+ except StopIteration:
+ if a is None:
+ break
+ b = None
+ b_begin = INFINITY
+ b_pre = ()
+ b_post = ()
+ b_end = INFINITY
+ else:
+ b_pre = b.pre_context
+ b_post = b.post_context
+ b_begin = b.lhs_line + len(b_pre)
+ b_end = b_begin + len(b.lhs)
+
+ if a_end <= b_begin:
+ gap = b_begin - a_end
+
+ if (n := len(a_post) + len(b_pre) - gap) > 0:
+ assert (a_pre + a.rhs + a_post)[-n:] == (b_pre + b.lhs + b_post)[:n]
+ if (n := len(a_post) - gap) > 0:
+ a_post = a_post[:gap] + (b.rhs + b_post)[:n]
+ if (n := len(b_pre) - gap) > 0:
+ b_pre = (a_pre + a.lhs)[-n:] + b_pre[n:]
+
+ wa(Hunk(
+ lhs_line = a.lhs_line + delta_b,
+ rhs_line = a.rhs_line + delta_b,
+ pre_context = a_pre,
+ lhs = a.lhs,
+ rhs = a.rhs,
+ post_context = a_post,
+ ))
+ delta_a += len(a.rhs) - len(a.lhs)
+ need_a = True
+ continue
+
+ if b_end <= a_begin:
+ gap = a_begin - b_end
+
+ if (n := len(b_post) + len(a_pre) - gap) > 0:
+ assert (b_pre + b.lhs + b_post)[-n:] == (a_pre + a.rhs + a_post)[:n]
+ if (n := len(b_post) - gap) > 0:
+ b_post = b_post[:gap] + (a.lhs + a_post)[:n]
+ if (n := len(a_pre) - gap) > 0:
+ a_pre = (b_pre + b.rhs)[-n:] + a_pre[n:]
+
+ wb(Hunk(
+ lhs_line = b.lhs_line - delta_a,
+ rhs_line = b.rhs_line - delta_a,
+ pre_context = b_pre,
+ lhs = b.lhs,
+ rhs = b.rhs,
+ post_context = b_post,
+ ))
+ delta_b += len(b.rhs) - len(b.lhs)
+ need_b = True
+ continue
+
+ raise ConflictError(a, b)
+
+ return (*rb,), (*ra,)
+
+class FileDiff:
+ __slots__ = ('src_name', 'dst_name', 'src_mode', 'dst_mode', 'hunks',)
+
+ def __init__(self, /, *, src_name, dst_name, src_mode, dst_mode, hunks):
+ assert (src_mode is ...) == (dst_mode is ...)
+ assert (src_name is None) == (src_mode is None)
+ assert (dst_name is None) == (dst_mode is None)
+
+ self.src_name = src_name
+ self.dst_name = dst_name
+ if src_mode == dst_mode:
+ self.src_mode = ...
+ self.dst_mode = ...
+ else:
+ self.src_mode = src_mode
+ self.dst_mode = dst_mode
+ self.hunks = hunks
+
+ @property
+ def op_kind_name(self):
+ if self.src_name is None:
+ return 'create'
+ if self.dst_name is None:
+ return 'delete'
+ is_chmod = self.dst_mode != self.dst_mode
+ if self.dst_name != self.dst_name:
+ return 'chmod+rename' if is_chmod else 'rename'
+ if self.hunks:
+ return 'chmod+modify' if is_chmod else 'modify'
+ return 'chmod' if is_chmod else 'noop'
+
+ def write(self, w):
+ src_name = self.src_name
+ dst_name = self.dst_name
+
+ if src_name is None:
+ assert dst_name is not None
+ w(f'diff --git a/{dst_name} b/{dst_name}\n'.encode('utf-8'))
+ elif dst_name is None:
+ w(f'diff --git a/{src_name} b/{src_name}\n'.encode('utf-8'))
+ else:
+ w(f'diff --git a/{src_name} b/{dst_name}\n'.encode('utf-8'))
+
+ if src_name is not None and dst_name is not None and src_name != dst_name:
+ w(f'rename from {src_name}\nrename to {dst_name}\n'.encode('utf-8'))
+
+ m = self.src_mode
+ if dst_name is None:
+ assert m is not None
+ w(f'deleted file mode {m}\n'.encode('utf-8'))
+ elif m not in (None, ...):
+ w(f'old mode {m}\n'.encode('utf-8'))
+
+ m = self.dst_mode
+ if src_name is None:
+ assert m is not None
+ w(f'new file mode {m}\n'.encode('utf-8'))
+ elif m not in (None, ...):
+ w(f'new mode {m}\n'.encode('utf-8'))
+
+ h = self.hunks
+ if h:
+ if src_name is None:
+ src_name = '/dev/null'
+ lhs_line = 0
+ else:
+ src_name = f'a/{src_name}'
+ lhs_line = None
+ if dst_name is None:
+ dst_name = '/dev/null'
+ rhs_line = 0
+ else:
+ dst_name = f'b/{dst_name}'
+ rhs_line = None
+
+ w(f'--- {src_name}\n+++ {dst_name}\n'.encode('utf-8'))
+
+ for h in h:
+ h.write(w, lhs_line=lhs_line, rhs_line=rhs_line)
+
+ def __neg__(self, /):
+ tp = type(self)
+
+class Diff:
+ __slots__ = ('files',)
+
+ def __init__(self):
+ self.files = ()
+
+ def __bool__(self):
+ return not not self.files
+
+ @classmethod
+ def _sort_files(tp, files):
+ r = object.__new__(tp)
+ r.files = (*sorted(files, key=lambda f: (f.src_name or '', f.dst_name or '')),)
+ return r
+
+ @classmethod
+ def parse_l(tp, l, line):
+ files = []
+
+ while True:
+ l = l.decode('utf-8').rstrip()
+ if l == '--' or not l:
+ break
+ m = rem_diffgit(l)
+ if m is None:
+ raise ValueError(f'Invalid patch: expected \'diff --git ...\' header, got {l!r}')
+
+ header = l
+
+ src_name,dst_name = m.groups()
+ src_exists = ...
+ dst_exists = ...
+
+ if not src_exists and not dst_exists:
+ raise RuntimeError('wtf')
+ if (not src_exists or not dst_exists) and src_name != dst_name:
+ raise RuntimeError('wtf')
+
+ has_diff = False
+ has_chmod = False
+ has_rename = False
+ dst_mode = ...
+ src_mode = ...
+
+ while True:
+ l = line()
+ if l.startswith(b'--- a/'):
+ if src_exists is False: raise RuntimeError('wtf')
+ src_exists = True
+ if l[6:].decode('utf-8') != src_name:
+ raise RuntimeError('wtf')
+ has_diff = True
+ continue
+ if l == b'--- /dev/null':
+ if src_exists is True: raise RuntimeError('wtf')
+ src_exists = False
+ has_diff = True
+ continue
+ if l.startswith(b'+++ b/'):
+ if dst_exists is False: raise RuntimeError('wtf')
+ dst_exists = True
+ if l[6:].decode('utf-8') != dst_name:
+ raise RuntimeError('wtf')
+ has_diff = True
+ continue
+ if l == b'+++ /dev/null':
+ if dst_exists is True: raise RuntimeError('wtf')
+ dst_exists = False
+ has_diff = True
+ continue
+ if l.startswith(b'index '):
+ continue
+ if l.startswith(b'new file mode '):
+ if src_exists is True: raise RuntimeError('wtf')
+ if dst_exists is False: raise RuntimeError('wtf')
+ src_exists = False
+ dst_exists=True
+ dst_mode = l[14:].decode('ascii')
+ src_mode = None
+ continue
+ if l.startswith(b'new mode '):
+ dst_mode = l[9:].decode('ascii')
+ has_chmod = True
+ continue
+ if l.startswith(b'deleted file mode '):
+ if src_exists is False: raise RuntimeError('wtf')
+ if dst_exists is True: raise RuntimeError('wtf')
+ src_exists = True
+ dst_exists = False
+ src_mode = l[18:].decode('ascii')
+ dst_mode = None
+ continue
+ if l.startswith(b'old mode '):
+ if dst_exists is False:
+ raise RuntimeError('File deleted')
+ src_mode = l[9:].decode('ascii')
+ has_chmod = True
+ dst_exists = True
+ continue
+ if l.startswith(b'similarity index '):
+ has_rename = True
+ continue
+ if l.startswith(b'rename from '):
+ has_rename = True
+ rename_from = l[12:].decode('utf-8')
+ continue
+ if l.startswith(b'rename to '):
+ rename_to = l[10:].decode('utf-8')
+ has_rename = True
+ continue
+ break
+
+ if has_chmod or has_rename:
+ if src_exists is False: raise RuntimeError('wtf')
+ if dst_exists is False: raise RuntimeError('wtf')
+ src_exists = True
+ dst_exists = True
+
+ if has_chmod:
+ assert src_mode is not ...
+ assert dst_mode is not ...
+
+ assert src_exists is not ...
+ if not src_exists:
+ assert dst_mode is not ...
+
+ assert dst_exists is not ...
+ if not dst_exists:
+ assert src_mode is not ...
+
+ if has_rename:
+ if src_name == dst_name or not src_exists or not dst_exists:
+ raise RuntimeError(f'Bad rename: {src_name!r} -> {dst_name!r}')
+ assert rename_from == src_name
+ assert rename_to == dst_name
+
+ if not src_exists: src_name = None
+ if not dst_exists: dst_name = None
+
+ assert (src_mode is ...) == (dst_mode is ...)
+ assert has_diff or has_chmod or has_rename or not dst_exists
+
+ #want_header = fmt_diff_git(src_name, dst_name)
+ #if header != want_header:
+ # raise ValueError(f'Invalid patch: expected {want_header!r} header, got {header!r}')
+
+ if has_diff:
+ hunks = []
+
+ while True:
+ l,hunk = Hunk.parse_l(l, line)
+ hunk = hunk.reduce_context()
+ if hunk is not None:
+ hunks.append(hunk)
+ if not l.startswith(b'@@ '):
+ break
+
+ if src_name is None:
+ assert len(hunks) == 1
+ assert len(hunks[0].pre_context) == 0
+ assert len(hunks[0].post_context) == 0
+ assert len(hunks[0].lhs) == 0
+ assert hunks[0].lhs_line == 1
+ hunks[0].lhs_line = 1
+
+ if dst_name is None:
+ assert len(hunks) == 1
+ assert len(hunks[0].pre_context) == 0
+ assert len(hunks[0].post_context) == 0
+ assert len(hunks[0].rhs) == 0
+ assert hunks[0].rhs_line == 1
+
+ hunks = (*hunks,)
+
+ else:
+ hunks = ()
+
+ files.append(FileDiff(
+ src_name = src_name,
+ dst_name = dst_name,
+ src_mode = src_mode,
+ dst_mode = dst_mode,
+ hunks = hunks,
+ ))
+
+ try:
+ l = line().rstrip()
+ except StopIteration:
+ pass
+ else:
+ if l[0] not in b'0123456789':
+ l = l.decode('utf-8')
+ raise ValueError('Invalid patch: expected git version, got {l!r}')
+
+
+ while True:
+ try:
+ l = line()
+ except StopIteration:
+ break
+ if l.strip():
+ l = l.decode('utf-8')
+ raise ValueError('Invalid patch: got trailing garbage {l!r}')
+
+ return tp._sort_files(files)
+
+ @classmethod
+ def parse(tp, body):
+ line = iter(body.split(b'\n')).__next__
+ return tp.parse_l(line(), line)
+
+ def write(self, w, /):
+ for f in self.files:
+ f.write(w)
+
+ def write_munged(self, w, /):
+ for f in self.files:
+ w((f.src_name or '').encode('utf-8'))
+ w(b'\n')
+ w((f.dst_name or '').encode('utf-8'))
+ w(b'\n')
+ for h in f.hunks:
+ h.write(w, True)
+ w(b'\n')
+
+ def __neg__(self):
+ return Diff._sort_files((*(FileDiff(
+ src_name = f.dst_name,
+ dst_name = f.src_name,
+ src_mode = f.dst_mode,
+ dst_mode = f.src_mode,
+ hunks = (*(Hunk(
+ lhs_line = h.rhs_line,
+ rhs_line = h.lhs_line,
+ pre_context = h.pre_context,
+ lhs = h.rhs,
+ rhs = h.lhs,
+ post_context = h.post_context,
+ ) for h in f.hunks),),
+ ) for f in self.files),))
+
+ def __add__(a, b):
+ tp = type(a)
+ if type(b) is not tp:
+ raise TypeError
+
+ r = []
+ w = r.append
+
+ fdiff = FileDiff
+ def join(fa, fb):
+ src_name = fa.src_name
+ dst_name = fb.dst_name
+ src_mode = fa.src_mode
+ dst_mode = fb.dst_mode
+ if src_mode is ...:
+ if dst_mode is not ...:
+ src_mode = fb.src_mode
+ elif dst_mode is ...:
+ dst_mode = fa.dst_mode
+ elif fa.dst_mode != fb.src_mode:
+ raise ValueError('mismatch')
+
+ hunks = (*join_hunks(fa.hunks, fb.hunks),)
+
+ if src_name is None and dst_name is None:
+ if hunks:
+ raise ValueError('mismatch')
+ return
+
+ if src_name == dst_name and src_mode == dst_mode and not hunks:
+ return
+
+ w(fdiff(
+ src_name = src_name,
+ dst_name = dst_name,
+ src_mode = src_mode,
+ dst_mode = dst_mode,
+ hunks = hunks,
+ ))
+
+ _pair_diffs(a, b, on_lhs=w, on_rhs=w, on_pair=join, on_re=join)
+ return tp._sort_files(r)
+
+ @classmethod
+ def commute(tp, a, b):
+ if type(a) is not tp or type(b) is not tp:
+ raise TypeError
+
+ ra = []
+ rb = []
+ wa = ra.append
+ wb = rb.append
+ fdiff = FileDiff
+
+ def join(fa, fb):
+ name1 = fa.src_name
+ name2 = fa.dst_name
+ name3 = fb.dst_name
+ if name1 == name2:
+ name2 = name3
+ elif name2 == name3:
+ name2 = name1
+ else:
+ raise ValueError('cannot commute {name2!r}: {fa.op_kind_name} / {fb.op_kind_name}')
+
+ mode1 = fa.src_mode
+ mode3 = fb.dst_mode
+ if mode1 is ...:
+ if mode3 is ...:
+ mode2 = ...
+ else:
+ mode1 = fb.src_mode
+ mode2 = mode3
+
+ elif mode3 is ...:
+ mode3 = fa.dst_mode
+ mode2 = mode1
+
+ else:
+ mode2 = fb.dst_mode
+ if fa.src_mode != mode2:
+ raise ValueError('mismatch')
+ if mode1 == mode2:
+ mode2 = mode3
+ elif mode2 == mode3:
+ mode2 = mode1
+ else:
+ raise ValueError('cannot commute {name2!r}: chmod/chmod conflict')
+
+ try:
+ hb,ha = commute_hunks(fa.hunks, fb.hunks)
+ except PatchError as err:
+ err.args = (name1, name2, name3)
+ raise
+
+ wa(fdiff(
+ src_name = name1,
+ dst_name = name2,
+ src_mode = mode1,
+ dst_mode = mode2,
+ hunks = ha,
+ ))
+ wb(fdiff(
+ src_name = name2,
+ dst_name = name3,
+ src_mode = mode2,
+ dst_mode = mode3,
+ hunks = hb,
+ ))
+
+ _pair_diffs(a, b, on_lhs=wa, on_rhs=wb, on_re=join, on_pair=join)
+ return tp._sort_files(rb), tp._sort_files(ra)
+
+def _pair_diffs(a, b, *, on_lhs, on_rhs, on_pair, on_re):
+ a_del = {}
+ b_new = {}
+
+ a_dst = {}
+ b_dst = set()
+ for f in a.files:
+ dst = f.dst_name
+ if dst is None:
+ a_del[f.src_name] = f
+ continue
+ a_dst[dst] = f
+
+ a = a_dst.pop
+ for f in b.files:
+ dst = f.dst_name
+ if dst is not None:
+ b_dst.add(dst)
+
+ src = f.src_name
+ if src is None:
+ b_new[dst] = f
+ continue
+
+ try:
+ f2 = a(src)
+ except KeyError:
+ pass
+ else:
+ on_pair(f2, f)
+ continue
+ on_rhs(f)
+
+ for dst in a_dst.keys() & b_dst:
+ raise ValueError(f'mismatch for {dst!r}')
+
+ b = b_new.pop
+ for src,f in a_del.items():
+ try:
+ f2 = b(src)
+ except KeyError:
+ pass
+ else:
+ on_re(f, f2)
+ continue
+ on_lhs(f)
+
+ for f in b_new.values():
+ on_rhs(f)
+
+ for f in a_dst.values():
+ on_lhs(f)