diff --git a/demos/custom_auth.py b/demos/custom_auth.py new file mode 100644 index 0000000..fd5c244 --- /dev/null +++ b/demos/custom_auth.py @@ -0,0 +1,21 @@ +import tornado.ioloop +from tornado.web import RequestHandler, Application +from tornado_http_auth import DigestAuthMixin, BasicAuthMixin, auth_required + +credentials = {"user1": "pass1"} + + +class MainHandler(BasicAuthMixin, RequestHandler): + @auth_required(realm="Protected", auth_func=credentials.get) + def get(self): + self.write("Hello %s" % self._current_user) + + +app = Application( + [ + (r"/", MainHandler), + ] +) + +app.listen(8888) +tornado.ioloop.IOLoop.current().start() diff --git a/demos/pam_auth.py b/demos/pam_auth.py new file mode 100644 index 0000000..00ebdf8 --- /dev/null +++ b/demos/pam_auth.py @@ -0,0 +1,21 @@ +import tornado.ioloop +from tornado.web import RequestHandler, Application +from tornado_http_auth import DigestAuthMixin, BasicAuthMixin, auth_required + +import pamela + + +class MainHandler(BasicAuthMixin, RequestHandler): + @auth_required(realm="Protected", auth_func=pamela.authenticate) + def get(self): + self.write("Hello %s" % self._current_user) + + +app = Application( + [ + (r"/", MainHandler), + ] +) + +app.listen(8888) +tornado.ioloop.IOLoop.current().start() diff --git a/demos/requirements.txt b/demos/requirements.txt new file mode 100644 index 0000000..77ab09d --- /dev/null +++ b/demos/requirements.txt @@ -0,0 +1 @@ +pamela diff --git a/tests/test_functional.py b/tests/test_functional.py index d03680b..449762b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -16,18 +16,28 @@ class BasicAuthHandler(BasicAuthMixin, RequestHandler): def get(self): self.write('Hello %s' % self._current_user) +class BasicAuthWithPwdHandler(BasicAuthMixin, RequestHandler): + @auth_required(realm='Protected', auth_func=lambda u, p: p == credentials[u]) + def get(self): + self.write('Hello %s' % self._current_user) class DigestAuthHandler(DigestAuthMixin, RequestHandler): @auth_required(realm='Protected', auth_func=credentials.get) def get(self): self.write('Hello %s' % self._current_user) +class DigestAuthWithPwdHandler(DigestAuthMixin, RequestHandler): + @auth_required(realm='Protected', auth_func=lambda u, p: p == credentials[u]) + def get(self): + self.write('Hello %s' % self._current_user) class AuthTest(AsyncHTTPTestCase): def get_app(self): urls = [ ('/digest', DigestAuthHandler), + ('/digestpw', DigestAuthWithPwdHandler), ('/basic', BasicAuthHandler), + ('/basicpw', BasicAuthWithPwdHandler), ] return Application(urls, http_client=self.http_client) @@ -36,6 +46,11 @@ def test_digest_auth(self): self.assertEqual(res.code, 401) # TODO: Add digest authentication to HTTPClient in order to test this. + def test_digest_auth_with_pw(self): + res = self.fetch('/digestpw') + self.assertEqual(res.code, 401) + # TODO: Add digest authentication to HTTPClient in order to test this. + def test_basic_auth(self): res = self.fetch('/basic') self.assertEqual(res.code, 401) @@ -46,3 +61,12 @@ def test_basic_auth(self): res = self.fetch('/basic', headers=hdr) self.assertEqual(res.code, 200) + def test_basic_auth_with_pw(self): + res = self.fetch('/basicpw') + self.assertEqual(res.code, 401) + + auth = '%s:%s' % ('user1', 'pass1') + auth = b64encode(auth.encode('ascii')) + hdr = {'Authorization': 'Basic %s' % auth.decode('utf8')} + res = self.fetch('/basicpw', headers=hdr) + self.assertEqual(res.code, 200) diff --git a/tornado_http_auth.py b/tornado_http_auth.py index 2c6efbb..aefb298 100644 --- a/tornado_http_auth.py +++ b/tornado_http_auth.py @@ -1,12 +1,13 @@ # Mostly everything in this module is inspired or outright copied from Twisted's cred. # https://github.com/twisted/twisted/blob/trunk/src/twisted/cred/credentials.py - +import inspect import os, re, sys import time import binascii import base64 import hashlib +import pamela from tornado.auth import AuthError @@ -212,16 +213,34 @@ def authenticate_user(self, check_credentials_func, realm): auth_data = base64.b64decode(auth_data).decode('ascii') username, password = auth_data.split(':', 1) - challenge = check_credentials_func(username) - if not challenge: - raise self.SendChallenge() + # Here we keep backward compatibility with: + # username -> password + # behavior, while also providing a + # username, password -> Bool + # behavior for stricter authentication systems (PAM for instance) + loggedin = False + + try: + challenge = check_credentials_func(username) + if not challenge: # no password -> wrong username -> try again (too nice to be safe ?) + raise self.SendChallenge() + + loggedin = challenge == password - if challenge == password: + except (TypeError, pamela.PAMError): + # PAMError can occur because password is None by default, as pamela relies on PAM conversation interface... + + try: + # password required, expect boolean return + loggedin = check_credentials_func(username, password) + + except Exception: # any exception here means wrong username / password combination -> try again + raise self.SendChallenge() + + if loggedin: self._current_user = username - return True - else: - raise self.SendChallenge() - return False + + return loggedin def auth_required(realm, auth_func):