# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later

import atheris  # pragma: no cover


@atheris.instrument_func
def is_expected_exception(
    error_message_list: list[str], exception: Exception
) -> bool:  # pragma: no cover
    """Checks if the message of a given exception matches any of the expected error messages.

    Args:
         error_message_list (List[str]): A list of error message substrings to check against the exception's message.
         exception (Exception): The exception object raised during execution.

    Returns:
      bool: True if the exception's message contains any of the substrings from the error_message_list, otherwise False.
    """
    for error in error_message_list:
        if error in str(exception):
            return True
    return False


class EnhancedFuzzedDataProvider(atheris.FuzzedDataProvider):  # pragma: no cover
    """Extends atheris.FuzzedDataProvider to offer additional methods to make fuzz testing slightly more DRY."""

    def __init__(self, data) -> None:
        """Initializes the EnhancedFuzzedDataProvider with fuzzing data from the argument provided to TestOneInput.

        Args:
            data (bytes): The binary data used for fuzzing.
        """
        super().__init__(data)

    def ConsumeRemainingBytes(self) -> bytes:
        """Consume the remaining bytes in the bytes container.

        Returns:
          bytes: Zero or more bytes.
        """
        return self.ConsumeBytes(self.remaining_bytes())

    def ConsumeRandomBytes(self, max_length=None) -> bytes:
        """Consume a random count of bytes from the bytes container.

        Args:
          max_length (int, optional): The maximum length of the string. Defaults to the number of remaining bytes.

        Returns:
          bytes: Zero or more bytes.
        """
        if max_length is None:
            max_length = self.remaining_bytes()
        else:
            max_length = min(max_length, self.remaining_bytes())

        return self.ConsumeBytes(self.ConsumeIntInRange(0, max_length))

    def ConsumeRandomString(self, max_length=None, without_surrogates=False) -> str:
        """Consume bytes to produce a Unicode string.

        Args:
          max_length (int, optional): The maximum length of the string. Defaults to the number of remaining bytes.
          without_surrogates (bool, optional): If True, never generate surrogate pair characters. Defaults to False.

        Returns:
         str: A Unicode string.
        """
        if max_length is None:
            max_length = self.remaining_bytes()
        else:
            max_length = min(max_length, self.remaining_bytes())

        count = self.ConsumeIntInRange(0, max_length)

        if without_surrogates:
            return self.ConsumeUnicodeNoSurrogates(count)
        else:
            return self.ConsumeUnicode(count)

    def ConsumeRandomInt(self, minimum=0, maximum=1234567890) -> int:
        """Consume bytes to produce an integer.

        Args:
          minimum (int, optional): The minimum value of the integer. Defaults to 0.
          maximum (int, optional): The maximum value of the integer. Defaults to 1234567890.

        Returns:
          int: An integer.
        """
        return self.ConsumeIntInRange(minimum, maximum)