Ivan Habunek 2023-06-22 17:23:05 +02:00
rodzic 835f789145
commit dc376f67d8
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: F5F0623FF5EBCB3D
7 zmienionych plików z 299 dodań i 78 usunięć

Wyświetl plik

@ -13,6 +13,7 @@ export TOOT_TEST_DATABASE_DSN="dbname=mastodon_development"
``` ```
""" """
import asyncio
import re import re
import os import os
import psycopg2 import psycopg2
@ -95,7 +96,7 @@ def run(app, user, capsys):
# The try/catch duplicates logic from console.main to convert exceptions # The try/catch duplicates logic from console.main to convert exceptions
# to printed error messages. TODO: could be deduped # to printed error messages. TODO: could be deduped
try: try:
run_command(app, as_user or user, command, params or []) asyncio.run(run_command(app, as_user or user, command, params or []))
except (ConsoleError, ApiError) as e: except (ConsoleError, ApiError) as e:
print_out(str(e)) print_out(str(e))
@ -108,7 +109,7 @@ def run(app, user, capsys):
@pytest.fixture @pytest.fixture
def run_anon(capsys): def run_anon(capsys):
def _run(command, *params): def _run(command, *params):
run_command(None, None, command, params or []) asyncio.run(run_command(None, None, command, params or []))
out, err = capsys.readouterr() out, err = capsys.readouterr()
assert err == "" assert err == ""
return strip_ansi(out) return strip_ansi(out)

Wyświetl plik

@ -1,3 +1,4 @@
import asyncio
import io import io
import pytest import pytest
import re import re
@ -16,6 +17,10 @@ user = User('habunek.com', 'ivan@habunek.com', 'xxx')
MockUuid = namedtuple("MockUuid", ["hex"]) MockUuid = namedtuple("MockUuid", ["hex"])
def run_command(app, user, name, args):
return asyncio.run(console.run_command(app, user, name, args))
def uncolorize(text): def uncolorize(text):
"""Remove ANSI color sequences from a string""" """Remove ANSI color sequences from a string"""
return re.sub(r'\x1b[^m]*m', '', text) return re.sub(r'\x1b[^m]*m', '', text)
@ -35,7 +40,7 @@ def test_post_defaults(mock_post, mock_uuid, capsys):
'url': 'https://habunek.com/@ihabunek/1234567890' 'url': 'https://habunek.com/@ihabunek/1234567890'
}) })
console.run_command(app, user, 'post', ['Hello world']) run_command(app, user, 'post', ['Hello world'])
mock_post.assert_called_once_with(app, user, '/api/v1/statuses', json={ mock_post.assert_called_once_with(app, user, '/api/v1/statuses', json={
'status': 'Hello world', 'status': 'Hello world',
@ -67,7 +72,7 @@ def test_post_with_options(mock_post, mock_uuid, capsys):
'url': 'https://habunek.com/@ihabunek/1234567890' 'url': 'https://habunek.com/@ihabunek/1234567890'
}) })
console.run_command(app, user, 'post', args) run_command(app, user, 'post', args)
mock_post.assert_called_once_with(app, user, '/api/v1/statuses', json={ mock_post.assert_called_once_with(app, user, '/api/v1/statuses', json={
'status': 'Hello world', 'status': 'Hello world',
@ -89,7 +94,7 @@ def test_post_invalid_visibility(capsys):
args = ['Hello world', '--visibility', 'foo'] args = ['Hello world', '--visibility', 'foo']
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
console.run_command(app, user, 'post', args) run_command(app, user, 'post', args)
out, err = capsys.readouterr() out, err = capsys.readouterr()
assert "invalid visibility value: 'foo'" in err assert "invalid visibility value: 'foo'" in err
@ -99,7 +104,7 @@ def test_post_invalid_media(capsys):
args = ['Hello world', '--media', 'does_not_exist.jpg'] args = ['Hello world', '--media', 'does_not_exist.jpg']
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
console.run_command(app, user, 'post', args) run_command(app, user, 'post', args)
out, err = capsys.readouterr() out, err = capsys.readouterr()
assert "can't open 'does_not_exist.jpg'" in err assert "can't open 'does_not_exist.jpg'" in err
@ -107,7 +112,7 @@ def test_post_invalid_media(capsys):
@mock.patch('toot.http.delete') @mock.patch('toot.http.delete')
def test_delete(mock_delete, capsys): def test_delete(mock_delete, capsys):
console.run_command(app, user, 'delete', ['12321']) run_command(app, user, 'delete', ['12321'])
mock_delete.assert_called_once_with(app, user, '/api/v1/statuses/12321') mock_delete.assert_called_once_with(app, user, '/api/v1/statuses/12321')
@ -131,7 +136,7 @@ def test_timeline(mock_get, monkeypatch, capsys):
'media_attachments': [], 'media_attachments': [],
}]) }])
console.run_command(app, user, 'timeline', ['--once']) run_command(app, user, 'timeline', ['--once'])
mock_get.assert_called_once_with(app, user, '/api/v1/timelines/home', {'limit': 10}) mock_get.assert_called_once_with(app, user, '/api/v1/timelines/home', {'limit': 10})
@ -173,7 +178,7 @@ def test_timeline_with_re(mock_get, monkeypatch, capsys):
'media_attachments': [], 'media_attachments': [],
}]) }])
console.run_command(app, user, 'timeline', ['--once']) run_command(app, user, 'timeline', ['--once'])
mock_get.assert_called_once_with(app, user, '/api/v1/timelines/home', {'limit': 10}) mock_get.assert_called_once_with(app, user, '/api/v1/timelines/home', {'limit': 10})
@ -235,7 +240,7 @@ def test_thread(mock_get, monkeypatch, capsys):
}), }),
] ]
console.run_command(app, user, 'thread', ['111111111111111111']) run_command(app, user, 'thread', ['111111111111111111'])
calls = [ calls = [
mock.call(app, user, '/api/v1/statuses/111111111111111111'), mock.call(app, user, '/api/v1/statuses/111111111111111111'),
@ -259,6 +264,7 @@ def test_thread(mock_get, monkeypatch, capsys):
assert "111111111111111111" in out assert "111111111111111111" in out
assert "In reply to" in out assert "In reply to" in out
@mock.patch('toot.http.get') @mock.patch('toot.http.get')
def test_reblogged_by(mock_get, monkeypatch, capsys): def test_reblogged_by(mock_get, monkeypatch, capsys):
mock_get.return_value = MockResponse([{ mock_get.return_value = MockResponse([{
@ -269,7 +275,7 @@ def test_reblogged_by(mock_get, monkeypatch, capsys):
'acct': 'dweezil@zappafamily.social', 'acct': 'dweezil@zappafamily.social',
}]) }])
console.run_command(app, user, 'reblogged_by', ['111111111111111111']) run_command(app, user, 'reblogged_by', ['111111111111111111'])
calls = [ calls = [
mock.call(app, user, '/api/v1/statuses/111111111111111111/reblogged_by'), mock.call(app, user, '/api/v1/statuses/111111111111111111/reblogged_by'),
@ -298,7 +304,7 @@ def test_upload(mock_post, capsys):
'type': 'image', 'type': 'image',
}) })
console.run_command(app, user, 'upload', [__file__]) run_command(app, user, 'upload', [__file__])
assert mock_post.call_count == 1 assert mock_post.call_count == 1
@ -341,7 +347,7 @@ def test_search(mock_get, capsys):
'statuses': [], 'statuses': [],
}) })
console.run_command(app, user, 'search', ['freddy']) run_command(app, user, 'search', ['freddy'])
mock_get.assert_called_once_with(app, user, '/api/v2/search', { mock_get.assert_called_once_with(app, user, '/api/v2/search', {
'q': 'freddy', 'q': 'freddy',
@ -368,7 +374,7 @@ def test_follow(mock_get, mock_post, capsys):
}) })
mock_post.return_value = MockResponse() mock_post.return_value = MockResponse()
console.run_command(app, user, 'follow', ['blixa']) run_command(app, user, 'follow', ['blixa'])
mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True}) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True})
mock_post.assert_called_once_with(app, user, '/api/v1/accounts/321/follow') mock_post.assert_called_once_with(app, user, '/api/v1/accounts/321/follow')
@ -388,7 +394,7 @@ def test_follow_case_insensitive(mock_get, mock_post, capsys):
}) })
mock_post.return_value = MockResponse() mock_post.return_value = MockResponse()
console.run_command(app, user, 'follow', ['bLiXa@oThEr.aCc']) run_command(app, user, 'follow', ['bLiXa@oThEr.aCc'])
mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'bLiXa@oThEr.aCc', 'type': 'accounts', 'resolve': True}) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'bLiXa@oThEr.aCc', 'type': 'accounts', 'resolve': True})
mock_post.assert_called_once_with(app, user, '/api/v1/accounts/123/follow') mock_post.assert_called_once_with(app, user, '/api/v1/accounts/123/follow')
@ -402,7 +408,7 @@ def test_follow_not_found(mock_get, capsys):
mock_get.return_value = MockResponse({"accounts": []}) mock_get.return_value = MockResponse({"accounts": []})
with pytest.raises(ConsoleError) as ex: with pytest.raises(ConsoleError) as ex:
console.run_command(app, user, 'follow', ['blixa']) run_command(app, user, 'follow', ['blixa'])
mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True}) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True})
assert "Account not found" == str(ex.value) assert "Account not found" == str(ex.value)
@ -420,7 +426,7 @@ def test_unfollow(mock_get, mock_post, capsys):
mock_post.return_value = MockResponse() mock_post.return_value = MockResponse()
console.run_command(app, user, 'unfollow', ['blixa']) run_command(app, user, 'unfollow', ['blixa'])
mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True}) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True})
mock_post.assert_called_once_with(app, user, '/api/v1/accounts/321/unfollow') mock_post.assert_called_once_with(app, user, '/api/v1/accounts/321/unfollow')
@ -434,51 +440,13 @@ def test_unfollow_not_found(mock_get, capsys):
mock_get.return_value = MockResponse({"accounts": []}) mock_get.return_value = MockResponse({"accounts": []})
with pytest.raises(ConsoleError) as ex: with pytest.raises(ConsoleError) as ex:
console.run_command(app, user, 'unfollow', ['blixa']) run_command(app, user, 'unfollow', ['blixa'])
mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True}) mock_get.assert_called_once_with(app, user, '/api/v2/search', {'q': 'blixa', 'type': 'accounts', 'resolve': True})
assert "Account not found" == str(ex.value) assert "Account not found" == str(ex.value)
@mock.patch('toot.http.get')
def test_whoami(mock_get, capsys):
mock_get.return_value = MockResponse({
'acct': 'ihabunek',
'avatar': 'https://files.mastodon.social/accounts/avatars/000/046/103/original/6a1304e135cac514.jpg?1491312434',
'avatar_static': 'https://files.mastodon.social/accounts/avatars/000/046/103/original/6a1304e135cac514.jpg?1491312434',
'created_at': '2017-04-04T13:23:09.777Z',
'display_name': 'Ivan Habunek',
'followers_count': 5,
'following_count': 9,
'header': '/headers/original/missing.png',
'header_static': '/headers/original/missing.png',
'id': 46103,
'locked': False,
'note': 'A developer.',
'statuses_count': 19,
'url': 'https://mastodon.social/@ihabunek',
'username': 'ihabunek',
'fields': []
})
console.run_command(app, user, 'whoami', [])
mock_get.assert_called_once_with(app, user, '/api/v1/accounts/verify_credentials')
out, err = capsys.readouterr()
out = uncolorize(out)
assert "@ihabunek Ivan Habunek" in out
assert "A developer." in out
assert "https://mastodon.social/@ihabunek" in out
assert "ID: 46103" in out
assert "Since: 2017-04-04" in out
assert "Followers: 5" in out
assert "Following: 9" in out
assert "Statuses: 19" in out
@mock.patch('toot.http.get') @mock.patch('toot.http.get')
def test_notifications(mock_get, capsys): def test_notifications(mock_get, capsys):
mock_get.return_value = MockResponse([{ mock_get.return_value = MockResponse([{
@ -551,7 +519,7 @@ def test_notifications(mock_get, capsys):
}, },
}]) }])
console.run_command(app, user, 'notifications', []) run_command(app, user, 'notifications', [])
mock_get.assert_called_once_with(app, user, '/api/v1/notifications', {'exclude_types[]': [], 'limit': 20}) mock_get.assert_called_once_with(app, user, '/api/v1/notifications', {'exclude_types[]': [], 'limit': 20})
@ -592,7 +560,7 @@ def test_notifications(mock_get, capsys):
def test_notifications_empty(mock_get, capsys): def test_notifications_empty(mock_get, capsys):
mock_get.return_value = MockResponse([]) mock_get.return_value = MockResponse([])
console.run_command(app, user, 'notifications', []) run_command(app, user, 'notifications', [])
mock_get.assert_called_once_with(app, user, '/api/v1/notifications', {'exclude_types[]': [], 'limit': 20}) mock_get.assert_called_once_with(app, user, '/api/v1/notifications', {'exclude_types[]': [], 'limit': 20})
@ -605,7 +573,7 @@ def test_notifications_empty(mock_get, capsys):
@mock.patch('toot.http.post') @mock.patch('toot.http.post')
def test_notifications_clear(mock_post, capsys): def test_notifications_clear(mock_post, capsys):
console.run_command(app, user, 'notifications', ['--clear']) run_command(app, user, 'notifications', ['--clear'])
out, err = capsys.readouterr() out, err = capsys.readouterr()
out = uncolorize(out) out = uncolorize(out)
@ -634,7 +602,7 @@ def test_logout(mock_load, mock_save, capsys):
"active_user": "king@gizzard.social", "active_user": "king@gizzard.social",
} }
console.run_command(app, user, "logout", ["king@gizzard.social"]) run_command(app, user, "logout", ["king@gizzard.social"])
mock_save.assert_called_once_with({ mock_save.assert_called_once_with({
'users': { 'users': {
@ -658,7 +626,7 @@ def test_activate(mock_load, mock_save, capsys):
"active_user": "king@gizzard.social", "active_user": "king@gizzard.social",
} }
console.run_command(app, user, "activate", ["lizard@wizard.social"]) run_command(app, user, "activate", ["lizard@wizard.social"])
mock_save.assert_called_once_with({ mock_save.assert_called_once_with({
'users': { 'users': {

Wyświetl plik

@ -1,4 +1,7 @@
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass
from aiohttp import ClientSession
__version__ = '0.36.0' __version__ = '0.36.0'
@ -9,3 +12,10 @@ DEFAULT_INSTANCE = 'https://mastodon.social'
CLIENT_NAME = 'toot - a Mastodon CLI client' CLIENT_NAME = 'toot - a Mastodon CLI client'
CLIENT_WEBSITE = 'https://github.com/ihabunek/toot' CLIENT_WEBSITE = 'https://github.com/ihabunek/toot'
@dataclass
class Context:
app: App
user: User
session: ClientSession

118
toot/aapi.py 100644
Wyświetl plik

@ -0,0 +1,118 @@
from typing import Optional
from uuid import uuid4
from toot import Context
from toot.ahttp import Response, request
from toot.exceptions import ConsoleError
from toot.utils import drop_empty_values, str_bool
async def find_account(ctx: Context, account_name: str):
if not account_name:
raise ConsoleError("Empty account name given")
normalized_name = account_name.lstrip("@").lower()
# Strip @<instance_name> from accounts on the local instance. The `acct`
# field in account object contains the qualified name for users of other
# instances, but only the username for users of the local instance. This is
# required in order to match the account name below.
if "@" in normalized_name:
[username, instance] = normalized_name.split("@", maxsplit=1)
if instance == ctx.app.instance:
normalized_name = username
response = await search(ctx, account_name, type="accounts", resolve=True)
accounts = response.json["accounts"]
for account in accounts:
if account["acct"].lower() == normalized_name:
return account
raise ConsoleError("Account not found")
# ------------------------------------------------------------------------------
# Accounts
# https://docs.joinmastodon.org/methods/accounts/
# ------------------------------------------------------------------------------
async def verify_credentials(ctx: Context) -> Response:
"""
Test to make sure that the user token works.
https://docs.joinmastodon.org/methods/accounts/#verify_credentials
"""
return await request(ctx, "GET", "/api/v1/accounts/verify_credentials")
# ------------------------------------------------------------------------------
# Search
# https://docs.joinmastodon.org/methods/search/
# ------------------------------------------------------------------------------
async def search(ctx: Context, query: str, resolve: bool = False, type: Optional[str] = None):
"""
Perform a search.
https://docs.joinmastodon.org/methods/search/#v2
"""
return await request(ctx, "GET", "/api/v2/search", params={
"q": query,
"resolve": str_bool(resolve),
"type": type
})
# ------------------------------------------------------------------------------
# Statuses
# https://docs.joinmastodon.org/methods/statuses/
# ------------------------------------------------------------------------------
async def post_status(
ctx: Context,
status,
visibility='public',
media_ids=None,
sensitive=False,
spoiler_text=None,
in_reply_to_id=None,
language=None,
scheduled_at=None,
content_type=None,
poll_options=None,
poll_expires_in=None,
poll_multiple=None,
poll_hide_totals=None,
):
"""
Publish a new status.
https://docs.joinmastodon.org/methods/statuses/#create
"""
# Idempotency key assures the same status is not posted multiple times
# if the request is retried.
headers = {"Idempotency-Key": uuid4().hex}
# Strip keys for which value is None
# Sending null values doesn't bother Mastodon, but it breaks Pleroma
data = drop_empty_values({
"status": status,
"media_ids": media_ids,
"visibility": visibility,
"sensitive": sensitive,
"in_reply_to_id": in_reply_to_id,
"language": language,
"scheduled_at": scheduled_at,
"content_type": content_type,
"spoiler_text": spoiler_text,
})
if poll_options:
data["poll"] = {
"options": poll_options,
"expires_in": poll_expires_in,
"multiple": poll_multiple,
"hide_totals": poll_hide_totals,
}
return await request(ctx, "POST", "/api/v1/statuses", json=data, headers=headers)

81
toot/ahttp.py 100644
Wyświetl plik

@ -0,0 +1,81 @@
import asyncio
import logging
import json
from aiohttp import ClientResponse, TraceConfig
from dataclasses import dataclass
from functools import lru_cache
from http import HTTPStatus
from toot import Context
from typing import Any, Mapping, Dict, Optional, Tuple
logger = logging.getLogger(__name__)
Params = Dict[str, str]
Headers = Dict[str, str]
Json = Dict[str, Any]
@dataclass(frozen=True)
class Response():
body: str
headers: Mapping[str, str]
@property
# @lru_cache
def json(self) -> Json:
return json.loads(self.body)
class ResponseError(Exception):
"""Raised when the API retruns a response with status code >= 400."""
def __init__(self, status_code, error, description):
self.status_code = status_code
self.error = error
self.description = description
status_message = HTTPStatus(status_code).phrase
msg = f"HTTP {status_code} {status_message}"
msg += f". Error: {error}" if error else ""
msg += f". Description: {description}" if description else ""
super().__init__(msg)
async def request(ctx: Context, method: str, url: str, **kwargs) -> Response:
async with ctx.session.request(method, url, **kwargs) as response:
if not response.ok:
error, description = await get_error(response)
raise ResponseError(response.status, error, description)
body = await response.text()
return Response(body, response.headers)
async def get_error(response: ClientResponse) -> Tuple[Optional[str], Optional[str]]:
"""Attempt to extract the error and error description from response body.
See: https://docs.joinmastodon.org/entities/error/
"""
try:
data = await response.json()
return data.get("error"), data.get("error_description")
except Exception:
pass
return None, None
def logger_trace_config() -> TraceConfig:
async def on_request_start(session, context, params):
context.start = asyncio.get_event_loop().time()
logger.debug(f"--> {params.method} {params.url}")
async def on_request_end(session, context, params):
elapsed = round(100 * (asyncio.get_event_loop().time() - context.start))
logger.debug(f"<-- {params.method} {params.url} HTTP {params.response.status} {elapsed}ms")
trace_config = TraceConfig()
trace_config.on_request_start.append(on_request_start)
trace_config.on_request_end.append(on_request_end)
return trace_config

Wyświetl plik

@ -4,7 +4,7 @@ import platform
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from time import sleep, time from time import sleep, time
from toot import api, config, __version__ from toot import api, aapi, config, __version__, Context
from toot.auth import login_interactive, login_browser_interactive, create_app_interactive from toot.auth import login_interactive, login_browser_interactive, create_app_interactive
from toot.entities import Instance, Notification, Status, from_dict from toot.entities import Instance, Notification, Status, from_dict
from toot.exceptions import ApiError, ConsoleError from toot.exceptions import ApiError, ConsoleError
@ -84,22 +84,26 @@ def thread(app, user, args):
print_timeline(statuses) print_timeline(statuses)
def post(app, user, args): async def post(ctx, args):
if args.editor and not sys.stdin.isatty(): if args.editor and not sys.stdin.isatty():
raise ConsoleError("Cannot run editor if not in tty.") raise ConsoleError("Cannot run editor if not in tty.")
if args.media and len(args.media) > 4: if args.media and len(args.media) > 4:
raise ConsoleError("Cannot attach more than 4 files.") raise ConsoleError("Cannot attach more than 4 files.")
media_ids = _upload_media(app, user, args) # TODO!
# media_ids = _upload_media(app, user, args)
media_ids = []
status_text = _get_status_text(args.text, args.editor, args.media) status_text = _get_status_text(args.text, args.editor, args.media)
scheduled_at = _get_scheduled_at(args.scheduled_at, args.scheduled_in) scheduled_at = _get_scheduled_at(args.scheduled_at, args.scheduled_in)
if not status_text and not media_ids: if not status_text and not media_ids:
raise ConsoleError("You must specify either text or media to post.") raise ConsoleError("You must specify either text or media to post.")
response = api.post_status( response = await aapi.post_status(
app, user, status_text, ctx,
status_text,
visibility=args.visibility, visibility=args.visibility,
media_ids=media_ids, media_ids=media_ids,
sensitive=args.sensitive, sensitive=args.sensitive,
@ -114,12 +118,14 @@ def post(app, user, args):
poll_hide_totals=args.poll_hide_totals, poll_hide_totals=args.poll_hide_totals,
) )
if "scheduled_at" in response: data = response.json
scheduled_at = parse_datetime(response["scheduled_at"])
if "scheduled_at" in data:
scheduled_at = parse_datetime(data["scheduled_at"])
scheduled_at = datetime.strftime(scheduled_at, "%Y-%m-%d %H:%M:%S%z") scheduled_at = datetime.strftime(scheduled_at, "%Y-%m-%d %H:%M:%S%z")
print_out(f"Toot scheduled for: <green>{scheduled_at}</green>") print_out(f"Toot scheduled for: <green>{scheduled_at}</green>")
else: else:
print_out(f"Toot posted: <green>{response['url']}") print_out(f"Toot posted: <green>{data['url']}")
delete_tmp_status_file() delete_tmp_status_file()
@ -499,13 +505,17 @@ def unblock(app, user, args):
print_out("<green>✓ {} is no longer blocked</green>".format(args.account)) print_out("<green>✓ {} is no longer blocked</green>".format(args.account))
def whoami(app, user, args): async def whoami(ctx: Context, args):
account = api.verify_credentials(app, user) response = await aapi.verify_credentials(ctx)
print_account(account) if args.json:
print_out(response.body)
else:
print(response.json)
print_account(response.json)
def whois(app, user, args): async def whois(ctx: Context, args):
account = api.find_account(app, user, args.account) account = await aapi.find_account(ctx, args.account)
print_account(account) print_account(account)

Wyświetl plik

@ -1,3 +1,4 @@
import asyncio
import logging import logging
import os import os
import re import re
@ -7,7 +8,10 @@ import sys
from argparse import ArgumentParser, FileType, ArgumentTypeError, Action from argparse import ArgumentParser, FileType, ArgumentTypeError, Action
from collections import namedtuple from collections import namedtuple
from itertools import chain from itertools import chain
from toot import config, commands, CLIENT_NAME, CLIENT_WEBSITE, __version__
from aiohttp import ClientSession
from toot import App, Context, User, config, commands, CLIENT_NAME, CLIENT_WEBSITE, __version__
from toot.ahttp import ResponseError, logger_trace_config
from toot.exceptions import ApiError, ConsoleError from toot.exceptions import ApiError, ConsoleError
from toot.output import print_out, print_err from toot.output import print_out, print_err
@ -178,6 +182,11 @@ common_args = [
"action": 'store_true', "action": 'store_true',
"default": False, "default": False,
}), }),
(["--json"], {
"help": "display output as JSON (experimental, may not work everywhere)",
"action": 'store_true',
"default": False,
}),
] ]
# Arguments added to commands which require authentication # Arguments added to commands which require authentication
@ -878,7 +887,7 @@ def get_argument_parser(name, command):
return parser return parser
def run_command(app, user, name, args): async def run_command(app, user, name, args):
command = next((c for c in COMMANDS if c.name == name), None) command = next((c for c in COMMANDS if c.name == name), None)
if not command: if not command:
@ -905,7 +914,25 @@ def run_command(app, user, name, args):
if not fn: if not fn:
raise NotImplementedError("Command '{}' does not have an implementation.".format(name)) raise NotImplementedError("Command '{}' does not have an implementation.".format(name))
return fn(app, user, parsed_args) if asyncio.iscoroutinefunction(fn):
async with make_session(app, user, parsed_args.debug) as session:
ctx = Context(app, user, session)
return await fn(ctx, parsed_args)
else:
return fn(app, user, parsed_args)
def make_session(app: App, user: User, debug: bool) -> ClientSession:
headers = {"User-Agent": f"toot/{__version__}"}
if user:
headers["Authorization"] = f"Bearer {user.access_token}"
trace_configs = [logger_trace_config()] if debug else []
return ClientSession(
headers=headers,
base_url=app.base_url,
trace_configs=trace_configs,
)
def main(): def main():
@ -924,9 +951,15 @@ def main():
user, app = config.get_active_user_app() user, app = config.get_active_user_app()
try: try:
run_command(app, user, command_name, args) asyncio.run(run_command(app, user, command_name, args))
except (ConsoleError, ApiError) as e: except (ConsoleError, ApiError) as e:
print_err(str(e)) print_err(str(e))
sys.exit(1) sys.exit(1)
except ResponseError as e:
if e.error:
print_err(e.error)
if e.description:
print_err(e.description)
sys.exit(1)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass