import re, operator, io, sys rem_range = re.compile(b'@@ -([0-9]+)(?:,([0-9]+))? \\+([0-9]+)(?:,([0-9]+))? @@(?: .*)?').fullmatch rem_diffgit = re.compile('diff --git a/([^ ]*) b/([^ ]*)').fullmatch INFINITY = 2**64 class PatchError(Exception): __slots__ = ('a', 'b') def __init__(self, a, b, /, fn=None): super().__init__(fn) self.a = a self.b = b def __str__(self, /): fn = self.args[0] r = [f' in {fn!r}:' if fn is not None else ':'] def tabmangle(p): o = io.BytesIO() p.write(o.write) p = o.getvalue().decode('utf-8', 'replace').splitlines() return [' ' + l for l in p] r.append('') r += tabmangle(self.a) r.append('') r += tabmangle(self.b) return '\n'.join(r) class ConflictError(PatchError): __slots__ = () def __str__(self, /): return 'Conflict' + PatchError.__str__(self) class MismatchError(PatchError): __slots__ = () def __str__(self, /): return 'Mismatch' + PatchError.__str__(self) class Hunk: __slots__ = ('lhs_line', 'rhs_line', 'pre_context', 'lhs', 'rhs', 'post_context') def __init__(self, /, *, lhs_line, rhs_line, pre_context, lhs, rhs, post_context): assert lhs_line > 0 assert rhs_line > 0 self.lhs_line = lhs_line self.rhs_line = rhs_line self.pre_context = pre_context self.lhs = lhs self.rhs = rhs self.post_context = post_context @property def lhs_size(self, /): return len(self.pre_context) + len(self.lhs) + len(self.post_context) @property def rhs_size(self, /): return len(self.pre_context) + len(self.rhs) + len(self.post_context) def write(self, w, /, munge=False, *, lhs_line=None, rhs_line=None, lhs_nl=True, rhs_nl=True): if lhs_line is None: lhs_line = self.lhs_line if rhs_line is None: rhs_line = self.rhs_line pre_context = self.pre_context lhs = self.lhs rhs = self.rhs post_context = self.post_context clen = len(pre_context) + len(post_context) llen = len(lhs) + clen rlen = len(rhs) + clen if munge: w(f'@@ -xxx,{llen} +yyy,{rlen} @@'.encode('utf-8')) else: w(f'@@ -{lhs_line},{llen} +{rhs_line},{rlen} @@'.encode('utf-8')) for l in pre_context: w(b'\n ') w(l) for l in lhs: w(b'\n-') w(l) for l in rhs: w(b'\n+') w(l) for l in post_context: w(b'\n ') w(l) w(b'\n') @classmethod def parse_l(tp, l, line, /): m = rem_range(l) if m is None: raise ValueError(f'Invalid patch: bad range info {l!r}') lhs_line,n2,rhs_line,n4 = m.groups() lhs_line = 1 if lhs_line == b'0' else int(lhs_line) rhs_line = 1 if rhs_line == b'0' else int(rhs_line) n2 = int(n2) if n2 else 1 n4 = int(n4) if n4 else 1 hunk = [] prev = -1 lhs_nl = True rhs_nl = True post_nl = True while True: l = line() if l.startswith(b'\\ '): if prev == 0: lhs_nl = False elif prev == 1: rhs_nl = False elif prev == -1: post_nl = False else: raise ValueError('Invalid patch: bad "No newline at end of file" flag') prev = -2 continue if not n2 and not n4: break k = l[0] if k == b' '[0]: if not post_nl or not lhs_nl or not rhs_nl: raise ValueError('Invalid patch: bad "No newline at end of file" flag') if not n2 or not n4: raise ValueError('Invalid patch: bad line count') n2 -= 1 n4 -= 1 prev = -1 elif k == b'-'[0]: if not post_nl or not lhs_nl: raise ValueError('Invalid patch: bad "No newline at end of file" flag') if not n2: raise ValueError('Invalid patch: bad line count') n2 -= 1 prev = 0 elif k == b'+'[0]: if not post_nl or not rhs_nl: raise ValueError('Invalid patch: bad "No newline at end of file" flag') if not n4: raise ValueError('Invalid patch: bad line count') n4 -= 1 prev = 1 else: raise ValueError(f'Invalid patch: bad prefix character {chr(k)!r}') assert False hunk.append(l) lret = l strip1 = operator.itemgetter(slice(1,None)) for i,l in enumerate(hunk): if l[0] != b' '[0]: break else: raise ValueError('Invalid patch: no changed lines in hunk') pre_context = (*map(strip1, hunk[:i]),) del hunk[:i] j = -1 while True: l = hunk[j] if l[0] != b' '[0]: break j -= 1 j += len(hunk) + 1 if not post_nl: hunk[j] += b'\n\\ No newline at end of file' post_context = (*map(strip1, hunk[j:]),) del hunk[j:] lhs = [] rhs = [] for l in hunk: l0 = l[0] if l0 != b'+'[0]: lhs.append(l) if l0 != b'-'[0]: rhs.append(l) if not lhs_nl: lhs[-1] += b'\n\\ No newline at end of file' if not rhs_nl: rhs[-1] += b'\n\\ No newline at end of file' return lret, tp( lhs_line = lhs_line, rhs_line = rhs_line, pre_context = pre_context, lhs = (*map(strip1, lhs),), rhs = (*map(strip1, rhs),), post_context = post_context, ) def reduce_context(self, n=3): pre = self.pre_context lhs = self.lhs rhs = self.rhs post = self.post_context i = 0 j = -1 m = min(len(lhs), len(rhs)) while i - j <= m: if lhs[i] == rhs[i]: i += 1 if i - j > m: break if lhs[j] == rhs[j]: j -= 1 else: if lhs[j] != rhs[j]: break j -= 1 j += 1 jl = len(lhs) + j jr = len(rhs) + j assert lhs[:i] == rhs[:i] assert lhs[jl:] == rhs[jr:] pre = pre + lhs[:i] post = lhs[jl:] + post lhs = lhs[i:jl] rhs = rhs[i:jr] if not lhs and not rhs: return None assert lhs != rhs assert not lhs or not rhs or lhs[0] != rhs[0] and lhs[-1] != rhs[-1] lhs_line = self.lhs_line rhs_line = self.rhs_line if n is None: pass elif not n: d = len(pre) lhs_line += d rhs_line += d pre = () post = () else: if len(pre) > n: d = len(pre) - n lhs_line += d rhs_line += d pre = pre[d:] if len(post) > n: post = post[:n] return Hunk( lhs_line = lhs_line, rhs_line = rhs_line, pre_context = pre, lhs = lhs, rhs = rhs, post_context = post, ) def join_hunks(hs_a, hs_b): hs_a = iter(hs_a) hs_b = iter(hs_b) need_a = True need_b = True have_cur = False def make_cur(n): if not have_cur: return None return Hunk( lhs_line = cur_lhs_line, rhs_line = cur_rhs_line, pre_context = (), lhs = tuple(cur_lhs), rhs = tuple(cur_rhs), post_context = (), ).reduce_context(n) delta = 0 r = [] while True: if need_a: need_a = False a = next(hs_a, None) if need_b: need_b = False b = next(hs_b, None) if not have_cur: if a is not None and (b is None or a.rhs_line <= b.lhs_line): cur_lhs_line = a.lhs_line cur_rhs_line = a.lhs_line + delta cur_lhs = [*a.pre_context, *a.lhs, *a.post_context] cur_rhs = [*a.pre_context, *a.rhs, *a.post_context] need_a = True elif b is not None: cur_lhs_line = b.rhs_line - delta cur_rhs_line = b.rhs_line cur_lhs = [*b.pre_context, *b.lhs, *b.post_context] cur_rhs = [*b.pre_context, *b.rhs, *b.post_context] need_b = True else: return (*r,) have_cur = True continue cur_lhs_end = cur_lhs_line + len(cur_lhs) cur_rhs_end = cur_rhs_line + len(cur_rhs) if a is not None and a.lhs_line <= cur_lhs_end: lhs = [*a.pre_context, *a.lhs, *a.post_context] rhs = [*a.pre_context, *a.rhs, *a.post_context] i = a.lhs_line - cur_lhs_line k = i + len(rhs) if len(cur_lhs) >= k: if cur_lhs[i:k] != rhs: raise MismatchError(make_cur(None), a) cur_lhs[i:k] = lhs else: j = len(cur_lhs) - i if cur_lhs[i:] != rhs[:j]: raise MismatchError(make_cur(None), a) cur_lhs[i:] = lhs cur_rhs.extend(rhs[j:]) need_a = True continue if b is not None and b.rhs_line <= cur_rhs_end: lhs = [*b.pre_context, *b.lhs, *b.post_context] rhs = [*b.pre_context, *b.rhs, *b.post_context] i = b.rhs_line - cur_rhs_line k = i + len(lhs) if len(cur_rhs) >= k: if cur_rhs[i:k] != lhs: raise MismatchError(make_cur(None), b) cur_rhs[i:k] = rhs else: j = len(cur_rhs) - i if cur_rhs[i:] != lhs[:j]: raise MismatchError(make_cur(None), b) cur_rhs[i:] = rhs cur_lhs.extend(lhs[j:]) need_b = True continue cur = make_cur(3) have_cur = False if cur is None: continue r.append(cur) delta += len(cur.rhs) - len(cur.lhs) def commute_hunks(hs_a, hs_b): hs_a = iter(hs_a) hs_b = iter(hs_b) need_a = True need_b = True delta_a = 0 delta_b = 0 ra = [] rb = [] wa = ra.append wb = rb.append while True: if need_a: need_a = False try: a = next(hs_a) except StopIteration: if b is None and not need_b: break a = None a_begin = INFINITY a_pre = () a_post = () a_end = INFINITY else: a_pre = a.pre_context a_post = a.post_context a_begin = a.rhs_line + len(a_pre) a_end = a_begin + len(a.rhs) if need_b: need_b = False try: b = next(hs_b) except StopIteration: if a is None: break b = None b_begin = INFINITY b_pre = () b_post = () b_end = INFINITY else: b_pre = b.pre_context b_post = b.post_context b_begin = b.lhs_line + len(b_pre) b_end = b_begin + len(b.lhs) if a_end <= b_begin: gap = b_begin - a_end if (n := len(a_post) + len(b_pre) - gap) > 0: assert (a_pre + a.rhs + a_post)[-n:] == (b_pre + b.lhs + b_post)[:n] if (n := len(a_post) - gap) > 0: a_post = a_post[:gap] + (b.rhs + b_post)[:n] if (n := len(b_pre) - gap) > 0: b_pre = (a_pre + a.lhs)[-n:] + b_pre[n:] wa(Hunk( lhs_line = a.lhs_line + delta_b, rhs_line = a.rhs_line + delta_b, pre_context = a_pre, lhs = a.lhs, rhs = a.rhs, post_context = a_post, )) delta_a += len(a.rhs) - len(a.lhs) need_a = True continue if b_end <= a_begin: gap = a_begin - b_end if (n := len(b_post) + len(a_pre) - gap) > 0: assert (b_pre + b.lhs + b_post)[-n:] == (a_pre + a.rhs + a_post)[:n] if (n := len(b_post) - gap) > 0: b_post = b_post[:gap] + (a.lhs + a_post)[:n] if (n := len(a_pre) - gap) > 0: a_pre = (b_pre + b.rhs)[-n:] + a_pre[n:] wb(Hunk( lhs_line = b.lhs_line - delta_a, rhs_line = b.rhs_line - delta_a, pre_context = b_pre, lhs = b.lhs, rhs = b.rhs, post_context = b_post, )) delta_b += len(b.rhs) - len(b.lhs) need_b = True continue raise ConflictError(a, b) return (*rb,), (*ra,) class FileDiff: __slots__ = ('src_name', 'dst_name', 'src_mode', 'dst_mode', 'hunks',) def __init__(self, /, *, src_name, dst_name, src_mode, dst_mode, hunks): assert (src_mode is ...) == (dst_mode is ...) assert (src_name is None) == (src_mode is None) assert (dst_name is None) == (dst_mode is None) self.src_name = src_name self.dst_name = dst_name if src_mode == dst_mode: self.src_mode = ... self.dst_mode = ... else: self.src_mode = src_mode self.dst_mode = dst_mode self.hunks = hunks @property def op_kind_name(self): if self.src_name is None: return 'create' if self.dst_name is None: return 'delete' is_chmod = self.dst_mode != self.dst_mode if self.dst_name != self.dst_name: return 'chmod+rename' if is_chmod else 'rename' if self.hunks: return 'chmod+modify' if is_chmod else 'modify' return 'chmod' if is_chmod else 'noop' def write(self, w): src_name = self.src_name dst_name = self.dst_name if src_name is None: assert dst_name is not None w(f'diff --git a/{dst_name} b/{dst_name}\n'.encode('utf-8')) elif dst_name is None: w(f'diff --git a/{src_name} b/{src_name}\n'.encode('utf-8')) else: w(f'diff --git a/{src_name} b/{dst_name}\n'.encode('utf-8')) if src_name is not None and dst_name is not None and src_name != dst_name: w(f'rename from {src_name}\nrename to {dst_name}\n'.encode('utf-8')) m = self.src_mode if dst_name is None: assert m is not None w(f'deleted file mode {m}\n'.encode('utf-8')) elif m not in (None, ...): w(f'old mode {m}\n'.encode('utf-8')) m = self.dst_mode if src_name is None: assert m is not None w(f'new file mode {m}\n'.encode('utf-8')) elif m not in (None, ...): w(f'new mode {m}\n'.encode('utf-8')) h = self.hunks if h: if src_name is None: src_name = '/dev/null' lhs_line = 0 else: src_name = f'a/{src_name}' lhs_line = None if dst_name is None: dst_name = '/dev/null' rhs_line = 0 else: dst_name = f'b/{dst_name}' rhs_line = None w(f'--- {src_name}\n+++ {dst_name}\n'.encode('utf-8')) for h in h: h.write(w, lhs_line=lhs_line, rhs_line=rhs_line) def __neg__(self, /): tp = type(self) class Diff: __slots__ = ('files',) def __init__(self): self.files = () def __bool__(self): return not not self.files @classmethod def _sort_files(tp, files): r = object.__new__(tp) r.files = (*sorted(files, key=lambda f: (f.src_name or '', f.dst_name or '')),) return r @classmethod def parse_l(tp, l, line): files = [] while True: l = l.decode('utf-8').rstrip() if l == '--' or not l: break m = rem_diffgit(l) if m is None: raise ValueError(f'Invalid patch: expected \'diff --git ...\' header, got {l!r}') header = l src_name,dst_name = m.groups() src_exists = ... dst_exists = ... if not src_exists and not dst_exists: raise RuntimeError('wtf') if (not src_exists or not dst_exists) and src_name != dst_name: raise RuntimeError('wtf') has_diff = False has_chmod = False has_rename = False dst_mode = ... src_mode = ... while True: l = line() if l.startswith(b'--- a/'): if src_exists is False: raise RuntimeError('wtf') src_exists = True if l[6:].decode('utf-8') != src_name: raise RuntimeError('wtf') has_diff = True continue if l == b'--- /dev/null': if src_exists is True: raise RuntimeError('wtf') src_exists = False has_diff = True continue if l.startswith(b'+++ b/'): if dst_exists is False: raise RuntimeError('wtf') dst_exists = True if l[6:].decode('utf-8') != dst_name: raise RuntimeError('wtf') has_diff = True continue if l == b'+++ /dev/null': if dst_exists is True: raise RuntimeError('wtf') dst_exists = False has_diff = True continue if l.startswith(b'index '): continue if l.startswith(b'new file mode '): if src_exists is True: raise RuntimeError('wtf') if dst_exists is False: raise RuntimeError('wtf') src_exists = False dst_exists=True dst_mode = l[14:].decode('ascii') src_mode = None continue if l.startswith(b'new mode '): dst_mode = l[9:].decode('ascii') has_chmod = True continue if l.startswith(b'deleted file mode '): if src_exists is False: raise RuntimeError('wtf') if dst_exists is True: raise RuntimeError('wtf') src_exists = True dst_exists = False src_mode = l[18:].decode('ascii') dst_mode = None continue if l.startswith(b'old mode '): if dst_exists is False: raise RuntimeError('File deleted') src_mode = l[9:].decode('ascii') has_chmod = True dst_exists = True continue if l.startswith(b'similarity index '): has_rename = True continue if l.startswith(b'rename from '): has_rename = True rename_from = l[12:].decode('utf-8') continue if l.startswith(b'rename to '): rename_to = l[10:].decode('utf-8') has_rename = True continue break if has_chmod or has_rename: if src_exists is False: raise RuntimeError('wtf') if dst_exists is False: raise RuntimeError('wtf') src_exists = True dst_exists = True if has_chmod: assert src_mode is not ... assert dst_mode is not ... assert src_exists is not ... if not src_exists: assert dst_mode is not ... assert dst_exists is not ... if not dst_exists: assert src_mode is not ... if has_rename: if src_name == dst_name or not src_exists or not dst_exists: raise RuntimeError(f'Bad rename: {src_name!r} -> {dst_name!r}') assert rename_from == src_name assert rename_to == dst_name if not src_exists: src_name = None if not dst_exists: dst_name = None assert (src_mode is ...) == (dst_mode is ...) assert has_diff or has_chmod or has_rename or not dst_exists #want_header = fmt_diff_git(src_name, dst_name) #if header != want_header: # raise ValueError(f'Invalid patch: expected {want_header!r} header, got {header!r}') if has_diff: hunks = [] while True: l,hunk = Hunk.parse_l(l, line) hunk = hunk.reduce_context() if hunk is not None: hunks.append(hunk) if not l.startswith(b'@@ '): break if src_name is None: assert len(hunks) == 1 assert len(hunks[0].pre_context) == 0 assert len(hunks[0].post_context) == 0 assert len(hunks[0].lhs) == 0 assert hunks[0].lhs_line == 1 hunks[0].lhs_line = 1 if dst_name is None: assert len(hunks) == 1 assert len(hunks[0].pre_context) == 0 assert len(hunks[0].post_context) == 0 assert len(hunks[0].rhs) == 0 assert hunks[0].rhs_line == 1 hunks = (*hunks,) else: hunks = () files.append(FileDiff( src_name = src_name, dst_name = dst_name, src_mode = src_mode, dst_mode = dst_mode, hunks = hunks, )) try: l = line().rstrip() except StopIteration: pass else: if l[0] not in b'0123456789': l = l.decode('utf-8') raise ValueError('Invalid patch: expected git version, got {l!r}') while True: try: l = line() except StopIteration: break if l.strip(): l = l.decode('utf-8') raise ValueError('Invalid patch: got trailing garbage {l!r}') return tp._sort_files(files) @classmethod def parse(tp, body): line = iter(body.split(b'\n')).__next__ return tp.parse_l(line(), line) def write(self, w, /): for f in self.files: f.write(w) def write_munged(self, w, /): for f in self.files: w((f.src_name or '').encode('utf-8')) w(b'\n') w((f.dst_name or '').encode('utf-8')) w(b'\n') for h in f.hunks: h.write(w, True) w(b'\n') def __neg__(self): return Diff._sort_files((*(FileDiff( src_name = f.dst_name, dst_name = f.src_name, src_mode = f.dst_mode, dst_mode = f.src_mode, hunks = (*(Hunk( lhs_line = h.rhs_line, rhs_line = h.lhs_line, pre_context = h.pre_context, lhs = h.rhs, rhs = h.lhs, post_context = h.post_context, ) for h in f.hunks),), ) for f in self.files),)) def __add__(a, b): tp = type(a) if type(b) is not tp: raise TypeError r = [] w = r.append fdiff = FileDiff def join(fa, fb): src_name = fa.src_name dst_name = fb.dst_name src_mode = fa.src_mode dst_mode = fb.dst_mode if src_mode is ...: if dst_mode is not ...: src_mode = fb.src_mode elif dst_mode is ...: dst_mode = fa.dst_mode elif fa.dst_mode != fb.src_mode: raise ValueError('mismatch') hunks = (*join_hunks(fa.hunks, fb.hunks),) if src_name is None and dst_name is None: if hunks: raise ValueError('mismatch') return if src_name == dst_name and src_mode == dst_mode and not hunks: return w(fdiff( src_name = src_name, dst_name = dst_name, src_mode = src_mode, dst_mode = dst_mode, hunks = hunks, )) _pair_diffs(a, b, on_lhs=w, on_rhs=w, on_pair=join, on_re=join) return tp._sort_files(r) @classmethod def commute(tp, a, b): if type(a) is not tp or type(b) is not tp: raise TypeError ra = [] rb = [] wa = ra.append wb = rb.append fdiff = FileDiff def join(fa, fb): name1 = fa.src_name name2 = fa.dst_name name3 = fb.dst_name if name1 == name2: name2 = name3 elif name2 == name3: name2 = name1 else: raise ValueError('cannot commute {name2!r}: {fa.op_kind_name} / {fb.op_kind_name}') mode1 = fa.src_mode mode3 = fb.dst_mode if mode1 is ...: if mode3 is ...: mode2 = ... else: mode1 = fb.src_mode mode2 = mode3 elif mode3 is ...: mode3 = fa.dst_mode mode2 = mode1 else: mode2 = fb.dst_mode if fa.src_mode != mode2: raise ValueError('mismatch') if mode1 == mode2: mode2 = mode3 elif mode2 == mode3: mode2 = mode1 else: raise ValueError('cannot commute {name2!r}: chmod/chmod conflict') try: hb,ha = commute_hunks(fa.hunks, fb.hunks) except PatchError as err: err.args = (name1, name2, name3) raise wa(fdiff( src_name = name1, dst_name = name2, src_mode = mode1, dst_mode = mode2, hunks = ha, )) wb(fdiff( src_name = name2, dst_name = name3, src_mode = mode2, dst_mode = mode3, hunks = hb, )) _pair_diffs(a, b, on_lhs=wa, on_rhs=wb, on_re=join, on_pair=join) return tp._sort_files(rb), tp._sort_files(ra) def _pair_diffs(a, b, *, on_lhs, on_rhs, on_pair, on_re): a_del = {} b_new = {} a_dst = {} b_dst = set() for f in a.files: dst = f.dst_name if dst is None: a_del[f.src_name] = f continue a_dst[dst] = f a = a_dst.pop for f in b.files: dst = f.dst_name if dst is not None: b_dst.add(dst) src = f.src_name if src is None: b_new[dst] = f continue try: f2 = a(src) except KeyError: pass else: on_pair(f2, f) continue on_rhs(f) for dst in a_dst.keys() & b_dst: raise ValueError(f'mismatch for {dst!r}') b = b_new.pop for src,f in a_del.items(): try: f2 = b(src) except KeyError: pass else: on_re(f, f2) continue on_lhs(f) for f in b_new.values(): on_rhs(f) for f in a_dst.values(): on_lhs(f)