test_geoip2.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import ipaddress
  2. import itertools
  3. import pathlib
  4. from unittest import mock, skipUnless
  5. from django.conf import settings
  6. from django.contrib.gis.geoip2 import HAS_GEOIP2
  7. from django.contrib.gis.geos import GEOSGeometry
  8. from django.test import SimpleTestCase, override_settings
  9. if HAS_GEOIP2:
  10. import geoip2
  11. from django.contrib.gis.geoip2 import GeoIP2, GeoIP2Exception
  12. def build_geoip_path(*parts):
  13. return pathlib.Path(__file__).parent.joinpath("data/geoip2", *parts).resolve()
  14. @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
  15. @override_settings(
  16. GEOIP_CITY="GeoLite2-City-Test.mmdb",
  17. GEOIP_COUNTRY="GeoLite2-Country-Test.mmdb",
  18. )
  19. class GeoLite2Test(SimpleTestCase):
  20. fqdn = "sky.uk"
  21. ipv4_str = "2.125.160.216"
  22. ipv6_str = "::ffff:027d:a0d8"
  23. ipv4_addr = ipaddress.ip_address(ipv4_str)
  24. ipv6_addr = ipaddress.ip_address(ipv6_str)
  25. query_values = (fqdn, ipv4_str, ipv6_str, ipv4_addr, ipv6_addr)
  26. expected_city = {
  27. "accuracy_radius": 100,
  28. "city": "Boxford",
  29. "continent_code": "EU",
  30. "continent_name": "Europe",
  31. "country_code": "GB",
  32. "country_name": "United Kingdom",
  33. "is_in_european_union": False,
  34. "latitude": 51.75,
  35. "longitude": -1.25,
  36. "metro_code": None,
  37. "postal_code": "OX1",
  38. "region_code": "ENG",
  39. "region_name": "England",
  40. "time_zone": "Europe/London",
  41. # Kept for backward compatibility.
  42. "dma_code": None,
  43. "region": "ENG",
  44. }
  45. expected_country = {
  46. "continent_code": "EU",
  47. "continent_name": "Europe",
  48. "country_code": "GB",
  49. "country_name": "United Kingdom",
  50. "is_in_european_union": False,
  51. }
  52. @classmethod
  53. def setUpClass(cls):
  54. # Avoid referencing __file__ at module level.
  55. cls.enterClassContext(override_settings(GEOIP_PATH=build_geoip_path()))
  56. # Always mock host lookup to avoid test breakage if DNS changes.
  57. cls.enterClassContext(
  58. mock.patch("socket.gethostbyname", return_value=cls.ipv4_str)
  59. )
  60. super().setUpClass()
  61. def test_init(self):
  62. # Everything inferred from GeoIP path.
  63. g1 = GeoIP2()
  64. # Path passed explicitly.
  65. g2 = GeoIP2(settings.GEOIP_PATH, GeoIP2.MODE_AUTO)
  66. # Path provided as a string.
  67. g3 = GeoIP2(str(settings.GEOIP_PATH))
  68. # Only passing in the location of one database.
  69. g4 = GeoIP2(settings.GEOIP_PATH / settings.GEOIP_CITY, country="")
  70. g5 = GeoIP2(settings.GEOIP_PATH / settings.GEOIP_COUNTRY, city="")
  71. for g in (g1, g2, g3, g4, g5):
  72. self.assertTrue(g._reader)
  73. # Improper parameters.
  74. bad_params = (23, "foo", 15.23)
  75. for bad in bad_params:
  76. with self.assertRaises(GeoIP2Exception):
  77. GeoIP2(cache=bad)
  78. if isinstance(bad, str):
  79. e = GeoIP2Exception
  80. else:
  81. e = TypeError
  82. with self.assertRaises(e):
  83. GeoIP2(bad, GeoIP2.MODE_AUTO)
  84. def test_no_database_file(self):
  85. invalid_path = pathlib.Path(__file__).parent.joinpath("data/invalid").resolve()
  86. msg = "Path must be a valid database or directory containing databases."
  87. with self.assertRaisesMessage(GeoIP2Exception, msg):
  88. GeoIP2(invalid_path)
  89. def test_bad_query(self):
  90. g = GeoIP2(city="<invalid>")
  91. functions = (g.city, g.geos, g.lat_lon, g.lon_lat)
  92. msg = "Invalid GeoIP city data file: "
  93. for function in functions:
  94. with self.subTest(function=function.__qualname__):
  95. with self.assertRaisesMessage(GeoIP2Exception, msg):
  96. function("example.com")
  97. functions += (g.country, g.country_code, g.country_name)
  98. values = (123, 123.45, b"", (), [], {}, set(), frozenset(), GeoIP2)
  99. msg = (
  100. "GeoIP query must be a string or instance of IPv4Address or IPv6Address, "
  101. "not type"
  102. )
  103. for function, value in itertools.product(functions, values):
  104. with self.subTest(function=function.__qualname__, type=type(value)):
  105. with self.assertRaisesMessage(TypeError, msg):
  106. function(value)
  107. def test_country(self):
  108. g = GeoIP2(city="<invalid>")
  109. self.assertIs(g.is_city, False)
  110. self.assertIs(g.is_country, True)
  111. for query in self.query_values:
  112. with self.subTest(query=query):
  113. self.assertEqual(g.country(query), self.expected_country)
  114. self.assertEqual(
  115. g.country_code(query), self.expected_country["country_code"]
  116. )
  117. self.assertEqual(
  118. g.country_name(query), self.expected_country["country_name"]
  119. )
  120. def test_country_using_city_database(self):
  121. g = GeoIP2(country="<invalid>")
  122. self.assertIs(g.is_city, True)
  123. self.assertIs(g.is_country, False)
  124. for query in self.query_values:
  125. with self.subTest(query=query):
  126. self.assertEqual(g.country(query), self.expected_country)
  127. self.assertEqual(
  128. g.country_code(query), self.expected_country["country_code"]
  129. )
  130. self.assertEqual(
  131. g.country_name(query), self.expected_country["country_name"]
  132. )
  133. def test_city(self):
  134. g = GeoIP2(country="<invalid>")
  135. self.assertIs(g.is_city, True)
  136. self.assertIs(g.is_country, False)
  137. for query in self.query_values:
  138. with self.subTest(query=query):
  139. self.assertEqual(g.city(query), self.expected_city)
  140. geom = g.geos(query)
  141. self.assertIsInstance(geom, GEOSGeometry)
  142. self.assertEqual(geom.srid, 4326)
  143. expected_lat = self.expected_city["latitude"]
  144. expected_lon = self.expected_city["longitude"]
  145. self.assertEqual(geom.tuple, (expected_lon, expected_lat))
  146. self.assertEqual(g.lat_lon(query), (expected_lat, expected_lon))
  147. self.assertEqual(g.lon_lat(query), (expected_lon, expected_lat))
  148. # Country queries should still work.
  149. self.assertEqual(g.country(query), self.expected_country)
  150. self.assertEqual(
  151. g.country_code(query), self.expected_country["country_code"]
  152. )
  153. self.assertEqual(
  154. g.country_name(query), self.expected_country["country_name"]
  155. )
  156. def test_not_found(self):
  157. g1 = GeoIP2(city="<invalid>")
  158. g2 = GeoIP2(country="<invalid>")
  159. for function, query in itertools.product(
  160. (g1.country, g2.city), ("127.0.0.1", "::1")
  161. ):
  162. with self.subTest(function=function.__qualname__, query=query):
  163. msg = f"The address {query} is not in the database."
  164. with self.assertRaisesMessage(geoip2.errors.AddressNotFoundError, msg):
  165. function(query)
  166. def test_del(self):
  167. g = GeoIP2()
  168. reader = g._reader
  169. self.assertIs(reader._db_reader.closed, False)
  170. del g
  171. self.assertIs(reader._db_reader.closed, True)
  172. def test_repr(self):
  173. g = GeoIP2()
  174. m = g._metadata
  175. version = f"{m.binary_format_major_version}.{m.binary_format_minor_version}"
  176. self.assertEqual(repr(g), f"<GeoIP2 [v{version}] _path='{g._path}'>")
  177. @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
  178. @override_settings(
  179. GEOIP_CITY="GeoIP2-City-Test.mmdb",
  180. GEOIP_COUNTRY="GeoIP2-Country-Test.mmdb",
  181. )
  182. class GeoIP2Test(GeoLite2Test):
  183. """Non-free GeoIP2 databases are supported."""
  184. @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
  185. @override_settings(
  186. GEOIP_CITY="dbip-city-lite-test.mmdb",
  187. GEOIP_COUNTRY="dbip-country-lite-test.mmdb",
  188. )
  189. class DBIPLiteTest(GeoLite2Test):
  190. """DB-IP Lite databases are supported."""
  191. expected_city = GeoLite2Test.expected_city | {
  192. "accuracy_radius": None,
  193. "city": "London (Shadwell)",
  194. "latitude": 51.5181,
  195. "longitude": -0.0714189,
  196. "postal_code": None,
  197. "region_code": None,
  198. "time_zone": None,
  199. # Kept for backward compatibility.
  200. "region": None,
  201. }
  202. @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
  203. class ErrorTest(SimpleTestCase):
  204. def test_missing_path(self):
  205. msg = "GeoIP path must be provided via parameter or the GEOIP_PATH setting."
  206. with self.settings(GEOIP_PATH=None):
  207. with self.assertRaisesMessage(GeoIP2Exception, msg):
  208. GeoIP2()
  209. def test_unsupported_database(self):
  210. msg = "Unable to handle database edition: GeoLite2-ASN"
  211. with self.settings(GEOIP_PATH=build_geoip_path("GeoLite2-ASN-Test.mmdb")):
  212. with self.assertRaisesMessage(GeoIP2Exception, msg):
  213. GeoIP2()