summaryrefslogblamecommitdiff
path: root/rebiss.py
blob: 51dbf301941bfeeaeef8f8d2d0b156fe796b8bc3 (plain) (tree)
1
2
3
4
5
6
7
8







                                       





                           
             

























































































































































































































                                                                                                                                     







                                      
                                   

                          
                                   





















                                                                                                                  
                       





                                                  
                          























                                                                 

                                                          





                                                              
                                                            











                                                                   

                                              

                 
                                            






                                                                   
                                                                       
                    
                                               





































                                                                                                                        
                                          


                                   
                                                






























































































                                                                                                              
# SPDX-License-Identifier: GPL-3.0-only

import base64
import logging
import getpass
import hashlib
import json
import os
import unicodedata
import time
import datetime

import OpenSSL.crypto as cr

import pkcs11


def load_ts(ts):
    t = time.strptime(ts.decode('ascii'), '%Y%m%d%H%M%SZ')
    return datetime.datetime(
        t.tm_year,
        t.tm_mon,
        t.tm_mday,
        t.tm_hour,
        t.tm_min,
        t.tm_sec,
        tzinfo=datetime.timezone.utc,
    )

def check_selector(cert, selector, valid_only):
    if valid_only:
        valid_from = load_ts(cert.get_notBefore())
        valid_until = load_ts(cert.get_notAfter())
        if not (valid_from <= datetime.datetime.now(datetime.timezone.utc) <= valid_until):
            return False

    vals = selector.get('issuers')
    if vals is not None:
        issuer_cn = [val.decode('utf-8', 'surrogateescape') for name,val in cert.get_issuer().get_components() if name == b'CN']
        if issuer_cn:
            issuer_cn = issuer_cn[0]
        else:
            issuer_cn = ''
        if f'CN={issuer_cn}' not in vals:
            return False

    return True

def describe_name(val):
    return ', '.join(f'{k.decode("ascii")}={escape_bad(v.decode("utf-8", "surrogateescape"), ",=")}' for k,v in val.get_components())

def describe_cert(cert):
    valid_from = load_ts(cert.get_notBefore())
    valid_until = load_ts(cert.get_notAfter())
    return (
        f'Issuer: {describe_name(cert.get_issuer())}\n'
        f'Subject: {describe_name(cert.get_subject())}\n'
        f'Valid from: {valid_from:%Y-%m-%d %H:%M:%S} UTC\n'
        f'Valid until: {valid_until:%Y-%m-%d %H:%M:%S} UTC\n'
    )

def escape_bad(l, esc=''):
    r = []
    for c in l:
        o = ord(c)
        if 0xdc80 <= o <= 0xdd00:
            r.append(f'\\x{ord(c)-0xdc00:02x}')
            continue
        if c == '\\':
            r.append('\\\\')
            continue
        if o < 32 or o == 127 or (c in esc and o < 128):
            r.append(f'\\x{ord(c):02x}')
            continue
        cat = unicodedata.category(c)
        if cat in {'Cc', 'Cf', 'Cs'} or (cat.startswith('Z') and c != ' ') or c in esc:
            if o < 0x10000:
                r.append(f'\\u{o:04x}')
            else:
                r.append(f'\\U{o:08x}')
            continue
        r.append(c)
    return ''.join(r)

def print_blob(name, data, /):
    if type(data) is str:
        kind = 'text'
    elif type(data) is bytes:
        kind = 'blob'
        data = '\n'.join(escape_bad(l.decode('utf-8', 'surrogateescape')) for l in data.splitlines())
    else:
        kind = 'repr'
        data = repr(data)
    print(f'{name} ({kind}):')

    for line in data.splitlines():
        print(f'  > {line!s}')

def prompt(txt, /, default=False):
    while True:
        if default:
            d = 'Y/n'
        else:
            d = 'y/N'
        txt = input(f'{txt} [{d}]? ')
        if txt in {'y', 'Y'}:
            return True
        if txt in {'n', 'N'}:
            return False
        if not txt:
            return default

STATUS_TEXT = {
    200: 'OK',
    201: 'Created',
    202: 'Accepted',
    203: 'Non-Authoritative Information',
    204: 'No Content',
    205: 'Reset Content',
    206: 'Partial Content',
    207: 'Multi-Status',
    208: 'Already Reported',
    226: 'IM Used',
    300: 'Multiple Choices',
    301: 'Moved Permanently',
    302: 'Found',
    303: 'See Other',
    304: 'Not Modified',
    305: 'Use Proxy',
    306: 'Switch Proxy',
    307: 'Temporary Redirect',
    308: 'Permanent Redirect',
    400: 'Bad Request',
    401: 'Unauthorized',
    402: 'Payment Required',
    403: 'Forbidden',
    404: 'Not Found',
    405: 'Method Not Allowed',
    406: 'Not Acceptable',
    407: 'Proxy Authentication Required',
    408: 'Request Timeout',
    409: 'Conflict',
    410: 'Gone',
    411: 'Length Required',
    412: 'Precondition Failed',
    413: 'Payload Too Large',
    414: 'URI Too Long',
    415: 'Unsupported Media Type',
    416: 'Range Not Satisfiable',
    417: 'Expectation Failed',
    418: 'I\'m a teapot',
    421: 'Misdirected Request',
    422: 'Unprocessable Entity',
    423: 'Locked',
    424: 'Failed Dependency',
    425: 'Too Early',
    426: 'Upgrade Required',
    428: 'Precondition Required',
    429: 'Too Many Requests',
    431: 'Request Header Fields Too Large',
    451: 'Unavailable For Legal Reasons',
    500: 'Internal Server Error',
    501: 'Not Implemented',
    502: 'Bad Gateway',
    503: 'Service Unavailable',
    504: 'Gateway Timeout',
    505: 'HTTP Version Not Supported',
    506: 'Variant Also Negotiates',
    507: 'Insufficient Storage',
    508: 'Loop Detected',
    510: 'Not Extended',
    511: 'Network Authentication Required',
}

STATUS_FULL = {status: f'{status} {text}' for status,text in STATUS_TEXT.items()}

class Req:
    __slots__ = ('path', 'query_string', 'method', 'headers', 'input')

    def __init__(self, env):
        self.path = env['PATH_INFO']
        self.query_string = env.get('QUERY_STRING')
        self.method = env['REQUEST_METHOD']
        self.input = env['wsgi.input']
        headers = self.headers = {}
        for k,v in env.items():
            if not k.startswith('HTTP_'):
                continue
            k = ''.join(map(str.title, k[5:].split('_')))
            headers[k] = v

    def read_data(self, /):
        d = self.input
        del self.input
        return d.read()

    def resp(self, /, status, headers, data, *, content_type):
        assert type(status) is int
        assert type(data) is bytes
        assert type(content_type) is str
        headers = [*headers]
        assert all(type(k) is str and type(v) is str for k,v in headers)
        hdr = headers.append

        hdr(('Access-Control-Allow-Methods', 'GET, POST'))
        hdr(('Access-Control-Allow-Headers', 'Accept, Content-Type'))

        try:
            hdr(('Access-Control-Allow-Origin', self.headers['Origin']))
        except KeyError:
            pass

        if self.method != 'OPTIONS':
            hdr(('Content-Type', content_type))
            hdr(('Content-Length', f'{len(data)}'))

        hdr(('Cache-Control', 'no-cache'))

        return (status, headers, data)

    def resp_json(self, /, data, *, status=200, content_type='application/json;charset=UTF-8', headers=()):
        return self.resp(
            status,
            headers,
            json.dumps(data).encode('utf-8'),
            content_type=content_type,
        )

    def resp_error(self, status, /):
        try:
            text = STATUS_TEXT[status]
        except KeyError:
            text = f'HTTP {status}'
        return self.resp(
            status,
            (),
            text.encode('utf-8'),
            content_type='text/plain',
        )

    def resp_404(self, /):
        return self.resp_error(404)

    def resp_405(self, /):
        return self.resp_error(405)

    def resp_biss_error(self, /, status, message):
        return self.resp_json({
            'status': 'failed',
            'reasonCode': status,
            'reasonText': message,
        }, status=status)

    def resp_biss_ok(self, /, **kw):
        return self.resp_json({
            'status': 'ok',
            'reasonCode': 200,
            'reasonText': 'response.success',
            **kw,
        })


cr.X509StoreFlags.NO_CHECK_TIME = 0x200000
store = cr.X509Store()
store.set_flags(cr.X509StoreFlags.CRL_CHECK_ALL | cr.X509StoreFlags.X509_STRICT | cr.X509StoreFlags.NO_CHECK_TIME)
store.load_locations('cacerts.pem')

keys = [*pkcs11.list()]

def req_options(req):
    return req.resp(200, (), b'', content_type='')

def req_get_version(req):
    return req.resp_json({
        'version': '3.15',
        'httpMethods': 'GET, POST',
        'contentTypes': 'data',
        'signatureTypes': 'signature',
        'selectorAvailable': True,
        'hashAlgorithms': 'SHA256, SHA384, SHA512',
    })

def req_get_status(req):
    return req.resp_json('')

def req_post_getsigner(req):
    data = json.loads(req.read_data().decode('utf-8'))

    valid_only = data.get('showValidCerts', False)
    # WAT
    if valid_only == 'true':
        valid_only = True
    if valid_only == 'false':
        valid_only = False
    if type(valid_only) is not bool:
        return req.resp_biss_error(400, 'error.request.bad-type')

    selector = data.get('selector', {})

    for key in keys:
        if check_selector(key.cert, selector, valid_only):
            break
    else:
        return req.resp_biss_error(403, 'error.user-canceled')

    return req.resp_biss_ok(
        chain = [
            base64.b64encode(key.cert_data).decode('ascii'),
        ],
    )

def req_post_sign(req):
    data = json.loads(req.read_data().decode('utf-8'))

    try:
        own_cert_data = data['signerCertificateB64']
        own_cert_data = base64.b64decode(own_cert_data)
    except (KeyError, IndexError):
        return req.resp_biss_error(400, 'error.wsp-cert-not-found')

    for own_key in keys:
        if own_key.cert_data == own_cert_data:
            break
    else:
        logging.error('Unknown certificate')
        return req.resp_biss_error(400, 'error.request.bad-type')

    # BISS ignores it
    # if data['contentType'] != 'data':
    #     return req.resp_biss_error(400, 'error.request.bad-type')

    try:
        hash_alg = pkcs11.HASH_ALG[data.get('hashAlgorithm', 'SHA256')]
    except KeyError:
        logging.error('Unknown hash algorithm')
        return req.resp_biss_error(400, 'error.request.bad-type')

    msgs = []
    for i,(msg,sig,server_cert) in enumerate(zip(data['contents'], data['signedContents'], data['signedContentsCert'])):
        msg = base64.b64decode(msg)
        sig = base64.b64decode(sig)
        server_cert = cr.load_certificate(cr.FILETYPE_ASN1, base64.b64decode(server_cert))

        try:
            # WAT?
            cr.verify(server_cert, sig, hashlib.new(hash_alg.hashlib_name, msg).digest(), hash_alg.hashlib_name)
        except cr.Error:
            return req.resp_biss_error(400, 'error.request.signature-not-val')

        verify_ctx = cr.X509StoreContext(store, server_cert, None)
        try:
            verify_ctx.verify_certificate()
        except cr.X509StoreContextError as exc:
            if not prompt(f'Server certificate invalid ({exc!s}); continue'):
                return req.resp_biss_error(403, 'error.request.certificate-not-valid')

        msgs.append((msg, server_cert))

    t = data.get('confirmText')
    if t is not None:
        print_blob('Confirmation text', t)

    t = data.get('additionalConfirmText')
    if t is not None:
        print_blob('Additional confirmation text', t)

    for (msg,server_cert) in msgs:
        print_blob(f'Message #{1+i}', msg)
        print_blob('Server name', describe_cert(server_cert))

    if not prompt('Sign'):
        return req.resp_biss_error(403, 'error.user-canceled')

    own_key.pin = getpass.getpass('PIN: ')

    sigs = []
    for (msg, server_cert) in msgs:
        sigs.append(own_key.sign(msg, hash_alg))

    return req.resp_biss_ok(
        signatures = [base64.b64encode(sig).decode('ascii') for sig in sigs],
    )

def req_post_chain_found(req):
    # stub
    data = json.loads(req.read_data().decode('utf-8'))
    return req.resp_biss_ok(chainFound=True)


def fix_methods(hnd):
    assert type(hnd) is dict and all(type(k) is str for k in hnd)
    hnd = hnd.copy()

    if 'GET' in hnd and 'HEAD' not in hnd:
        on_get = hnd['GET']
        def on_head(req):
            (status, headers, data) = on_get(req)
            return (status, headers, None)
        hnd['HEAD'] = on_head

    hnd['OPTIONS'] = req_options

    return hnd

def resp_exc(env, resp, exc):
    rdata = (
        '<!doctype html><html><head><meta charset="utf-8" /><title>Internal Server Error</title></head><body>'
        '<h1>Internal Server Error</h1>'
        '</body></html>'
    ).encode()
    resp('500 Internal Server Error', [
        ('Content-Type', 'text/html;charset=utf-8'),
        ('Content-Length', f'{len(rdata)}'),
    ])
    return rdata


handlers = {
    '/version': {
        'GET': req_get_version,
    },
    '/status': {
        'GET': req_get_status,
    },
    '/getsigner': {
        'POST': req_post_getsigner,
    },
    '/sign': {
        'POST': req_post_sign,
    },
    '/chainFound': {
        'POST': req_post_chain_found,
    },
}

handlers = {path: fix_methods(sub) for path,sub in handlers.items()}

def application(env, resp):
    try:
        req = Req(env)
    except Exception as exc:
        logging.exception('Failed to parse request')
        return resp_exc(env, resp, exc)

    try:
        h = handlers[req.path]
    except KeyError:
        h = Req.resp_404
    else:
        try:
            h = h[req.method]
        except KeyError:
            h = Req.resp_405

    try:
        rcode, rhdr, rdata = h(req)
    except Exception as exc:
        logging.exception('Request handler failed')
        return resp_exc(env, resp, exc)

    assert type(rcode) is int

    try:
        rcode = STATUS_FULL[rcode]
    except KeyError:
        rcode = f'{rcode}'

    resp(rcode, rhdr)
    return rdata

tty = os.open('/dev/tty', os.O_RDWR)
os.dup2(tty, 0)
os.dup2(tty, 1)