Format code with black
This commit is contained in:
parent
1b64ae294a
commit
69ebfb1418
78
src/main.py
78
src/main.py
|
@ -10,47 +10,47 @@ from redis.exceptions import ConnectionError as RedisConnectionError
|
|||
from .utils import strtobool
|
||||
|
||||
|
||||
NO_SSL = bool(strtobool(os.environ.get('NO_SSL', 'False')))
|
||||
URL_PREFIX = os.environ.get('URL_PREFIX', None)
|
||||
HOST_OVERRIDE = os.environ.get('HOST_OVERRIDE', None)
|
||||
TOKEN_SEPARATOR = '~'
|
||||
NO_SSL = bool(strtobool(os.environ.get("NO_SSL", "False")))
|
||||
URL_PREFIX = os.environ.get("URL_PREFIX", None)
|
||||
HOST_OVERRIDE = os.environ.get("HOST_OVERRIDE", None)
|
||||
TOKEN_SEPARATOR = "~"
|
||||
|
||||
|
||||
# Initialize Flask Application
|
||||
app = Flask(__name__)
|
||||
if os.environ.get('DEBUG'):
|
||||
if os.environ.get("DEBUG"):
|
||||
app.debug = True
|
||||
app.secret_key = os.environ.get('SECRET_KEY', 'Secret Key')
|
||||
app.config.update(
|
||||
dict(STATIC_URL=os.environ.get('STATIC_URL', 'static')))
|
||||
app.secret_key = os.environ.get("SECRET_KEY", "Secret Key")
|
||||
app.config.update(dict(STATIC_URL=os.environ.get("STATIC_URL", "static")))
|
||||
|
||||
# Initialize Redis
|
||||
if os.environ.get('MOCK_REDIS'):
|
||||
if os.environ.get("MOCK_REDIS"):
|
||||
from fakeredis import FakeStrictRedis
|
||||
|
||||
redis_client = FakeStrictRedis()
|
||||
elif os.environ.get('REDIS_URL'):
|
||||
redis_client = redis.StrictRedis.from_url(os.environ.get('REDIS_URL'))
|
||||
elif os.environ.get("REDIS_URL"):
|
||||
redis_client = redis.StrictRedis.from_url(os.environ.get("REDIS_URL"))
|
||||
else:
|
||||
redis_host = os.environ.get('REDIS_HOST', 'localhost')
|
||||
redis_port = os.environ.get('REDIS_PORT', 6379)
|
||||
redis_db = os.environ.get('SNAPPASS_REDIS_DB', 0)
|
||||
redis_client = redis.StrictRedis(
|
||||
host=redis_host, port=redis_port, db=redis_db)
|
||||
REDIS_PREFIX = os.environ.get('REDIS_PREFIX', 'snappass')
|
||||
redis_host = os.environ.get("REDIS_HOST", "localhost")
|
||||
redis_port = os.environ.get("REDIS_PORT", 6379)
|
||||
redis_db = os.environ.get("SNAPPASS_REDIS_DB", 0)
|
||||
redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
|
||||
REDIS_PREFIX = os.environ.get("REDIS_PREFIX", "snappass")
|
||||
|
||||
|
||||
def check_redis_alive(fn):
|
||||
def inner(*args, **kwargs):
|
||||
try:
|
||||
if fn.__name__ == 'main':
|
||||
if fn.__name__ == "main":
|
||||
redis_client.ping()
|
||||
return fn(*args, **kwargs)
|
||||
except RedisConnectionError as error:
|
||||
print('Failed to connect to redis! %s' % error)
|
||||
if fn.__name__ == 'main':
|
||||
print("Failed to connect to redis! %s" % error)
|
||||
if fn.__name__ == "main":
|
||||
sys.exit(0)
|
||||
else:
|
||||
return abort(500)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
|
@ -61,7 +61,7 @@ def encrypt(password):
|
|||
"""
|
||||
encryption_key = Fernet.generate_key()
|
||||
fernet = Fernet(encryption_key)
|
||||
encrypted_password = fernet.encrypt(password.encode('utf-8'))
|
||||
encrypted_password = fernet.encrypt(password.encode("utf-8"))
|
||||
return encrypted_password, encryption_key
|
||||
|
||||
|
||||
|
@ -79,7 +79,7 @@ def parse_token(token):
|
|||
storage_key = token_fragments[0]
|
||||
|
||||
try:
|
||||
decryption_key = token_fragments[1].encode('utf-8')
|
||||
decryption_key = token_fragments[1].encode("utf-8")
|
||||
except IndexError:
|
||||
decryption_key = None
|
||||
|
||||
|
@ -97,7 +97,7 @@ def set_password(password, ttl):
|
|||
storage_key = REDIS_PREFIX + uuid.uuid4().hex
|
||||
encrypted_password, encryption_key = encrypt(password)
|
||||
redis_client.setex(storage_key, ttl, encrypted_password)
|
||||
encryption_key = encryption_key.decode('utf-8')
|
||||
encryption_key = encryption_key.decode("utf-8")
|
||||
token = TOKEN_SEPARATOR.join([storage_key, encryption_key])
|
||||
return token
|
||||
|
||||
|
@ -115,11 +115,10 @@ def get_password(token):
|
|||
redis_client.delete(storage_key)
|
||||
|
||||
if password is not None:
|
||||
|
||||
if decryption_key is not None:
|
||||
password = decrypt(password, decryption_key)
|
||||
|
||||
return password.decode('utf-8')
|
||||
return password.decode("utf-8")
|
||||
|
||||
|
||||
@check_redis_alive
|
||||
|
@ -138,20 +137,23 @@ def clean_input():
|
|||
Make sure we're not getting bad data from the front end,
|
||||
format data to be machine readable
|
||||
"""
|
||||
if empty(request.form.get('password', '')):
|
||||
if empty(request.form.get("password", "")):
|
||||
abort(400)
|
||||
|
||||
if empty(request.form.get('ttl', None)):
|
||||
request.form.set('ttl', 604800)
|
||||
if not isinstance(request.form.get("ttl"), int):
|
||||
abort(400)
|
||||
|
||||
if request.form.get('ttl') > 2419200:
|
||||
abort(400, 'TTL must be less than 2419200 seconds (4 weeks)')
|
||||
if empty(request.form.get("ttl", None)):
|
||||
request.form.set("ttl", 604800)
|
||||
|
||||
if request.form['password']:
|
||||
return request.form['ttl'], request.form['password']
|
||||
if request.form.get("ttl") > 2419200:
|
||||
abort(400, "TTL must be less than 2419200 seconds (4 weeks)")
|
||||
|
||||
if request.form["password"]:
|
||||
return request.form["ttl"], request.form["password"]
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'])
|
||||
@app.route("/", methods=["POST"])
|
||||
def handle_password():
|
||||
if request.is_json:
|
||||
request.form = request.get_json()
|
||||
|
@ -160,13 +162,13 @@ def handle_password():
|
|||
return jsonify(key=token)
|
||||
|
||||
|
||||
@app.route('/get-secret', methods=['POST'])
|
||||
@app.route("/get-secret", methods=["POST"])
|
||||
def show_password():
|
||||
if request.is_json:
|
||||
request.form = request.get_json()
|
||||
if empty(request.form.get('key', '')):
|
||||
if empty(request.form.get("key", "")):
|
||||
abort(400)
|
||||
password_key = request.form['key']
|
||||
password_key = request.form["key"]
|
||||
password = get_password(password_key)
|
||||
if not password:
|
||||
return abort(404)
|
||||
|
@ -175,8 +177,8 @@ def show_password():
|
|||
|
||||
@check_redis_alive
|
||||
def main():
|
||||
app.run(host='0.0.0.0')
|
||||
app.run(host="0.0.0.0")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue