From 10023d17fd53b749882d2c7059f30c052aec2bfe Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Sun, 21 Apr 2024 12:18:12 -0700 Subject: [PATCH] Protocol.enable_protocol: create copy user if necessary --- activitypub.py | 2 +- ids.py | 2 +- models.py | 7 ++++++- protocol.py | 4 +++- tests/test_atproto.py | 3 +-- tests/test_integrations.py | 1 + tests/test_models.py | 8 ++++++-- tests/test_protocol.py | 10 ++++++++++ tests/testutil.py | 10 +++++++--- 9 files changed, 36 insertions(+), 11 deletions(-) diff --git a/activitypub.py b/activitypub.py index 802d873..2a20a96 100644 --- a/activitypub.py +++ b/activitypub.py @@ -59,7 +59,7 @@ FEDI_URL_RE = re.compile(r'https://[^/]+/(@|users/)([^/@]+)(@[^/@]+)?(/(?:status # can't use translate_user_id because Web.owns_id checks valid_domain, which # doesn't allow our protocol subdomains -BOT_ACTOR_IDS = [f'https://{domain}/{domain}' for domain in PROTOCOL_DOMAINS] +BOT_ACTOR_IDS = tuple(f'https://{domain}/{domain}' for domain in PROTOCOL_DOMAINS) def instance_actor(): diff --git a/ids.py b/ids.py index a0f628a..14d6181 100644 --- a/ids.py +++ b/ids.py @@ -17,7 +17,7 @@ import models logger = logging.getLogger(__name__) # Protocols to check User.copies and Object.copies before translating -COPIES_PROTOCOLS = ('atproto', 'fake', 'other') +COPIES_PROTOCOLS = ('atproto',) # Web user domains whose AP actor ids are on fed.brid.gy, not web.brid.gy, for # historical compatibility. Loaded on first call to web_ap_subdomain(). diff --git a/models.py b/models.py index 9d3cf32..10f14d7 100644 --- a/models.py +++ b/models.py @@ -362,6 +362,8 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): """ user = self.key.get() add(user.enabled_protocols, to_proto.LABEL) + if not user.get_copy(to_proto): + to_proto.create_for(user) user.put() add(self.enabled_protocols, to_proto.LABEL) @@ -375,6 +377,8 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): """ user = self.key.get() remove(user.enabled_protocols, to_proto.LABEL) + # TODO: delete copy user + # https://github.com/snarfed/bridgy-fed/issues/783 user.put() remove(self.enabled_protocols, to_proto.LABEL) @@ -523,7 +527,8 @@ class User(StringIdModel, metaclass=ProtocolUserMeta): Returns: str: """ - if isinstance(self, proto): + # don't use isinstance because the testutil Fake protocol has subclasses + if self.LABEL == proto.LABEL: return self.key.id() for copy in self.copies: diff --git a/protocol.py b/protocol.py index 78f0178..5726999 100644 --- a/protocol.py +++ b/protocol.py @@ -445,9 +445,11 @@ class Protocol: def create_for(cls, user): """Creates a copy user in this protocol. + Should add the copy user to :attr:`copies`. + Args: user (models.User): original source user. Shouldn't already have a - copy user for this protocol in ``copies``. + copy user for this protocol in :attr:`copies`. """ raise NotImplementedError() diff --git a/tests/test_atproto.py b/tests/test_atproto.py index 0a771f0..3144375 100644 --- a/tests/test_atproto.py +++ b/tests/test_atproto.py @@ -60,6 +60,7 @@ NOTE_AS = { } +@patch('ids.COPIES_PROTOCOLS', ['atproto']) class ATProtoTest(TestCase): def setUp(self): @@ -700,7 +701,6 @@ class ATProtoTest(TestCase): mock_create_task.assert_called() - @patch('ids.COPIES_PROTOCOLS', ['atproto']) @patch('google.cloud.dns.client.ManagedZone', autospec=True) @patch.object(tasks_client, 'create_task', return_value=Task(name='my task')) @patch('requests.post', @@ -765,7 +765,6 @@ class ATProtoTest(TestCase): self.assert_task(mock_create_task, 'atproto-commit', '/queue/atproto-commit') - @patch('ids.COPIES_PROTOCOLS', ['atproto']) @patch('requests.get', return_value=requests_response( 'blob contents', content_type='image/png')) # image blob fetch @patch('google.cloud.dns.client.ManagedZone', autospec=True) diff --git a/tests/test_integrations.py b/tests/test_integrations.py index 298becc..8c6d4cc 100644 --- a/tests/test_integrations.py +++ b/tests/test_integrations.py @@ -32,6 +32,7 @@ PROFILE_GETRECORD = { } +@patch('ids.COPIES_PROTOCOLS', ['atproto']) class IntegrationTests(TestCase): @patch('requests.post') diff --git a/tests/test_models.py b/tests/test_models.py index 3e4cc8d..4a76f42 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -66,11 +66,11 @@ class UserTest(TestCase): user.direct = True self.assert_entities_equal(same, user, ignore=['updated']) + @patch('ids.COPIES_PROTOCOLS', ['fake', 'other']) def test_get_or_create_propagate_fake_other(self): user = Fake.get_or_create('fake:user', propagate=True) self.assertEqual(['fake:user'], OtherFake.created_for) - @patch('ids.COPIES_PROTOCOLS', ['fake', 'other', 'atproto']) @patch.object(tasks_client, 'create_task', return_value=Task(name='my task')) @patch('requests.post', return_value=requests_response('OK')) # create DID on PLC @@ -112,6 +112,7 @@ class UserTest(TestCase): mock_create_task.assert_called() + @patch('ids.COPIES_PROTOCOLS', ['eefake', 'atproto']) @patch.object(tasks_client, 'create_task') @patch('requests.post') @patch('requests.get') @@ -128,7 +129,6 @@ class UserTest(TestCase): self.assertEqual([], user.copies) self.assertEqual(0, AtpRepo.query().count()) - def test_get_or_create_use_instead(self): user = Fake.get_or_create('a.b') user.use_instead = self.user.key @@ -283,9 +283,13 @@ class UserTest(TestCase): user.copies.append(Target(uri='fake:foo', protocol='fake')) self.assertIsNone(user.get_copy(OtherFake)) + self.assertIsNone(user.get_copy(OtherFake)) user.copies = [Target(uri='other:foo', protocol='other')] self.assertEqual('other:foo', user.get_copy(OtherFake)) + self.assertIsNone(OtherFake().get_copy(Fake)) + + def test_count_followers(self): self.assertEqual((0, 0), self.user.count_followers()) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8b6d9f4..3763635 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1800,12 +1800,14 @@ class ProtocolReceiveTest(TestCase): self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(block)) user = user.key.get() self.assertEqual([], user.enabled_protocols) + self.assertEqual([], Fake.created_for) # follow should add to enabled_protocols with self.assertRaises(NoContent): ExplicitEnableFake.receive_as1(follow) user = user.key.get() self.assertEqual(['fake'], user.enabled_protocols) + self.assertEqual(['eefake:user'], Fake.created_for) self.assertTrue(ExplicitEnableFake.is_enabled_to(Fake, user)) self.assertEqual([ ('https://fa.brid.gy//followers#accept-eefake:follow', @@ -1814,16 +1816,19 @@ class ProtocolReceiveTest(TestCase): # another follow should be a noop follow['id'] += '2' + Fake.created_for = [] with self.assertRaises(NoContent): ExplicitEnableFake.receive_as1(follow) user = user.key.get() self.assertEqual(['fake'], user.enabled_protocols) + self.assertEqual([], Fake.created_for) # block should remove from enabled_protocols block['id'] += '2' self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(block)) user = user.key.get() self.assertEqual([], user.enabled_protocols) + self.assertEqual([], Fake.created_for) self.assertFalse(ExplicitEnableFake.is_enabled_to(Fake, user)) def test_dm_no_yes_sets_enabled_protocols(self): @@ -1842,6 +1847,7 @@ class ProtocolReceiveTest(TestCase): self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm)) user = user.key.get() self.assertEqual([], user.enabled_protocols) + self.assertEqual([], Fake.created_for) # yes DM should add to enabled_protocols dm['id'] += '2' @@ -1849,13 +1855,16 @@ class ProtocolReceiveTest(TestCase): self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm)) user = user.key.get() self.assertEqual(['fake'], user.enabled_protocols) + self.assertEqual(['eefake:user'], Fake.created_for) self.assertTrue(ExplicitEnableFake.is_enabled_to(Fake, user)) # another yes DM should be a noop dm['id'] += '3' + Fake.created_for = [] self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm)) user = user.key.get() self.assertEqual(['fake'], user.enabled_protocols) + self.assertEqual([], Fake.created_for) # block should remove from enabled_protocols dm['id'] += '4' @@ -1863,6 +1872,7 @@ class ProtocolReceiveTest(TestCase): self.assertEqual(('OK', 200), ExplicitEnableFake.receive_as1(dm)) user = user.key.get() self.assertEqual([], user.enabled_protocols) + self.assertEqual([], Fake.created_for) self.assertFalse(ExplicitEnableFake.is_enabled_to(Fake, user)) def test_receive_task_handler(self): diff --git a/tests/testutil.py b/tests/testutil.py index 0e08bac..e5eb32d 100644 --- a/tests/testutil.py +++ b/tests/testutil.py @@ -34,7 +34,7 @@ import requests # other modules are imported _after_ Fake etc classes is defined so that it's in # PROTOCOLS when URL routes are registered. -from common import long_to_base64, TASKS_LOCATION +from common import add, long_to_base64, TASKS_LOCATION import ids import models from models import KEY_BITS, Object, PROTOCOLS, Target, User @@ -90,7 +90,11 @@ class Fake(User, protocol.Protocol): @classmethod def create_for(cls, user): - cls.created_for.append(user.key.id()) + assert not user.get_copy(cls) + id = user.key.id() + cls.created_for.append(id) + add(user.copies, Target(uri=ids.translate_user_id(id=id, from_=user, to=cls), + protocol=cls.LABEL)) @classmethod def owns_id(cls, id): @@ -227,7 +231,7 @@ class TestCase(unittest.TestCase, testutil.Asserts): common.OTHER_DOMAINS += ('fake.brid.gy',) common.DOMAINS += ('fake.brid.gy',) - ids.COPIES_PROTOCOLS = ['fake', 'other'] + ids.COPIES_PROTOCOLS = ('atproto', 'fake', 'other') # make random test data deterministic arroba.util._clockid = 17