diff --git a/app/api/externalauth.py b/app/api/externalauth.py index 8a77c71a..5f9c564c 100644 --- a/app/api/externalauth.py +++ b/app/api/externalauth.py @@ -37,31 +37,3 @@ class ExternalTokenAuth(APIView): except Exception as e: return Response({'error': str(e)}) -# TODO: move to simple http server -# class TestExternalAuth(APIView): -# permission_classes = (permissions.AllowAny,) -# parser_classes = (parsers.JSONParser, parsers.FormParser,) - -# def post(self, request): -# print("YO!!!") -# if settings.EXTERNAL_AUTH_ENDPOINT == '': -# return Response({'message': 'Disabled'}) - -# username = request.data.get("username") -# password = request.data.get("password") - -# print("HERE", username) - -# if username == "extuser1" and password == "test1234": -# return Response({ -# 'user_id': 100, -# 'username': 'extuser1', -# 'maxQuota': 500, -# 'token': 'test', -# 'node': { -# 'hostname': 'localhost', -# 'port': 4444 -# } -# }) -# else: -# return Response({'message': "Invalid credentials"}) \ No newline at end of file diff --git a/app/auth/backends.py b/app/auth/backends.py index 22c7a65f..c3c13ea1 100644 --- a/app/auth/backends.py +++ b/app/auth/backends.py @@ -2,7 +2,7 @@ import requests from django.contrib.auth.backends import ModelBackend from django.contrib.auth.models import User from nodeodm.models import ProcessingNode -from webodm.settings import EXTERNAL_AUTH_ENDPOINT +from webodm import settings from guardian.shortcuts import assign_perm import logging @@ -15,45 +15,48 @@ def get_user_from_external_auth_response(res): if 'user_id' in res and 'username' in res: try: user = User.objects.get(pk=res['user_id']) - - # Update user info - if user.username != res['username']: - user.username = res['username'] - user.save() - - # Update quotas - maxQuota = -1 - if 'maxQuota' in res: - maxQuota = res['maxQuota'] - if 'node' in res and 'limits' in res['node'] and 'maxQuota' in res['node']['limits']: - maxQuota = res['node']['limits']['maxQuota'] - - if user.profile.quota != maxQuota: - user.profile.quota = maxQuota - user.save() except User.DoesNotExist: - user = User(pk=res['user_id'], username=username) + user = User(pk=res['user_id'], username=res['username']) + user.save() + + # Update user info + if user.username != res['username']: + user.username = res['username'] + user.save() + + maxQuota = -1 + if 'maxQuota' in res: + maxQuota = res['maxQuota'] + if 'node' in res and 'limits' in res['node'] and 'maxQuota' in res['node']['limits']: + maxQuota = res['node']['limits']['maxQuota'] + + # Update quotas + if user.profile.quota != maxQuota: + user.profile.quota = maxQuota user.save() # Setup/update processing node - if ('api_key' in res or 'token' in res) and 'node' in res: + if 'node' in res and 'hostname' in res['node'] and 'port' in res['node']: hostname = res['node']['hostname'] port = res['node']['port'] - token = res['api_key'] if 'api_key' in res else res['token'] + token = res['node'].get('token', '') - try: - node = ProcessingNode.objects.get(token=token) - if node.hostname != hostname or node.port != port: - node.hostname = hostname - node.port = port + # Only add/update if a token is provided, since we use + # tokens as unique identifiers for hostname/port updates + if token != "": + try: + node = ProcessingNode.objects.get(token=token) + if node.hostname != hostname or node.port != port: + node.hostname = hostname + node.port = port + node.save() + + except ProcessingNode.DoesNotExist: + node = ProcessingNode(hostname=hostname, port=port, token=token) node.save() - except ProcessingNode.DoesNotExist: - node = ProcessingNode(hostname=hostname, port=port, token=token) - node.save() - - if not user.has_perm('view_processingnode', node): - assign_perm('view_processingnode', user, node) + if not user.has_perm('view_processingnode', node): + assign_perm('view_processingnode', user, node) return user else: @@ -61,11 +64,11 @@ def get_user_from_external_auth_response(res): class ExternalBackend(ModelBackend): def authenticate(self, request, username=None, password=None): - if EXTERNAL_AUTH_ENDPOINT == "": + if settings.EXTERNAL_AUTH_ENDPOINT == "": return None try: - r = requests.post(EXTERNAL_AUTH_ENDPOINT, { + r = requests.post(settings.EXTERNAL_AUTH_ENDPOINT, { 'username': username, 'password': password }, headers={'Accept': 'application/json'}) @@ -76,7 +79,7 @@ class ExternalBackend(ModelBackend): return None def get_user(self, user_id): - if EXTERNAL_AUTH_ENDPOINT == "": + if settings.EXTERNAL_AUTH_ENDPOINT == "": return None try: diff --git a/app/tests/scripts/simple_auth_server.py b/app/tests/scripts/simple_auth_server.py new file mode 100644 index 00000000..690b8636 --- /dev/null +++ b/app/tests/scripts/simple_auth_server.py @@ -0,0 +1,97 @@ +import http.server +from http.server import SimpleHTTPRequestHandler +import socketserver +import sys +import threading +from time import sleep +import json + +class MyHandler(SimpleHTTPRequestHandler): + def do_GET(self): + self.send_response(200) + self.send_header('Content-type','text/html') + self.end_headers() + self.wfile.write(bytes("Simple auth server is running", encoding="utf-8")) + + + def send_error(self, code, error): + self.send_json(code, {"error": error}) + + def send_json(self, code, data): + response = bytes(json.dumps(data), encoding="utf-8") + + self.send_response(200) + self.send_header('Content-type','application/json') + self.send_header('Content-length', len(response)) + self.end_headers() + self.wfile.write(response) + + def do_POST(self): + if self.path == '/auth': + if not 'Content-Length' in self.headers: + self.send_error(403, "Missing form data") + return + + content_length = int(self.headers['Content-Length']) + post_data_str = self.rfile.read(content_length).decode("utf-8") + post_data = {} + for item in post_data_str.split('&'): + k,v = item.split('=') + post_data[k] = v + + username = post_data.get("username") + password = post_data.get("password") + + print("Login request for " + username) + + if username == "extuser1" and password == "test1234": + print("Granted") + self.send_json(200, { + 'user_id': 100, + 'username': 'extuser1', + 'maxQuota': 500, + 'node': { + 'hostname': 'localhost', + 'port': 4444, + 'token': 'test' + } + }) + else: + print("Unauthorized") + return self.send_error(401, "unauthorized") + else: + self.send_error(404, "not found") + +class WebServer(threading.Thread): + def __init__(self): + super().__init__() + self.host = "0.0.0.0" + self.port = int(sys.argv[1]) if len(sys.argv) >= 2 else 8080 + self.ws = socketserver.TCPServer((self.host, self.port), MyHandler) + + def run(self): + print("WebServer started at Port:", self.port) + self.ws.serve_forever() + + def shutdown(self): + # set the two flags needed to shutdown the HTTP server manually + # self.ws._BaseServer__is_shut_down.set() + # self.ws.__shutdown_request = True + + print('Shutting down server.') + # call it anyway, for good measure... + self.ws.shutdown() + print('Closing server.') + self.ws.server_close() + self.join() + +if __name__=='__main__': + webServer = WebServer() + webServer.start() + while True: + try: + sleep(0.5) + except KeyboardInterrupt: + print('Keyboard Interrupt sent.') + webServer.shutdown() + exit(0) \ No newline at end of file diff --git a/app/tests/test_external_auth.py b/app/tests/test_external_auth.py index 90dcda3f..2928bebd 100644 --- a/app/tests/test_external_auth.py +++ b/app/tests/test_external_auth.py @@ -1,8 +1,10 @@ from django.contrib.auth.models import User, Group +from nodeodm.models import ProcessingNode from rest_framework import status from rest_framework.test import APIClient from .classes import BootTestCase +from .utils import start_simple_auth_server from webodm import settings class TestAuth(BootTestCase): @@ -19,50 +21,28 @@ class TestAuth(BootTestCase): settings.EXTERNAL_AUTH_ENDPOINT = '' # Try to log-in - user = client.login(username='extuser1', password='test1234') - self.assertFalse(user) + ok = client.login(username='extuser1', password='test1234') + self.assertFalse(ok) # Enable - settings.EXTERNAL_AUTH_ENDPOINT = 'http://0.0.0.0:5555' + settings.EXTERNAL_AUTH_ENDPOINT = 'http://0.0.0.0:5555/auth' - # TODO: start simplehttp auth server - - user = client.login(username='extuser1', password='test1234') - # self.assertEqual(user.username, 'extuser1') - # self.assertEqual(user.id, 100) - - - # client.login(username="testuser", password="test1234") - - # user = User.objects.get(username="testuser") - - # # Cannot list profiles (not admin) - # res = client.get('/api/admin/profiles/') - # self.assertEqual(res.status_code, status.HTTP_403_FORBIDDEN) - - # res = client.get('/api/admin/profiles/%s/' % user.id) - # self.assertEqual(res.status_code, status.HTTP_403_FORBIDDEN) - - # # Cannot update quota deadlines - # res = client.post('/api/admin/profiles/%s/update_quota_deadline/' % user.id, data={'hours': 1}) - # self.assertEqual(res.status_code, status.HTTP_403_FORBIDDEN) - - # # Admin can - # client.login(username="testsuperuser", password="test1234") - - # res = client.get('/api/admin/profiles/') - # self.assertEqual(res.status_code, status.HTTP_200_OK) - # self.assertTrue(len(res.data) > 0) - - # res = client.get('/api/admin/profiles/%s/' % user.id) - # self.assertEqual(res.status_code, status.HTTP_200_OK) - # self.assertTrue('quota' in res.data) - # self.assertTrue('user' in res.data) - - # # User is the primary key (not profile id) - # self.assertEqual(res.data['user'], user.id) - - # # There should be no quota by default - # self.assertEqual(res.data['quota'], -1) - - \ No newline at end of file + with start_simple_auth_server(["5555"]): + ok = client.login(username='extuser1', password='invalid') + self.assertFalse(ok) + self.assertFalse(User.objects.filter(username="extuser1").exists()) + ok = client.login(username='extuser1', password='test1234') + self.assertTrue(ok) + user = User.objects.get(username="extuser1") + self.assertEqual(user.id, 100) + self.assertEqual(user.profile.quota, 500) + pnode = ProcessingNode.objects.get(token='test') + self.assertEqual(pnode.hostname, 'localhost') + self.assertEqual(pnode.port, 4444) + self.assertTrue(user.has_perm('view_processingnode', pnode)) + self.assertFalse(user.has_perm('delete_processingnode', pnode)) + self.assertFalse(user.has_perm('change_processingnode', pnode)) + + # Re-test login + ok = client.login(username='extuser1', password='test1234') + self.assertTrue(ok) diff --git a/app/tests/test_quota.py b/app/tests/test_quota.py new file mode 100644 index 00000000..0e43b33a --- /dev/null +++ b/app/tests/test_quota.py @@ -0,0 +1,58 @@ +from django.contrib.auth.models import User, Group +from rest_framework import status +from rest_framework.test import APIClient +from app.models import Task, Project +from .classes import BootTestCase + +class TestQuota(BootTestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_quota(self): + c = APIClient() + c.login(username="testuser", password="test1234") + + user = User.objects.get(username="testuser") + self.assertEqual(user.profile.quota, -1) + + # There should be no quota panel + res = c.get('/dashboard/', follow=True) + body = res.content.decode("utf-8") + + # There should be no quota panel + self.assertFalse('