summaryrefslogtreecommitdiff
path: root/rebiss.py
diff options
context:
space:
mode:
Diffstat (limited to 'rebiss.py')
-rw-r--r--rebiss.py522
1 files changed, 522 insertions, 0 deletions
diff --git a/rebiss.py b/rebiss.py
new file mode 100644
index 0000000..28a3e44
--- /dev/null
+++ b/rebiss.py
@@ -0,0 +1,522 @@
+# SPDX-License-Identifier: GPL-3.0-only
+
+import base64
+import logging
+import getpass
+import hashlib
+import json
+import os
+import re
+import subprocess
+import unicodedata
+import time
+import datetime
+
+import OpenSSL.crypto as cr
+
+
+class HashAlg:
+ __slots__ = ('len', 'hashlib_name', 'mech')
+
+ def __init__(self, len, hashlib_name, mech):
+ self.len = len
+ self.hashlib_name = hashlib_name
+ self.mech = mech
+
+HASH_ALG = {
+ 'SHA256': HashAlg(32, 'sha256', 'SHA256-RSA-PKCS'),
+ 'SHA384': HashAlg(48, 'sha384', 'SHA384-RSA-PKCS'),
+ 'SHA512': HashAlg(64, 'sha512', 'SHA512-RSA-PKCS'),
+}
+
+
+def pkcs11_list():
+ lines = subprocess.check_output(['pkcs11-tool', '--list-slots'], text=True).splitlines()
+ readers = []
+ for line in lines:
+ m = re.fullmatch('Slot ([0-9]+) \\(.*\\): .*', line)
+ if m is None:
+ continue
+ readers.append(m.group(1))
+
+ certs = {}
+ for reader in readers:
+ lines = subprocess.check_output(['pkcs15-tool', '--reader', reader, '-c'], text=True).splitlines()
+ for line in lines:
+ m = re.fullmatch('\tID *: *([0-9a-f]+)$', line)
+ if m is None:
+ continue
+ cert_id = m.group(1)
+ cert_data = subprocess.check_output(['pkcs15-tool', '--reader', reader, '--read-certificate', cert_id])
+ cert = cr.load_certificate(cr.FILETYPE_PEM, cert_data)
+ certs[reader,cert_id] = cert
+
+ return certs
+
+def pkcs11_sign(key, msg, hash_alg, pin):
+ hash_val = hashlib.new(hash_alg.hashlib_name, msg).digest()
+
+ reader,kid = key
+ env = os.environ.copy()
+ env['PIN'] = pin
+ proc = subprocess.run(['pkcs11-tool', '--slot', reader, '--id', kid, '-m', hash_alg.mech, '--pin', 'env:PIN', '--sign'], input=hash_val, stdout=subprocess.PIPE, env=env, check=True)
+
+ return proc.stdout
+
+
+def load_ts(ts):
+ t = time.strptime(ts.decode('ascii'), '%Y%m%d%H%M%SZ')
+ return datetime.datetime(
+ t.tm_year,
+ t.tm_mon,
+ t.tm_mday,
+ t.tm_hour,
+ t.tm_min,
+ t.tm_sec,
+ tzinfo=datetime.timezone.utc,
+ )
+
+def check_selector(cert, selector, valid_only):
+ if valid_only:
+ valid_from = load_ts(cert.get_notBefore())
+ valid_until = load_ts(cert.get_notAfter())
+ if not (valid_from <= datetime.datetime.now(datetime.timezone.utc) <= valid_until):
+ return False
+
+ vals = selector.get('issuers')
+ if vals is not None:
+ issuer_cn = [val.decode('utf-8', 'surrogateescape') for name,val in cert.get_issuer().get_components() if name == b'CN']
+ if issuer_cn:
+ issuer_cn = issuer_cn[0]
+ else:
+ issuer_cn = ''
+ if f'CN={issuer_cn}' not in vals:
+ return False
+
+ return True
+
+def describe_name(val):
+ return ', '.join(f'{k.decode("ascii")}={escape_bad(v.decode("utf-8", "surrogateescape"), ",=")}' for k,v in val.get_components())
+
+def describe_cert(cert):
+ valid_from = load_ts(cert.get_notBefore())
+ valid_until = load_ts(cert.get_notAfter())
+ return (
+ f'Issuer: {describe_name(cert.get_issuer())}\n'
+ f'Subject: {describe_name(cert.get_subject())}\n'
+ f'Valid from: {valid_from:%Y-%m-%d %H:%M:%S} UTC\n'
+ f'Valid until: {valid_until:%Y-%m-%d %H:%M:%S} UTC\n'
+ )
+
+def escape_bad(l, esc=''):
+ r = []
+ for c in l:
+ o = ord(c)
+ if 0xdc80 <= o <= 0xdd00:
+ r.append(f'\\x{ord(c)-0xdc00:02x}')
+ continue
+ if c == '\\':
+ r.append('\\\\')
+ continue
+ if o < 32 or o == 127 or (c in esc and o < 128):
+ r.append(f'\\x{ord(c):02x}')
+ continue
+ cat = unicodedata.category(c)
+ if cat in {'Cc', 'Cf', 'Cs'} or (cat.startswith('Z') and c != ' ') or c in esc:
+ if o < 0x10000:
+ r.append(f'\\u{o:04x}')
+ else:
+ r.append(f'\\U{o:08x}')
+ continue
+ r.append(c)
+ return ''.join(r)
+
+def print_blob(name, data, /):
+ if type(data) is str:
+ kind = 'text'
+ elif type(data) is bytes:
+ kind = 'blob'
+ data = '\n'.join(escape_bad(l.decode('utf-8', 'surrogateescape')) for l in data.splitlines())
+ else:
+ kind = 'repr'
+ data = repr(data)
+ print(f'{name} ({kind}):')
+
+ for line in data.splitlines():
+ print(f' > {line!s}')
+
+def prompt(txt, /, default=False):
+ while True:
+ if default:
+ d = 'Y/n'
+ else:
+ d = 'y/N'
+ txt = input(f'{txt} [{d}]? ')
+ if txt in {'y', 'Y'}:
+ return True
+ if txt in {'n', 'N'}:
+ return False
+ if not txt:
+ return default
+
+STATUS_TEXT = {
+ 200: 'OK',
+ 201: 'Created',
+ 202: 'Accepted',
+ 203: 'Non-Authoritative Information',
+ 204: 'No Content',
+ 205: 'Reset Content',
+ 206: 'Partial Content',
+ 207: 'Multi-Status',
+ 208: 'Already Reported',
+ 226: 'IM Used',
+ 300: 'Multiple Choices',
+ 301: 'Moved Permanently',
+ 302: 'Found',
+ 303: 'See Other',
+ 304: 'Not Modified',
+ 305: 'Use Proxy',
+ 306: 'Switch Proxy',
+ 307: 'Temporary Redirect',
+ 308: 'Permanent Redirect',
+ 400: 'Bad Request',
+ 401: 'Unauthorized',
+ 402: 'Payment Required',
+ 403: 'Forbidden',
+ 404: 'Not Found',
+ 405: 'Method Not Allowed',
+ 406: 'Not Acceptable',
+ 407: 'Proxy Authentication Required',
+ 408: 'Request Timeout',
+ 409: 'Conflict',
+ 410: 'Gone',
+ 411: 'Length Required',
+ 412: 'Precondition Failed',
+ 413: 'Payload Too Large',
+ 414: 'URI Too Long',
+ 415: 'Unsupported Media Type',
+ 416: 'Range Not Satisfiable',
+ 417: 'Expectation Failed',
+ 418: 'I\'m a teapot',
+ 421: 'Misdirected Request',
+ 422: 'Unprocessable Entity',
+ 423: 'Locked',
+ 424: 'Failed Dependency',
+ 425: 'Too Early',
+ 426: 'Upgrade Required',
+ 428: 'Precondition Required',
+ 429: 'Too Many Requests',
+ 431: 'Request Header Fields Too Large',
+ 451: 'Unavailable For Legal Reasons',
+ 500: 'Internal Server Error',
+ 501: 'Not Implemented',
+ 502: 'Bad Gateway',
+ 503: 'Service Unavailable',
+ 504: 'Gateway Timeout',
+ 505: 'HTTP Version Not Supported',
+ 506: 'Variant Also Negotiates',
+ 507: 'Insufficient Storage',
+ 508: 'Loop Detected',
+ 510: 'Not Extended',
+ 511: 'Network Authentication Required',
+}
+
+STATUS_FULL = {status: f'{status} {text}' for status,text in STATUS_TEXT.items()}
+
+class Req:
+ __slots__ = ('path', 'query_string', 'method', 'headers', 'input')
+
+ def __init__(self, env):
+ self.path = env['PATH_INFO']
+ self.query_string = env.get('QUERY_STRING')
+ self.method = env['REQUEST_METHOD']
+ self.input = env['wsgi.input']
+ headers = self.headers = {}
+ for k,v in env.items():
+ if not k.startswith('HTTP_'):
+ continue
+ k = ''.join(map(str.title, k[5:].split('_')))
+ headers[k] = v
+
+ def read_data(self, /):
+ d = self.input
+ del self.input
+ return d.read()
+
+ def resp(self, /, status, headers, data, *, content_type):
+ assert type(status) is int
+ assert type(data) is bytes
+ assert type(content_type) is str
+ headers = [*headers]
+ assert all(type(k) is str and type(v) is str for k,v in headers)
+ hdr = headers.append
+
+ hdr(('Access-Control-Allow-Methods', 'GET, POST'))
+ hdr(('Access-Control-Allow-Headers', 'Accept, Content-Type'))
+
+ try:
+ hdr(('Access-Control-Allow-Origin', self.headers['Origin']))
+ except KeyError:
+ pass
+
+ if self.method != 'OPTIONS':
+ hdr(('Content-Type', content_type))
+ hdr(('Content-Length', f'{len(data)}'))
+
+ hdr(('Cache-Control', 'no-cache'))
+
+ return (status, headers, data)
+
+ def resp_json(self, /, data, *, status=200, content_type='application/json;charset=UTF-8', headers=()):
+ return self.resp(
+ status,
+ headers,
+ json.dumps(data).encode('utf-8'),
+ content_type=content_type,
+ )
+
+ def resp_error(self, status, /):
+ try:
+ text = STATUS_TEXT[status]
+ except KeyError:
+ text = f'HTTP {status}'
+ if message is not None:
+ text = f'{text}: {message}'
+ return self.resp(
+ status,
+ (),
+ text.encode('utf-8'),
+ content_type='text/plain',
+ )
+
+ def resp_404(self, /):
+ return self.resp_error(404, **kw)
+
+ def resp_405(self, /):
+ return self.resp_error(405, **kw)
+
+ def resp_biss_error(self, /, status, message):
+ return self.resp_json({
+ 'status': 'failed',
+ 'reasonCode': status,
+ 'reasonText': message,
+ }, status=status)
+
+ def resp_biss_ok(self, /, **kw):
+ return self.resp_json({
+ 'status': 'ok',
+ 'reasonCode': 200,
+ 'reasonText': 'response.success',
+ **kw,
+ })
+
+
+cr.X509StoreFlags.NO_CHECK_TIME = 0x200000
+store = cr.X509Store()
+store.set_flags(cr.X509StoreFlags.CRL_CHECK_ALL | cr.X509StoreFlags.X509_STRICT | cr.X509StoreFlags.NO_CHECK_TIME)
+store.load_locations('cacerts.pem')
+
+certs = pkcs11_list()
+
+def req_options(req):
+ return req.resp(200, (), b'', content_type='')
+
+def req_get_version(req):
+ return req.resp_json({
+ 'version': '2.30',
+ 'httpMethods': 'GET, POST',
+ 'contentTypes': 'data',
+ 'signatureTypes': 'signature',
+ 'selectorAvailable': True,
+ 'hashAlgorithms': 'SHA256, SHA384, SHA512',
+ })
+
+def req_get_status(req):
+ return req.resp_json('')
+
+def req_post_getsigner(req):
+ data = json.loads(req.read_data().decode('utf-8'))
+
+ valid_only = data.get('showValidCerts', False)
+ # WAT
+ if valid_only == 'true':
+ valid_only = True
+ if valid_only == 'false':
+ valid_only = False
+ if type(valid_only) is not bool:
+ return req.resp_biss_error(400, 'error.request.bad-type')
+
+ selector = data.get('selector', {})
+
+ for cert in certs.values():
+ if check_selector(cert, selector, valid_only):
+ break
+ else:
+ return req.resp_biss_error(403, 'error.user-canceled')
+
+ return req.resp_biss_ok(
+ chain = [
+ base64.b64encode(cr.dump_certificate(cr.FILETYPE_ASN1, cert)).decode('ascii'),
+ ],
+ )
+
+def req_post_sign(req):
+ data = json.loads(req.read_data().decode('utf-8'))
+
+ try:
+ own_cert_data = data['signerCertificateB64']
+ own_cert_data = base64.b64decode(own_cert_data)
+ except (KeyError, IndexError):
+ return req.resp_biss_error(400, 'error.wsp-cert-not-found')
+
+ for own_key,own_cert in certs.items():
+ if cr.dump_certificate(cr.FILETYPE_ASN1, own_cert) == own_cert_data:
+ break
+ else:
+ return req.resp_biss_error(400, 'error.request.bad-type')
+
+ # BISS ignores it
+ # if data['contentType'] != 'data':
+ # return req.resp_biss_error(400, 'error.request.bad-type')
+
+ try:
+ hash_alg = HASH_ALG[data.get('hashAlgorithm', 'SHA256')]
+ except KeyError:
+ return req.resp_biss_error(400, 'error.request.bad-type')
+
+ msgs = []
+ for i,(msg,sig,server_cert) in enumerate(zip(data['contents'], data['signedContents'], data['signedContentsCert'])):
+ msg = base64.b64decode(msg)
+ sig = base64.b64decode(sig)
+ server_cert = cr.load_certificate(cr.FILETYPE_ASN1, base64.b64decode(server_cert))
+
+ try:
+ # WAT?
+ cr.verify(server_cert, sig, hashlib.new(hash_alg.hashlib_name, msg).digest(), hash_alg.hashlib_name)
+ except cr.Error:
+ return req.resp_biss_error(400, 'error.request.signature-not-val')
+
+ verify_ctx = cr.X509StoreContext(store, server_cert, None)
+ try:
+ verify_ctx.verify_certificate()
+ except cr.X509StoreContextError as exc:
+ if not prompt(f'Server certificate invalid ({exc!s}); continue'):
+ return req.resp_biss_error(403, 'error.request.certificate-not-valid')
+
+ msgs.append((msg, server_cert))
+
+ t = data.get('confirmText')
+ if t is not None:
+ print_blob('Confirmation text', t)
+
+ t = data.get('additionalConfirmText')
+ if t is not None:
+ print_blob('Additional confirmation text', t)
+
+ for (msg,server_cert) in msgs:
+ print_blob(f'Message #{1+i}', msg)
+ print_blob('Server name', describe_cert(server_cert))
+
+ if not prompt('Sign'):
+ return req.resp_biss_error(403, 'error.user-canceled')
+
+ pin = getpass.getpass('PIN: ')
+
+ sigs = []
+ for (msg, server_cert) in msgs:
+ sigs.append(pkcs11_sign(own_key, msg, hash_alg, pin))
+
+ return req.resp_biss_ok(
+ signatures = [base64.b64encode(sig).decode('ascii') for sig in sigs],
+ )
+
+def req_post_chain_found(req):
+ # stub
+ data = json.loads(req.read_data().decode('utf-8'))
+ return req.resp_biss_ok(chainFound=True)
+
+
+def fix_methods(hnd):
+ assert type(hnd) is dict and all(type(k) is str for k in hnd)
+ hnd = hnd.copy()
+
+ if 'GET' in hnd and 'HEAD' not in hnd:
+ on_get = hnd['GET']
+ def on_head(req):
+ (status, headers, data) = on_get(req)
+ return (status, headers, None)
+ hnd['HEAD'] = on_head
+
+ hnd['OPTIONS'] = req_options
+
+ return hnd
+
+def resp_exc(env, resp, exc):
+ rdata = (
+ '<!doctype html><html><head><meta charset="utf-8" /><title>Internal Server Error</title></head><body>'
+ '<h1>Internal Server Error</h1>'
+ '</body></html>'
+ ).encode()
+ resp('500 Internal Server Error', [
+ ('Content-Type', 'text/html;charset=utf-8'),
+ ('Content-Length', f'{len(rdata)}'),
+ ])
+ return rdata
+
+
+handlers = {
+ '/version': {
+ 'GET': req_get_version,
+ },
+ '/status': {
+ 'GET': req_get_status,
+ },
+ '/getsigner': {
+ 'POST': req_post_getsigner,
+ },
+ '/sign': {
+ 'POST': req_post_sign,
+ },
+ '/chainFound': {
+ 'POST': req_post_chain_found,
+ },
+}
+
+handlers = {path: fix_methods(sub) for path,sub in handlers.items()}
+
+def application(env, resp):
+ try:
+ req = Req(env)
+ except Exception as exc:
+ logging.exception('Failed to parse request')
+ return resp_exc(env, resp, exc)
+
+ try:
+ h = handlers[req.path]
+ except KeyError:
+ h = Req.resp_404
+ else:
+ try:
+ h = h[req.method]
+ except KeyError:
+ h = Req.resp_405
+
+ try:
+ rcode, rhdr, rdata = h(req)
+ except Exception as exc:
+ logging.exception('Request handler failed')
+ return resp_exc(env, resp, exc)
+
+ assert type(rcode) is int
+
+ try:
+ rcode = STATUS_FULL[rcode]
+ except KeyError:
+ rcode = f'{rcode}'
+
+ resp(rcode, rhdr)
+ return rdata
+
+tty = os.open('/dev/tty', os.O_RDWR)
+os.dup2(tty, 0)
+os.dup2(tty, 1)