kopia lustrzana https://github.com/snarfed/bridgy-fed
Protocol.enable_protocol: create copy user if necessary
rodzic
6b597c90c3
commit
10023d17fd
|
@ -59,7 +59,7 @@ FEDI_URL_RE = re.compile(r'https://[^/]+/(@|users/)([^/@]+)(@[^/@]+)?(/(?:status
|
||||||
|
|
||||||
# can't use translate_user_id because Web.owns_id checks valid_domain, which
|
# can't use translate_user_id because Web.owns_id checks valid_domain, which
|
||||||
# doesn't allow our protocol subdomains
|
# doesn't allow our protocol subdomains
|
||||||
BOT_ACTOR_IDS = [f'https://{domain}/{domain}' for domain in PROTOCOL_DOMAINS]
|
BOT_ACTOR_IDS = tuple(f'https://{domain}/{domain}' for domain in PROTOCOL_DOMAINS)
|
||||||
|
|
||||||
|
|
||||||
def instance_actor():
|
def instance_actor():
|
||||||
|
|
2
ids.py
2
ids.py
|
@ -17,7 +17,7 @@ import models
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Protocols to check User.copies and Object.copies before translating
|
# Protocols to check User.copies and Object.copies before translating
|
||||||
COPIES_PROTOCOLS = ('atproto', 'fake', 'other')
|
COPIES_PROTOCOLS = ('atproto',)
|
||||||
|
|
||||||
# Web user domains whose AP actor ids are on fed.brid.gy, not web.brid.gy, for
|
# Web user domains whose AP actor ids are on fed.brid.gy, not web.brid.gy, for
|
||||||
# historical compatibility. Loaded on first call to web_ap_subdomain().
|
# historical compatibility. Loaded on first call to web_ap_subdomain().
|
||||||
|
|
|
@ -362,6 +362,8 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
|
||||||
"""
|
"""
|
||||||
user = self.key.get()
|
user = self.key.get()
|
||||||
add(user.enabled_protocols, to_proto.LABEL)
|
add(user.enabled_protocols, to_proto.LABEL)
|
||||||
|
if not user.get_copy(to_proto):
|
||||||
|
to_proto.create_for(user)
|
||||||
user.put()
|
user.put()
|
||||||
|
|
||||||
add(self.enabled_protocols, to_proto.LABEL)
|
add(self.enabled_protocols, to_proto.LABEL)
|
||||||
|
@ -375,6 +377,8 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
|
||||||
"""
|
"""
|
||||||
user = self.key.get()
|
user = self.key.get()
|
||||||
remove(user.enabled_protocols, to_proto.LABEL)
|
remove(user.enabled_protocols, to_proto.LABEL)
|
||||||
|
# TODO: delete copy user
|
||||||
|
# https://github.com/snarfed/bridgy-fed/issues/783
|
||||||
user.put()
|
user.put()
|
||||||
|
|
||||||
remove(self.enabled_protocols, to_proto.LABEL)
|
remove(self.enabled_protocols, to_proto.LABEL)
|
||||||
|
@ -523,7 +527,8 @@ class User(StringIdModel, metaclass=ProtocolUserMeta):
|
||||||
Returns:
|
Returns:
|
||||||
str:
|
str:
|
||||||
"""
|
"""
|
||||||
if isinstance(self, proto):
|
# don't use isinstance because the testutil Fake protocol has subclasses
|
||||||
|
if self.LABEL == proto.LABEL:
|
||||||
return self.key.id()
|
return self.key.id()
|
||||||
|
|
||||||
for copy in self.copies:
|
for copy in self.copies:
|
||||||
|
|
|
@ -445,9 +445,11 @@ class Protocol:
|
||||||
def create_for(cls, user):
|
def create_for(cls, user):
|
||||||
"""Creates a copy user in this protocol.
|
"""Creates a copy user in this protocol.
|
||||||
|
|
||||||
|
Should add the copy user to :attr:`copies`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user (models.User): original source user. Shouldn't already have a
|
user (models.User): original source user. Shouldn't already have a
|
||||||
copy user for this protocol in ``copies``.
|
copy user for this protocol in :attr:`copies`.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -60,6 +60,7 @@ NOTE_AS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch('ids.COPIES_PROTOCOLS', ['atproto'])
|
||||||
class ATProtoTest(TestCase):
|
class ATProtoTest(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -700,7 +701,6 @@ class ATProtoTest(TestCase):
|
||||||
|
|
||||||
mock_create_task.assert_called()
|
mock_create_task.assert_called()
|
||||||
|
|
||||||
@patch('ids.COPIES_PROTOCOLS', ['atproto'])
|
|
||||||
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
||||||
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
||||||
@patch('requests.post',
|
@patch('requests.post',
|
||||||
|
@ -765,7 +765,6 @@ class ATProtoTest(TestCase):
|
||||||
self.assert_task(mock_create_task, 'atproto-commit',
|
self.assert_task(mock_create_task, 'atproto-commit',
|
||||||
'/queue/atproto-commit')
|
'/queue/atproto-commit')
|
||||||
|
|
||||||
@patch('ids.COPIES_PROTOCOLS', ['atproto'])
|
|
||||||
@patch('requests.get', return_value=requests_response(
|
@patch('requests.get', return_value=requests_response(
|
||||||
'blob contents', content_type='image/png')) # image blob fetch
|
'blob contents', content_type='image/png')) # image blob fetch
|
||||||
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
@patch('google.cloud.dns.client.ManagedZone', autospec=True)
|
||||||
|
|
|
@ -32,6 +32,7 @@ PROFILE_GETRECORD = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch('ids.COPIES_PROTOCOLS', ['atproto'])
|
||||||
class IntegrationTests(TestCase):
|
class IntegrationTests(TestCase):
|
||||||
|
|
||||||
@patch('requests.post')
|
@patch('requests.post')
|
||||||
|
|
|
@ -66,11 +66,11 @@ class UserTest(TestCase):
|
||||||
user.direct = True
|
user.direct = True
|
||||||
self.assert_entities_equal(same, user, ignore=['updated'])
|
self.assert_entities_equal(same, user, ignore=['updated'])
|
||||||
|
|
||||||
|
@patch('ids.COPIES_PROTOCOLS', ['fake', 'other'])
|
||||||
def test_get_or_create_propagate_fake_other(self):
|
def test_get_or_create_propagate_fake_other(self):
|
||||||
user = Fake.get_or_create('fake:user', propagate=True)
|
user = Fake.get_or_create('fake:user', propagate=True)
|
||||||
self.assertEqual(['fake:user'], OtherFake.created_for)
|
self.assertEqual(['fake:user'], OtherFake.created_for)
|
||||||
|
|
||||||
@patch('ids.COPIES_PROTOCOLS', ['fake', 'other', 'atproto'])
|
|
||||||
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
@patch.object(tasks_client, 'create_task', return_value=Task(name='my task'))
|
||||||
@patch('requests.post',
|
@patch('requests.post',
|
||||||
return_value=requests_response('OK')) # create DID on PLC
|
return_value=requests_response('OK')) # create DID on PLC
|
||||||
|
@ -112,6 +112,7 @@ class UserTest(TestCase):
|
||||||
|
|
||||||
mock_create_task.assert_called()
|
mock_create_task.assert_called()
|
||||||
|
|
||||||
|
@patch('ids.COPIES_PROTOCOLS', ['eefake', 'atproto'])
|
||||||
@patch.object(tasks_client, 'create_task')
|
@patch.object(tasks_client, 'create_task')
|
||||||
@patch('requests.post')
|
@patch('requests.post')
|
||||||
@patch('requests.get')
|
@patch('requests.get')
|
||||||
|
@ -128,7 +129,6 @@ class UserTest(TestCase):
|
||||||
self.assertEqual([], user.copies)
|
self.assertEqual([], user.copies)
|
||||||
self.assertEqual(0, AtpRepo.query().count())
|
self.assertEqual(0, AtpRepo.query().count())
|
||||||
|
|
||||||
|
|
||||||
def test_get_or_create_use_instead(self):
|
def test_get_or_create_use_instead(self):
|
||||||
user = Fake.get_or_create('a.b')
|
user = Fake.get_or_create('a.b')
|
||||||
user.use_instead = self.user.key
|
user.use_instead = self.user.key
|
||||||
|
@ -283,9 +283,13 @@ class UserTest(TestCase):
|
||||||
user.copies.append(Target(uri='fake:foo', protocol='fake'))
|
user.copies.append(Target(uri='fake:foo', protocol='fake'))
|
||||||
self.assertIsNone(user.get_copy(OtherFake))
|
self.assertIsNone(user.get_copy(OtherFake))
|
||||||
|
|
||||||
|
self.assertIsNone(user.get_copy(OtherFake))
|
||||||
user.copies = [Target(uri='other:foo', protocol='other')]
|
user.copies = [Target(uri='other:foo', protocol='other')]
|
||||||
self.assertEqual('other:foo', user.get_copy(OtherFake))
|
self.assertEqual('other:foo', user.get_copy(OtherFake))
|
||||||
|
|
||||||
|
self.assertIsNone(OtherFake().get_copy(Fake))
|
||||||
|
|
||||||
|
|
||||||
def test_count_followers(self):
|
def test_count_followers(self):
|
||||||
self.assertEqual((0, 0), self.user.count_followers())
|
self.assertEqual((0, 0), self.user.count_followers())
|
||||||
|
|
||||||
|
|
|
@ -1800,12 +1800,14 @@ class ProtocolReceiveTest(TestCase):
|
||||||
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(block))
|
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(block))
|
||||||
user = user.key.get()
|
user = user.key.get()
|
||||||
self.assertEqual([], user.enabled_protocols)
|
self.assertEqual([], user.enabled_protocols)
|
||||||
|
self.assertEqual([], Fake.created_for)
|
||||||
|
|
||||||
# follow should add to enabled_protocols
|
# follow should add to enabled_protocols
|
||||||
with self.assertRaises(NoContent):
|
with self.assertRaises(NoContent):
|
||||||
ExplicitEnableFake.receive_as1(follow)
|
ExplicitEnableFake.receive_as1(follow)
|
||||||
user = user.key.get()
|
user = user.key.get()
|
||||||
self.assertEqual(['fake'], user.enabled_protocols)
|
self.assertEqual(['fake'], user.enabled_protocols)
|
||||||
|
self.assertEqual(['eefake:user'], Fake.created_for)
|
||||||
self.assertTrue(ExplicitEnableFake.is_enabled_to(Fake, user))
|
self.assertTrue(ExplicitEnableFake.is_enabled_to(Fake, user))
|
||||||
self.assertEqual([
|
self.assertEqual([
|
||||||
('https://fa.brid.gy//followers#accept-eefake:follow',
|
('https://fa.brid.gy//followers#accept-eefake:follow',
|
||||||
|
@ -1814,16 +1816,19 @@ class ProtocolReceiveTest(TestCase):
|
||||||
|
|
||||||
# another follow should be a noop
|
# another follow should be a noop
|
||||||
follow['id'] += '2'
|
follow['id'] += '2'
|
||||||
|
Fake.created_for = []
|
||||||
with self.assertRaises(NoContent):
|
with self.assertRaises(NoContent):
|
||||||
ExplicitEnableFake.receive_as1(follow)
|
ExplicitEnableFake.receive_as1(follow)
|
||||||
user = user.key.get()
|
user = user.key.get()
|
||||||
self.assertEqual(['fake'], user.enabled_protocols)
|
self.assertEqual(['fake'], user.enabled_protocols)
|
||||||
|
self.assertEqual([], Fake.created_for)
|
||||||
|
|
||||||
# block should remove from enabled_protocols
|
# block should remove from enabled_protocols
|
||||||
block['id'] += '2'
|
block['id'] += '2'
|
||||||
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(block))
|
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(block))
|
||||||
user = user.key.get()
|
user = user.key.get()
|
||||||
self.assertEqual([], user.enabled_protocols)
|
self.assertEqual([], user.enabled_protocols)
|
||||||
|
self.assertEqual([], Fake.created_for)
|
||||||
self.assertFalse(ExplicitEnableFake.is_enabled_to(Fake, user))
|
self.assertFalse(ExplicitEnableFake.is_enabled_to(Fake, user))
|
||||||
|
|
||||||
def test_dm_no_yes_sets_enabled_protocols(self):
|
def test_dm_no_yes_sets_enabled_protocols(self):
|
||||||
|
@ -1842,6 +1847,7 @@ class ProtocolReceiveTest(TestCase):
|
||||||
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm))
|
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm))
|
||||||
user = user.key.get()
|
user = user.key.get()
|
||||||
self.assertEqual([], user.enabled_protocols)
|
self.assertEqual([], user.enabled_protocols)
|
||||||
|
self.assertEqual([], Fake.created_for)
|
||||||
|
|
||||||
# yes DM should add to enabled_protocols
|
# yes DM should add to enabled_protocols
|
||||||
dm['id'] += '2'
|
dm['id'] += '2'
|
||||||
|
@ -1849,13 +1855,16 @@ class ProtocolReceiveTest(TestCase):
|
||||||
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm))
|
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm))
|
||||||
user = user.key.get()
|
user = user.key.get()
|
||||||
self.assertEqual(['fake'], user.enabled_protocols)
|
self.assertEqual(['fake'], user.enabled_protocols)
|
||||||
|
self.assertEqual(['eefake:user'], Fake.created_for)
|
||||||
self.assertTrue(ExplicitEnableFake.is_enabled_to(Fake, user))
|
self.assertTrue(ExplicitEnableFake.is_enabled_to(Fake, user))
|
||||||
|
|
||||||
# another yes DM should be a noop
|
# another yes DM should be a noop
|
||||||
dm['id'] += '3'
|
dm['id'] += '3'
|
||||||
|
Fake.created_for = []
|
||||||
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm))
|
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm))
|
||||||
user = user.key.get()
|
user = user.key.get()
|
||||||
self.assertEqual(['fake'], user.enabled_protocols)
|
self.assertEqual(['fake'], user.enabled_protocols)
|
||||||
|
self.assertEqual([], Fake.created_for)
|
||||||
|
|
||||||
# block should remove from enabled_protocols
|
# block should remove from enabled_protocols
|
||||||
dm['id'] += '4'
|
dm['id'] += '4'
|
||||||
|
@ -1863,6 +1872,7 @@ class ProtocolReceiveTest(TestCase):
|
||||||
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm))
|
self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm))
|
||||||
user = user.key.get()
|
user = user.key.get()
|
||||||
self.assertEqual([], user.enabled_protocols)
|
self.assertEqual([], user.enabled_protocols)
|
||||||
|
self.assertEqual([], Fake.created_for)
|
||||||
self.assertFalse(ExplicitEnableFake.is_enabled_to(Fake, user))
|
self.assertFalse(ExplicitEnableFake.is_enabled_to(Fake, user))
|
||||||
|
|
||||||
def test_receive_task_handler(self):
|
def test_receive_task_handler(self):
|
||||||
|
|
|
@ -34,7 +34,7 @@ import requests
|
||||||
|
|
||||||
# other modules are imported _after_ Fake etc classes is defined so that it's in
|
# other modules are imported _after_ Fake etc classes is defined so that it's in
|
||||||
# PROTOCOLS when URL routes are registered.
|
# PROTOCOLS when URL routes are registered.
|
||||||
from common import long_to_base64, TASKS_LOCATION
|
from common import add, long_to_base64, TASKS_LOCATION
|
||||||
import ids
|
import ids
|
||||||
import models
|
import models
|
||||||
from models import KEY_BITS, Object, PROTOCOLS, Target, User
|
from models import KEY_BITS, Object, PROTOCOLS, Target, User
|
||||||
|
@ -90,7 +90,11 @@ class Fake(User, protocol.Protocol):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_for(cls, user):
|
def create_for(cls, user):
|
||||||
cls.created_for.append(user.key.id())
|
assert not user.get_copy(cls)
|
||||||
|
id = user.key.id()
|
||||||
|
cls.created_for.append(id)
|
||||||
|
add(user.copies, Target(uri=ids.translate_user_id(id=id, from_=user, to=cls),
|
||||||
|
protocol=cls.LABEL))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def owns_id(cls, id):
|
def owns_id(cls, id):
|
||||||
|
@ -227,7 +231,7 @@ class TestCase(unittest.TestCase, testutil.Asserts):
|
||||||
|
|
||||||
common.OTHER_DOMAINS += ('fake.brid.gy',)
|
common.OTHER_DOMAINS += ('fake.brid.gy',)
|
||||||
common.DOMAINS += ('fake.brid.gy',)
|
common.DOMAINS += ('fake.brid.gy',)
|
||||||
ids.COPIES_PROTOCOLS = ['fake', 'other']
|
ids.COPIES_PROTOCOLS = ('atproto', 'fake', 'other')
|
||||||
|
|
||||||
# make random test data deterministic
|
# make random test data deterministic
|
||||||
arroba.util._clockid = 17
|
arroba.util._clockid = 17
|
||||||
|
|
Ładowanie…
Reference in New Issue