Skip to content

Commit

Permalink
Add typing to dependencies and improve type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Larralde committed Mar 23, 2018
1 parent 52ac603 commit 1b1e624
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 147 deletions.
11 changes: 9 additions & 2 deletions instalooter/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from __future__ import absolute_import
from __future__ import unicode_literals

import typing

try:
import ujson as json
import simplejson as json
except ImportError:
import json
import json # type: ignore

try:
import PIL.Image
Expand All @@ -20,6 +22,7 @@
from operator import length_hint
except ImportError:
def length_hint(obj, default=0):
# type: (typing.Any, int) -> int
"""Return an estimate of the number of items in obj.
This is useful for presizing containers when building from an
Expand All @@ -31,6 +34,7 @@ def length_hint(obj, default=0):
See Also:
`PEP 424 <https://www.python.org/dev/peps/pep-0424/>`_
"""
try:
return len(obj)
Expand All @@ -51,3 +55,6 @@ def length_hint(obj, default=0):
if hint < 0:
raise ValueError("__length_hint__() should return >= 0")
return hint


__all__ = ["PIL", "piexif", "json", "length_hint"]
17 changes: 13 additions & 4 deletions instalooter/_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
# coding: utf-8
"""Internal utility classes and functions.
"""
from __future__ import absolute_import
from __future__ import unicode_literals

import datetime
import itertools
import operator
import os
from typing import Any, Dict, Mapping, Optional, Text

import six


class NameGenerator(object):
"""Generator for filenames using a template.
"""Generator for filenames using a templitertoolsitertoolsitertoolsate.
"""

@classmethod
def _get_info(cls, media):
# type: (Mapping[Text, Any]) -> Mapping[Text, Any]

info = {
'id': media['id'],
'code': media['shortcode'],
Expand All @@ -26,7 +30,7 @@ def _get_info(cls, media):
'likescount': media['edge_media_preview_like']['count'],
'width': media.get('dimensions', {}).get('width'),
'height': media.get('dimensions', {}).get('height'),
}
} # type: Dict[Text, Any]

timestamp = media.get('date') or media.get('taken_at_timestamp')
if timestamp is not None:
Expand All @@ -35,20 +39,25 @@ def _get_info(cls, media):
"h{0.minute}m{0.second}s{0.microsecond}").format(dt)
info['date'] = datetime.date.fromtimestamp(timestamp)

return dict(six.moves.filter(operator.itemgetter(1), six.iteritems(info)))
return dict(six.moves.filter(
operator.itemgetter(1), six.iteritems(info)))

def __init__(self, template="{id}"):
# type: (Text) -> None
self.template = template

def base(self, media):
# type: (Mapping[Text, Any]) -> Text
info = self._get_info(media)
return self.template.format(**info)

def file(self, media, ext=None):
# type: (Mapping[Text, Any], Optional[Text]) -> Text
ext = ext or ("mp4" if media['is_video'] else "jpg")
return os.path.extsep.join([self.base(media), ext])

def needs_extended(self, media):
# type: (Mapping[Text, Any]) -> bool
try:
self.base(media)
return False
Expand Down
98 changes: 72 additions & 26 deletions instalooter/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@
from __future__ import absolute_import
from __future__ import unicode_literals

import io
import getpass
import logging
import os
import typing
from typing import Any, Mapping, Optional, Text, Type, Union

import requests
import six
from requests import Session

from .looters import HashtagLooter, InstaLooter, ProfileLooter
from .looters import InstaLooter, HashtagLooter, ProfileLooter
from .pbar import TqdmProgressBar

# mypy annotations
if False:
import io
from typing import Any, Dict, Type, Union


#: The module logger
logger = logging.getLogger(__name__)
Expand All @@ -31,55 +28,103 @@ class BatchRunner(object):
_CLS_MAP = {
'users': ProfileLooter,
'hashtag': HashtagLooter,
} # type: Dict[str, Type[InstaLooter]]
} # type: Mapping[Text, Union[Type[ProfileLooter], Type[HashtagLooter]]]

def __init__(self,
handle, # type: Any
args=None # Dict[str, Any]
):
# type: (...) -> None
def __init__(self, handle, args=None):
# type: (Any, Optional[Mapping[Text, Any]]) -> None

close_handle = False
if isinstance(handle, six.binary_type):
handle = handle.decode('utf-8')
if isinstance(handle, six.text_type):
handle = open(handle)
fp = open(handle) # type: typing.IO
close_handle = True
else:
fp = handle

try:
self.args = args or {}
self.parser = six.moves.configparser.ConfigParser()
getattr(self.parser, "readfp" if six.PY2 else "read_file")(handle)
getattr(self.parser, "readfp" if six.PY2 else "read_file")(fp)
finally:
if close_handle:
handle.close()

@typing.overload
def _getboolean(self, section_id, key, default):
# type: (Text, Text, bool) -> bool
pass

@typing.overload
def _getboolean(self, section_id, key):
# type: (Text, Text) -> Optional[bool]
pass

@typing.overload
def _getboolean(self, section_id, key, default):
# type: (Text, Text, None) -> Optional[bool]
pass

def _getboolean(self, section_id, key, default=None):
# type: (Text, Text, Optional[bool]) -> Optional[bool]
if self.parser.has_option(section_id, key):
return self.parser.getboolean(section_id, key)
return default

@typing.overload
def _getint(self, section_id, key, default):
# type: (Text, Text, None) -> Optional[int]
pass

@typing.overload
def _getint(self, section_id, key):
# type: (Text, Text) -> Optional[int]
pass

@typing.overload
def _getint(self, section_id, key, default):
# type: (Text, Text, int) -> int
pass

def _getint(self, section_id, key, default=None):
# type: (Text, Text, Optional[int]) -> Optional[int]
if self.parser.has_option(section_id, key):
return self.parser.getint(section_id, key)
return default

@typing.overload
def _get(self, section_id, key, default):
# type: (Text, Text, None) -> Optional[Text]
pass

@typing.overload
def _get(self, section_id, key):
# type: (Text, Text) -> Optional[Text]
pass

@typing.overload
def _get(self, section_id, key, default):
# type: (Text, Text, Text) -> Text
pass

def _get(self, section_id, key, default=None):
# type: (Text, Text, Optional[Text]) -> Optional[Text]
if self.parser.has_option(section_id, key):
return self.parser.get(section_id, key)
return default

def runAll(self):
# type: (...) -> None
"""Run all the jobs specified in the configuration file.
"""

logger.debug("Creating batch session")
session = requests.Session()
session = Session()

for section_id in self.parser.sections():
self.runJob(section_id, session=session)

def runJob(self, section_id, session=None):
# type: (Text, Optional[Session]) -> None
"""Run a job as described in the section named ``section_id``.
Raises:
Expand All @@ -89,7 +134,7 @@ def runJob(self, section_id, session=None):
if not self.parser.has_section(section_id):
raise KeyError('section not found: {}'.format(section_id))

session = session or requests.Session()
session = session or Session()

for name, looter_cls in six.iteritems(self._CLS_MAP):

Expand All @@ -115,24 +160,25 @@ def runJob(self, section_id, session=None):
extended_dump=self._getboolean(section_id, 'extended-dump', False),
session=session)

# if self.parser.has_option(section_id, 'username'):
# looter.logout()
# username = self._get(section_id, 'username')
# password = self._get(section_id, 'password') or \
# getpass.getpass('Password for "{}": '.format(username))
# looter.login(username, password)
if self.parser.has_option(section_id, 'username'):
looter.logout()
username = self._get(section_id, 'username')
password = self._get(section_id, 'password') or \
getpass.getpass('Password for "{}": '.format(username))
looter.login(username, password) # type: ignore

n = looter.download(
directory,
media_count=self._getint(section_id, 'num-to-dl'),
timeframe=self._get(section_id, 'timeframe'),
# FIXME: timeframe=self._get(section_id, 'timeframe'),
new_only=self._getboolean(section_id, 'new', False),
pgpbar_cls=None if quiet else TqdmProgressBar,
dlpbar_cls=None if quiet else TqdmProgressBar)

logger.log(35, "Downloaded {} medias !".format(n))

def getTargets(self, raw_string):
# type: (Optional[Text]) -> Mapping[Text, Text]
"""Extract targets from a string in 'key: value' format.
"""
targets = {}
Expand Down
4 changes: 3 additions & 1 deletion instalooter/cli/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
import logging
import warnings
from typing import Callable

logging.SPAM = 5
logging.NOTICE = 25
Expand All @@ -18,6 +19,7 @@


def warn_logging(logger):
# type: (logging.Logger) -> Callable
"""Create a `showwarning` function that uses the given logger.
Arguments:
Expand All @@ -38,7 +40,7 @@ def wrap_warnings(logger):
Arguments:
logger (~logging.logger): the logger to wrap warnings with when
the decorated function is called
the decorated function is called.
Returns:
`function`: a decorator function.
Expand Down
Loading

0 comments on commit 1b1e624

Please sign in to comment.