funkwhale/api/funkwhale_api/common/mutations.py

172 wiersze
5.4 KiB
Python

import persisting_theory
from django.db import models, transaction
from rest_framework import serializers
class ConfNotFound(KeyError):
pass
class Registry(persisting_theory.Registry):
look_into = "mutations"
def connect(self, type, klass, perm_checkers=None):
def decorator(serializer_class):
t = self.setdefault(type, {})
t[klass] = {
"serializer_class": serializer_class,
"perm_checkers": perm_checkers or {},
}
return serializer_class
return decorator
@transaction.atomic
def apply(self, type, obj, payload):
conf = self.get_conf(type, obj)
serializer = conf["serializer_class"](obj, data=payload)
serializer.is_valid(raise_exception=True)
previous_state = serializer.get_previous_state(obj, serializer.validated_data)
serializer.apply(obj, serializer.validated_data)
return previous_state
def is_valid(self, type, obj, payload):
conf = self.get_conf(type, obj)
serializer = conf["serializer_class"](obj, data=payload)
return serializer.is_valid(raise_exception=True)
def get_validated_payload(self, type, obj, payload):
conf = self.get_conf(type, obj)
serializer = conf["serializer_class"](obj, data=payload)
serializer.is_valid(raise_exception=True)
return serializer.payload_serialize(serializer.validated_data)
def has_perm(self, perm, type, obj, actor):
if perm not in ["approve", "suggest"]:
raise ValueError(f"Invalid permission {perm}")
conf = self.get_conf(type, obj)
checker = conf["perm_checkers"].get(perm)
if not checker:
return False
return checker(obj=obj, actor=actor)
def get_conf(self, type, obj):
try:
type_conf = self[type]
except KeyError:
raise ConfNotFound(f"{type} is not a registered mutation")
try:
conf = type_conf[obj.__class__]
except KeyError:
try:
conf = type_conf[None]
except KeyError:
raise ConfNotFound(
f"No mutation configuration found for {obj.__class__}"
)
return conf
class MutationSerializer(serializers.Serializer):
def apply(self, obj, validated_data):
raise NotImplementedError()
def post_apply(self, obj, validated_data):
pass
def get_previous_state(self, obj, validated_data):
return
def payload_serialize(self, data):
return data
class UpdateMutationSerializer(serializers.ModelSerializer, MutationSerializer):
def __init__(self, *args, **kwargs):
# we force partial mode, because update mutations are partial
kwargs.setdefault("partial", True)
super().__init__(*args, **kwargs)
@transaction.atomic
def apply(self, obj, validated_data):
r = self.update(obj, validated_data)
self.post_apply(r, validated_data)
return r
def validate(self, validated_data):
if not validated_data:
raise serializers.ValidationError("You must update at least one field")
return super().validate(validated_data)
def db_serialize(self, validated_data):
serialized_relations = self.get_serialized_relations()
data = {}
# ensure model fields are serialized properly
for key, value in list(validated_data.items()):
if not isinstance(value, models.Model):
data[key] = value
continue
field = serialized_relations[key]
data[key] = getattr(value, field)
return data
def payload_serialize(self, data):
data = super().payload_serialize(data)
# we use our serialized_relations configuration
# to ensure we store ids instead of model instances in our json
# payload
for field, attr in self.get_serialized_relations().items():
try:
obj = data[field]
except KeyError:
continue
if obj is None:
data[field] = None
else:
data[field] = getattr(obj, attr)
return data
def create(self, validated_data):
validated_data = self.db_serialize(validated_data)
return super().create(validated_data)
def get_previous_state(self, obj, validated_data):
return get_update_previous_state(
obj,
*list(validated_data.keys()),
serialized_relations=self.get_serialized_relations(),
handlers=self.get_previous_state_handlers(),
)
def get_serialized_relations(self):
return {}
def get_previous_state_handlers(self):
return {}
def get_update_previous_state(obj, *fields, serialized_relations={}, handlers={}):
if not fields:
raise ValueError("You need to provide at least one field")
state = {}
for field in fields:
if field in handlers:
state[field] = handlers[field](obj)
continue
value = getattr(obj, field)
if isinstance(value, models.Model):
# we store the related object id and repr for better UX
id_field = serialized_relations[field]
related_value = getattr(value, id_field)
state[field] = {"value": related_value, "repr": str(value)}
else:
state[field] = {"value": value}
return state
registry = Registry()