diff --git a/README.md b/README.md index e0524fe..121c149 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ You should now be ready to start using it. ## About the tool ### Command-line arguments ``` -usage: tiktok-hashtag-analysis [-h] [--file FILE] [-d] [--number NUMBER] [-p] [-t] [--output-dir OUTPUT_DIR] [--config CONFIG] [--log LOG] [--limit LIMIT] [-v] [hashtags ...] +usage: tiktok-hashtag-analysis [-h] [--file FILE] [-d] [--number NUMBER] [-p] [-t] [--output-dir OUTPUT_DIR] [--config CONFIG] [--log LOG] [--limit LIMIT] [-v] [--headed] [hashtags ...] Analyze hashtags within posts scraped from TikTok. @@ -35,6 +35,7 @@ optional arguments: --log LOG File to write logs to --limit LIMIT Maximum number of videos to download for each hashtag -v, --verbose Increase output verbosity + --headed Don't use headless version of TikTok scraper ``` ### Structure of output data diff --git a/tests/base.py b/tests/base.py index dbc139d..df9e62b 100644 --- a/tests/base.py +++ b/tests/base.py @@ -3,7 +3,16 @@ def test_scrape(tmp_path, hashtags): downloader = TikTokDownloader(hashtags=hashtags[:1], data_dir=tmp_path) - downloader.run(limit=1000, download=True, plot=True, table=True, number=20) + downloader.run( + limit=10, download=True, plot=True, table=True, number=5, headed=True + ) + + +def test_scrape_headless(tmp_path, hashtags): + downloader = TikTokDownloader(hashtags=hashtags[:1], data_dir=tmp_path) + downloader.run( + limit=10, download=True, plot=True, table=True, number=5, headed=False + ) def test_load_hashtags_from_file(tmp_path, hashtags): diff --git a/tests/cli.py b/tests/cli.py index 6aa267e..b489828 100644 --- a/tests/cli.py +++ b/tests/cli.py @@ -20,6 +20,7 @@ ("table", True, "--table"), ("table", True, "-t"), ("verbose", True, "--verbose"), + ("headed", True, "--headed"), ("verbose", True, "-v"), ("output_dir", "/tmp/tiktok_download", "--output-dir"), ("config", "~/.tiktok", "--config"), @@ -51,6 +52,7 @@ def test_output_dir_spec_noexist_nowrite(tmp_path): specified_output_dir=specified_output_dir, parser=parser ) assert system_exit.type == SystemExit + os.chmod(tmp_path, 0o666) def test_output_dir_spec_exist_nowrite(tmp_path): @@ -63,6 +65,7 @@ def test_output_dir_spec_exist_nowrite(tmp_path): specified_output_dir=specified_output_dir, parser=parser ) assert system_exit.type == SystemExit + os.chmod(tmp_path, 0o666) def test_output_dir_unspec_nowrite(monkeypatch, tmp_path): @@ -75,6 +78,7 @@ def test_output_dir_unspec_nowrite(monkeypatch, tmp_path): result = process_output_dir(specified_output_dir=None, parser=parser) monkeypatch.chdir(cwd) assert result == DEFAULT_OUTPUT_DIR + os.chmod(tmp_path, 0o666) def test_output_dir_spec_noexist_write(tmp_path): diff --git a/tiktok_hashtag_analysis/base.py b/tiktok_hashtag_analysis/base.py index 71d7e63..bcfcfcb 100644 --- a/tiktok_hashtag_analysis/base.py +++ b/tiktok_hashtag_analysis/base.py @@ -52,11 +52,15 @@ def load_hashtags_from_file(file: str) -> List[str]: # Retry upon encountering transient playwright errors @retry(retry=retry_if_exception_type(Error), stop=stop_after_attempt(3)) -async def _fetch_hashtag_data(hashtag: str, limit: int) -> List[Dict]: +async def _fetch_hashtag_data( + hashtag: str, limit: int, headed: bool = False +) -> List[Dict]: """Fetch data for videos containing a specified hashtag, asynchronously.""" data = [] async with TikTokApi() as api: - await api.create_sessions(ms_tokens=[], num_sessions=1, sleep_after=3) + await api.create_sessions( + ms_tokens=[], num_sessions=1, sleep_after=3, headless=not headed + ) async for video in api.hashtag(name=hashtag).videos(count=limit): data.append(video.as_dict) return data @@ -157,7 +161,7 @@ def prioritize_hashtags(self): } self.hashtags.sort(key=lambda h: last_edited.get(h, 0)) - def get_hashtag_posts(self, hashtag: str, limit: int): + def get_hashtag_posts(self, hashtag: str, limit: int, headed: bool): """Fetch data about posts that used a specified hashtag and merge with existing data, if it exists.""" @@ -172,8 +176,20 @@ def get_hashtag_posts(self, hashtag: str, limit: int): already_fetched_data = [] already_fetched_ids = set(video["id"] for video in already_fetched_data) - # Scrape posts that use the specified hashtag - fetched_data = asyncio.run(_fetch_hashtag_data(hashtag=hashtag, limit=limit)) + # Scrape posts that use the specified hashag + # Attempt to be robust against TikTok's countermeasures for headless browsing + try: + fetched_data = asyncio.run( + _fetch_hashtag_data(hashtag=hashtag, limit=limit, headed=headed) + ) + except Exception as e: + logger.warning( + "Encountered error {e} when fetching data, retrying in headed mode" + ) + fetched_data = asyncio.run( + _fetch_hashtag_data(hashtag=hashtag, limit=limit, headed=True) + ) + fetched_ids = set(video["id"] for video in fetched_data) if len(fetched_data) == 0: @@ -303,13 +319,21 @@ def plot(self, hashtag: str, number: int): plt.savefig(plot_file, bbox_inches="tight", facecolor="white", dpi=300) logger.info(f"Plot saved to file: {plot_file}") - def run(self, limit: int, download: bool, plot: bool, table: bool, number: int): + def run( + self, + limit: int, + download: bool, + plot: bool, + table: bool, + number: int, + headed: bool, + ): """Execute the specified operations on all specified hashtags.""" # Scrape all specified hashtags and perform analyses, depending on if # `--table`, `--plot`, and `--download` flags are used in the command for hashtag in self.hashtags: - self.get_hashtag_posts(hashtag=hashtag, limit=limit) + self.get_hashtag_posts(hashtag=hashtag, limit=limit, headed=headed) if plot: self.plot(hashtag=hashtag, number=number) if table: diff --git a/tiktok_hashtag_analysis/cli.py b/tiktok_hashtag_analysis/cli.py index 778c28f..63342e5 100644 --- a/tiktok_hashtag_analysis/cli.py +++ b/tiktok_hashtag_analysis/cli.py @@ -77,7 +77,11 @@ def create_parser(): help="Increase output verbosity", action="store_true", ) - + parser.add_argument( + "--headed", + help="Don't use headless version of TikTok scraper", + action="store_true", + ) return parser @@ -146,6 +150,7 @@ def main(): plot=args.plot, table=args.table, number=args.number, + headed=args.headed, )