Saddle/saddle.py

263 lines
6.7 KiB
Python

#!/usr/bin/env python3
"""
A file hosting service similar to Pomf and Uguu but without the public nature.
"""
import os
import re
import time
import string
import random
import asyncio
import datetime
from aiohttp import web
import jinja2
import aiohttp_jinja2
from aiohttp_jinja2 import render_template
import asyncpg
import uvloop
import requests
import config
import buckler_aiohttp
uvloop.install()
routes = web.RouteTableDef()
@routes.get('/', name='index')
@routes.post('/', name='index')
async def index(request):
"""The index page."""
if request.method == 'GET':
return render_template("index.html", request, locals())
data = await request.post()
rand_name = bool(data.get('rand_name'))
response_type = data.get('response_type', 'plain')
files = []
for filefield in data.getall('files'):
if not filefield:
continue
files.append(handle_filefield(filefield, rand_name=rand_name))
if data.get('url'):
files.append(handle_url(data.get('url'), rand_name=rand_name))
if data.get('text'):
files.append(handle_text(
data.get('text'),
data.get('text_filename', ''),
rand_name=rand_name)
)
if data.get('delete_this'):
delete_num = data.get('delete_num', '')
delete_type = data.get('delete_type', '')
try:
delete_num = int(delete_num)
assert delete_num >= 1 and delete_num <= 59
assert delete_type in ['minutes', 'hours', 'days', 'weeks']
except (ValueError, AssertionError):
return 'ur ghey' # TODO: return error
delta = datetime.timedelta(**{delete_type: delete_num})
expiration_date = datetime.datetime.now() + delta
else:
expiration_date = None
files_insert = []
for file in files:
t = (int(request.cookies.get('userid')), file[0], file[1], expiration_date)
files_insert.append(t)
async with request.app['pool'].acquire() as conn:
await conn.executemany(
"INSERT INTO upload (user_id, id, filename, expiration_date) "
"VALUES ($1, $2, $3, $4)",
files_insert)
urls = [config.upload_url + f[1] for f in files]
if response_type == 'html':
return render_template("result.html", request, locals())
elif response_type == 'plain':
return web.Response(body='\n'.join(urls))
elif response_type == 'json':
return web.json_response(urls)
@routes.get('/gallery', name='gallery')
async def gallery(request):
"""A user's gallery page."""
user_id = int(request.cookies.get('userid'))
async with request.app['pool'].acquire() as conn:
uploads = await conn.fetch(
"SELECT * FROM upload WHERE user_id = $1 ORDER BY upload_date DESC",
user_id)
upload_url = config.upload_url
return render_template("gallery.html", request, locals())
def handle_filefield(filefield, rand_name=True):
"""Handles a posted file."""
filename = safe_filename(filefield.filename)
if not filename:
rand_name = True
prefix = get_rand_chars()
if rand_name:
filename = prefix + os.path.splitext(filename)[1]
else:
filename = prefix + '_' + filename
with open(os.path.join(config.upload_dir, filename), 'wb') as file:
file.write(filefield.file.read())
tup = (prefix, filename)
return tup
def handle_url(url, rand_name=True):
"""Handles a posted URL."""
try:
filename, data = download_file(url)
except ValueError:
return None
filename = safe_filename(filename)
if not filename:
rand_name = True
prefix = get_rand_chars()
if rand_name:
filename = prefix + os.path.splitext(filename)[1]
else:
filename = prefix + '_' + filename
with open(os.path.join(config.upload_dir, filename), 'wb') as file:
file.write(data)
tup = (prefix, filename)
return tup
def handle_text(text, filename, rand_name=True):
"""Handles a posted text field."""
filename = safe_filename(filename)
if not filename:
rand_name = True
prefix = get_rand_chars()
if rand_name:
filename = prefix + os.path.splitext(filename)[1]
else:
filename = prefix + '_' + filename
filename = filename + '.txt'
with open(os.path.join(config.upload_dir, filename), 'w') as file:
file.write(text)
tup = (prefix, filename)
return tup
def safe_filename(filename=''):
"""Sanitizes the given filename."""
safe_char = string.ascii_letters + string.digits + '._ '
filename = ''.join([c for c in filename if c in safe_char])
filename = filename.strip('._ ')
return filename
def get_rand_chars(n=8):
"""Returns `n` number of random ASCII characters."""
chars = []
for _ in range(n):
char = random.choice(string.ascii_letters + string.digits)
chars.append(char)
return "".join(chars)
def download_file(url, timeout=10, max_file_size=config.client_max_size):
"""
Downloads the file at the given url while observing file size and
timeout limitations.
"""
requests_kwargs = {
'stream': True,
'headers': {'User-Agent': "Steelbea.me LTD needs YOUR files."},
'timeout': timeout,
'verify': True
}
temp = b''
with requests.get(url, **requests_kwargs) as r:
size = 0
start_time = time.time()
for chunk in r.iter_content(102400):
if time.time() - start_time > timeout:
raise ValueError('timeout reached')
if len(temp) > max_file_size:
raise ValueError('response too large')
temp += chunk
if r.headers.get('Content-Disposition'):
fname = re.search(r'filename="(.+)"',
r.headers['content-disposition']).group(1)
else:
fname = os.path.basename(url)
fname = re.sub('\?.*$', '', fname)
return (fname, temp)
async def cleaner(app):
"""Removes files marked for deletion."""
async with app['pool'].acquire() as conn:
expired = await conn.fetch(
"SELECT * FROM upload WHERE expiration_date < NOW()")
if not expired:
return
for record in expired:
os.remove(os.path.join(config.upload_dir, record['filename']))
await conn.executemany(
"DELETE FROM upload WHERE id = $1",
[(record['id'],) for record in expired])
async def cleaner_loop(app):
"""Loops cleaner() continuously until shutdown."""
try:
while True:
await cleaner(app)
await asyncio.sleep(60)
except asyncio.CancelledError:
return
async def start_background_tasks(app):
app['cleaner'] = asyncio.create_task(cleaner_loop(app))
async def cleanup_background_tasks(app):
app['cleaner'].cancel()
await app['cleaner']
async def init_app():
"""Initializes the application."""
app = web.Application(middlewares=[buckler_aiohttp.buckler_session])
aiohttp_jinja2.setup(app, loader=jinja2.FileSystemLoader('templates'))
app['pool'] = await asyncpg.create_pool(**config.db)
app.on_startup.append(start_background_tasks)
app.on_cleanup.append(cleanup_background_tasks)
async with app['pool'].acquire() as conn:
with open('saddle.sql', 'r') as file:
await conn.execute(file.read())
app.router.add_routes(routes)
app_wrap = web.Application(client_max_size=config.client_max_size)
app_wrap.add_subapp(config.url_prefix, app)
return app_wrap
if __name__ == "__main__":
app = init_app()
web.run_app(app, host='0.0.0.0', port=5000)