test_utils.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from typing import List # pragma: no cover
  2. import atheris # pragma: no cover
  3. @atheris.instrument_func
  4. def is_expected_exception(
  5. error_message_list: List[str], exception: Exception
  6. ): # pragma: no cover
  7. """Checks if the message of a given exception matches any of the expected error messages.
  8. Args:
  9. error_message_list (List[str]): A list of error message substrings to check against the exception's message.
  10. exception (Exception): The exception object raised during execution.
  11. Returns:
  12. bool: True if the exception's message contains any of the substrings from the error_message_list, otherwise False.
  13. """
  14. for error in error_message_list:
  15. if error in str(exception):
  16. return True
  17. return False
  18. class EnhancedFuzzedDataProvider(atheris.FuzzedDataProvider): # pragma: no cover
  19. """Extends atheris.FuzzedDataProvider to offer additional methods to make fuzz testing slightly more DRY."""
  20. def __init__(self, data):
  21. """Initializes the EnhancedFuzzedDataProvider with fuzzing data from the argument provided to TestOneInput.
  22. Args:
  23. data (bytes): The binary data used for fuzzing.
  24. """
  25. super().__init__(data)
  26. def ConsumeRemainingBytes(self) -> bytes:
  27. """Consume the remaining bytes in the bytes container.
  28. Returns:
  29. bytes: Zero or more bytes.
  30. """
  31. return self.ConsumeBytes(self.remaining_bytes())
  32. def ConsumeRandomBytes(self, max_length=None) -> bytes:
  33. """Consume a random count of bytes from the bytes container.
  34. Args:
  35. max_length (int, optional): The maximum length of the string. Defaults to the number of remaining bytes.
  36. Returns:
  37. bytes: Zero or more bytes.
  38. """
  39. if max_length is None:
  40. max_length = self.remaining_bytes()
  41. else:
  42. max_length = min(max_length, self.remaining_bytes())
  43. return self.ConsumeBytes(self.ConsumeIntInRange(0, max_length))
  44. def ConsumeRandomString(self, max_length=None, without_surrogates=False) -> str:
  45. """Consume bytes to produce a Unicode string.
  46. Args:
  47. max_length (int, optional): The maximum length of the string. Defaults to the number of remaining bytes.
  48. without_surrogates (bool, optional): If True, never generate surrogate pair characters. Defaults to False.
  49. Returns:
  50. str: A Unicode string.
  51. """
  52. if max_length is None:
  53. max_length = self.remaining_bytes()
  54. else:
  55. max_length = min(max_length, self.remaining_bytes())
  56. count = self.ConsumeIntInRange(0, max_length)
  57. if without_surrogates:
  58. return self.ConsumeUnicodeNoSurrogates(count)
  59. else:
  60. return self.ConsumeUnicode(count)
  61. def ConsumeRandomInt(self, minimum=0, maximum=1234567890) -> int:
  62. """Consume bytes to produce an integer.
  63. Args:
  64. minimum (int, optional): The minimum value of the integer. Defaults to 0.
  65. maximum (int, optional): The maximum value of the integer. Defaults to 1234567890.
  66. Returns:
  67. int: An integer.
  68. """
  69. return self.ConsumeIntInRange(minimum, maximum)