diff --git a/fileHost.py b/fileHost.py index cd3e091..fc98741 100755 --- a/fileHost.py +++ b/fileHost.py @@ -4,9 +4,11 @@ Simple file host using Flask. """ import os import time +import atexit import string import secrets import sqlite3 +import functools import threading from datetime import datetime @@ -64,7 +66,7 @@ class CronThread(threading.Thread): for filename, delete_date in records: if time.time() >= delete_date: delete_file(filename) - time.sleep(60) + self.stop.wait(60) app = Flask(__name__) @@ -188,6 +190,31 @@ def delete_file(filename): return True +def login_required(url=None): + """ + A decorator function to protect certain endpoints by requiring the user + to either pass a valid session cookie, or pass thier username and + password along with the request to login. + """ + def actual_decorator(func): + @functools.wraps(func) + def _nop(*args, **kwargs): + username = session.get("username") + if verify_username(username): + return func(*args, **kwargs) + + username = request.form.get("user") + password = request.form.get("pass") + if verify_password(username, password): + return func(*args, **kwargs) + + if url: + return redirect(url_for(url)) + else: + abort(401) + return _nop + return actual_decorator + @app.route("/delete_file", methods=["POST"]) def deleteFile(): @@ -236,7 +263,7 @@ def addUser(): return "Username already exists." -@app.route("/logout", methods=["POST", "GET"]) +@app.route("/logout", methods=["GET"]) def logout(): """ Logs the user out and removes his session cookie. @@ -246,17 +273,15 @@ def logout(): @app.route("/change_password", methods=["POST", "GET"]) +@login_required() def change_password(): """ Allows the user to change their password. """ - username = session.get("username") - if not verify_username(username): - abort(401) - if request.method == "GET": return render_template("change_password.html") + username = session.get("username") current_password = request.form.get("current_password") new_password = request.form.get("new_password") new_password_verify = request.form.get("new_password_verify") @@ -293,14 +318,12 @@ def login(): @app.route("/manage_uploads", methods=["POST", "GET"]) +@login_required() def manage_uploads(): """ Allows the user to view and/or delete uploads they've made. """ username = session.get("username") - if not verify_username(username): - abort(401) - if request.method == "GET": uploads = db_execute( "SELECT filename, uploaded_date FROM uploads WHERE uploaded_by = ?", @@ -345,26 +368,17 @@ def gallery(username): @app.route("/", methods=["POST", "GET"]) +@login_required("login") def index(): """ Saves the uploaded file and returns a URL pointing to it. """ - if not session.get("username"): - if request.method == "GET": - return redirect(url_for("login")) - - username = request.form.get("user") - password = request.form.get("pass") - if not verify_password(username, password): - abort(401) - else: - username = session.get("username") - if not verify_username(username): - abort(401) - if request.method == "GET": return render_template("index.html") + username = session.get("username") + if not username: + username = request.form.get("user") urls = [] for file in request.files.getlist("file"): fname = secure_filename(file.filename) @@ -415,6 +429,7 @@ def get_rand_chars(n): init() +atexit.register(app.config["CRON_THREAD"].stop.set) if __name__ == "__main__": import sys if len(sys.argv) > 1: