From a2b2fe7da78378dc6a64f58c14e907496c9adea9 Mon Sep 17 00:00:00 2001 From: Hristo Venev Date: Thu, 26 Aug 2021 17:30:35 +0300 Subject: Initial commit --- patchtheory.py | 960 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 960 insertions(+) create mode 100644 patchtheory.py (limited to 'patchtheory.py') diff --git a/patchtheory.py b/patchtheory.py new file mode 100644 index 0000000..21b3e29 --- /dev/null +++ b/patchtheory.py @@ -0,0 +1,960 @@ +import re, operator, io, sys + +rem_range = re.compile(b'@@ -([0-9]+)(?:,([0-9]+))? \\+([0-9]+)(?:,([0-9]+))? @@(?: .*)?').fullmatch +rem_diffgit = re.compile('diff --git a/([^ ]*) b/([^ ]*)').fullmatch + +INFINITY = 2**64 + +class PatchError(Exception): + __slots__ = ('a', 'b') + def __init__(self, a, b, /, fn=None): + super().__init__(fn) + self.a = a + self.b = b + + def __str__(self, /): + fn = self.args[0] + r = [f' in {fn!r}:' if fn is not None else ':'] + + def tabmangle(p): + o = io.BytesIO() + p.write(o.write) + p = o.getvalue().decode('utf-8', 'replace').splitlines() + return [' ' + l for l in p] + + r.append('') + r += tabmangle(self.a) + + r.append('') + r += tabmangle(self.b) + + return '\n'.join(r) + +class ConflictError(PatchError): + __slots__ = () + + def __str__(self, /): + return 'Conflict' + PatchError.__str__(self) + +class MismatchError(PatchError): + __slots__ = () + + def __str__(self, /): + return 'Mismatch' + PatchError.__str__(self) + +class Hunk: + __slots__ = ('lhs_line', 'rhs_line', 'pre_context', 'lhs', 'rhs', 'post_context') + + def __init__(self, /, *, lhs_line, rhs_line, pre_context, lhs, rhs, post_context): + assert lhs_line > 0 + assert rhs_line > 0 + + self.lhs_line = lhs_line + self.rhs_line = rhs_line + self.pre_context = pre_context + self.lhs = lhs + self.rhs = rhs + self.post_context = post_context + + @property + def lhs_size(self, /): + return len(self.pre_context) + len(self.lhs) + len(self.post_context) + + @property + def rhs_size(self, /): + return len(self.pre_context) + len(self.rhs) + len(self.post_context) + + def write(self, w, /, munge=False, *, lhs_line=None, rhs_line=None, lhs_nl=True, rhs_nl=True): + if lhs_line is None: + lhs_line = self.lhs_line + if rhs_line is None: + rhs_line = self.rhs_line + pre_context = self.pre_context + lhs = self.lhs + rhs = self.rhs + post_context = self.post_context + clen = len(pre_context) + len(post_context) + llen = len(lhs) + clen + rlen = len(rhs) + clen + + if munge: + w(f'@@ -xxx,{llen} +yyy,{rlen} @@'.encode('utf-8')) + else: + w(f'@@ -{lhs_line},{llen} +{rhs_line},{rlen} @@'.encode('utf-8')) + + for l in pre_context: + w(b'\n ') + w(l) + + for l in lhs: + w(b'\n-') + w(l) + + for l in rhs: + w(b'\n+') + w(l) + + for l in post_context: + w(b'\n ') + w(l) + + w(b'\n') + + @classmethod + def parse_l(tp, l, line, /): + m = rem_range(l) + if m is None: + raise ValueError(f'Invalid patch: bad range info {l!r}') + lhs_line,n2,rhs_line,n4 = m.groups() + lhs_line = 1 if lhs_line == b'0' else int(lhs_line) + rhs_line = 1 if rhs_line == b'0' else int(rhs_line) + n2 = int(n2) if n2 else 1 + n4 = int(n4) if n4 else 1 + + hunk = [] + prev = -1 + lhs_nl = True + rhs_nl = True + post_nl = True + while True: + l = line() + if l.startswith(b'\\ '): + if prev == 0: + lhs_nl = False + elif prev == 1: + rhs_nl = False + elif prev == -1: + post_nl = False + else: + raise ValueError('Invalid patch: bad "No newline at end of file" flag') + prev = -2 + continue + if not n2 and not n4: + break + k = l[0] + if k == b' '[0]: + if not post_nl or not lhs_nl or not rhs_nl: + raise ValueError('Invalid patch: bad "No newline at end of file" flag') + if not n2 or not n4: + raise ValueError('Invalid patch: bad line count') + n2 -= 1 + n4 -= 1 + prev = -1 + elif k == b'-'[0]: + if not post_nl or not lhs_nl: + raise ValueError('Invalid patch: bad "No newline at end of file" flag') + if not n2: + raise ValueError('Invalid patch: bad line count') + n2 -= 1 + prev = 0 + elif k == b'+'[0]: + if not post_nl or not rhs_nl: + raise ValueError('Invalid patch: bad "No newline at end of file" flag') + if not n4: + raise ValueError('Invalid patch: bad line count') + n4 -= 1 + prev = 1 + else: + raise ValueError(f'Invalid patch: bad prefix character {chr(k)!r}') + assert False + hunk.append(l) + lret = l + + strip1 = operator.itemgetter(slice(1,None)) + + for i,l in enumerate(hunk): + if l[0] != b' '[0]: + break + else: + raise ValueError('Invalid patch: no changed lines in hunk') + pre_context = (*map(strip1, hunk[:i]),) + del hunk[:i] + + j = -1 + while True: + l = hunk[j] + if l[0] != b' '[0]: + break + j -= 1 + j += len(hunk) + 1 + if not post_nl: + hunk[j] += b'\n\\ No newline at end of file' + post_context = (*map(strip1, hunk[j:]),) + del hunk[j:] + + lhs = [] + rhs = [] + for l in hunk: + l0 = l[0] + if l0 != b'+'[0]: + lhs.append(l) + if l0 != b'-'[0]: + rhs.append(l) + if not lhs_nl: + lhs[-1] += b'\n\\ No newline at end of file' + if not rhs_nl: + rhs[-1] += b'\n\\ No newline at end of file' + + return lret, tp( + lhs_line = lhs_line, + rhs_line = rhs_line, + pre_context = pre_context, + lhs = (*map(strip1, lhs),), + rhs = (*map(strip1, rhs),), + post_context = post_context, + ) + + def reduce_context(self, n=3): + pre = self.pre_context + lhs = self.lhs + rhs = self.rhs + post = self.post_context + + i = 0 + j = -1 + m = min(len(lhs), len(rhs)) + while i - j <= m: + if lhs[i] == rhs[i]: + i += 1 + if i - j > m: + break + if lhs[j] == rhs[j]: + j -= 1 + else: + if lhs[j] != rhs[j]: + break + j -= 1 + j += 1 + jl = len(lhs) + j + jr = len(rhs) + j + + assert lhs[:i] == rhs[:i] + assert lhs[jl:] == rhs[jr:] + pre = pre + lhs[:i] + post = lhs[jl:] + post + lhs = lhs[i:jl] + rhs = rhs[i:jr] + + if not lhs and not rhs: + return None + + assert lhs != rhs + assert not lhs or not rhs or lhs[0] != rhs[0] and lhs[-1] != rhs[-1] + + lhs_line = self.lhs_line + rhs_line = self.rhs_line + + if n is None: + pass + elif not n: + d = len(pre) + lhs_line += d + rhs_line += d + pre = () + post = () + else: + if len(pre) > n: + d = len(pre) - n + lhs_line += d + rhs_line += d + pre = pre[d:] + + if len(post) > n: + post = post[:n] + + return Hunk( + lhs_line = lhs_line, + rhs_line = rhs_line, + pre_context = pre, + lhs = lhs, + rhs = rhs, + post_context = post, + ) + +def join_hunks(hs_a, hs_b): + hs_a = iter(hs_a) + hs_b = iter(hs_b) + + need_a = True + need_b = True + have_cur = False + + def make_cur(n): + if not have_cur: + return None + return Hunk( + lhs_line = cur_lhs_line, + rhs_line = cur_rhs_line, + pre_context = (), + lhs = tuple(cur_lhs), + rhs = tuple(cur_rhs), + post_context = (), + ).reduce_context(n) + + delta = 0 + r = [] + + while True: + if need_a: + need_a = False + a = next(hs_a, None) + + if need_b: + need_b = False + b = next(hs_b, None) + + if not have_cur: + if a is not None and (b is None or a.rhs_line <= b.lhs_line): + cur_lhs_line = a.lhs_line + cur_rhs_line = a.lhs_line + delta + cur_lhs = [*a.pre_context, *a.lhs, *a.post_context] + cur_rhs = [*a.pre_context, *a.rhs, *a.post_context] + need_a = True + elif b is not None: + cur_lhs_line = b.rhs_line - delta + cur_rhs_line = b.rhs_line + cur_lhs = [*b.pre_context, *b.lhs, *b.post_context] + cur_rhs = [*b.pre_context, *b.rhs, *b.post_context] + need_b = True + else: + return (*r,) + have_cur = True + continue + + cur_lhs_end = cur_lhs_line + len(cur_lhs) + cur_rhs_end = cur_rhs_line + len(cur_rhs) + + if a is not None and a.lhs_line <= cur_lhs_end: + lhs = [*a.pre_context, *a.lhs, *a.post_context] + rhs = [*a.pre_context, *a.rhs, *a.post_context] + + i = a.lhs_line - cur_lhs_line + k = i + len(rhs) + + if len(cur_lhs) >= k: + if cur_lhs[i:k] != rhs: + raise MismatchError(make_cur(None), a) + cur_lhs[i:k] = lhs + else: + j = len(cur_lhs) - i + if cur_lhs[i:] != rhs[:j]: + raise MismatchError(make_cur(None), a) + cur_lhs[i:] = lhs + cur_rhs.extend(rhs[j:]) + need_a = True + continue + + if b is not None and b.rhs_line <= cur_rhs_end: + lhs = [*b.pre_context, *b.lhs, *b.post_context] + rhs = [*b.pre_context, *b.rhs, *b.post_context] + i = b.rhs_line - cur_rhs_line + k = i + len(lhs) + + if len(cur_rhs) >= k: + if cur_rhs[i:k] != lhs: + raise MismatchError(make_cur(None), b) + cur_rhs[i:k] = rhs + else: + j = len(cur_rhs) - i + if cur_rhs[i:] != lhs[:j]: + raise MismatchError(make_cur(None), b) + cur_rhs[i:] = rhs + cur_lhs.extend(lhs[j:]) + need_b = True + continue + + cur = make_cur(3) + have_cur = False + if cur is None: + continue + r.append(cur) + delta += len(cur.rhs) - len(cur.lhs) + +def commute_hunks(hs_a, hs_b): + hs_a = iter(hs_a) + hs_b = iter(hs_b) + + need_a = True + need_b = True + + delta_a = 0 + delta_b = 0 + ra = [] + rb = [] + wa = ra.append + wb = rb.append + + while True: + if need_a: + need_a = False + try: + a = next(hs_a) + except StopIteration: + if b is None and not need_b: + break + a = None + a_begin = INFINITY + a_pre = () + a_post = () + a_end = INFINITY + else: + a_pre = a.pre_context + a_post = a.post_context + a_begin = a.rhs_line + len(a_pre) + a_end = a_begin + len(a.rhs) + + if need_b: + need_b = False + try: + b = next(hs_b) + except StopIteration: + if a is None: + break + b = None + b_begin = INFINITY + b_pre = () + b_post = () + b_end = INFINITY + else: + b_pre = b.pre_context + b_post = b.post_context + b_begin = b.lhs_line + len(b_pre) + b_end = b_begin + len(b.lhs) + + if a_end <= b_begin: + gap = b_begin - a_end + + if (n := len(a_post) + len(b_pre) - gap) > 0: + assert (a_pre + a.rhs + a_post)[-n:] == (b_pre + b.lhs + b_post)[:n] + if (n := len(a_post) - gap) > 0: + a_post = a_post[:gap] + (b.rhs + b_post)[:n] + if (n := len(b_pre) - gap) > 0: + b_pre = (a_pre + a.lhs)[-n:] + b_pre[n:] + + wa(Hunk( + lhs_line = a.lhs_line + delta_b, + rhs_line = a.rhs_line + delta_b, + pre_context = a_pre, + lhs = a.lhs, + rhs = a.rhs, + post_context = a_post, + )) + delta_a += len(a.rhs) - len(a.lhs) + need_a = True + continue + + if b_end <= a_begin: + gap = a_begin - b_end + + if (n := len(b_post) + len(a_pre) - gap) > 0: + assert (b_pre + b.lhs + b_post)[-n:] == (a_pre + a.rhs + a_post)[:n] + if (n := len(b_post) - gap) > 0: + b_post = b_post[:gap] + (a.lhs + a_post)[:n] + if (n := len(a_pre) - gap) > 0: + a_pre = (b_pre + b.rhs)[-n:] + a_pre[n:] + + wb(Hunk( + lhs_line = b.lhs_line - delta_a, + rhs_line = b.rhs_line - delta_a, + pre_context = b_pre, + lhs = b.lhs, + rhs = b.rhs, + post_context = b_post, + )) + delta_b += len(b.rhs) - len(b.lhs) + need_b = True + continue + + raise ConflictError(a, b) + + return (*rb,), (*ra,) + +class FileDiff: + __slots__ = ('src_name', 'dst_name', 'src_mode', 'dst_mode', 'hunks',) + + def __init__(self, /, *, src_name, dst_name, src_mode, dst_mode, hunks): + assert (src_mode is ...) == (dst_mode is ...) + assert (src_name is None) == (src_mode is None) + assert (dst_name is None) == (dst_mode is None) + + self.src_name = src_name + self.dst_name = dst_name + if src_mode == dst_mode: + self.src_mode = ... + self.dst_mode = ... + else: + self.src_mode = src_mode + self.dst_mode = dst_mode + self.hunks = hunks + + @property + def op_kind_name(self): + if self.src_name is None: + return 'create' + if self.dst_name is None: + return 'delete' + is_chmod = self.dst_mode != self.dst_mode + if self.dst_name != self.dst_name: + return 'chmod+rename' if is_chmod else 'rename' + if self.hunks: + return 'chmod+modify' if is_chmod else 'modify' + return 'chmod' if is_chmod else 'noop' + + def write(self, w): + src_name = self.src_name + dst_name = self.dst_name + + if src_name is None: + assert dst_name is not None + w(f'diff --git a/{dst_name} b/{dst_name}\n'.encode('utf-8')) + elif dst_name is None: + w(f'diff --git a/{src_name} b/{src_name}\n'.encode('utf-8')) + else: + w(f'diff --git a/{src_name} b/{dst_name}\n'.encode('utf-8')) + + if src_name is not None and dst_name is not None and src_name != dst_name: + w(f'rename from {src_name}\nrename to {dst_name}\n'.encode('utf-8')) + + m = self.src_mode + if dst_name is None: + assert m is not None + w(f'deleted file mode {m}\n'.encode('utf-8')) + elif m not in (None, ...): + w(f'old mode {m}\n'.encode('utf-8')) + + m = self.dst_mode + if src_name is None: + assert m is not None + w(f'new file mode {m}\n'.encode('utf-8')) + elif m not in (None, ...): + w(f'new mode {m}\n'.encode('utf-8')) + + h = self.hunks + if h: + if src_name is None: + src_name = '/dev/null' + lhs_line = 0 + else: + src_name = f'a/{src_name}' + lhs_line = None + if dst_name is None: + dst_name = '/dev/null' + rhs_line = 0 + else: + dst_name = f'b/{dst_name}' + rhs_line = None + + w(f'--- {src_name}\n+++ {dst_name}\n'.encode('utf-8')) + + for h in h: + h.write(w, lhs_line=lhs_line, rhs_line=rhs_line) + + def __neg__(self, /): + tp = type(self) + +class Diff: + __slots__ = ('files',) + + def __init__(self): + self.files = () + + def __bool__(self): + return not not self.files + + @classmethod + def _sort_files(tp, files): + r = object.__new__(tp) + r.files = (*sorted(files, key=lambda f: (f.src_name or '', f.dst_name or '')),) + return r + + @classmethod + def parse_l(tp, l, line): + files = [] + + while True: + l = l.decode('utf-8').rstrip() + if l == '--' or not l: + break + m = rem_diffgit(l) + if m is None: + raise ValueError(f'Invalid patch: expected \'diff --git ...\' header, got {l!r}') + + header = l + + src_name,dst_name = m.groups() + src_exists = ... + dst_exists = ... + + if not src_exists and not dst_exists: + raise RuntimeError('wtf') + if (not src_exists or not dst_exists) and src_name != dst_name: + raise RuntimeError('wtf') + + has_diff = False + has_chmod = False + has_rename = False + dst_mode = ... + src_mode = ... + + while True: + l = line() + if l.startswith(b'--- a/'): + if src_exists is False: raise RuntimeError('wtf') + src_exists = True + if l[6:].decode('utf-8') != src_name: + raise RuntimeError('wtf') + has_diff = True + continue + if l == b'--- /dev/null': + if src_exists is True: raise RuntimeError('wtf') + src_exists = False + has_diff = True + continue + if l.startswith(b'+++ b/'): + if dst_exists is False: raise RuntimeError('wtf') + dst_exists = True + if l[6:].decode('utf-8') != dst_name: + raise RuntimeError('wtf') + has_diff = True + continue + if l == b'+++ /dev/null': + if dst_exists is True: raise RuntimeError('wtf') + dst_exists = False + has_diff = True + continue + if l.startswith(b'index '): + continue + if l.startswith(b'new file mode '): + if src_exists is True: raise RuntimeError('wtf') + if dst_exists is False: raise RuntimeError('wtf') + src_exists = False + dst_exists=True + dst_mode = l[14:].decode('ascii') + src_mode = None + continue + if l.startswith(b'new mode '): + dst_mode = l[9:].decode('ascii') + has_chmod = True + continue + if l.startswith(b'deleted file mode '): + if src_exists is False: raise RuntimeError('wtf') + if dst_exists is True: raise RuntimeError('wtf') + src_exists = True + dst_exists = False + src_mode = l[18:].decode('ascii') + dst_mode = None + continue + if l.startswith(b'old mode '): + if dst_exists is False: + raise RuntimeError('File deleted') + src_mode = l[9:].decode('ascii') + has_chmod = True + dst_exists = True + continue + if l.startswith(b'similarity index '): + has_rename = True + continue + if l.startswith(b'rename from '): + has_rename = True + rename_from = l[12:].decode('utf-8') + continue + if l.startswith(b'rename to '): + rename_to = l[10:].decode('utf-8') + has_rename = True + continue + break + + if has_chmod or has_rename: + if src_exists is False: raise RuntimeError('wtf') + if dst_exists is False: raise RuntimeError('wtf') + src_exists = True + dst_exists = True + + if has_chmod: + assert src_mode is not ... + assert dst_mode is not ... + + assert src_exists is not ... + if not src_exists: + assert dst_mode is not ... + + assert dst_exists is not ... + if not dst_exists: + assert src_mode is not ... + + if has_rename: + if src_name == dst_name or not src_exists or not dst_exists: + raise RuntimeError(f'Bad rename: {src_name!r} -> {dst_name!r}') + assert rename_from == src_name + assert rename_to == dst_name + + if not src_exists: src_name = None + if not dst_exists: dst_name = None + + assert (src_mode is ...) == (dst_mode is ...) + assert has_diff or has_chmod or has_rename or not dst_exists + + #want_header = fmt_diff_git(src_name, dst_name) + #if header != want_header: + # raise ValueError(f'Invalid patch: expected {want_header!r} header, got {header!r}') + + if has_diff: + hunks = [] + + while True: + l,hunk = Hunk.parse_l(l, line) + hunk = hunk.reduce_context() + if hunk is not None: + hunks.append(hunk) + if not l.startswith(b'@@ '): + break + + if src_name is None: + assert len(hunks) == 1 + assert len(hunks[0].pre_context) == 0 + assert len(hunks[0].post_context) == 0 + assert len(hunks[0].lhs) == 0 + assert hunks[0].lhs_line == 1 + hunks[0].lhs_line = 1 + + if dst_name is None: + assert len(hunks) == 1 + assert len(hunks[0].pre_context) == 0 + assert len(hunks[0].post_context) == 0 + assert len(hunks[0].rhs) == 0 + assert hunks[0].rhs_line == 1 + + hunks = (*hunks,) + + else: + hunks = () + + files.append(FileDiff( + src_name = src_name, + dst_name = dst_name, + src_mode = src_mode, + dst_mode = dst_mode, + hunks = hunks, + )) + + try: + l = line().rstrip() + except StopIteration: + pass + else: + if l[0] not in b'0123456789': + l = l.decode('utf-8') + raise ValueError('Invalid patch: expected git version, got {l!r}') + + + while True: + try: + l = line() + except StopIteration: + break + if l.strip(): + l = l.decode('utf-8') + raise ValueError('Invalid patch: got trailing garbage {l!r}') + + return tp._sort_files(files) + + @classmethod + def parse(tp, body): + line = iter(body.split(b'\n')).__next__ + return tp.parse_l(line(), line) + + def write(self, w, /): + for f in self.files: + f.write(w) + + def write_munged(self, w, /): + for f in self.files: + w((f.src_name or '').encode('utf-8')) + w(b'\n') + w((f.dst_name or '').encode('utf-8')) + w(b'\n') + for h in f.hunks: + h.write(w, True) + w(b'\n') + + def __neg__(self): + return Diff._sort_files((*(FileDiff( + src_name = f.dst_name, + dst_name = f.src_name, + src_mode = f.dst_mode, + dst_mode = f.src_mode, + hunks = (*(Hunk( + lhs_line = h.rhs_line, + rhs_line = h.lhs_line, + pre_context = h.pre_context, + lhs = h.rhs, + rhs = h.lhs, + post_context = h.post_context, + ) for h in f.hunks),), + ) for f in self.files),)) + + def __add__(a, b): + tp = type(a) + if type(b) is not tp: + raise TypeError + + r = [] + w = r.append + + fdiff = FileDiff + def join(fa, fb): + src_name = fa.src_name + dst_name = fb.dst_name + src_mode = fa.src_mode + dst_mode = fb.dst_mode + if src_mode is ...: + if dst_mode is not ...: + src_mode = fb.src_mode + elif dst_mode is ...: + dst_mode = fa.dst_mode + elif fa.dst_mode != fb.src_mode: + raise ValueError('mismatch') + + hunks = (*join_hunks(fa.hunks, fb.hunks),) + + if src_name is None and dst_name is None: + if hunks: + raise ValueError('mismatch') + return + + if src_name == dst_name and src_mode == dst_mode and not hunks: + return + + w(fdiff( + src_name = src_name, + dst_name = dst_name, + src_mode = src_mode, + dst_mode = dst_mode, + hunks = hunks, + )) + + _pair_diffs(a, b, on_lhs=w, on_rhs=w, on_pair=join, on_re=join) + return tp._sort_files(r) + + @classmethod + def commute(tp, a, b): + if type(a) is not tp or type(b) is not tp: + raise TypeError + + ra = [] + rb = [] + wa = ra.append + wb = rb.append + fdiff = FileDiff + + def join(fa, fb): + name1 = fa.src_name + name2 = fa.dst_name + name3 = fb.dst_name + if name1 == name2: + name2 = name3 + elif name2 == name3: + name2 = name1 + else: + raise ValueError('cannot commute {name2!r}: {fa.op_kind_name} / {fb.op_kind_name}') + + mode1 = fa.src_mode + mode3 = fb.dst_mode + if mode1 is ...: + if mode3 is ...: + mode2 = ... + else: + mode1 = fb.src_mode + mode2 = mode3 + + elif mode3 is ...: + mode3 = fa.dst_mode + mode2 = mode1 + + else: + mode2 = fb.dst_mode + if fa.src_mode != mode2: + raise ValueError('mismatch') + if mode1 == mode2: + mode2 = mode3 + elif mode2 == mode3: + mode2 = mode1 + else: + raise ValueError('cannot commute {name2!r}: chmod/chmod conflict') + + try: + hb,ha = commute_hunks(fa.hunks, fb.hunks) + except PatchError as err: + err.args = (name1, name2, name3) + raise + + wa(fdiff( + src_name = name1, + dst_name = name2, + src_mode = mode1, + dst_mode = mode2, + hunks = ha, + )) + wb(fdiff( + src_name = name2, + dst_name = name3, + src_mode = mode2, + dst_mode = mode3, + hunks = hb, + )) + + _pair_diffs(a, b, on_lhs=wa, on_rhs=wb, on_re=join, on_pair=join) + return tp._sort_files(rb), tp._sort_files(ra) + +def _pair_diffs(a, b, *, on_lhs, on_rhs, on_pair, on_re): + a_del = {} + b_new = {} + + a_dst = {} + b_dst = set() + for f in a.files: + dst = f.dst_name + if dst is None: + a_del[f.src_name] = f + continue + a_dst[dst] = f + + a = a_dst.pop + for f in b.files: + dst = f.dst_name + if dst is not None: + b_dst.add(dst) + + src = f.src_name + if src is None: + b_new[dst] = f + continue + + try: + f2 = a(src) + except KeyError: + pass + else: + on_pair(f2, f) + continue + on_rhs(f) + + for dst in a_dst.keys() & b_dst: + raise ValueError(f'mismatch for {dst!r}') + + b = b_new.pop + for src,f in a_del.items(): + try: + f2 = b(src) + except KeyError: + pass + else: + on_re(f, f2) + continue + on_lhs(f) + + for f in b_new.values(): + on_rhs(f) + + for f in a_dst.values(): + on_lhs(f) -- cgit