added @login_required() decorator for easier security

This commit is contained in:
iou1name 2018-06-07 12:24:53 -04:00
parent d3c744e072
commit 671e7a107d

View File

@ -4,9 +4,11 @@ Simple file host using Flask.
""" """
import os import os
import time import time
import atexit
import string import string
import secrets import secrets
import sqlite3 import sqlite3
import functools
import threading import threading
from datetime import datetime from datetime import datetime
@ -64,7 +66,7 @@ class CronThread(threading.Thread):
for filename, delete_date in records: for filename, delete_date in records:
if time.time() >= delete_date: if time.time() >= delete_date:
delete_file(filename) delete_file(filename)
time.sleep(60) self.stop.wait(60)
app = Flask(__name__) app = Flask(__name__)
@ -188,6 +190,31 @@ def delete_file(filename):
return True return True
def login_required(url=None):
"""
A decorator function to protect certain endpoints by requiring the user
to either pass a valid session cookie, or pass thier username and
password along with the request to login.
"""
def actual_decorator(func):
@functools.wraps(func)
def _nop(*args, **kwargs):
username = session.get("username")
if verify_username(username):
return func(*args, **kwargs)
username = request.form.get("user")
password = request.form.get("pass")
if verify_password(username, password):
return func(*args, **kwargs)
if url:
return redirect(url_for(url))
else:
abort(401)
return _nop
return actual_decorator
@app.route("/delete_file", methods=["POST"]) @app.route("/delete_file", methods=["POST"])
def deleteFile(): def deleteFile():
@ -236,7 +263,7 @@ def addUser():
return "Username already exists." return "Username already exists."
@app.route("/logout", methods=["POST", "GET"]) @app.route("/logout", methods=["GET"])
def logout(): def logout():
""" """
Logs the user out and removes his session cookie. Logs the user out and removes his session cookie.
@ -246,17 +273,15 @@ def logout():
@app.route("/change_password", methods=["POST", "GET"]) @app.route("/change_password", methods=["POST", "GET"])
@login_required()
def change_password(): def change_password():
""" """
Allows the user to change their password. Allows the user to change their password.
""" """
username = session.get("username")
if not verify_username(username):
abort(401)
if request.method == "GET": if request.method == "GET":
return render_template("change_password.html") return render_template("change_password.html")
username = session.get("username")
current_password = request.form.get("current_password") current_password = request.form.get("current_password")
new_password = request.form.get("new_password") new_password = request.form.get("new_password")
new_password_verify = request.form.get("new_password_verify") new_password_verify = request.form.get("new_password_verify")
@ -293,14 +318,12 @@ def login():
@app.route("/manage_uploads", methods=["POST", "GET"]) @app.route("/manage_uploads", methods=["POST", "GET"])
@login_required()
def manage_uploads(): def manage_uploads():
""" """
Allows the user to view and/or delete uploads they've made. Allows the user to view and/or delete uploads they've made.
""" """
username = session.get("username") username = session.get("username")
if not verify_username(username):
abort(401)
if request.method == "GET": if request.method == "GET":
uploads = db_execute( uploads = db_execute(
"SELECT filename, uploaded_date FROM uploads WHERE uploaded_by = ?", "SELECT filename, uploaded_date FROM uploads WHERE uploaded_by = ?",
@ -345,26 +368,17 @@ def gallery(username):
@app.route("/", methods=["POST", "GET"]) @app.route("/", methods=["POST", "GET"])
@login_required("login")
def index(): def index():
""" """
Saves the uploaded file and returns a URL pointing to it. Saves the uploaded file and returns a URL pointing to it.
""" """
if not session.get("username"):
if request.method == "GET":
return redirect(url_for("login"))
username = request.form.get("user")
password = request.form.get("pass")
if not verify_password(username, password):
abort(401)
else:
username = session.get("username")
if not verify_username(username):
abort(401)
if request.method == "GET": if request.method == "GET":
return render_template("index.html") return render_template("index.html")
username = session.get("username")
if not username:
username = request.form.get("user")
urls = [] urls = []
for file in request.files.getlist("file"): for file in request.files.getlist("file"):
fname = secure_filename(file.filename) fname = secure_filename(file.filename)
@ -415,6 +429,7 @@ def get_rand_chars(n):
init() init()
atexit.register(app.config["CRON_THREAD"].stop.set)
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
if len(sys.argv) > 1: if len(sys.argv) > 1: