added @login_required() decorator for easier security

This commit is contained in:
iou1name 2018-06-07 12:24:53 -04:00
parent d3c744e072
commit 671e7a107d

View File

@ -4,9 +4,11 @@ Simple file host using Flask.
"""
import os
import time
import atexit
import string
import secrets
import sqlite3
import functools
import threading
from datetime import datetime
@ -64,7 +66,7 @@ class CronThread(threading.Thread):
for filename, delete_date in records:
if time.time() >= delete_date:
delete_file(filename)
time.sleep(60)
self.stop.wait(60)
app = Flask(__name__)
@ -188,6 +190,31 @@ def delete_file(filename):
return True
def login_required(url=None):
"""
A decorator function to protect certain endpoints by requiring the user
to either pass a valid session cookie, or pass thier username and
password along with the request to login.
"""
def actual_decorator(func):
@functools.wraps(func)
def _nop(*args, **kwargs):
username = session.get("username")
if verify_username(username):
return func(*args, **kwargs)
username = request.form.get("user")
password = request.form.get("pass")
if verify_password(username, password):
return func(*args, **kwargs)
if url:
return redirect(url_for(url))
else:
abort(401)
return _nop
return actual_decorator
@app.route("/delete_file", methods=["POST"])
def deleteFile():
@ -236,7 +263,7 @@ def addUser():
return "Username already exists."
@app.route("/logout", methods=["POST", "GET"])
@app.route("/logout", methods=["GET"])
def logout():
"""
Logs the user out and removes his session cookie.
@ -246,17 +273,15 @@ def logout():
@app.route("/change_password", methods=["POST", "GET"])
@login_required()
def change_password():
"""
Allows the user to change their password.
"""
username = session.get("username")
if not verify_username(username):
abort(401)
if request.method == "GET":
return render_template("change_password.html")
username = session.get("username")
current_password = request.form.get("current_password")
new_password = request.form.get("new_password")
new_password_verify = request.form.get("new_password_verify")
@ -293,14 +318,12 @@ def login():
@app.route("/manage_uploads", methods=["POST", "GET"])
@login_required()
def manage_uploads():
"""
Allows the user to view and/or delete uploads they've made.
"""
username = session.get("username")
if not verify_username(username):
abort(401)
if request.method == "GET":
uploads = db_execute(
"SELECT filename, uploaded_date FROM uploads WHERE uploaded_by = ?",
@ -345,26 +368,17 @@ def gallery(username):
@app.route("/", methods=["POST", "GET"])
@login_required("login")
def index():
"""
Saves the uploaded file and returns a URL pointing to it.
"""
if not session.get("username"):
if request.method == "GET":
return redirect(url_for("login"))
username = request.form.get("user")
password = request.form.get("pass")
if not verify_password(username, password):
abort(401)
else:
username = session.get("username")
if not verify_username(username):
abort(401)
if request.method == "GET":
return render_template("index.html")
username = session.get("username")
if not username:
username = request.form.get("user")
urls = []
for file in request.files.getlist("file"):
fname = secure_filename(file.filename)
@ -415,6 +429,7 @@ def get_rand_chars(n):
init()
atexit.register(app.config["CRON_THREAD"].stop.set)
if __name__ == "__main__":
import sys
if len(sys.argv) > 1: