From dc376f67d8266e536c3db08e6c62c2fdeba005e8 Mon Sep 17 00:00:00 2001 From: Ivan Habunek Date: Thu, 22 Jun 2023 17:23:05 +0200 Subject: [PATCH] wip --- tests/integration/conftest.py | 5 +- tests/test_console.py | 86 ++++++++----------------- toot/__init__.py | 10 +++ toot/aapi.py | 118 ++++++++++++++++++++++++++++++++++ toot/ahttp.py | 81 +++++++++++++++++++++++ toot/commands.py | 36 +++++++---- toot/console.py | 41 ++++++++++-- 7 files changed, 299 insertions(+), 78 deletions(-) create mode 100644 toot/aapi.py create mode 100644 toot/ahttp.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d9421a8..b250ae5 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -13,6 +13,7 @@ export TOOT_TEST_DATABASE_DSN="dbname=mastodon_development" ``` """ +import asyncio import re import os import psycopg2 @@ -95,7 +96,7 @@ def run(app, user, capsys): # The try/catch duplicates logic from console.main to convert exceptions # to printed error messages. TODO: could be deduped try: - run_command(app, as_user or user, command, params or []) + asyncio.run(run_command(app, as_user or user, command, params or [])) except (ConsoleError, ApiError) as e: print_out(str(e)) @@ -108,7 +109,7 @@ def run(app, user, capsys): @pytest.fixture def run_anon(capsys): def _run(command, *params): - run_command(None, None, command, params or []) + asyncio.run(run_command(None, None, command, params or [])) out, err = capsys.readouterr() assert err == "" return strip_ansi(out) diff --git a/tests/test_console.py b/tests/test_console.py index 9f3b835..4cf8502 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -1,3 +1,4 @@ +import asyncio import io import pytest import re @@ -16,6 +17,10 @@ user = User('habunek.com', 'ivan@habunek.com', 'xxx') MockUuid = namedtuple("MockUuid", ["hex"]) +def run_command(app, user, name, args): + return asyncio.run(console.run_command(app, user, name, args)) + + def uncolorize(text): """Remove ANSI color sequences from a string""" return re.sub(r'\x1b[^m]*m', '', text) @@ -35,7 +40,7 @@ def test_post_defaults(mock_post, mock_uuid, capsys): 'url': 'https://habunek.com/@ihabunek/1234567890' }) - console.run_command(app, user, 'post', ['Hello world']) + run_command(app, user, 'post', ['Hello world']) mock_post.assert_called_once_with(app, user, '/api/v1/statuses', json={ 'status': 'Hello world', @@ -67,7 +72,7 @@ def test_post_with_options(mock_post, mock_uuid, capsys): 'url': 'https://habunek.com/@ihabunek/1234567890' }) - console.run_command(app, user, 'post', args) + run_command(app, user, 'post', args) mock_post.assert_called_once_with(app, user, '/api/v1/statuses', json={ 'status': 'Hello world', @@ -89,7 +94,7 @@ def test_post_invalid_visibility(capsys): args = ['Hello world', '--visibility', 'foo'] with pytest.raises(SystemExit): - console.run_command(app, user, 'post', args) + run_command(app, user, 'post', args) out, err = capsys.readouterr() assert "invalid visibility value: 'foo'" in err @@ -99,7 +104,7 @@ def test_post_invalid_media(capsys): args = ['Hello world', '--media', 'does_not_exist.jpg'] with pytest.raises(SystemExit): - console.run_command(app, user, 'post', args) + run_command(app, user, 'post', args) out, err = capsys.readouterr() assert "can't open 'does_not_exist.jpg'" in err @@ -107,7 +112,7 @@ def test_post_invalid_media(capsys): @mock.patch('toot.http.delete') def test_delete(mock_delete, capsys): - console.run_command(app, user, 'delete', ['12321']) + run_command(app, user, 'delete', ['12321']) mock_delete.assert_called_once_with(app, user, '/api/v1/statuses/12321') @@ -131,7 +136,7 @@ def test_timeline(mock_get, monkeypatch, capsys): 'media_attachments': [], }]) - console.run_command(app, user, 'timeline', ['--once']) + run_command(app, user, 'timeline', ['--once']) mock_get.assert_called_once_with(app, user, '/api/v1/timelines/home', {'limit': 10}) @@ -173,7 +178,7 @@ def test_timeline_with_re(mock_get, monkeypatch, capsys): 'media_attachments': [], }]) - console.run_command(app, user, 'timeline', ['--once']) + run_command(app, user, 'timeline', ['--once']) mock_get.assert_called_once_with(app, user, '/api/v1/timelines/home', {'limit': 10}) @@ -235,7 +240,7 @@ def test_thread(mock_get, monkeypatch, capsys): }), ] - console.run_command(app, user, 'thread', ['111111111111111111']) + run_command(app, user, 'thread', ['111111111111111111']) calls = [ mock.call(app, user, '/api/v1/statuses/111111111111111111'), @@ -259,6 +264,7 @@ def test_thread(mock_get, monkeypatch, capsys): assert "111111111111111111" in out assert "In reply to" in out + @mock.patch('toot.http.get') def test_reblogged_by(mock_get, monkeypatch, capsys): mock_get.return_value = MockResponse([{ @@ -269,7 +275,7 @@ def test_reblogged_by(mock_get, monkeypatch, capsys): 'acct': 'dweezil@zappafamily.social', }]) - console.run_command(app, user, 'reblogged_by', ['111111111111111111']) + run_command(app, user, 'reblogged_by', ['111111111111111111']) calls = [ mock.call(app, user, '/api/v1/statuses/111111111111111111/reblogged_by'), @@ -298,7 +304,7 @@ def test_upload(mock_post, capsys): 'type': 'image', }) - console.run_command(app, user, 'upload', [__file__]) + run_command(app, user, 'upload', [__file__]) assert mock_post.call_count == 1 @@ -341,7 +347,7 @@ def test_search(mock_get, capsys): 'statuses': [], }) - console.run_command(app, user, 'search', ['freddy']) + run_command(app, user, 'search', ['freddy']) mock_get.assert_called_once_with(app, user, '/api/v2/search', { 'q': 'freddy', @@ -368,7 +374,7 @@ def test_follow(mock_get, mock_post, capsys): }) mock_post.return_value = MockResponse() - console.run_command(app, user, 'follow', ['blixa']) + run_command(app, user, 'follow', ['blixa']) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True}) mock_post.assert_called_once_with(app, user, '/api/v1/accounts/321/follow') @@ -388,7 +394,7 @@ def test_follow_case_insensitive(mock_get, mock_post, capsys): }) mock_post.return_value = MockResponse() - console.run_command(app, user, 'follow', ['bLiXa@oThEr.aCc']) + run_command(app, user, 'follow', ['bLiXa@oThEr.aCc']) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'bLiXa@oThEr.aCc', 'type': 'accounts', 'resolve': True}) mock_post.assert_called_once_with(app, user, '/api/v1/accounts/123/follow') @@ -402,7 +408,7 @@ def test_follow_not_found(mock_get, capsys): mock_get.return_value = MockResponse({"accounts": []}) with pytest.raises(ConsoleError) as ex: - console.run_command(app, user, 'follow', ['blixa']) + run_command(app, user, 'follow', ['blixa']) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True}) assert "Account not found" == str(ex.value) @@ -420,7 +426,7 @@ def test_unfollow(mock_get, mock_post, capsys): mock_post.return_value = MockResponse() - console.run_command(app, user, 'unfollow', ['blixa']) + run_command(app, user, 'unfollow', ['blixa']) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True}) mock_post.assert_called_once_with(app, user, '/api/v1/accounts/321/unfollow') @@ -434,51 +440,13 @@ def test_unfollow_not_found(mock_get, capsys): mock_get.return_value = MockResponse({"accounts": []}) with pytest.raises(ConsoleError) as ex: - console.run_command(app, user, 'unfollow', ['blixa']) + run_command(app, user, 'unfollow', ['blixa']) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True}) assert "Account not found" == str(ex.value) -@mock.patch('toot.http.get') -def test_whoami(mock_get, capsys): - mock_get.return_value = MockResponse({ - 'acct': 'ihabunek', - 'avatar': 'https://files.mastodon.social/accounts/avatars/000/046/103/original/6a1304e135cac514.jpg?1491312434', - 'avatar_static': 'https://files.mastodon.social/accounts/avatars/000/046/103/original/6a1304e135cac514.jpg?1491312434', - 'created_at': '2017-04-04T13:23:09.777Z', - 'display_name': 'Ivan Habunek', - 'followers_count': 5, - 'following_count': 9, - 'header': '/headers/original/missing.png', - 'header_static': '/headers/original/missing.png', - 'id': 46103, - 'locked': False, - 'note': 'A developer.', - 'statuses_count': 19, - 'url': 'https://mastodon.social/@ihabunek', - 'username': 'ihabunek', - 'fields': [] - }) - - console.run_command(app, user, 'whoami', []) - - mock_get.assert_called_once_with(app, user, '/api/v1/accounts/verify_credentials') - - out, err = capsys.readouterr() - out = uncolorize(out) - - assert "@ihabunek Ivan Habunek" in out - assert "A developer." in out - assert "https://mastodon.social/@ihabunek" in out - assert "ID: 46103" in out - assert "Since: 2017-04-04" in out - assert "Followers: 5" in out - assert "Following: 9" in out - assert "Statuses: 19" in out - - @mock.patch('toot.http.get') def test_notifications(mock_get, capsys): mock_get.return_value = MockResponse([{ @@ -551,7 +519,7 @@ def test_notifications(mock_get, capsys): }, }]) - console.run_command(app, user, 'notifications', []) + run_command(app, user, 'notifications', []) mock_get.assert_called_once_with(app, user, '/api/v1/notifications', {'exclude_types[]': [], 'limit': 20}) @@ -592,7 +560,7 @@ def test_notifications(mock_get, capsys): def test_notifications_empty(mock_get, capsys): mock_get.return_value = MockResponse([]) - console.run_command(app, user, 'notifications', []) + run_command(app, user, 'notifications', []) mock_get.assert_called_once_with(app, user, '/api/v1/notifications', {'exclude_types[]': [], 'limit': 20}) @@ -605,7 +573,7 @@ def test_notifications_empty(mock_get, capsys): @mock.patch('toot.http.post') def test_notifications_clear(mock_post, capsys): - console.run_command(app, user, 'notifications', ['--clear']) + run_command(app, user, 'notifications', ['--clear']) out, err = capsys.readouterr() out = uncolorize(out) @@ -634,7 +602,7 @@ def test_logout(mock_load, mock_save, capsys): "active_user": "king@gizzard.social", } - console.run_command(app, user, "logout", ["king@gizzard.social"]) + run_command(app, user, "logout", ["king@gizzard.social"]) mock_save.assert_called_once_with({ 'users': { @@ -658,7 +626,7 @@ def test_activate(mock_load, mock_save, capsys): "active_user": "king@gizzard.social", } - console.run_command(app, user, "activate", ["lizard@wizard.social"]) + run_command(app, user, "activate", ["lizard@wizard.social"]) mock_save.assert_called_once_with({ 'users': { diff --git a/toot/__init__.py b/toot/__init__.py index 24be0af..3513b13 100644 --- a/toot/__init__.py +++ b/toot/__init__.py @@ -1,4 +1,7 @@ from collections import namedtuple +from dataclasses import dataclass + +from aiohttp import ClientSession __version__ = '0.36.0' @@ -9,3 +12,10 @@ DEFAULT_INSTANCE = 'https://mastodon.social' CLIENT_NAME = 'toot - a Mastodon CLI client' CLIENT_WEBSITE = 'https://github.com/ihabunek/toot' + + +@dataclass +class Context: + app: App + user: User + session: ClientSession diff --git a/toot/aapi.py b/toot/aapi.py new file mode 100644 index 0000000..6dbb945 --- /dev/null +++ b/toot/aapi.py @@ -0,0 +1,118 @@ +from typing import Optional +from uuid import uuid4 + +from toot import Context +from toot.ahttp import Response, request +from toot.exceptions import ConsoleError +from toot.utils import drop_empty_values, str_bool + + +async def find_account(ctx: Context, account_name: str): + if not account_name: + raise ConsoleError("Empty account name given") + + normalized_name = account_name.lstrip("@").lower() + + # Strip @ from accounts on the local instance. The `acct` + # field in account object contains the qualified name for users of other + # instances, but only the username for users of the local instance. This is + # required in order to match the account name below. + if "@" in normalized_name: + [username, instance] = normalized_name.split("@", maxsplit=1) + if instance == ctx.app.instance: + normalized_name = username + + response = await search(ctx, account_name, type="accounts", resolve=True) + accounts = response.json["accounts"] + + for account in accounts: + if account["acct"].lower() == normalized_name: + return account + + raise ConsoleError("Account not found") + + +# ------------------------------------------------------------------------------ +# Accounts +# https://docs.joinmastodon.org/methods/accounts/ +# ------------------------------------------------------------------------------ + + +async def verify_credentials(ctx: Context) -> Response: + """ + Test to make sure that the user token works. + https://docs.joinmastodon.org/methods/accounts/#verify_credentials + """ + return await request(ctx, "GET", "/api/v1/accounts/verify_credentials") + + +# ------------------------------------------------------------------------------ +# Search +# https://docs.joinmastodon.org/methods/search/ +# ------------------------------------------------------------------------------ + +async def search(ctx: Context, query: str, resolve: bool = False, type: Optional[str] = None): + """ + Perform a search. + https://docs.joinmastodon.org/methods/search/#v2 + """ + return await request(ctx, "GET", "/api/v2/search", params={ + "q": query, + "resolve": str_bool(resolve), + "type": type + }) + +# ------------------------------------------------------------------------------ +# Statuses +# https://docs.joinmastodon.org/methods/statuses/ +# ------------------------------------------------------------------------------ + + +async def post_status( + ctx: Context, + status, + visibility='public', + media_ids=None, + sensitive=False, + spoiler_text=None, + in_reply_to_id=None, + language=None, + scheduled_at=None, + content_type=None, + poll_options=None, + poll_expires_in=None, + poll_multiple=None, + poll_hide_totals=None, +): + """ + Publish a new status. + https://docs.joinmastodon.org/methods/statuses/#create + """ + + # Idempotency key assures the same status is not posted multiple times + # if the request is retried. + headers = {"Idempotency-Key": uuid4().hex} + + # Strip keys for which value is None + # Sending null values doesn't bother Mastodon, but it breaks Pleroma + data = drop_empty_values({ + "status": status, + "media_ids": media_ids, + "visibility": visibility, + "sensitive": sensitive, + "in_reply_to_id": in_reply_to_id, + "language": language, + "scheduled_at": scheduled_at, + "content_type": content_type, + "spoiler_text": spoiler_text, + }) + + if poll_options: + data["poll"] = { + "options": poll_options, + "expires_in": poll_expires_in, + "multiple": poll_multiple, + "hide_totals": poll_hide_totals, + } + + return await request(ctx, "POST", "/api/v1/statuses", json=data, headers=headers) diff --git a/toot/ahttp.py b/toot/ahttp.py new file mode 100644 index 0000000..d45e3da --- /dev/null +++ b/toot/ahttp.py @@ -0,0 +1,81 @@ +import asyncio +import logging +import json + +from aiohttp import ClientResponse, TraceConfig +from dataclasses import dataclass +from functools import lru_cache +from http import HTTPStatus +from toot import Context +from typing import Any, Mapping, Dict, Optional, Tuple + + +logger = logging.getLogger(__name__) + +Params = Dict[str, str] +Headers = Dict[str, str] +Json = Dict[str, Any] + + +@dataclass(frozen=True) +class Response(): + body: str + headers: Mapping[str, str] + + @property + # @lru_cache + def json(self) -> Json: + return json.loads(self.body) + + +class ResponseError(Exception): + """Raised when the API retruns a response with status code >= 400.""" + def __init__(self, status_code, error, description): + self.status_code = status_code + self.error = error + self.description = description + + status_message = HTTPStatus(status_code).phrase + msg = f"HTTP {status_code} {status_message}" + msg += f". Error: {error}" if error else "" + msg += f". Description: {description}" if description else "" + super().__init__(msg) + + +async def request(ctx: Context, method: str, url: str, **kwargs) -> Response: + async with ctx.session.request(method, url, **kwargs) as response: + if not response.ok: + error, description = await get_error(response) + raise ResponseError(response.status, error, description) + + body = await response.text() + return Response(body, response.headers) + + +async def get_error(response: ClientResponse) -> Tuple[Optional[str], Optional[str]]: + """Attempt to extract the error and error description from response body. + + See: https://docs.joinmastodon.org/entities/error/ + """ + try: + data = await response.json() + return data.get("error"), data.get("error_description") + except Exception: + pass + + return None, None + + +def logger_trace_config() -> TraceConfig: + async def on_request_start(session, context, params): + context.start = asyncio.get_event_loop().time() + logger.debug(f"--> {params.method} {params.url}") + + async def on_request_end(session, context, params): + elapsed = round(100 * (asyncio.get_event_loop().time() - context.start)) + logger.debug(f"<-- {params.method} {params.url} HTTP {params.response.status} {elapsed}ms") + + trace_config = TraceConfig() + trace_config.on_request_start.append(on_request_start) + trace_config.on_request_end.append(on_request_end) + return trace_config diff --git a/toot/commands.py b/toot/commands.py index 5dad7e8..82867c7 100644 --- a/toot/commands.py +++ b/toot/commands.py @@ -4,7 +4,7 @@ import platform from datetime import datetime, timedelta, timezone from time import sleep, time -from toot import api, config, __version__ +from toot import api, aapi, config, __version__, Context from toot.auth import login_interactive, login_browser_interactive, create_app_interactive from toot.entities import Instance, Notification, Status, from_dict from toot.exceptions import ApiError, ConsoleError @@ -84,22 +84,26 @@ def thread(app, user, args): print_timeline(statuses) -def post(app, user, args): +async def post(ctx, args): if args.editor and not sys.stdin.isatty(): raise ConsoleError("Cannot run editor if not in tty.") if args.media and len(args.media) > 4: raise ConsoleError("Cannot attach more than 4 files.") - media_ids = _upload_media(app, user, args) + # TODO! + # media_ids = _upload_media(app, user, args) + media_ids = [] + status_text = _get_status_text(args.text, args.editor, args.media) scheduled_at = _get_scheduled_at(args.scheduled_at, args.scheduled_in) if not status_text and not media_ids: raise ConsoleError("You must specify either text or media to post.") - response = api.post_status( - app, user, status_text, + response = await aapi.post_status( + ctx, + status_text, visibility=args.visibility, media_ids=media_ids, sensitive=args.sensitive, @@ -114,12 +118,14 @@ def post(app, user, args): poll_hide_totals=args.poll_hide_totals, ) - if "scheduled_at" in response: - scheduled_at = parse_datetime(response["scheduled_at"]) + data = response.json + + if "scheduled_at" in data: + scheduled_at = parse_datetime(data["scheduled_at"]) scheduled_at = datetime.strftime(scheduled_at, "%Y-%m-%d %H:%M:%S%z") print_out(f"Toot scheduled for: {scheduled_at}") else: - print_out(f"Toot posted: {response['url']}") + print_out(f"Toot posted: {data['url']}") delete_tmp_status_file() @@ -499,13 +505,17 @@ def unblock(app, user, args): print_out("✓ {} is no longer blocked".format(args.account)) -def whoami(app, user, args): - account = api.verify_credentials(app, user) - print_account(account) +async def whoami(ctx: Context, args): + response = await aapi.verify_credentials(ctx) + if args.json: + print_out(response.body) + else: + print(response.json) + print_account(response.json) -def whois(app, user, args): - account = api.find_account(app, user, args.account) +async def whois(ctx: Context, args): + account = await aapi.find_account(ctx, args.account) print_account(account) diff --git a/toot/console.py b/toot/console.py index da4f3ce..a24445a 100644 --- a/toot/console.py +++ b/toot/console.py @@ -1,3 +1,4 @@ +import asyncio import logging import os import re @@ -7,7 +8,10 @@ import sys from argparse import ArgumentParser, FileType, ArgumentTypeError, Action from collections import namedtuple from itertools import chain -from toot import config, commands, CLIENT_NAME, CLIENT_WEBSITE, __version__ + +from aiohttp import ClientSession +from toot import App, Context, User, config, commands, CLIENT_NAME, CLIENT_WEBSITE, __version__ +from toot.ahttp import ResponseError, logger_trace_config from toot.exceptions import ApiError, ConsoleError from toot.output import print_out, print_err @@ -178,6 +182,11 @@ common_args = [ "action": 'store_true', "default": False, }), + (["--json"], { + "help": "display output as JSON (experimental, may not work everywhere)", + "action": 'store_true', + "default": False, + }), ] # Arguments added to commands which require authentication @@ -878,7 +887,7 @@ def get_argument_parser(name, command): return parser -def run_command(app, user, name, args): +async def run_command(app, user, name, args): command = next((c for c in COMMANDS if c.name == name), None) if not command: @@ -905,7 +914,25 @@ def run_command(app, user, name, args): if not fn: raise NotImplementedError("Command '{}' does not have an implementation.".format(name)) - return fn(app, user, parsed_args) + if asyncio.iscoroutinefunction(fn): + async with make_session(app, user, parsed_args.debug) as session: + ctx = Context(app, user, session) + return await fn(ctx, parsed_args) + else: + return fn(app, user, parsed_args) + + +def make_session(app: App, user: User, debug: bool) -> ClientSession: + headers = {"User-Agent": f"toot/{__version__}"} + if user: + headers["Authorization"] = f"Bearer {user.access_token}" + trace_configs = [logger_trace_config()] if debug else [] + + return ClientSession( + headers=headers, + base_url=app.base_url, + trace_configs=trace_configs, + ) def main(): @@ -924,9 +951,15 @@ def main(): user, app = config.get_active_user_app() try: - run_command(app, user, command_name, args) + asyncio.run(run_command(app, user, command_name, args)) except (ConsoleError, ApiError) as e: print_err(str(e)) sys.exit(1) + except ResponseError as e: + if e.error: + print_err(e.error) + if e.description: + print_err(e.description) + sys.exit(1) except KeyboardInterrupt: pass