diff --git a/setup.py b/setup.py index f0c281fa..c2346251 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,10 @@ 'jwcrypto', 'redis', ], + extras_requires={ + 'libpass': ['libpass'], + 'bcrypt': ['bcrypt'] + }, zip_safe=False, entry_points={ 'console_scripts': [ diff --git a/test-requirements.txt b/test-requirements.txt index 4eeff976..b58679fd 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -3,3 +3,5 @@ nose2 six redis wrapt<=1.12.1;python_version<="3.4" +libpass;python_versions>="3.8" +bcrypt;python_versions>="3.8" diff --git a/tests/test_auth_plugins.py b/tests/test_auth_plugins.py index 4b3bfb5c..4dd86d1e 100644 --- a/tests/test_auth_plugins.py +++ b/tests/test_auth_plugins.py @@ -2,8 +2,9 @@ """ Unit tests for Authentication plugins""" -from websockify.auth_plugins import BasicHTTPAuth, AuthenticationError +from websockify.auth_plugins import BasicHTTPAuth, HtpasswdAuth, AuthenticationError import unittest +import tempfile class BasicHTTPAuthTestCase(unittest.TestCase): @@ -26,3 +27,46 @@ def test_valid_password(self): def test_garbage_auth(self): headers = {'Authorization': 'Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx'} self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + +try: + import passlib + PASSLIB_AVAILABLE = True +except ImportError: + PASSLIB_AVAILABLE = False + +@unittest.skipUnless(PASSLIB_AVAILABLE, "passlib package is not available") +class HtpasswdAuthTestCase(unittest.TestCase): + + def setUp(self): + self._temporary_htpasswd_file = tempfile.NamedTemporaryFile(delete=False) + + #file generated with `htpasswd -c5i test_auth_plugins.htpasswd Genie <<<"""let's make some Magic!"""; htpasswd -Bi test_auth_plugins.htpasswd Aladdin <<<"""open sesame"""` + file_content = 'Genie:$6$5EsSBArrdAYDSe.j$v9mqxcSfPQgrM7btHx5wysZ28a1gei62rH75f8nYxwzPT80gbaL4qqxlkIBy.zSTnmG5VW2/RKFXQcGIgqAQq/\nAladdin:$2y$05$HK/O9w/55MSjM2FMefSIbeFKKANQbfR/hlYWk8RlDrR7Qyb5gnuzG' + + self._temporary_htpasswd_file.write(file_content.encode('utf-8')) + self._temporary_htpasswd_file.close() + + self.plugin = HtpasswdAuth(self._temporary_htpasswd_file.name) + + def test_no_auth(self): + headers = {} + self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + + def test_invalid_password(self): + headers = {'Authorization': 'Basic QWxhZGRpbjpzZXNhbWUgc3RyZWV0'} + self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + + def test_valid_password(self): + headers = {'Authorization': 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='} + self.plugin.authenticate(headers, 'localhost', '1234') + headers = {'Authorization': 'Basic R2VuaWU6bGV0J3MgbWFrZSBzb21lIE1hZ2ljIQ=='} + self.plugin.authenticate(headers, 'localhost', '1234') + + def test_garbage_auth(self): + headers = {'Authorization': 'Basic xxxxxxxxxxxxxxxxxxxxxxxxxxxx'} + self.assertRaises(AuthenticationError, self.plugin.authenticate, headers, 'localhost', '1234') + + def tearDown(self): + import os + os.remove(self._temporary_htpasswd_file.name) + diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py index 36fac520..4275323c 100644 --- a/websockify/auth_plugins.py +++ b/websockify/auth_plugins.py @@ -1,3 +1,12 @@ +import logging +logger = logging.getLogger(__name__) + +try: + from passlib.apache import HtpasswdFile +except ImportError as e: + HtpasswdFile: None + + class BasePlugin(): def __init__(self, src=None): self.source = src @@ -76,6 +85,29 @@ def demand_auth(self): raise AuthenticationError(response_code=401, response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) +class HtpasswdAuth(BasicHTTPAuth): + """Verifies Basic Auth headers against a htpasswd database. Specify src as the path to the htpasswd file""" + + def __init__(self, src=None): + self.src = src + if HtpasswdFile is None: + logging.error("Class ''HtpasswdFile' from libpass (passlib.apache), is not initialized, verify the availability of the module 'libpass'" ) + raise AuthenticationError(response_code=500, response_msg=f"Internal Server Error") + + def validate_creds(self, username, password): + if self.src == None: + return False + try: + htfile = HtpasswdFile(self.src, new=False, encoding="utf-8") + isvalid_hash = htfile.check_password(username, password) + if isvalid_hash == None: + logger.warning("'%s' user not found in database." % (username)) + return isvalid_hash + except (FileNotFoundError, PermissionError, OSError, ValueError) as e: + logging.error("%s: %s" % (type(e).__name__, e)) + raise AuthenticationError(response_code=500, response_msg=f"Internal Server Error") + return False + class ExpectOrigin(): def __init__(self, src=None): if src is None: