kopia lustrzana https://github.com/ihabunek/toot
82 wiersze
2.4 KiB
Python
82 wiersze
2.4 KiB
Python
import asyncio
|
|
import logging
|
|
import json
|
|
|
|
from aiohttp import ClientResponse, TraceConfig
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from http import HTTPStatus
|
|
from toot import Context
|
|
from typing import Any, Mapping, Dict, Optional, Tuple
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
Params = Dict[str, str]
|
|
Headers = Dict[str, str]
|
|
Json = Any
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Response():
|
|
body: str
|
|
headers: Mapping[str, str]
|
|
|
|
@property
|
|
# @lru_cache
|
|
def json(self) -> Json:
|
|
return json.loads(self.body)
|
|
|
|
|
|
class ResponseError(Exception):
|
|
"""Raised when the API retruns a response with status code >= 400."""
|
|
def __init__(self, status_code, error, description):
|
|
self.status_code = status_code
|
|
self.error = error
|
|
self.description = description
|
|
|
|
status_message = HTTPStatus(status_code).phrase
|
|
msg = f"HTTP {status_code} {status_message}"
|
|
msg += f". Error: {error}" if error else ""
|
|
msg += f". Description: {description}" if description else ""
|
|
super().__init__(msg)
|
|
|
|
|
|
async def request(ctx: Context, method: str, url: str, **kwargs) -> Response:
|
|
async with ctx.session.request(method, url, **kwargs) as response:
|
|
if not response.ok:
|
|
error, description = await get_error(response)
|
|
raise ResponseError(response.status, error, description)
|
|
|
|
body = await response.text()
|
|
return Response(body, response.headers)
|
|
|
|
|
|
async def get_error(response: ClientResponse) -> Tuple[Optional[str], Optional[str]]:
|
|
"""Attempt to extract the error and error description from response body.
|
|
|
|
See: https://docs.joinmastodon.org/entities/error/
|
|
"""
|
|
try:
|
|
data = await response.json()
|
|
return data.get("error"), data.get("error_description")
|
|
except Exception:
|
|
pass
|
|
|
|
return None, None
|
|
|
|
|
|
def logger_trace_config() -> TraceConfig:
|
|
async def on_request_start(session, context, params):
|
|
context.start = asyncio.get_event_loop().time()
|
|
logger.debug(f"--> {params.method} {params.url}")
|
|
|
|
async def on_request_end(session, context, params):
|
|
elapsed = round(100 * (asyncio.get_event_loop().time() - context.start))
|
|
logger.debug(f"<-- {params.method} {params.url} HTTP {params.response.status} {elapsed}ms")
|
|
|
|
trace_config = TraceConfig()
|
|
trace_config.on_request_start.append(on_request_start)
|
|
trace_config.on_request_end.append(on_request_end)
|
|
return trace_config
|