Skip to content

Commit a5554d6

Browse files
JCoxwellbckohan
andcommitted
Add create_from_super method and test
Co-authored-by: Joshua <[email protected]> Co-authored-by: Brian Kohan <[email protected]>
1 parent b934093 commit a5554d6

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

src/polymorphic/managers.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
The manager class for use in the models.
33
"""
44

5-
from django.db import models
5+
import inspect
6+
7+
from django.contrib.contenttypes.models import ContentType
8+
from django.db import DEFAULT_DB_ALIAS, models
69

710
from polymorphic.query import PolymorphicQuerySet
811

@@ -49,3 +52,44 @@ def not_instance_of(self, *args):
4952

5053
def get_real_instances(self, base_result_objects=None):
5154
return self.all().get_real_instances(base_result_objects=base_result_objects)
55+
56+
def create_from_super(self, obj, **kwargs):
57+
"""Creates an instance of self.model (cls) from existing super class.
58+
The new subclass will be the same object with same database id
59+
and data as obj, but will be an instance of cls.
60+
61+
obj must be an instance of the direct superclass of cls.
62+
kwargs should contain all required fields of the subclass (cls).
63+
64+
returns obj as an instance of cls.
65+
"""
66+
cls = self.model
67+
68+
scls = inspect.getmro(cls)[1]
69+
if scls is not type(obj):
70+
raise TypeError(
71+
"create_from_super can only be used if obj is one level of inheritance up from cls"
72+
)
73+
74+
parent_link_field = None
75+
for parent, field in cls._meta.parents.items():
76+
if parent is scls:
77+
parent_link_field = field
78+
break
79+
if parent_link_field is None:
80+
raise TypeError(f"Could not find parent link field for {scls.__name__}")
81+
kwargs[parent_link_field.get_attname()] = obj.id
82+
83+
# create the new base class with only fields that apply to it.
84+
nobj = cls(**kwargs)
85+
nobj.save_base(raw=True)
86+
# force update the content type, but first we need to
87+
# retrieve a clean copy from the db to fill in the null
88+
# fields otherwise they would be overwritten.
89+
nobj = obj.__class__.objects.using(obj._state.db or DEFAULT_DB_ALIAS).get(pk=obj.pk)
90+
nobj.polymorphic_ctype = ContentType.objects.db_manager(
91+
using=(obj._state.db or DEFAULT_DB_ALIAS)
92+
).get_for_model(cls)
93+
nobj.save()
94+
95+
return nobj.get_real_instance() # cast to cls
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from django.test import TransactionTestCase
2+
from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D
3+
4+
5+
class CreateFromSuperTests(TransactionTestCase):
6+
def test_create_from_super(self):
7+
# run create test 3 times because initial implementation
8+
# would fail after first success.
9+
for i in range(3):
10+
mc = Model2C.objects.create(
11+
field1="C1{}".format(i), field2="C2{}".format(i), field3="C3{}".format(i)
12+
)
13+
mc.save()
14+
field4 = "D4{}".format(i)
15+
md = Model2D.objects.create_from_super(mc, field4=field4)
16+
self.assertEqual(mc.id, md.id)
17+
self.assertEqual(mc.field1, md.field1)
18+
self.assertEqual(mc.field2, md.field2)
19+
self.assertEqual(mc.field3, md.field3)
20+
self.assertEqual(md.field4, field4)
21+
ma = Model2A.objects.create(field1="A1e")
22+
self.assertRaises(Exception, Model2D.objects.create_from_super, ma, field4="D4e")
23+
mb = Model2B.objects.create(field1="B1e", field2="B2e")
24+
self.assertRaises(Exception, Model2D.objects.create_from_super, mb, field4="D4e")

0 commit comments

Comments
 (0)