test_geoip2.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import ipaddress
  2. import itertools
  3. import pathlib
  4. from unittest import mock, skipUnless
  5. from django.conf import settings
  6. from django.contrib.gis.geoip2 import HAS_GEOIP2
  7. from django.contrib.gis.geos import GEOSGeometry
  8. from django.test import SimpleTestCase, override_settings
  9. from django.utils.deprecation import RemovedInDjango60Warning
  10. if HAS_GEOIP2:
  11. import geoip2
  12. from django.contrib.gis.geoip2 import GeoIP2, GeoIP2Exception
  13. def build_geoip_path(*parts):
  14. return pathlib.Path(__file__).parent.joinpath("data/geoip2", *parts).resolve()
  15. @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
  16. @override_settings(
  17. GEOIP_CITY="GeoLite2-City-Test.mmdb",
  18. GEOIP_COUNTRY="GeoLite2-Country-Test.mmdb",
  19. )
  20. class GeoLite2Test(SimpleTestCase):
  21. fqdn = "sky.uk"
  22. ipv4_str = "2.125.160.216"
  23. ipv6_str = "::ffff:027d:a0d8"
  24. ipv4_addr = ipaddress.ip_address(ipv4_str)
  25. ipv6_addr = ipaddress.ip_address(ipv6_str)
  26. query_values = (fqdn, ipv4_str, ipv6_str, ipv4_addr, ipv6_addr)
  27. expected_city = {
  28. "accuracy_radius": 100,
  29. "city": "Boxford",
  30. "continent_code": "EU",
  31. "continent_name": "Europe",
  32. "country_code": "GB",
  33. "country_name": "United Kingdom",
  34. "is_in_european_union": False,
  35. "latitude": 51.75,
  36. "longitude": -1.25,
  37. "metro_code": None,
  38. "postal_code": "OX1",
  39. "region_code": "ENG",
  40. "region_name": "England",
  41. "time_zone": "Europe/London",
  42. # Kept for backward compatibility.
  43. "dma_code": None,
  44. "region": "ENG",
  45. }
  46. expected_country = {
  47. "continent_code": "EU",
  48. "continent_name": "Europe",
  49. "country_code": "GB",
  50. "country_name": "United Kingdom",
  51. "is_in_european_union": False,
  52. }
  53. @classmethod
  54. def setUpClass(cls):
  55. # Avoid referencing __file__ at module level.
  56. cls.enterClassContext(override_settings(GEOIP_PATH=build_geoip_path()))
  57. # Always mock host lookup to avoid test breakage if DNS changes.
  58. cls.enterClassContext(
  59. mock.patch("socket.gethostbyname", return_value=cls.ipv4_str)
  60. )
  61. super().setUpClass()
  62. def test_init(self):
  63. # Everything inferred from GeoIP path.
  64. g1 = GeoIP2()
  65. # Path passed explicitly.
  66. g2 = GeoIP2(settings.GEOIP_PATH, GeoIP2.MODE_AUTO)
  67. # Path provided as a string.
  68. g3 = GeoIP2(str(settings.GEOIP_PATH))
  69. # Only passing in the location of one database.
  70. g4 = GeoIP2(settings.GEOIP_PATH / settings.GEOIP_CITY, country="")
  71. g5 = GeoIP2(settings.GEOIP_PATH / settings.GEOIP_COUNTRY, city="")
  72. for g in (g1, g2, g3, g4, g5):
  73. self.assertTrue(g._reader)
  74. # Improper parameters.
  75. bad_params = (23, "foo", 15.23)
  76. for bad in bad_params:
  77. with self.assertRaises(GeoIP2Exception):
  78. GeoIP2(cache=bad)
  79. if isinstance(bad, str):
  80. e = GeoIP2Exception
  81. else:
  82. e = TypeError
  83. with self.assertRaises(e):
  84. GeoIP2(bad, GeoIP2.MODE_AUTO)
  85. def test_no_database_file(self):
  86. invalid_path = pathlib.Path(__file__).parent.joinpath("data/invalid").resolve()
  87. msg = "Path must be a valid database or directory containing databases."
  88. with self.assertRaisesMessage(GeoIP2Exception, msg):
  89. GeoIP2(invalid_path)
  90. def test_bad_query(self):
  91. g = GeoIP2(city="<invalid>")
  92. functions = (g.city, g.geos, g.lat_lon, g.lon_lat)
  93. msg = "Invalid GeoIP city data file: "
  94. for function in functions:
  95. with self.subTest(function=function.__qualname__):
  96. with self.assertRaisesMessage(GeoIP2Exception, msg):
  97. function("example.com")
  98. functions += (g.country, g.country_code, g.country_name)
  99. values = (123, 123.45, b"", (), [], {}, set(), frozenset(), GeoIP2)
  100. msg = (
  101. "GeoIP query must be a string or instance of IPv4Address or IPv6Address, "
  102. "not type"
  103. )
  104. for function, value in itertools.product(functions, values):
  105. with self.subTest(function=function.__qualname__, type=type(value)):
  106. with self.assertRaisesMessage(TypeError, msg):
  107. function(value)
  108. def test_country(self):
  109. g = GeoIP2(city="<invalid>")
  110. self.assertIs(g.is_city, False)
  111. self.assertIs(g.is_country, True)
  112. for query in self.query_values:
  113. with self.subTest(query=query):
  114. self.assertEqual(g.country(query), self.expected_country)
  115. self.assertEqual(
  116. g.country_code(query), self.expected_country["country_code"]
  117. )
  118. self.assertEqual(
  119. g.country_name(query), self.expected_country["country_name"]
  120. )
  121. def test_country_using_city_database(self):
  122. g = GeoIP2(country="<invalid>")
  123. self.assertIs(g.is_city, True)
  124. self.assertIs(g.is_country, False)
  125. for query in self.query_values:
  126. with self.subTest(query=query):
  127. self.assertEqual(g.country(query), self.expected_country)
  128. self.assertEqual(
  129. g.country_code(query), self.expected_country["country_code"]
  130. )
  131. self.assertEqual(
  132. g.country_name(query), self.expected_country["country_name"]
  133. )
  134. def test_city(self):
  135. g = GeoIP2(country="<invalid>")
  136. self.assertIs(g.is_city, True)
  137. self.assertIs(g.is_country, False)
  138. for query in self.query_values:
  139. with self.subTest(query=query):
  140. self.assertEqual(g.city(query), self.expected_city)
  141. geom = g.geos(query)
  142. self.assertIsInstance(geom, GEOSGeometry)
  143. self.assertEqual(geom.srid, 4326)
  144. expected_lat = self.expected_city["latitude"]
  145. expected_lon = self.expected_city["longitude"]
  146. self.assertEqual(geom.tuple, (expected_lon, expected_lat))
  147. self.assertEqual(g.lat_lon(query), (expected_lat, expected_lon))
  148. self.assertEqual(g.lon_lat(query), (expected_lon, expected_lat))
  149. # Country queries should still work.
  150. self.assertEqual(g.country(query), self.expected_country)
  151. self.assertEqual(
  152. g.country_code(query), self.expected_country["country_code"]
  153. )
  154. self.assertEqual(
  155. g.country_name(query), self.expected_country["country_name"]
  156. )
  157. def test_not_found(self):
  158. g1 = GeoIP2(city="<invalid>")
  159. g2 = GeoIP2(country="<invalid>")
  160. for function, query in itertools.product(
  161. (g1.country, g2.city), ("127.0.0.1", "::1")
  162. ):
  163. with self.subTest(function=function.__qualname__, query=query):
  164. msg = f"The address {query} is not in the database."
  165. with self.assertRaisesMessage(geoip2.errors.AddressNotFoundError, msg):
  166. function(query)
  167. def test_del(self):
  168. g = GeoIP2()
  169. reader = g._reader
  170. self.assertIs(reader._db_reader.closed, False)
  171. del g
  172. self.assertIs(reader._db_reader.closed, True)
  173. def test_repr(self):
  174. g = GeoIP2()
  175. m = g._metadata
  176. version = f"{m.binary_format_major_version}.{m.binary_format_minor_version}"
  177. self.assertEqual(repr(g), f"<GeoIP2 [v{version}] _path='{g._path}'>")
  178. def test_open_deprecation_warning(self):
  179. msg = "GeoIP2.open() is deprecated. Use GeoIP2() instead."
  180. with self.assertWarnsMessage(RemovedInDjango60Warning, msg) as ctx:
  181. g = GeoIP2.open(settings.GEOIP_PATH, GeoIP2.MODE_AUTO)
  182. self.assertTrue(g._reader)
  183. self.assertEqual(ctx.filename, __file__)
  184. @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
  185. @override_settings(
  186. GEOIP_CITY="GeoIP2-City-Test.mmdb",
  187. GEOIP_COUNTRY="GeoIP2-Country-Test.mmdb",
  188. )
  189. class GeoIP2Test(GeoLite2Test):
  190. """Non-free GeoIP2 databases are supported."""
  191. @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
  192. @override_settings(
  193. GEOIP_CITY="dbip-city-lite-test.mmdb",
  194. GEOIP_COUNTRY="dbip-country-lite-test.mmdb",
  195. )
  196. class DBIPLiteTest(GeoLite2Test):
  197. """DB-IP Lite databases are supported."""
  198. expected_city = GeoLite2Test.expected_city | {
  199. "accuracy_radius": None,
  200. "city": "London (Shadwell)",
  201. "latitude": 51.5181,
  202. "longitude": -0.0714189,
  203. "postal_code": None,
  204. "region_code": None,
  205. "time_zone": None,
  206. # Kept for backward compatibility.
  207. "region": None,
  208. }
  209. @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
  210. class ErrorTest(SimpleTestCase):
  211. def test_missing_path(self):
  212. msg = "GeoIP path must be provided via parameter or the GEOIP_PATH setting."
  213. with self.settings(GEOIP_PATH=None):
  214. with self.assertRaisesMessage(GeoIP2Exception, msg):
  215. GeoIP2()
  216. def test_unsupported_database(self):
  217. msg = "Unable to handle database edition: GeoLite2-ASN"
  218. with self.settings(GEOIP_PATH=build_geoip_path("GeoLite2-ASN-Test.mmdb")):
  219. with self.assertRaisesMessage(GeoIP2Exception, msg):
  220. GeoIP2()