Format code with black

This commit is contained in:
Andreas Zweili 2023-12-29 19:59:25 +01:00
parent 1b64ae294a
commit 69ebfb1418
1 changed files with 40 additions and 38 deletions

View File

@ -10,47 +10,47 @@ from redis.exceptions import ConnectionError as RedisConnectionError
from .utils import strtobool
NO_SSL = bool(strtobool(os.environ.get('NO_SSL', 'False')))
URL_PREFIX = os.environ.get('URL_PREFIX', None)
HOST_OVERRIDE = os.environ.get('HOST_OVERRIDE', None)
TOKEN_SEPARATOR = '~'
NO_SSL = bool(strtobool(os.environ.get("NO_SSL", "False")))
URL_PREFIX = os.environ.get("URL_PREFIX", None)
HOST_OVERRIDE = os.environ.get("HOST_OVERRIDE", None)
TOKEN_SEPARATOR = "~"
# Initialize Flask Application
app = Flask(__name__)
if os.environ.get('DEBUG'):
if os.environ.get("DEBUG"):
app.debug = True
app.secret_key = os.environ.get('SECRET_KEY', 'Secret Key')
app.config.update(
dict(STATIC_URL=os.environ.get('STATIC_URL', 'static')))
app.secret_key = os.environ.get("SECRET_KEY", "Secret Key")
app.config.update(dict(STATIC_URL=os.environ.get("STATIC_URL", "static")))
# Initialize Redis
if os.environ.get('MOCK_REDIS'):
if os.environ.get("MOCK_REDIS"):
from fakeredis import FakeStrictRedis
redis_client = FakeStrictRedis()
elif os.environ.get('REDIS_URL'):
redis_client = redis.StrictRedis.from_url(os.environ.get('REDIS_URL'))
elif os.environ.get("REDIS_URL"):
redis_client = redis.StrictRedis.from_url(os.environ.get("REDIS_URL"))
else:
redis_host = os.environ.get('REDIS_HOST', 'localhost')
redis_port = os.environ.get('REDIS_PORT', 6379)
redis_db = os.environ.get('SNAPPASS_REDIS_DB', 0)
redis_client = redis.StrictRedis(
host=redis_host, port=redis_port, db=redis_db)
REDIS_PREFIX = os.environ.get('REDIS_PREFIX', 'snappass')
redis_host = os.environ.get("REDIS_HOST", "localhost")
redis_port = os.environ.get("REDIS_PORT", 6379)
redis_db = os.environ.get("SNAPPASS_REDIS_DB", 0)
redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
REDIS_PREFIX = os.environ.get("REDIS_PREFIX", "snappass")
def check_redis_alive(fn):
def inner(*args, **kwargs):
try:
if fn.__name__ == 'main':
if fn.__name__ == "main":
redis_client.ping()
return fn(*args, **kwargs)
except RedisConnectionError as error:
print('Failed to connect to redis! %s' % error)
if fn.__name__ == 'main':
print("Failed to connect to redis! %s" % error)
if fn.__name__ == "main":
sys.exit(0)
else:
return abort(500)
return inner
@ -61,7 +61,7 @@ def encrypt(password):
"""
encryption_key = Fernet.generate_key()
fernet = Fernet(encryption_key)
encrypted_password = fernet.encrypt(password.encode('utf-8'))
encrypted_password = fernet.encrypt(password.encode("utf-8"))
return encrypted_password, encryption_key
@ -79,7 +79,7 @@ def parse_token(token):
storage_key = token_fragments[0]
try:
decryption_key = token_fragments[1].encode('utf-8')
decryption_key = token_fragments[1].encode("utf-8")
except IndexError:
decryption_key = None
@ -97,7 +97,7 @@ def set_password(password, ttl):
storage_key = REDIS_PREFIX + uuid.uuid4().hex
encrypted_password, encryption_key = encrypt(password)
redis_client.setex(storage_key, ttl, encrypted_password)
encryption_key = encryption_key.decode('utf-8')
encryption_key = encryption_key.decode("utf-8")
token = TOKEN_SEPARATOR.join([storage_key, encryption_key])
return token
@ -115,11 +115,10 @@ def get_password(token):
redis_client.delete(storage_key)
if password is not None:
if decryption_key is not None:
password = decrypt(password, decryption_key)
return password.decode('utf-8')
return password.decode("utf-8")
@check_redis_alive
@ -138,20 +137,23 @@ def clean_input():
Make sure we're not getting bad data from the front end,
format data to be machine readable
"""
if empty(request.form.get('password', '')):
if empty(request.form.get("password", "")):
abort(400)
if empty(request.form.get('ttl', None)):
request.form.set('ttl', 604800)
if not isinstance(request.form.get("ttl"), int):
abort(400)
if request.form.get('ttl') > 2419200:
abort(400, 'TTL must be less than 2419200 seconds (4 weeks)')
if empty(request.form.get("ttl", None)):
request.form.set("ttl", 604800)
if request.form['password']:
return request.form['ttl'], request.form['password']
if request.form.get("ttl") > 2419200:
abort(400, "TTL must be less than 2419200 seconds (4 weeks)")
if request.form["password"]:
return request.form["ttl"], request.form["password"]
@app.route('/', methods=['POST'])
@app.route("/", methods=["POST"])
def handle_password():
if request.is_json:
request.form = request.get_json()
@ -160,13 +162,13 @@ def handle_password():
return jsonify(key=token)
@app.route('/get-secret', methods=['POST'])
@app.route("/get-secret", methods=["POST"])
def show_password():
if request.is_json:
request.form = request.get_json()
if empty(request.form.get('key', '')):
if empty(request.form.get("key", "")):
abort(400)
password_key = request.form['key']
password_key = request.form["key"]
password = get_password(password_key)
if not password:
return abort(404)
@ -175,8 +177,8 @@ def show_password():
@check_redis_alive
def main():
app.run(host='0.0.0.0')
app.run(host="0.0.0.0")
if __name__ == '__main__':
if __name__ == "__main__":
main()