Juice/auth.py
2019-06-23 14:01:03 -04:00

306 lines
7.5 KiB
Python

#!/usr/bin/env python3
"""
Contains authentication methods for the app.
"""
import os
import re
import time
import sqlite3
import functools
from datetime import datetime
from fido2.client import ClientData
from fido2.server import Fido2Server, RelyingParty
from fido2.ctap2 import AttestationObject, AuthenticatorData, \
AttestedCredentialData
from fido2 import cbor
from flask import Blueprint, session, render_template, request, \
redirect, url_for, jsonify
from passlib.hash import argon2
from werkzeug.useragents import UserAgent
import db
import config
from tools import make_error
auth_views = Blueprint("auth_views", __name__)
rp = RelyingParty('steelbea.me', 'Juice')
server = Fido2Server(rp)
def auth_required(func):
"""
Wrapper for views which should be protected by authentication.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
username = session.get('username')
token = session.get('token')
if not username or not token:
return redirect(url_for('auth_views.login'))
user = db.get_user(username)
if not user:
session.pop('username')
session.pop('token')
return redirect(url_for('auth_views.login'))
user_id = user['id']
tokens = db.get_tokens(user_id)
if not tokens:
return redirect(url_for('auth_views.login'))
for token_line in tokens:
token_hash = token_line['token_hash']
date_expired = token_line['date_expired']
if int(time.time()) >= date_expired:
continue
if argon2.verify(token, token_hash):
db.refresh_token(token_line['id'])
return func(*args, **kwargs)
else:
session.pop('token')
return redirect(url_for('auth_views.login'))
return wrapper
@auth_views.route('/api/register/begin', methods=['POST'])
def register_begin():
if not config.registration_open:
return "Registration is closed."
username = session.get('username')
if not username:
return 'invalid'
user = db.get_user(username)
if not user:
session.pop('username')
return 'invalid'
user_id = user['id']
exist_cred = db.get_credentials(user_id)
exist_cred = [AttestedCredentialData(c['credential']) for c in exist_cred]
registration_data, state = server.register_begin({
'id': str(user_id).encode('utf8'),
'name': username,
'displayName': username,
}, exist_cred, user_verification='discouraged')
session['state'] = state
return cbor.encode(registration_data)
@auth_views.route('/api/register/complete', methods=['POST'])
def register_complete():
if not config.registration_open:
return "Registration is closed."
username = session.get('username')
if not username:
return 'invalid'
session.pop('username')
user = db.get_user(username)
if not user:
return 'invalid'
user_id = user['id']
data = cbor.decode(request.get_data())
client_data = ClientData(data['clientDataJSON'])
att_obj = AttestationObject(data['attestationObject'])
nick = data['security_key_nick']
try:
assert 64 >= len(nick) >= 1
except AssertionError:
return make_error(400, "security key nick too long/short")
auth_data = server.register_complete(
session.pop('state'),
client_data,
att_obj
)
db.set_credential(user_id, nick, auth_data.credential_data)
return jsonify(ok=True)
@auth_views.route('/api/authenticate/begin', methods=['POST'])
def authenticate_begin():
data = cbor.decode(request.get_data())
username = data.get('username')
user = db.get_user(username)
if not user:
return make_error(404, "username not found")
session['username'] = username
user_id = user['id']
credentials = db.get_credentials(user_id)
credentials =[AttestedCredentialData(c['credential']) for c in credentials]
auth_data, state = server.authenticate_begin(credentials)
session['state'] = state
return cbor.encode(auth_data)
@auth_views.route('/api/authenticate/complete', methods=['POST'])
def authenticate_complete():
username = session.get('username')
user = db.get_user(username)
if not user:
session.pop('username')
return make_error(404, "username not found")
user_id = user['id']
credentials = db.get_credentials(user_id)
credentials =[AttestedCredentialData(c['credential']) for c in credentials]
data = cbor.decode(request.get_data())
credential_id = data['credentialId']
client_data = ClientData(data['clientDataJSON'])
auth_data = AuthenticatorData(data['authenticatorData'])
signature = data['signature']
server.authenticate_complete(
session.pop('state'),
credentials,
credential_id,
client_data,
auth_data,
signature
)
token = os.urandom(32)
token_hash = argon2.hash(token)
user_agent = request.user_agent.string
ip_address = request.headers.get("X-Real-Ip")
db.set_token(user_id, user_agent, ip_address, token_hash)
session['token'] = token
return jsonify(ok=True)
@auth_views.route('/register', methods=['GET', 'POST'])
def register():
"""
Registration page.
"""
if not config.registration_open:
return "Registration is closed."
if request.method == 'GET':
params = {
'form_url': url_for('auth_views.register'),
'url_prefix': config.url_prefix,
}
return render_template('register.html', **params)
username = request.form.get('username')
email = request.form.get('email')
try:
assert 64 >= len(username) >= 3
except AssertionError:
return make_error(400, "username too long/short")
try:
assert 100 >= len(email)
except AssertionError:
return "email too long"
try:
user_id = db.set_user(username, email)
except sqlite3.IntegrityError as e:
field = re.search(r'user\.(.*)', str(e)).group(1)
return make_error(400, f"{field} already exists")
session['username'] = username
session['user_id'] = user_id
params = {
'url_prefix': config.url_prefix,
}
return render_template('register_key.html', **params)
@auth_views.route('/login')
def login():
"""
Login page.
"""
params = {
'url_prefix': config.url_prefix,
}
return render_template('login.html', **params)
@auth_views.route('/manage')
@auth_required
def manage():
"""
Allows a user to manage their security keys and tokens.
"""
url_prefix = config.url_prefix
username = session['username']
user_id = db.get_user(username)['id']
credentials = db.get_credentials(user_id)
tokens = db.get_tokens(user_id)
tokens_pretty = []
for token in tokens:
token_pretty = {}
token_pretty['id'] = token['id']
token_pretty['user_agent'] = UserAgent(token['user_agent'])
token_pretty['ip_address'] = token['ip_address']
di = token['date_issued']
di = datetime.utcfromtimestamp(di).strftime('%Y-%m-%d')
token_pretty['date_issued'] = di
de = token['date_expired']
de = datetime.utcfromtimestamp(de).strftime('%Y-%m-%d')
token_pretty['date_expired'] = de
tokens_pretty.append(token_pretty)
return render_template('manage.html', **locals())
@auth_views.route('/delete_key')
@auth_required
def delete_key():
"""
Allows a user to delete a security key credential.
"""
cred_id = request.args.get('key_id')
username = session['username']
user_id = db.get_user(username)['id']
cred = db.get_credential(cred_id)
if not cred:
return make_error(404, "security key not found")
if cred['user_id'] == user_id:
db.delete_credential(cred_id)
return jsonify(ok=True)
else:
return make_error(404, "security key not found")
@auth_views.route('/delete_token')
@auth_required
def delete_token():
"""
Allows a user to delete a token.
"""
token_id = request.args.get('token_id')
username = session['username']
user_id = db.get_user(username)['id']
token = db.get_token(token_id)
if not token:
return make_error(404, "token not found")
if token['user_id'] == user_id:
db.delete_token(token_id)
return jsonify(ok=True)
else:
return make_error(404, "token not found")