archive.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from dataclasses import asdict
  2. from typing import List
  3. from dacite import from_dict
  4. import json
  5. import requests
  6. import sqlite3
  7. from twitter_v2.types import TweetSearchResponse, DMEventsResponse, UserSearchResponse
  8. class ArchiveTweetSource:
  9. """
  10. id, created_at, retweeted, favorited, retweet_count, favorite_count, full_text, in_reply_to_status_id_str, in_reply_to_user_id, in_reply_to_screen_nam
  11. """
  12. def __init__ (self, archive_path, db_path = ".data/tweet.db", archive_user_id = None):
  13. self.archive_path = archive_path
  14. self.user_id = archive_user_id
  15. self.db_path = db_path
  16. return
  17. def get_db (self):
  18. db = sqlite3.connect(self.db_path)
  19. return db
  20. def get_user_timeline (self,
  21. author_id = None, max_results = 10, since_id = None):
  22. if max_results == None:
  23. max_results = -1
  24. sql_params = []
  25. where_sql = []
  26. # if the ID is not stored as a number (eg. string) then this could be a problem
  27. if since_id:
  28. where_sql.append("cast(id as integer) > ?")
  29. sql_params.append(since_id)
  30. #if author_id:
  31. # where_sql.append("author_id = ?")
  32. # sql_params.append(author_id)
  33. where_sql = " and ".join(where_sql)
  34. sql_cols = "id, created_at, retweeted, favorited, retweet_count, favorite_count, full_text, in_reply_to_status_id_str, in_reply_to_user_id, in_reply_to_screen_name"
  35. if author_id:
  36. sql_cols += ", '{}' as author_id".format(author_id)
  37. if where_sql:
  38. where_sql = "where {}".format(where_sql)
  39. sql = "select {} from tweet {} order by cast(id as integer) asc limit ?".format(sql_cols, where_sql)
  40. sql_params.append(max_results)
  41. results = self.search_tweets_sql(sql, sql_params)
  42. return results
  43. def get_tweet (self, id_):
  44. tweets = self.get_tweets([id_])
  45. if len(tweets):
  46. return tweets[0]
  47. def get_tweets (self,
  48. ids):
  49. sql_params = []
  50. where_sql = []
  51. ids_in_list_sql = "id in ({})".format( ','.join(['?'] * len(ids)))
  52. where_sql.append(ids_in_list_sql)
  53. sql_params += ids
  54. where_sql = " and ".join(where_sql)
  55. sql = "select * from tweet where {}".format(where_sql)
  56. results = self.search_tweets_sql(sql, sql_params)
  57. results.sort(key=lambda t: ids.index(t['id']))
  58. return results
  59. def search_tweets_sql (self,
  60. sql,
  61. sql_params = []
  62. ):
  63. with self.get_db() as db:
  64. cur = db.cursor()
  65. cur.row_factory = sqlite3.Row
  66. results = list(map(dict, cur.execute(sql, sql_params).fetchall()))
  67. print(f'search_tweets_sql {len(results)}')
  68. return results