archive.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from dataclasses import asdict
  2. from typing import List
  3. from dacite import from_dict
  4. import json
  5. import requests
  6. import sqlite3
  7. from twitter_v2.types import TweetSearchResponse, DMEventsResponse, UserSearchResponse
  8. class ArchiveTweetSource:
  9. """
  10. id, created_at, retweeted, favorited, retweet_count, favorite_count, full_text, in_reply_to_status_id_str, in_reply_to_user_id, in_reply_to_screen_nam
  11. """
  12. def __init__ (self, archive_path, db_path = ".data/tweet.db", archive_user_id = None):
  13. self.archive_path = archive_path
  14. self.user_id = archive_user_id
  15. self.db_path = db_path
  16. return
  17. def get_db (self):
  18. db = sqlite3.connect(self.db_path)
  19. return db
  20. def get_user_timeline (self,
  21. author_id = None, max_results = 10, since_id = None):
  22. if max_results == None:
  23. max_results = -1
  24. sql_params = []
  25. where_sql = []
  26. # if the ID is not stored as a number (eg. string) then this could be a problem
  27. if since_id:
  28. where_sql.append("cast(id as integer) > ?")
  29. sql_params.append(since_id)
  30. #if author_id:
  31. # where_sql.append("author_id = ?")
  32. # sql_params.append(author_id)
  33. where_sql = " and ".join(where_sql)
  34. sql_cols = "id, created_at, retweeted, favorited, retweet_count, favorite_count, full_text, in_reply_to_status_id_str, in_reply_to_user_id, in_reply_to_screen_name"
  35. if author_id:
  36. sql_cols += ", '{}' as author_id".format(author_id)
  37. if where_sql:
  38. where_sql = "where {}".format(where_sql)
  39. sql = "select {} from tweet {} order by cast(id as integer) asc limit ?".format(sql_cols, where_sql)
  40. sql_params.append(max_results)
  41. db = self.get_db()
  42. cur = db.cursor()
  43. cur.row_factory = sqlite3.Row
  44. print(sql)
  45. print(sql_params)
  46. results = list(map(dict, cur.execute(sql, sql_params).fetchall()))
  47. return results
  48. def get_tweet (self, id_):
  49. tweets = self.get_tweets([id_])
  50. if len(tweets):
  51. return tweets[0]
  52. def get_tweets (self,
  53. ids):
  54. sql_params = []
  55. where_sql = []
  56. ids_in_list_sql = "id in ({})".format( ','.join(['?'] * len(ids)))
  57. where_sql.append(ids_in_list_sql)
  58. sql_params += ids
  59. where_sql = " and ".join(where_sql)
  60. sql = "select * from tweet where {}".format(where_sql)
  61. db = self.get_db()
  62. cur = db.cursor()
  63. cur.row_factory = sqlite3.Row
  64. results = list(map(dict, cur.execute(sql, sql_params).fetchall()))
  65. results.sort(key=lambda t: ids.index(t['id']))
  66. return results
  67. def search_tweets (self,
  68. query,
  69. since_id = None,
  70. max_results = 10,
  71. sort_order = None
  72. ):
  73. return