kopia lustrzana https://github.com/snarfed/bridgy-fed
improve domain validation for Web key ids, normalize to lower case
rodzic
0f19654eb2
commit
7f6cc61683
10
common.py
10
common.py
|
@ -20,7 +20,15 @@ from oauth_dropins.webutil.util import json_dumps, json_loads
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DOMAIN_RE = r'[^/:]+\.[^/:]+'
|
||||
# allow hostname chars (a-z, 0-9, -), allow arbitrary unicode (eg ☃.net), don't
|
||||
# allow specific chars that we'll often see in webfinger, AP handles, etc. (@, :)
|
||||
# https://stackoverflow.com/questions/10306690/what-is-a-regular-expression-which-will-match-a-valid-domain-name-without-a-subd
|
||||
#
|
||||
# uses $ at end but not ^ at the beginning so that it can be used to match just
|
||||
# part of a URL path segment, eg for /acct:user.com in webfinger.py.
|
||||
#
|
||||
# TODO: preprocess with domain2idna, then narrow this to just [a-z0-9-]
|
||||
DOMAIN_RE = r'[^/:;@_?!\']+\.[^/:@_?!\']+$'
|
||||
TLD_BLOCKLIST = ('7z', 'asp', 'aspx', 'gif', 'html', 'ico', 'jpg', 'jpeg', 'js',
|
||||
'json', 'php', 'png', 'rar', 'txt', 'yaml', 'yml', 'zip')
|
||||
|
||||
|
|
|
@ -21,12 +21,12 @@ class CommonTest(TestCase):
|
|||
def setUpClass(cls):
|
||||
with appengine_config.ndb_client.context():
|
||||
# do this in setUpClass since generating RSA keys is slow
|
||||
cls.user = cls.make_user('site')
|
||||
cls.user = cls.make_user('user.com')
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.request_context.push()
|
||||
g.user = Fake(id='site')
|
||||
g.user = Fake(id='user.com')
|
||||
|
||||
def tearDown(self):
|
||||
self.request_context.pop()
|
||||
|
@ -48,8 +48,8 @@ class CommonTest(TestCase):
|
|||
common.pretty_link('http://foo'))
|
||||
|
||||
self.assertEqual(
|
||||
'<a class="h-card u-author" href="/fake/site"><img src="" class="profile"> site</a>',
|
||||
common.pretty_link('https://site/'))
|
||||
'<a class="h-card u-author" href="/fake/user.com"><img src="" class="profile"> user.com</a>',
|
||||
common.pretty_link('https://user.com/'))
|
||||
|
||||
def test_redirect_wrap_empty(self):
|
||||
self.assertIsNone(common.redirect_wrap(None))
|
||||
|
|
|
@ -63,7 +63,7 @@ class RemoteFollowTest(TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.make_user('me')
|
||||
self.make_user('user.com')
|
||||
|
||||
def test_no_domain(self, _):
|
||||
got = self.client.post('/remote-follow?address=@foo@bar&protocol=web')
|
||||
|
@ -74,11 +74,11 @@ class RemoteFollowTest(TestCase):
|
|||
self.assertEqual(400, got.status_code)
|
||||
|
||||
def test_no_protocol(self, _):
|
||||
got = self.client.post('/remote-follow?address=@foo@bar&domain=me')
|
||||
got = self.client.post('/remote-follow?address=@foo@bar&domain=user.com')
|
||||
self.assertEqual(400, got.status_code)
|
||||
|
||||
def test_unknown_protocol(self, _):
|
||||
got = self.client.post('/remote-follow?address=@foo@bar&domain=me&protocol=foo')
|
||||
got = self.client.post('/remote-follow?address=@foo@bar&domain=user.com&protocol=foo')
|
||||
self.assertEqual(400, got.status_code)
|
||||
|
||||
def test_no_user(self, _):
|
||||
|
@ -87,9 +87,9 @@ class RemoteFollowTest(TestCase):
|
|||
|
||||
def test(self, mock_get):
|
||||
mock_get.return_value = WEBFINGER
|
||||
got = self.client.post('/remote-follow?address=@foo@bar&domain=me&protocol=web')
|
||||
got = self.client.post('/remote-follow?address=@foo@bar&domain=user.com&protocol=web')
|
||||
self.assertEqual(302, got.status_code)
|
||||
self.assertEqual('https://bar/follow?uri=@me@me',
|
||||
self.assertEqual('https://bar/follow?uri=@user.com@user.com',
|
||||
got.headers['Location'])
|
||||
|
||||
mock_get.assert_has_calls((
|
||||
|
@ -98,9 +98,9 @@ class RemoteFollowTest(TestCase):
|
|||
|
||||
def test_url(self, mock_get):
|
||||
mock_get.return_value = WEBFINGER
|
||||
got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web')
|
||||
got = self.client.post('/remote-follow?address=https://bar/foo&domain=user.com&protocol=web')
|
||||
self.assertEqual(302, got.status_code)
|
||||
self.assertEqual('https://bar/follow?uri=@me@me', got.headers['Location'])
|
||||
self.assertEqual('https://bar/follow?uri=@user.com@user.com', got.headers['Location'])
|
||||
|
||||
mock_get.assert_has_calls((
|
||||
self.req('https://bar/.well-known/webfinger?resource=https://bar/foo'),
|
||||
|
@ -112,23 +112,23 @@ class RemoteFollowTest(TestCase):
|
|||
'links': [{'rel': 'other', 'template': 'meh'}],
|
||||
})
|
||||
|
||||
got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web')
|
||||
got = self.client.post('/remote-follow?address=https://bar/foo&domain=user.com&protocol=web')
|
||||
self.assertEqual(302, got.status_code)
|
||||
self.assertEqual('/web/me', got.headers['Location'])
|
||||
self.assertEqual('/web/user.com', got.headers['Location'])
|
||||
|
||||
def test_webfinger_error(self, mock_get):
|
||||
mock_get.return_value = requests_response(status=500)
|
||||
|
||||
got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web')
|
||||
got = self.client.post('/remote-follow?address=https://bar/foo&domain=user.com&protocol=web')
|
||||
self.assertEqual(302, got.status_code)
|
||||
self.assertEqual('/web/me', got.headers['Location'])
|
||||
self.assertEqual('/web/user.com', got.headers['Location'])
|
||||
|
||||
def test_webfinger_returns_not_json(self, mock_get):
|
||||
mock_get.return_value = requests_response('<html>not json</html>')
|
||||
|
||||
got = self.client.post('/remote-follow?address=https://bar/foo&domain=me&protocol=web')
|
||||
got = self.client.post('/remote-follow?address=https://bar/foo&domain=user.com&protocol=web')
|
||||
self.assertEqual(302, got.status_code)
|
||||
self.assertEqual('/web/me', got.headers['Location'])
|
||||
self.assertEqual('/web/user.com', got.headers['Location'])
|
||||
|
||||
|
||||
@patch('requests.post')
|
||||
|
|
|
@ -407,6 +407,30 @@ class WebTest(TestCase):
|
|||
def assert_object(self, id, **props):
|
||||
return super().assert_object(id, delivered_protocol='activitypub', **props)
|
||||
|
||||
def test_put_validates_domain_id(self, *_):
|
||||
for bad in (
|
||||
'AbC.cOm',
|
||||
'foo',
|
||||
'@user.com',
|
||||
'@user.com@user.com',
|
||||
'acct:user.com',
|
||||
'acct:@user.com@user.com',
|
||||
'acc:me@user.com',
|
||||
):
|
||||
with self.assertRaises(AssertionError):
|
||||
Web(id=bad).put()
|
||||
|
||||
def test_get_or_create_lower_cases_domain(self, *_):
|
||||
user = Web.get_or_create('AbC.oRg')
|
||||
self.assertEqual('abc.org', user.key.id())
|
||||
self.assert_entities_equal(user, Web.get_by_id('abc.org'))
|
||||
self.assertIsNone(Web.get_by_id('AbC.oRg'))
|
||||
|
||||
def test_get_or_create_unicode_domain(self, *_):
|
||||
user = Web.get_or_create('☃.net')
|
||||
self.assertEqual('☃.net', user.key.id())
|
||||
self.assert_entities_equal(user, Web.get_by_id('☃.net'))
|
||||
|
||||
def test_bad_source_url(self, mock_get, mock_post):
|
||||
for data in b'', {'source': 'bad'}, {'source': 'https://'}:
|
||||
got = self.client.post('/webmention', data=data)
|
||||
|
@ -1581,6 +1605,29 @@ http://this/404s
|
|||
self.assertEqual('Person', user.actor_as2['type'])
|
||||
self.assertEqual('http://localhost/user.com', user.actor_as2['id'])
|
||||
|
||||
def test_check_web_site_unicode_domain(self, mock_get, _):
|
||||
mock_get.side_effect = (
|
||||
requests_response(''),
|
||||
requests_response(''),
|
||||
)
|
||||
|
||||
got = self.client.post('/web-site', data={'url': 'https://☃.net/'})
|
||||
self.assert_equals(302, got.status_code)
|
||||
self.assert_equals('/web/%E2%98%83.net', got.headers['Location'])
|
||||
self.assertIsNotNone(Web.get_by_id('☃.net'))
|
||||
|
||||
def test_check_web_site_lower_cases_domain(self, mock_get, _):
|
||||
mock_get.side_effect = (
|
||||
requests_response(''),
|
||||
requests_response(''),
|
||||
)
|
||||
|
||||
got = self.client.post('/web-site', data={'url': 'https://AbC.oRg/'})
|
||||
self.assert_equals(302, got.status_code)
|
||||
self.assert_equals('/web/abc.org', got.headers['Location'])
|
||||
self.assertIsNotNone(Web.get_by_id('abc.org'))
|
||||
self.assertIsNone(Web.get_by_id('AbC.oRg'))
|
||||
|
||||
def test_check_web_site_bad_url(self, _, __):
|
||||
got = self.client.post('/web-site', data={'url': '!!!'})
|
||||
self.assert_equals(200, got.status_code)
|
||||
|
@ -1594,10 +1641,10 @@ http://this/404s
|
|||
requests_response('', status=503),
|
||||
)
|
||||
|
||||
got = self.client.post('/web-site', data={'url': 'https://orig/'})
|
||||
got = self.client.post('/web-site', data={'url': 'https://orig.co/'})
|
||||
self.assert_equals(200, got.status_code, got.headers)
|
||||
self.assertTrue(get_flashed_messages()[0].startswith(
|
||||
"Couldn't connect to https://orig/: "))
|
||||
"Couldn't connect to https://orig.co/: "))
|
||||
|
||||
|
||||
@patch('requests.post')
|
||||
|
|
14
web.py
14
web.py
|
@ -2,6 +2,7 @@
|
|||
import datetime
|
||||
import difflib
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from urllib.parse import urlencode, urljoin, urlparse
|
||||
|
||||
|
@ -62,6 +63,18 @@ class Web(User, Protocol):
|
|||
if username != self.key.id():
|
||||
return util.domain_from_link(username, minimize=False)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
"""Validate domain id, don't allow lower case or invalid characters."""
|
||||
id = self.key.id()
|
||||
assert re.match(common.DOMAIN_RE, id)
|
||||
assert id.lower() == id, f'lower case is not allowed in Web key id: {id}'
|
||||
return super().put(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_or_create(cls, id, **kwargs):
|
||||
"""Lower cases id (domain), then passes through to :meth:`User.get_or_create`."""
|
||||
return super().get_or_create(id.lower(), **kwargs)
|
||||
|
||||
def web_url(self):
|
||||
"""Returns this user's web URL aka web_url, eg 'https://foo.com/'."""
|
||||
return f'https://{self.key.id()}/'
|
||||
|
@ -325,6 +338,7 @@ def enter_web_site():
|
|||
@app.post('/web-site')
|
||||
def check_web_site():
|
||||
url = request.values['url']
|
||||
# this normalizes and lower cases domain
|
||||
domain = util.domain_from_link(url, minimize=False)
|
||||
if not domain:
|
||||
flash(f'No domain found in {url}')
|
||||
|
|
|
@ -228,6 +228,7 @@ def fetch(addr):
|
|||
return data
|
||||
|
||||
|
||||
# TODO: why do we serve this URL? should we drop it?
|
||||
app.add_url_rule(f'/acct:<regex("{common.DOMAIN_RE}"):domain>',
|
||||
view_func=Actor.as_view('actor_acct'))
|
||||
app.add_url_rule('/.well-known/webfinger', view_func=Webfinger.as_view('webfinger'))
|
||||
|
|
Ładowanie…
Reference in New Issue