test_concat.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from unittest import skipUnless
  2. from django.db import connection
  3. from django.db.models import CharField, TextField
  4. from django.db.models import Value as V
  5. from django.db.models.functions import Concat, ConcatPair, Upper
  6. from django.test import TestCase
  7. from django.utils import timezone
  8. from ..models import Article, Author
  9. lorem_ipsum = """
  10. Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod
  11. tempor incididunt ut labore et dolore magna aliqua."""
  12. class ConcatTests(TestCase):
  13. def test_basic(self):
  14. Author.objects.create(name="Jayden")
  15. Author.objects.create(name="John Smith", alias="smithj", goes_by="John")
  16. Author.objects.create(name="Margaret", goes_by="Maggie")
  17. Author.objects.create(name="Rhonda", alias="adnohR")
  18. authors = Author.objects.annotate(joined=Concat("alias", "goes_by"))
  19. self.assertQuerySetEqual(
  20. authors.order_by("name"),
  21. [
  22. "",
  23. "smithjJohn",
  24. "Maggie",
  25. "adnohR",
  26. ],
  27. lambda a: a.joined,
  28. )
  29. def test_gt_two_expressions(self):
  30. with self.assertRaisesMessage(
  31. ValueError, "Concat must take at least two expressions"
  32. ):
  33. Author.objects.annotate(joined=Concat("alias"))
  34. def test_many(self):
  35. Author.objects.create(name="Jayden")
  36. Author.objects.create(name="John Smith", alias="smithj", goes_by="John")
  37. Author.objects.create(name="Margaret", goes_by="Maggie")
  38. Author.objects.create(name="Rhonda", alias="adnohR")
  39. authors = Author.objects.annotate(
  40. joined=Concat("name", V(" ("), "goes_by", V(")"), output_field=CharField()),
  41. )
  42. self.assertQuerySetEqual(
  43. authors.order_by("name"),
  44. [
  45. "Jayden ()",
  46. "John Smith (John)",
  47. "Margaret (Maggie)",
  48. "Rhonda ()",
  49. ],
  50. lambda a: a.joined,
  51. )
  52. def test_mixed_char_text(self):
  53. Article.objects.create(
  54. title="The Title", text=lorem_ipsum, written=timezone.now()
  55. )
  56. article = Article.objects.annotate(
  57. title_text=Concat("title", V(" - "), "text", output_field=TextField()),
  58. ).get(title="The Title")
  59. self.assertEqual(article.title + " - " + article.text, article.title_text)
  60. # Wrap the concat in something else to ensure that text is returned
  61. # rather than bytes.
  62. article = Article.objects.annotate(
  63. title_text=Upper(
  64. Concat("title", V(" - "), "text", output_field=TextField())
  65. ),
  66. ).get(title="The Title")
  67. expected = article.title + " - " + article.text
  68. self.assertEqual(expected.upper(), article.title_text)
  69. @skipUnless(
  70. connection.vendor in ("sqlite", "postgresql"),
  71. "SQLite and PostgreSQL specific implementation detail.",
  72. )
  73. def test_coalesce_idempotent(self):
  74. pair = ConcatPair(V("a"), V("b"))
  75. # Check nodes counts
  76. self.assertEqual(len(list(pair.flatten())), 3)
  77. self.assertEqual(
  78. len(list(pair.coalesce().flatten())), 7
  79. ) # + 2 Coalesce + 2 Value()
  80. self.assertEqual(len(list(pair.flatten())), 3)
  81. def test_sql_generation_idempotency(self):
  82. qs = Article.objects.annotate(description=Concat("title", V(": "), "summary"))
  83. # Multiple compilations should not alter the generated query.
  84. self.assertEqual(str(qs.query), str(qs.all().query))
  85. def test_concat_non_str(self):
  86. Author.objects.create(name="The Name", age=42)
  87. with self.assertNumQueries(1) as ctx:
  88. author = Author.objects.annotate(
  89. name_text=Concat(
  90. "name", V(":"), "alias", V(":"), "age", output_field=TextField()
  91. ),
  92. ).get()
  93. self.assertEqual(author.name_text, "The Name::42")
  94. # Only non-string columns are casted on PostgreSQL.
  95. self.assertEqual(
  96. ctx.captured_queries[0]["sql"].count("::text"),
  97. 1 if connection.vendor == "postgresql" else 0,
  98. )
  99. def test_equal(self):
  100. self.assertEqual(
  101. Concat("foo", "bar", output_field=TextField()),
  102. Concat("foo", "bar", output_field=TextField()),
  103. )
  104. self.assertNotEqual(
  105. Concat("foo", "bar", output_field=TextField()),
  106. Concat("foo", "bar", output_field=CharField()),
  107. )
  108. self.assertNotEqual(
  109. Concat("foo", "bar", output_field=TextField()),
  110. Concat("bar", "foo", output_field=TextField()),
  111. )