Skip to content

Commit

Permalink
Add progress bar to companies
Browse files Browse the repository at this point in the history
  • Loading branch information
cuducos committed Nov 17, 2019
1 parent 6fe22c0 commit a13471d
Showing 1 changed file with 8 additions and 23 deletions.
31 changes: 8 additions & 23 deletions serenata_toolbox/companies/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import aiohttp
import numpy as np
import pandas as pd
from tqdm import tqdm

from serenata_toolbox import log
from serenata_toolbox.companies.db import Database
Expand Down Expand Up @@ -76,10 +77,7 @@ def __init__(self, path="data", header="cnpj_cpf"):
self.output = self.path / f"{date.today()}-companies.csv.xz"
self.header = header
self.db = Database(path)

self.last_count_at = datetime.now()
self.count = 0
self._datasets = None # cache
self._datasets, self._documents = None, None # cache

@property
def datasets(self):
Expand All @@ -102,6 +100,9 @@ def is_cnpj(number):

@property
def documents(self):
if self._documents:
return self._documents

numbers = set()
for dataset in self.datasets:
log.info(f"Reading {dataset}…")
Expand All @@ -122,7 +123,8 @@ def documents(self):
if self.is_cnpj(number):
numbers.add(number)

yield from numbers
self._documents = tuple(numbers)
return self._documents

def translate_dict_keys(self, obj, translations=None):
translations = translations or self.TRANSLATION
Expand All @@ -148,33 +150,16 @@ async def companies(self):
semaphore = asyncio.Semaphore(2 ** 12)

async with semaphore, aiohttp.ClientSession() as session:
for cnpj in self.documents:
for cnpj in tqdm(self.documents, unit="companies"):
company = await self.db.get_company(session, cnpj)
if not company:
continue

company = self.serialize(company)
self.count += 1
companies.append(company)

if self.count % 100 == 0:
self.log_count()

if self.count % 100 != 0:
self.log_count()

return companies

def log_count(self):
now = datetime.now()
delta = now - self.last_count_at
ratio = self.count / delta.total_seconds()

msg = f"{self.count:,} companies fetched ({ratio:.2f} companies/s)"
log.info(msg)

self.last_count_at = now

def __call__(self):
companies = asyncio.run(self.companies())
df = pd.DataFrame(companies)
Expand Down

0 comments on commit a13471d

Please sign in to comment.