123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658 |
- from __future__ import unicode_literals
- from collections import defaultdict
- from django.contrib.contenttypes.models import ContentType
- from django.core import checks
- from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
- from django.db import DEFAULT_DB_ALIAS, models, router, transaction
- from django.db.models import DO_NOTHING, signals
- from django.db.models.base import ModelBase, make_foreign_order_accessors
- from django.db.models.fields.related import (
- ForeignObject, ForeignObjectRel, ReverseManyToOneDescriptor,
- lazy_related_operation,
- )
- from django.db.models.query_utils import PathInfo
- from django.utils.encoding import python_2_unicode_compatible, smart_text
- from django.utils.functional import cached_property
- @python_2_unicode_compatible
- class GenericForeignKey(object):
- """
- Provide a generic many-to-one relation through the ``content_type`` and
- ``object_id`` fields.
- This class also doubles as an accessor to the related object (similar to
- ForwardManyToOneDescriptor) by adding itself as a model attribute.
- """
- # Field flags
- auto_created = False
- concrete = False
- editable = False
- hidden = False
- is_relation = True
- many_to_many = False
- many_to_one = True
- one_to_many = False
- one_to_one = False
- related_model = None
- remote_field = None
- def __init__(self, ct_field='content_type', fk_field='object_id', for_concrete_model=True):
- self.ct_field = ct_field
- self.fk_field = fk_field
- self.for_concrete_model = for_concrete_model
- self.editable = False
- self.rel = None
- self.column = None
- def contribute_to_class(self, cls, name, **kwargs):
- self.name = name
- self.model = cls
- self.cache_attr = "_%s_cache" % name
- cls._meta.add_field(self, virtual=True)
- # Only run pre-initialization field assignment on non-abstract models
- if not cls._meta.abstract:
- signals.pre_init.connect(self.instance_pre_init, sender=cls)
- setattr(cls, name, self)
- def get_filter_kwargs_for_object(self, obj):
- """See corresponding method on Field"""
- return {
- self.fk_field: getattr(obj, self.fk_field),
- self.ct_field: getattr(obj, self.ct_field),
- }
- def get_forward_related_filter(self, obj):
- """See corresponding method on RelatedField"""
- return {
- self.fk_field: obj.pk,
- self.ct_field: ContentType.objects.get_for_model(obj).pk,
- }
- def __str__(self):
- model = self.model
- app = model._meta.app_label
- return '%s.%s.%s' % (app, model._meta.object_name, self.name)
- def check(self, **kwargs):
- errors = []
- errors.extend(self._check_field_name())
- errors.extend(self._check_object_id_field())
- errors.extend(self._check_content_type_field())
- return errors
- def _check_field_name(self):
- if self.name.endswith("_"):
- return [
- checks.Error(
- 'Field names must not end with an underscore.',
- obj=self,
- id='fields.E001',
- )
- ]
- else:
- return []
- def _check_object_id_field(self):
- try:
- self.model._meta.get_field(self.fk_field)
- except FieldDoesNotExist:
- return [
- checks.Error(
- "The GenericForeignKey object ID references the non-existent field '%s'." % self.fk_field,
- obj=self,
- id='contenttypes.E001',
- )
- ]
- else:
- return []
- def _check_content_type_field(self):
- """
- Check if field named `field_name` in model `model` exists and is a
- valid content_type field (is a ForeignKey to ContentType).
- """
- try:
- field = self.model._meta.get_field(self.ct_field)
- except FieldDoesNotExist:
- return [
- checks.Error(
- "The GenericForeignKey content type references the non-existent field '%s.%s'." % (
- self.model._meta.object_name, self.ct_field
- ),
- obj=self,
- id='contenttypes.E002',
- )
- ]
- else:
- if not isinstance(field, models.ForeignKey):
- return [
- checks.Error(
- "'%s.%s' is not a ForeignKey." % (
- self.model._meta.object_name, self.ct_field
- ),
- hint=(
- "GenericForeignKeys must use a ForeignKey to "
- "'contenttypes.ContentType' as the 'content_type' field."
- ),
- obj=self,
- id='contenttypes.E003',
- )
- ]
- elif field.remote_field.model != ContentType:
- return [
- checks.Error(
- "'%s.%s' is not a ForeignKey to 'contenttypes.ContentType'." % (
- self.model._meta.object_name, self.ct_field
- ),
- hint=(
- "GenericForeignKeys must use a ForeignKey to "
- "'contenttypes.ContentType' as the 'content_type' field."
- ),
- obj=self,
- id='contenttypes.E004',
- )
- ]
- else:
- return []
- def instance_pre_init(self, signal, sender, args, kwargs, **_kwargs):
- """
- Handle initializing an object with the generic FK instead of
- content_type and object_id fields.
- """
- if self.name in kwargs:
- value = kwargs.pop(self.name)
- if value is not None:
- kwargs[self.ct_field] = self.get_content_type(obj=value)
- kwargs[self.fk_field] = value._get_pk_val()
- else:
- kwargs[self.ct_field] = None
- kwargs[self.fk_field] = None
- def get_content_type(self, obj=None, id=None, using=None):
- if obj is not None:
- return ContentType.objects.db_manager(obj._state.db).get_for_model(
- obj, for_concrete_model=self.for_concrete_model)
- elif id is not None:
- return ContentType.objects.db_manager(using).get_for_id(id)
- else:
- # This should never happen. I love comments like this, don't you?
- raise Exception("Impossible arguments to GFK.get_content_type!")
- def get_prefetch_queryset(self, instances, queryset=None):
- if queryset is not None:
- raise ValueError("Custom queryset can't be used for this lookup.")
- # For efficiency, group the instances by content type and then do one
- # query per model
- fk_dict = defaultdict(set)
- # We need one instance for each group in order to get the right db:
- instance_dict = {}
- ct_attname = self.model._meta.get_field(self.ct_field).get_attname()
- for instance in instances:
- # We avoid looking for values if either ct_id or fkey value is None
- ct_id = getattr(instance, ct_attname)
- if ct_id is not None:
- fk_val = getattr(instance, self.fk_field)
- if fk_val is not None:
- fk_dict[ct_id].add(fk_val)
- instance_dict[ct_id] = instance
- ret_val = []
- for ct_id, fkeys in fk_dict.items():
- instance = instance_dict[ct_id]
- ct = self.get_content_type(id=ct_id, using=instance._state.db)
- ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
- # For doing the join in Python, we have to match both the FK val and the
- # content type, so we use a callable that returns a (fk, class) pair.
- def gfk_key(obj):
- ct_id = getattr(obj, ct_attname)
- if ct_id is None:
- return None
- else:
- model = self.get_content_type(id=ct_id,
- using=obj._state.db).model_class()
- return (model._meta.pk.get_prep_value(getattr(obj, self.fk_field)),
- model)
- return (ret_val,
- lambda obj: (obj._get_pk_val(), obj.__class__),
- gfk_key,
- True,
- self.cache_attr)
- def is_cached(self, instance):
- return hasattr(instance, self.cache_attr)
- def __get__(self, instance, cls=None):
- if instance is None:
- return self
- try:
- return getattr(instance, self.cache_attr)
- except AttributeError:
- rel_obj = None
- # Make sure to use ContentType.objects.get_for_id() to ensure that
- # lookups are cached (see ticket #5570). This takes more code than
- # the naive ``getattr(instance, self.ct_field)``, but has better
- # performance when dealing with GFKs in loops and such.
- f = self.model._meta.get_field(self.ct_field)
- ct_id = getattr(instance, f.get_attname(), None)
- if ct_id is not None:
- ct = self.get_content_type(id=ct_id, using=instance._state.db)
- try:
- rel_obj = ct.get_object_for_this_type(pk=getattr(instance, self.fk_field))
- except ObjectDoesNotExist:
- pass
- setattr(instance, self.cache_attr, rel_obj)
- return rel_obj
- def __set__(self, instance, value):
- ct = None
- fk = None
- if value is not None:
- ct = self.get_content_type(obj=value)
- fk = value._get_pk_val()
- setattr(instance, self.ct_field, ct)
- setattr(instance, self.fk_field, fk)
- setattr(instance, self.cache_attr, value)
- class GenericRel(ForeignObjectRel):
- """
- Used by GenericRelation to store information about the relation.
- """
- def __init__(self, field, to, related_name=None, related_query_name=None, limit_choices_to=None):
- super(GenericRel, self).__init__(
- field, to,
- related_name=related_query_name or '+',
- related_query_name=related_query_name,
- limit_choices_to=limit_choices_to,
- on_delete=DO_NOTHING,
- )
- class GenericRelation(ForeignObject):
- """
- Provide a reverse to a relation created by a GenericForeignKey.
- """
- # Field flags
- auto_created = False
- many_to_many = False
- many_to_one = False
- one_to_many = True
- one_to_one = False
- rel_class = GenericRel
- def __init__(self, to, object_id_field='object_id', content_type_field='content_type',
- for_concrete_model=True, related_query_name=None, limit_choices_to=None, **kwargs):
- kwargs['rel'] = self.rel_class(
- self, to,
- related_query_name=related_query_name,
- limit_choices_to=limit_choices_to,
- )
- kwargs['blank'] = True
- kwargs['on_delete'] = models.CASCADE
- kwargs['editable'] = False
- kwargs['serialize'] = False
- # This construct is somewhat of an abuse of ForeignObject. This field
- # represents a relation from pk to object_id field. But, this relation
- # isn't direct, the join is generated reverse along foreign key. So,
- # the from_field is object_id field, to_field is pk because of the
- # reverse join.
- super(GenericRelation, self).__init__(
- to, from_fields=[object_id_field], to_fields=[], **kwargs)
- self.object_id_field_name = object_id_field
- self.content_type_field_name = content_type_field
- self.for_concrete_model = for_concrete_model
- def check(self, **kwargs):
- errors = super(GenericRelation, self).check(**kwargs)
- errors.extend(self._check_generic_foreign_key_existence())
- return errors
- def _check_generic_foreign_key_existence(self):
- target = self.remote_field.model
- if isinstance(target, ModelBase):
- fields = target._meta.virtual_fields
- if any(isinstance(field, GenericForeignKey) and
- field.ct_field == self.content_type_field_name and
- field.fk_field == self.object_id_field_name
- for field in fields):
- return []
- else:
- return [
- checks.Error(
- "The GenericRelation defines a relation with the model "
- "'%s.%s', but that model does not have a GenericForeignKey." % (
- target._meta.app_label, target._meta.object_name
- ),
- obj=self,
- id='contenttypes.E004',
- )
- ]
- else:
- return []
- def resolve_related_fields(self):
- self.to_fields = [self.model._meta.pk.name]
- return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)]
- def _get_path_info_with_parent(self):
- """
- Return the path that joins the current model through any parent models.
- The idea is that if you have a GFK defined on a parent model then we
- need to join the parent model first, then the child model.
- """
- # With an inheritance chain ChildTag -> Tag and Tag defines the
- # GenericForeignKey, and a TaggedItem model has a GenericRelation to
- # ChildTag, then we need to generate a join from TaggedItem to Tag
- # (as Tag.object_id == TaggedItem.pk), and another join from Tag to
- # ChildTag (as that is where the relation is to). Do this by first
- # generating a join to the parent model, then generating joins to the
- # child models.
- path = []
- opts = self.remote_field.model._meta
- parent_opts = opts.get_field(self.object_id_field_name).model._meta
- target = parent_opts.pk
- path.append(PathInfo(self.model._meta, parent_opts, (target,), self.remote_field, True, False))
- # Collect joins needed for the parent -> child chain. This is easiest
- # to do if we collect joins for the child -> parent chain and then
- # reverse the direction (call to reverse() and use of
- # field.remote_field.get_path_info()).
- parent_field_chain = []
- while parent_opts != opts:
- field = opts.get_ancestor_link(parent_opts.model)
- parent_field_chain.append(field)
- opts = field.remote_field.model._meta
- parent_field_chain.reverse()
- for field in parent_field_chain:
- path.extend(field.remote_field.get_path_info())
- return path
- def get_path_info(self):
- opts = self.remote_field.model._meta
- object_id_field = opts.get_field(self.object_id_field_name)
- if object_id_field.model != opts.model:
- return self._get_path_info_with_parent()
- else:
- target = opts.pk
- return [PathInfo(self.model._meta, opts, (target,), self.remote_field, True, False)]
- def get_reverse_path_info(self):
- opts = self.model._meta
- from_opts = self.remote_field.model._meta
- return [PathInfo(from_opts, opts, (opts.pk,), self, not self.unique, False)]
- def get_choices_default(self):
- return super(GenericRelation, self).get_choices(include_blank=False)
- def value_to_string(self, obj):
- qs = getattr(obj, self.name).all()
- return smart_text([instance._get_pk_val() for instance in qs])
- def contribute_to_class(self, cls, name, **kwargs):
- kwargs['virtual_only'] = True
- super(GenericRelation, self).contribute_to_class(cls, name, **kwargs)
- self.model = cls
- setattr(cls, self.name, ReverseGenericManyToOneDescriptor(self.remote_field))
- # Add get_RELATED_order() and set_RELATED_order() methods if the model
- # on the other end of this relation is ordered with respect to this.
- def matching_gfk(field):
- return (
- isinstance(field, GenericForeignKey) and
- self.content_type_field_name == field.ct_field and
- self.object_id_field_name == field.fk_field
- )
- def make_generic_foreign_order_accessors(related_model, model):
- if matching_gfk(model._meta.order_with_respect_to):
- make_foreign_order_accessors(model, related_model)
- lazy_related_operation(make_generic_foreign_order_accessors, self.model, self.remote_field.model)
- def set_attributes_from_rel(self):
- pass
- def get_internal_type(self):
- return "ManyToManyField"
- def get_content_type(self):
- """
- Return the content type associated with this field's model.
- """
- return ContentType.objects.get_for_model(self.model,
- for_concrete_model=self.for_concrete_model)
- def get_extra_restriction(self, where_class, alias, remote_alias):
- field = self.remote_field.model._meta.get_field(self.content_type_field_name)
- contenttype_pk = self.get_content_type().pk
- cond = where_class()
- lookup = field.get_lookup('exact')(field.get_col(remote_alias), contenttype_pk)
- cond.add(lookup, 'AND')
- return cond
- def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
- """
- Return all objects related to ``objs`` via this ``GenericRelation``.
- """
- return self.remote_field.model._base_manager.db_manager(using).filter(**{
- "%s__pk" % self.content_type_field_name: ContentType.objects.db_manager(using).get_for_model(
- self.model, for_concrete_model=self.for_concrete_model).pk,
- "%s__in" % self.object_id_field_name: [obj.pk for obj in objs]
- })
- class ReverseGenericManyToOneDescriptor(ReverseManyToOneDescriptor):
- """
- Accessor to the related objects manager on the one-to-many relation created
- by GenericRelation.
- In the example::
- class Post(Model):
- comments = GenericRelation(Comment)
- ``post.comments`` is a ReverseGenericManyToOneDescriptor instance.
- """
- @cached_property
- def related_manager_cls(self):
- return create_generic_related_manager(
- self.rel.model._default_manager.__class__,
- self.rel,
- )
- def create_generic_related_manager(superclass, rel):
- """
- Factory function to create a manager that subclasses another manager
- (generally the default manager of a given model) and adds behaviors
- specific to generic relations.
- """
- class GenericRelatedObjectManager(superclass):
- def __init__(self, instance=None):
- super(GenericRelatedObjectManager, self).__init__()
- self.instance = instance
- self.model = rel.model
- content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(
- instance, for_concrete_model=rel.field.for_concrete_model)
- self.content_type = content_type
- self.content_type_field_name = rel.field.content_type_field_name
- self.object_id_field_name = rel.field.object_id_field_name
- self.prefetch_cache_name = rel.field.attname
- self.pk_val = instance._get_pk_val()
- self.core_filters = {
- '%s__pk' % self.content_type_field_name: content_type.id,
- self.object_id_field_name: self.pk_val,
- }
- def __call__(self, **kwargs):
- # We use **kwargs rather than a kwarg argument to enforce the
- # `manager='manager_name'` syntax.
- manager = getattr(self.model, kwargs.pop('manager'))
- manager_class = create_generic_related_manager(manager.__class__, rel)
- return manager_class(instance=self.instance)
- do_not_call_in_templates = True
- def __str__(self):
- return repr(self)
- def get_queryset(self):
- try:
- return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
- except (AttributeError, KeyError):
- db = self._db or router.db_for_read(self.model, instance=self.instance)
- return super(GenericRelatedObjectManager, self).get_queryset().using(db).filter(**self.core_filters)
- def get_prefetch_queryset(self, instances, queryset=None):
- if queryset is None:
- queryset = super(GenericRelatedObjectManager, self).get_queryset()
- queryset._add_hints(instance=instances[0])
- queryset = queryset.using(queryset._db or self._db)
- query = {
- '%s__pk' % self.content_type_field_name: self.content_type.id,
- '%s__in' % self.object_id_field_name: set(obj._get_pk_val() for obj in instances)
- }
- # We (possibly) need to convert object IDs to the type of the
- # instances' PK in order to match up instances:
- object_id_converter = instances[0]._meta.pk.to_python
- return (queryset.filter(**query),
- lambda relobj: object_id_converter(getattr(relobj, self.object_id_field_name)),
- lambda obj: obj._get_pk_val(),
- False,
- self.prefetch_cache_name)
- def add(self, *objs, **kwargs):
- bulk = kwargs.pop('bulk', True)
- db = router.db_for_write(self.model, instance=self.instance)
- def check_and_update_obj(obj):
- if not isinstance(obj, self.model):
- raise TypeError("'%s' instance expected, got %r" % (
- self.model._meta.object_name, obj
- ))
- setattr(obj, self.content_type_field_name, self.content_type)
- setattr(obj, self.object_id_field_name, self.pk_val)
- if bulk:
- pks = []
- for obj in objs:
- if obj._state.adding or obj._state.db != db:
- raise ValueError(
- "%r instance isn't saved. Use bulk=False or save "
- "the object first. but must be." % obj
- )
- check_and_update_obj(obj)
- pks.append(obj.pk)
- self.model._base_manager.using(db).filter(pk__in=pks).update(**{
- self.content_type_field_name: self.content_type,
- self.object_id_field_name: self.pk_val,
- })
- else:
- with transaction.atomic(using=db, savepoint=False):
- for obj in objs:
- check_and_update_obj(obj)
- obj.save()
- add.alters_data = True
- def remove(self, *objs, **kwargs):
- if not objs:
- return
- bulk = kwargs.pop('bulk', True)
- self._clear(self.filter(pk__in=[o.pk for o in objs]), bulk)
- remove.alters_data = True
- def clear(self, **kwargs):
- bulk = kwargs.pop('bulk', True)
- self._clear(self, bulk)
- clear.alters_data = True
- def _clear(self, queryset, bulk):
- db = router.db_for_write(self.model, instance=self.instance)
- queryset = queryset.using(db)
- if bulk:
- # `QuerySet.delete()` creates its own atomic block which
- # contains the `pre_delete` and `post_delete` signal handlers.
- queryset.delete()
- else:
- with transaction.atomic(using=db, savepoint=False):
- for obj in queryset:
- obj.delete()
- _clear.alters_data = True
- def set(self, objs, **kwargs):
- # Force evaluation of `objs` in case it's a queryset whose value
- # could be affected by `manager.clear()`. Refs #19816.
- objs = tuple(objs)
- bulk = kwargs.pop('bulk', True)
- clear = kwargs.pop('clear', False)
- db = router.db_for_write(self.model, instance=self.instance)
- with transaction.atomic(using=db, savepoint=False):
- if clear:
- self.clear()
- self.add(*objs, bulk=bulk)
- else:
- old_objs = set(self.using(db).all())
- new_objs = []
- for obj in objs:
- if obj in old_objs:
- old_objs.remove(obj)
- else:
- new_objs.append(obj)
- self.remove(*old_objs)
- self.add(*new_objs, bulk=bulk)
- set.alters_data = True
- def create(self, **kwargs):
- kwargs[self.content_type_field_name] = self.content_type
- kwargs[self.object_id_field_name] = self.pk_val
- db = router.db_for_write(self.model, instance=self.instance)
- return super(GenericRelatedObjectManager, self).using(db).create(**kwargs)
- create.alters_data = True
- def get_or_create(self, **kwargs):
- kwargs[self.content_type_field_name] = self.content_type
- kwargs[self.object_id_field_name] = self.pk_val
- db = router.db_for_write(self.model, instance=self.instance)
- return super(GenericRelatedObjectManager, self).using(db).get_or_create(**kwargs)
- get_or_create.alters_data = True
- def update_or_create(self, **kwargs):
- kwargs[self.content_type_field_name] = self.content_type
- kwargs[self.object_id_field_name] = self.pk_val
- db = router.db_for_write(self.model, instance=self.instance)
- return super(GenericRelatedObjectManager, self).using(db).update_or_create(**kwargs)
- update_or_create.alters_data = True
- return GenericRelatedObjectManager
|