123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- import ipaddress
- import itertools
- import pathlib
- from unittest import mock, skipUnless
- from django.conf import settings
- from django.contrib.gis.geoip2 import HAS_GEOIP2
- from django.contrib.gis.geos import GEOSGeometry
- from django.test import SimpleTestCase, override_settings
- if HAS_GEOIP2:
- import geoip2
- from django.contrib.gis.geoip2 import GeoIP2, GeoIP2Exception
- def build_geoip_path(*parts):
- return pathlib.Path(__file__).parent.joinpath("data/geoip2", *parts).resolve()
- @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
- @override_settings(
- GEOIP_CITY="GeoLite2-City-Test.mmdb",
- GEOIP_COUNTRY="GeoLite2-Country-Test.mmdb",
- )
- class GeoLite2Test(SimpleTestCase):
- fqdn = "sky.uk"
- ipv4_str = ""
- ipv6_str = "::ffff:027d:a0d8"
- ipv4_addr = ipaddress.ip_address(ipv4_str)
- ipv6_addr = ipaddress.ip_address(ipv6_str)
- query_values = (fqdn, ipv4_str, ipv6_str, ipv4_addr, ipv6_addr)
- expected_city = {
- "accuracy_radius": 100,
- "city": "Boxford",
- "continent_code": "EU",
- "continent_name": "Europe",
- "country_code": "GB",
- "country_name": "United Kingdom",
- "is_in_european_union": False,
- "latitude": 51.75,
- "longitude": -1.25,
- "metro_code": None,
- "postal_code": "OX1",
- "region_code": "ENG",
- "region_name": "England",
- "time_zone": "Europe/London",
- # Kept for backward compatibility.
- "dma_code": None,
- "region": "ENG",
- }
- expected_country = {
- "continent_code": "EU",
- "continent_name": "Europe",
- "country_code": "GB",
- "country_name": "United Kingdom",
- "is_in_european_union": False,
- }
- @classmethod
- def setUpClass(cls):
- # Avoid referencing __file__ at module level.
- cls.enterClassContext(override_settings(GEOIP_PATH=build_geoip_path()))
- # Always mock host lookup to avoid test breakage if DNS changes.
- cls.enterClassContext(
- mock.patch("socket.gethostbyname", return_value=cls.ipv4_str)
- )
- super().setUpClass()
- def test_init(self):
- # Everything inferred from GeoIP path.
- g1 = GeoIP2()
- # Path passed explicitly.
- g2 = GeoIP2(settings.GEOIP_PATH, GeoIP2.MODE_AUTO)
- # Path provided as a string.
- g3 = GeoIP2(str(settings.GEOIP_PATH))
- # Only passing in the location of one database.
- g4 = GeoIP2(settings.GEOIP_PATH / settings.GEOIP_CITY, country="")
- g5 = GeoIP2(settings.GEOIP_PATH / settings.GEOIP_COUNTRY, city="")
- for g in (g1, g2, g3, g4, g5):
- self.assertTrue(g._reader)
- # Improper parameters.
- bad_params = (23, "foo", 15.23)
- for bad in bad_params:
- with self.assertRaises(GeoIP2Exception):
- GeoIP2(cache=bad)
- if isinstance(bad, str):
- e = GeoIP2Exception
- else:
- e = TypeError
- with self.assertRaises(e):
- GeoIP2(bad, GeoIP2.MODE_AUTO)
- def test_no_database_file(self):
- invalid_path = pathlib.Path(__file__).parent.joinpath("data/invalid").resolve()
- msg = "Path must be a valid database or directory containing databases."
- with self.assertRaisesMessage(GeoIP2Exception, msg):
- GeoIP2(invalid_path)
- def test_bad_query(self):
- g = GeoIP2(city="<invalid>")
- functions = (g.city, g.geos, g.lat_lon, g.lon_lat)
- msg = "Invalid GeoIP city data file: "
- for function in functions:
- with self.subTest(function=function.__qualname__):
- with self.assertRaisesMessage(GeoIP2Exception, msg):
- function("example.com")
- functions += (g.country, g.country_code, g.country_name)
- values = (123, 123.45, b"", (), [], {}, set(), frozenset(), GeoIP2)
- msg = (
- "GeoIP query must be a string or instance of IPv4Address or IPv6Address, "
- "not type"
- )
- for function, value in itertools.product(functions, values):
- with self.subTest(function=function.__qualname__, type=type(value)):
- with self.assertRaisesMessage(TypeError, msg):
- function(value)
- def test_country(self):
- g = GeoIP2(city="<invalid>")
- self.assertIs(g.is_city, False)
- self.assertIs(g.is_country, True)
- for query in self.query_values:
- with self.subTest(query=query):
- self.assertEqual(g.country(query), self.expected_country)
- self.assertEqual(
- g.country_code(query), self.expected_country["country_code"]
- )
- self.assertEqual(
- g.country_name(query), self.expected_country["country_name"]
- )
- def test_country_using_city_database(self):
- g = GeoIP2(country="<invalid>")
- self.assertIs(g.is_city, True)
- self.assertIs(g.is_country, False)
- for query in self.query_values:
- with self.subTest(query=query):
- self.assertEqual(g.country(query), self.expected_country)
- self.assertEqual(
- g.country_code(query), self.expected_country["country_code"]
- )
- self.assertEqual(
- g.country_name(query), self.expected_country["country_name"]
- )
- def test_city(self):
- g = GeoIP2(country="<invalid>")
- self.assertIs(g.is_city, True)
- self.assertIs(g.is_country, False)
- for query in self.query_values:
- with self.subTest(query=query):
- self.assertEqual(g.city(query), self.expected_city)
- geom = g.geos(query)
- self.assertIsInstance(geom, GEOSGeometry)
- self.assertEqual(geom.srid, 4326)
- expected_lat = self.expected_city["latitude"]
- expected_lon = self.expected_city["longitude"]
- self.assertEqual(geom.tuple, (expected_lon, expected_lat))
- self.assertEqual(g.lat_lon(query), (expected_lat, expected_lon))
- self.assertEqual(g.lon_lat(query), (expected_lon, expected_lat))
- # Country queries should still work.
- self.assertEqual(g.country(query), self.expected_country)
- self.assertEqual(
- g.country_code(query), self.expected_country["country_code"]
- )
- self.assertEqual(
- g.country_name(query), self.expected_country["country_name"]
- )
- def test_not_found(self):
- g1 = GeoIP2(city="<invalid>")
- g2 = GeoIP2(country="<invalid>")
- for function, query in itertools.product(
- (g1.country, g2.city), ("", "::1")
- ):
- with self.subTest(function=function.__qualname__, query=query):
- msg = f"The address {query} is not in the database."
- with self.assertRaisesMessage(geoip2.errors.AddressNotFoundError, msg):
- function(query)
- def test_del(self):
- g = GeoIP2()
- reader = g._reader
- self.assertIs(reader._db_reader.closed, False)
- del g
- self.assertIs(reader._db_reader.closed, True)
- def test_repr(self):
- g = GeoIP2()
- m = g._metadata
- version = f"{m.binary_format_major_version}.{m.binary_format_minor_version}"
- self.assertEqual(repr(g), f"<GeoIP2 [v{version}] _path='{g._path}'>")
- @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
- @override_settings(
- GEOIP_CITY="GeoIP2-City-Test.mmdb",
- GEOIP_COUNTRY="GeoIP2-Country-Test.mmdb",
- )
- class GeoIP2Test(GeoLite2Test):
- """Non-free GeoIP2 databases are supported."""
- @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
- @override_settings(
- GEOIP_CITY="dbip-city-lite-test.mmdb",
- GEOIP_COUNTRY="dbip-country-lite-test.mmdb",
- )
- class DBIPLiteTest(GeoLite2Test):
- """DB-IP Lite databases are supported."""
- expected_city = GeoLite2Test.expected_city | {
- "accuracy_radius": None,
- "city": "London (Shadwell)",
- "latitude": 51.5181,
- "longitude": -0.0714189,
- "postal_code": None,
- "region_code": None,
- "time_zone": None,
- # Kept for backward compatibility.
- "region": None,
- }
- @skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
- class ErrorTest(SimpleTestCase):
- def test_missing_path(self):
- msg = "GeoIP path must be provided via parameter or the GEOIP_PATH setting."
- with self.settings(GEOIP_PATH=None):
- with self.assertRaisesMessage(GeoIP2Exception, msg):
- GeoIP2()
- def test_unsupported_database(self):
- msg = "Unable to handle database edition: GeoLite2-ASN"
- with self.settings(GEOIP_PATH=build_geoip_path("GeoLite2-ASN-Test.mmdb")):
- with self.assertRaisesMessage(GeoIP2Exception, msg):
- GeoIP2()