diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 0000000..d040b9c
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,634 @@
+[MAIN]
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Clear in-memory caches upon conclusion of linting. Useful if running pylint
+# in a server-like mode.
+clear-cache-post-run=no
+
+# Load and enable all available extensions. Use --list-extensions to see a list
+# all available extensions.
+#enable-all-extensions=
+
+# In error mode, messages with a category besides ERROR or FATAL are
+# suppressed, and no reports are done by default. Error mode is compatible with
+# disabling specific errors.
+#errors-only=
+
+# Always return a 0 (non-error) status code, even if lint errors are found.
+# This is primarily useful in continuous integration scripts.
+#exit-zero=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code.
+extension-pkg-allow-list=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
+# for backward compatibility.)
+extension-pkg-whitelist=
+
+# Return non-zero exit code if any of these messages/categories are detected,
+# even if score is above --fail-under value. Syntax same as enable. Messages
+# specified are enabled, while categories only check already-enabled messages.
+fail-on=
+
+# Specify a score threshold under which the program will exit with error.
+fail-under=10
+
+# Interpret the stdin as a python script, whose filename needs to be passed as
+# the module_or_package argument.
+#from-stdin=
+
+# Files or directories to be skipped. They should be base names, not paths.
+ignore=CVS
+
+# Add files or directories matching the regular expressions patterns to the
+# ignore-list. The regex matches against paths and can be in Posix or Windows
+# format. Because '\\' represents the directory delimiter on Windows systems,
+# it can't be used as an escape character.
+ignore-paths=
+
+# Files or directories matching the regular expression patterns are skipped.
+# The regex matches against base names, not paths. The default value ignores
+# Emacs file locks
+ignore-patterns=^\.#
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis). It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
+# number of processors available to use, and will cap the count on Windows to
+# avoid hangs.
+jobs=1
+
+# Control the amount of potential inferred values when inferring a single
+# object. This can help the performance when dealing with large functions or
+# complex, nested conditions.
+limit-inference-results=100
+
+# List of plugins (as comma separated values of python module names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Minimum Python version to use for version dependent checks. Will default to
+# the version used to run pylint.
+py-version=3.10
+
+# Discover python modules and packages in the file system subtree.
+recursive=no
+
+# Add paths to the list of the source roots. Supports globbing patterns. The
+# source root is an absolute path or a path relative to the current working
+# directory used to determine a package namespace for modules located under the
+# source root.
+source-roots=
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages.
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+# In verbose mode, extra non-checker-related info will be displayed.
+#verbose=
+
+
+[BASIC]
+
+# Naming style matching correct argument names.
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style. If left empty, argument names will be checked with the set
+# naming style.
+#argument-rgx=
+
+# Naming style matching correct attribute names.
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style. If left empty, attribute names will be checked with the set naming
+# style.
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma.
+bad-names=foo,
+ bar,
+ baz,
+ toto,
+ tutu,
+ tata
+
+# Bad variable names regexes, separated by a comma. If names match any regex,
+# they will always be refused
+bad-names-rgxs=
+
+# Naming style matching correct class attribute names.
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style. If left empty, class attribute names will be checked
+# with the set naming style.
+#class-attribute-rgx=
+
+# Naming style matching correct class constant names.
+class-const-naming-style=UPPER_CASE
+
+# Regular expression matching correct class constant names. Overrides class-
+# const-naming-style. If left empty, class constant names will be checked with
+# the set naming style.
+#class-const-rgx=
+
+# Naming style matching correct class names.
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-
+# style. If left empty, class names will be checked with the set naming style.
+#class-rgx=
+
+# Naming style matching correct constant names.
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style. If left empty, constant names will be checked with the set naming
+# style.
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names.
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style. If left empty, function names will be checked with the set
+# naming style.
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma.
+good-names=i,
+ j,
+ k,
+ ex,
+ Run,
+ _
+
+# Good variable names regexes, separated by a comma. If names match any regex,
+# they will always be accepted
+good-names-rgxs=
+
+# Include a hint for the correct naming format with invalid-name.
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names.
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style. If left empty, inline iteration names will be checked
+# with the set naming style.
+#inlinevar-rgx=
+
+# Naming style matching correct method names.
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style. If left empty, method names will be checked with the set naming style.
+#method-rgx=
+
+# Naming style matching correct module names.
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style. If left empty, module names will be checked with the set naming style.
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+# These decorators are taken in consideration only for invalid-name.
+property-classes=abc.abstractproperty
+
+# Regular expression matching correct type alias names. If left empty, type
+# alias names will be checked with the set naming style.
+#typealias-rgx=
+
+# Regular expression matching correct type variable names. If left empty, type
+# variable names will be checked with the set naming style.
+#typevar-rgx=
+
+# Naming style matching correct variable names.
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style. If left empty, variable names will be checked with the set
+# naming style.
+#variable-rgx=
+
+
+[CLASSES]
+
+# Warn about protected attribute access inside special methods
+check-protected-access-in-special-methods=no
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp,
+ asyncSetUp,
+ __post_init__
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[DESIGN]
+
+# List of regular expressions of class ancestor names to ignore when counting
+# public methods (see R0903)
+exclude-too-few-public-methods=
+
+# List of qualified class names to ignore when counting class parents (see
+# R0901)
+ignored-parents=
+
+# Maximum number of arguments for function / method.
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in an if statement (see R0916).
+max-bool-expr=5
+
+# Maximum number of branch for function / method body.
+max-branches=12
+
+# Maximum number of locals for function / method body.
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body.
+max-returns=6
+
+# Maximum number of statements in function / method body.
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when caught.
+overgeneral-exceptions=builtins.BaseException,builtins.Exception
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )??$
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Maximum number of characters on a single line.
+max-line-length=100
+
+# Maximum number of lines in a module.
+max-module-lines=1000
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[IMPORTS]
+
+# List of modules that can be imported at any level, not just the top level
+# one.
+allow-any-import-level=
+
+# Allow explicit reexports by alias from a package __init__.
+allow-reexport-from-package=no
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Deprecated modules which should not be used, separated by a comma.
+deprecated-modules=
+
+# Output a graph (.gv or any supported image format) of external dependencies
+# to the given file (report RP0402 must not be disabled).
+ext-import-graph=
+
+# Output a graph (.gv or any supported image format) of all (i.e. internal and
+# external) dependencies to the given file (report RP0402 must not be
+# disabled).
+import-graph=
+
+# Output a graph (.gv or any supported image format) of internal dependencies
+# to the given file (report RP0402 must not be disabled).
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+# Couples of modules and preferred modules, separated by a comma.
+preferred-modules=
+
+
+[LOGGING]
+
+# The type of string formatting that logging methods do. `old` means using %
+# formatting, `new` is for `{}` formatting.
+logging-format-style=old
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format.
+logging-modules=logging
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
+# UNDEFINED.
+confidence=HIGH,
+ CONTROL_FLOW,
+ INFERENCE,
+ INFERENCE_FAILURE,
+ UNDEFINED
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once). You can also use "--disable=all" to
+# disable everything first and then re-enable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use "--disable=all --enable=classes
+# --disable=W".
+disable=raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ use-symbolic-message-instead,
+ use-implicit-booleaness-not-comparison-to-string,
+ use-implicit-booleaness-not-comparison-to-zero
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=
+
+
+[METHOD_ARGS]
+
+# List of qualified names (i.e., library.method) which require a timeout
+# parameter e.g. 'requests.api.get,requests.api.post'
+timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+ XXX,
+ TODO
+
+# Regular expression of note tags to take in consideration.
+notes-rgx=
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=sys.exit,argparse.parse_error
+
+
+[REPORTS]
+
+# Python expression which should return a score less than or equal to 10. You
+# have access to the variables 'fatal', 'error', 'warning', 'refactor',
+# 'convention', and 'info' which contain the number of messages in each
+# category, as well as 'statement' which is the total number of statements
+# analyzed. This score is used by the global evaluation report (RP0004).
+evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details.
+msg-template=
+
+# Set the output format. Available formats are: text, parseable, colorized,
+# json2 (improved json format), json (old json format) and msvs (visual
+# studio). You can also give a reporter class, e.g.
+# mypackage.mymodule.MyReporterClass.
+#output-format=
+
+# Tells whether to display a full report or only the messages.
+reports=no
+
+# Activate the evaluation score.
+score=yes
+
+
+[SIMILARITIES]
+
+# Comments are removed from the similarity computation
+ignore-comments=yes
+
+# Docstrings are removed from the similarity computation
+ignore-docstrings=yes
+
+# Imports are removed from the similarity computation
+ignore-imports=yes
+
+# Signatures are removed from the similarity computation
+ignore-signatures=yes
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes.
+max-spelling-suggestions=4
+
+# Spelling dictionary name. No available dictionaries : You need to install
+# both the python package and the system dependency for enchant to work.
+spelling-dict=
+
+# List of comma separated words that should be considered directives if they
+# appear at the beginning of a comment and should not be checked.
+spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains the private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to the private dictionary (see the
+# --spelling-private-dict-file option) instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=no
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=no
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+# Tells whether to warn about missing members when the owner of the attribute
+# is inferred to be None.
+ignore-none=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of symbolic message names to ignore for Mixin members.
+ignored-checks-for-mixins=no-member,
+ not-async-context-manager,
+ not-context-manager,
+ attribute-defined-outside-init
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+# Regex pattern to define which classes are considered mixins.
+mixin-class-rgx=.*[Mm]ixin
+
+# List of decorators that change the signature of a decorated function.
+signature-mutators=
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid defining new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of names allowed to shadow builtins
+allowed-redefined-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+ _cb
+
+# A regular expression matching the name of dummy variables (i.e. expected to
+# not be used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored.
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
diff --git a/examples/inference.ipynb b/examples/inference.ipynb
index ef1ae03..ff0ee81 100644
--- a/examples/inference.ipynb
+++ b/examples/inference.ipynb
@@ -37,10 +37,7 @@
"extension: The extension of the tiles. The default is 'png', which is the standard extension for Web Map Tile Services.\n",
"\n",
"dump_percent: The percent of segmentation images to dump to the subdirectory 'segmentation/seg_results': 100 means all, 0 means none. The default is 0.\n"
- ],
- "metadata": {
- "collapsed": false
- }
+ ]
},
{
"cell_type": "code",
diff --git a/src/tile2net/tileseg/inference/__init__.py b/src/tile2net/tileseg/inference/__init__.py
index b6c73ab..de8340c 100644
--- a/src/tile2net/tileseg/inference/__init__.py
+++ b/src/tile2net/tileseg/inference/__init__.py
@@ -3,9 +3,8 @@
import concurrent.futures
from typing import Optional
-import time
-import numpy
+
"""
Copyright 2020 Nvidia Corporation
@@ -31,22 +30,19 @@
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""
-
+import os
+import numpy
from geopandas import GeoDataFrame, GeoSeries
-import itertools
-import torchvision.transforms as standard_transforms
-import torchvision.utils as vutils
+
import pandas as pd
import geopandas as gpd
-import os
import sys
import argh
import torch
from torch.utils.data import DataLoader
import torch.distributed as dist
-from torch.cuda import amp
from runx.logx import logx
import tile2net.tileseg.network.ocrnet
@@ -98,12 +94,12 @@ def inference_(args: Namespace):
):
weights.path.mkdir(parents=True, exist_ok=True)
logging.info(
- "Downloading weights for segmentation, this may take a while..."
+ "Downloading weights for segmentation, this may take a while..."
)
weights.download()
logging.info("Weights downloaded.")
- args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
+ args.best_record = {'epoch' : -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
# Enable CUDNN Benchmarking optimization
@@ -120,58 +116,34 @@ def inference_(args: Namespace):
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
- # Distributed training setup
-
- args.world_size = int(os.environ.get('WORLD_SIZE', num_gpus))
- dist.init_process_group(backend='nccl', init_method='env://')
- args.local_rank = dist.get_rank()
- torch.cuda.set_device(args.local_rank)
- args.distributed = True
- args.global_rank = int(os.environ['RANK'])
- # print(f'Using distributed training with {args.world_size} GPUs.')
- logger.info(f'Using distributed training with {args.world_size} GPUs.')
+ if args.eval == 'test':
+ # Single GPU setup
+ logger.info('Using a single GPU.')
+ args.local_rank = 0
+ torch.cuda.set_device(args.local_rank)
+ else:
+ # Distributed training setup
+ if "RANK" not in os.environ:
+ raise ValueError("You need to launch the process with torch.distributed.launch to \
+ set RANK environment variable")
+ args.world_size = int(os.environ.get('WORLD_SIZE', num_gpus))
+ dist.init_process_group(backend='nccl', init_method='env://')
+ args.local_rank = dist.get_rank()
+ torch.cuda.set_device(args.local_rank)
+ args.distributed = True
+ args.global_rank = int(os.environ['RANK'])
+ logger.info(f'Using distributed training with {args.world_size} GPUs.')
elif num_gpus == 1:
# Single GPU setup
- logger.info('Using a single GPU.')
args.local_rank = 0
torch.cuda.set_device(args.local_rank)
+ logger.info('Using a single GPU.')
else:
# CPU setup
# print('Using CPU.')
- logger.info('Using CPU.')
+ logger.info('Using CPU. This is not recommended for inference.')
args.local_rank = -1 # Indicating CPU usage
- # if 'WORLD_SIZE' in os.environ and args.model.apex:
- # # args.model.apex = int(os.environ['WORLD_SIZE']) > 1
- # args.world_size = int(os.environ['WORLD_SIZE'])
- # args.global_rank = int(os.environ['RANK'])
-
- # if args.model.apex:
- # print('Global Rank: {} Local Rank: {}'.format(
- # args.global_rank, args.local_rank))
- # torch.cuda.set_device(args.local_rank)
- # torch.distributed.init_process_group(backend='nccl',
- # init_method='env://')
-
- # def check_termination(epoch):
- # if AutoResume:
- # shouldterminate = AutoResume.termination_requested()
- # if shouldterminate:
- # if args.global_rank == 0:
- # progress = "Progress %d%% (epoch %d of %d)" % (
- # (epoch * 100 / args.max_epoch),
- # epoch,
- # args.max_epoch
- # )
- # AutoResume.request_resume(
- # user_dict={"RESUME_FILE": logx.save_ckpt_fn,
- # "TENSORBOARD_DIR": args.result_dir,
- # "EPOCH": str(epoch)
- # }, message=progress)
- # return 1
- # else:
- # return 1
- # return 0
def run_inference(args=args, rasterfactory=None):
"""
@@ -304,14 +276,14 @@ def validate(val_loader, net, criterion, optim, epoch,
# pred.update({img_names[0]: dict(zip(values, counts))})
dumper.dump(
- {'gt_images': labels, 'input_images': input_images, 'img_names': img_names,
- 'assets': assets},
- val_idx, testing=True, grid=grid)
+ {'gt_images': labels, 'input_images': input_images, 'img_names': img_names,
+ 'assets' : assets},
+ val_idx, testing=True, grid=grid)
else:
- dumper.dump({'gt_images': labels,
+ dumper.dump({'gt_images' : labels,
'input_images': input_images,
- 'img_names': img_names,
- 'assets': assets}, val_idx)
+ 'img_names' : img_names,
+ 'assets' : assets}, val_idx)
if val_idx > 5 and args.options.test_mode:
break
@@ -325,8 +297,8 @@ def validate(val_loader, net, criterion, optim, epoch,
polys = grid.ntw_poly
# net = PedNet(polys, grid.project)
net = PedNet(
- poly=polys,
- project=grid.project,
+ poly=polys,
+ project=grid.project,
)
net.convert_whole_poly2line()
@@ -408,9 +380,9 @@ def inference(self, rasterfactory=None):
args = self.args
logx.initialize(
- logdir=str(args.result_dir),
- tensorboard=True, hparams=vars(args),
- global_rank=args.global_rank
+ logdir=str(args.result_dir),
+ tensorboard=True, hparams=vars(args),
+ global_rank=args.global_rank
)
assert_and_infer_cfg(args)
@@ -454,20 +426,20 @@ def inference(self, rasterfactory=None):
match args.model.eval:
case 'test':
self.validate(
- val_loader, net, criterion=None, optim=None, epoch=0,
- calc_metrics=False, dump_assets=args.dump_assets,
- dump_all_images=True, testing=True, grid=city_data,
- args=args,
+ val_loader, net, criterion=None, optim=None, epoch=0,
+ calc_metrics=False, dump_assets=args.dump_assets,
+ dump_all_images=True, testing=True, grid=city_data,
+ args=args,
)
return 0
case 'folder':
# Using a folder for evaluation means to not calculate metrics
self.validate(
- val_loader, net, criterion=criterion_val, optim=optim, epoch=0,
- calc_metrics=False, dump_assets=args.dump_assets,
- dump_all_images=True,
- args=args,
+ val_loader, net, criterion=criterion_val, optim=optim, epoch=0,
+ calc_metrics=False, dump_assets=args.dump_assets,
+ dump_all_images=True,
+ args=args,
)
return 0
@@ -503,10 +475,10 @@ def validate(
gdfs: list[GeoDataFrame] = []
self.dumper = dumper = self.Dumper(
- val_len=len(val_loader),
- dump_all_images=dump_all_images,
- dump_assets=dump_assets,
- args=args,
+ val_len=len(val_loader),
+ dump_all_images=dump_all_images,
+ dump_assets=dump_assets,
+ args=args,
)
net.eval()
@@ -519,16 +491,16 @@ def validate(
# Run network
assets, _iou_acc = eval_minibatch(
- data, net, criterion, val_loss, calc_metrics, args, val_idx,
+ data, net, criterion, val_loss, calc_metrics, args, val_idx,
)
iou_acc += _iou_acc
input_images, labels, img_names, _ = data
dumpdict = dict(
- gt_images=labels,
- input_images=input_images,
- img_names=img_names,
- assets=assets,
+ gt_images=labels,
+ input_images=input_images,
+ img_names=img_names,
+ assets=assets,
)
if testing:
# prediction = assets['predictions'][0]
@@ -559,7 +531,7 @@ def validate(
if not gdfs:
poly_network = gpd.GeoDataFrame()
logging.warning(
- f'No polygons were dumped'
+ f'No polygons were dumped'
)
else:
poly_network = pd.concat(gdfs)
@@ -698,7 +670,7 @@ def map_features(
# write the segmentation to assets
future = self.threads.submit(np.save, tile.segmentation, src_img)
self.futures.append(future)
- result = super().map_features(tile, src_img, img_array=img_array)
+ result = super().map_features(tile, src_img, img_array=img_array)
for future in self.futures:
future.result()
return result
@@ -727,15 +699,15 @@ def inference(self, rasterfactory=None):
it_exists = threads.map(os.path.exists, paths)
predictions = threads.map(np.load, paths)
it_polygons = (
- self.Dumper.map_features(tile, prediction, img_array=True)
- for tile, prediction, exists in zip(grid.tiles.ravel(), predictions, it_exists)
- if exists
+ self.Dumper.map_features(tile, prediction, img_array=True)
+ for tile, prediction, exists in zip(grid.tiles.ravel(), predictions, it_exists)
+ if exists
)
gdfs = [
- polygons
- for polygons in it_polygons
- if polygons is not None
+ polygons
+ for polygons in it_polygons
+ if polygons is not None
]
logger.debug(f'{len(gdfs)} polygons dumped')
if not len(gdfs):
@@ -744,7 +716,7 @@ def inference(self, rasterfactory=None):
if not gdfs:
poly_network = gpd.GeoDataFrame()
logging.warning(
- f'No polygons were dumped'
+ f'No polygons were dumped'
)
else:
poly_network = pd.concat(gdfs)