Skip to content

Commit

Permalink
Mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
natelust committed Apr 12, 2023
1 parent ffadae8 commit f4ba318
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def getOutputSchema(self) -> KeyedDataSchema:

def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
results = {}
highMaskKey = f'{self.identity.lower() or ""}HighSNMask'
highMaskKey = f'{(self.identity or "").lower()}HighSNMask'
results[highMaskKey] = self.highSNSelector(data, **kwargs)

lowMaskKey = f'{self.identity.lower() or ""}LowSNMask'
lowMaskKey = f'{(self.identity or "").lower()}LowSNMask'
results[lowMaskKey] = self.lowSNSelector(data, **kwargs)

prefix = f"{band}_" if (band := kwargs.get("band")) else ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def getOutputSchema(self) -> KeyedDataSchema:
return (
(self.name_mask, Vector),
(self.name_median, Scalar),
(self.name_sig_mad, Scalar),
(self.name_sigmaMad, Scalar),
(self.name_count, Scalar),
(self.name_select_maximum, Scalar),
(self.name_select_median, Scalar),
Expand Down
6 changes: 2 additions & 4 deletions python/lsst/analysis/tools/actions/vector/ellipticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,13 @@ class CalcEDiff(VectorAction):
the returned quantity therefore corresponds to |e|*exp(j*theta).
"""

colA = ConfigurableActionField(
colA = ConfigurableActionField[VectorAction](
doc="Ellipticity to subtract from",
dtype=VectorAction,
default=CalcE,
)

colB = ConfigurableActionField(
colB = ConfigurableActionField[VectorAction](
doc="Ellipticity to subtract",
dtype=VectorAction,
default=CalcE,
)

Expand Down
24 changes: 13 additions & 11 deletions python/lsst/analysis/tools/actions/vector/vectorActions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ class DownselectVector(VectorAction):

vectorKey = Field[str](doc="column key to load from KeyedData")

selector = ConfigurableActionField(doc="Action which returns a selection mask", default=VectorSelector)
selector = ConfigurableActionField[VectorAction](
doc="Action which returns a selection mask", default=VectorSelector
)

def getInputSchema(self) -> KeyedDataSchema:
yield (self.vectorKey, Vector)
Expand All @@ -84,7 +86,7 @@ class MultiCriteriaDownselectVector(VectorAction):
def getInputSchema(self) -> KeyedDataSchema:
yield (self.vectorKey, Vector)
for action in self.selectors:
yield from cast(VectorAction, action).getInputSchema()
yield from action.getInputSchema()

def __call__(self, data: KeyedData, **kwargs) -> Vector:
mask: Optional[Vector] = None
Expand Down Expand Up @@ -118,8 +120,8 @@ def __call__(self, data: KeyedData, **kwargs) -> Vector:
class FractionalDifference(VectorAction):
"""Calculate (A-B)/B"""

actionA = ConfigurableActionField(doc="Action which supplies vector A", dtype=VectorAction)
actionB = ConfigurableActionField(doc="Action which supplies vector B", dtype=VectorAction)
actionA = ConfigurableActionField[VectorAction](doc="Action which supplies vector A")
actionB = ConfigurableActionField[VectorAction](doc="Action which supplies vector B")

def getInputSchema(self) -> KeyedDataSchema:
yield from self.actionA.getInputSchema() # type: ignore
Expand Down Expand Up @@ -176,8 +178,8 @@ def __call__(self, data: KeyedData, **kwargs) -> Vector:
class SubtractVector(VectorAction):
"""Calculate (A-B)"""

actionA = ConfigurableActionField(doc="Action which supplies vector A", dtype=VectorAction)
actionB = ConfigurableActionField(doc="Action which supplies vector B", dtype=VectorAction)
actionA = ConfigurableActionField[VectorAction](doc="Action which supplies vector A")
actionB = ConfigurableActionField[VectorAction](doc="Action which supplies vector B")

def getInputSchema(self) -> KeyedDataSchema:
yield from self.actionA.getInputSchema() # type: ignore
Expand All @@ -192,8 +194,8 @@ def __call__(self, data: KeyedData, **kwargs) -> Vector:
class DivideVector(VectorAction):
"""Calculate (A/B)"""

actionA = ConfigurableActionField(doc="Action which supplies vector A", dtype=VectorAction)
actionB = ConfigurableActionField(doc="Action which supplies vector B", dtype=VectorAction)
actionA = ConfigurableActionField[VectorAction](doc="Action which supplies vector A")
actionB = ConfigurableActionField[VectorAction](doc="Action which supplies vector B")

def getInputSchema(self) -> KeyedDataSchema:
yield from self.actionA.getInputSchema() # type: ignore
Expand Down Expand Up @@ -289,8 +291,8 @@ class ExtinctionCorrectedMagDiff(VectorAction):
If band1 and band2 are supplied, the flux column names are ignored.
"""

magDiff = ConfigurableActionField(
doc="Action that returns a difference in magnitudes", default=MagDiff, dtype=VectorAction
magDiff = ConfigurableActionField[VectorAction](
doc="Action that returns a difference in magnitudes", default=MagDiff
)
ebvCol = Field[str](doc="E(B-V) Column Name", default="ebv")
band1 = Field[str](
Expand Down Expand Up @@ -395,7 +397,7 @@ class PerGroupStatistic(VectorAction):
"""

groupKey = Field[str](doc="Column key to use for forming groups", default="obj_index")
buildAction = ConfigurableActionField(doc="Action to build vector", default=LoadVector)
buildAction = ConfigurableActionField[VectorAction](doc="Action to build vector", default=LoadVector)
func = Field[str](doc="Name of function to be applied per group")

def getInputSchema(self) -> KeyedDataSchema:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _setActions(self) -> None:
# Need to pass a mapping of new names so the default names get the
# band prepended. Otherwise, each subsequent band's metric will
# overwrite the current one.
self.produce.newNames = {
self.produce.newNames = { # type: ignore
"validFracColumn": "{band}_validFracColumn",
"nanFracColumn": "{band}_nanFracColumn",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _setActions(self) -> None:
# overwrite the current one (e.g., running with g, r bands without
# this, you would get "meanSky," "meanSky"; with it: "g_meanSky,"
# "r_meanSky").
self.produce.newNames = {
self.produce.newNames = { # type: ignore
"medianSky": "{band}_medianSky",
"meanSky": "{band}_meanSky",
"stdevSky": "{band}_stdevSky",
Expand Down
8 changes: 4 additions & 4 deletions python/lsst/analysis/tools/analysisParts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def getInputSchema(self) -> KeyedDataSchema:
else:
filterOutputSchema[fieldName] = Vector

for action in self.calculateActions:
for name, typ in action.getInputSchema():
for calcAction in self.calculateActions:
for name, typ in calcAction.getInputSchema():
if name not in buildOutputSchema and name not in filterOutputSchema:
inputSchema[name] = typ
return ((name, typ) for name, typ in inputSchema.items())
Expand Down Expand Up @@ -117,8 +117,8 @@ def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
results[name] = item

view2 = data | results
for name, action in self.calculateActions.items():
match action(view2, **kwargs):
for name, calcAction in self.calculateActions.items():
match calcAction(view2, **kwargs):
case abc.Mapping() as item:
for key, result in item.items():
results[key] = result
Expand Down
6 changes: 2 additions & 4 deletions python/lsst/analysis/tools/analysisParts/diffMatched.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@
class MatchedCoaddVisitConfig(Config):
"""Config for tools that can be applied in coadd and visit contexts."""

context = ChoiceField(
context = ChoiceField[str](
doc="The analysis context for this class",
dtype=str,
allowed={"coadd": "Coadded images", "visit": "Single visit images"},
optional=False,
)
Expand All @@ -61,9 +60,8 @@ class MatchedRefCoaddTool(AnalysisTool):
appropriate candidates (and stores a match_candidate column).
"""

context = ChoiceField(
context = ChoiceField[str](
doc="The type of metric to compute",
dtype=str,
allowed={"diff": "Measured - ref value", "chi": "(Measured - ref value)/sigma (a.k.a. chi/pull)"},
optional=False,
)
Expand Down
10 changes: 6 additions & 4 deletions python/lsst/analysis/tools/analysisPlots/analysisPlots.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def _setActions(self) -> None:
self.process.buildActions.astromDiff = AstromDiff(
col1=f"coord_{coordStr}_target", col2=f"coord_{coordStr}_ref"
)
self.produce.yAxisLabel = f"${self.coordinate}_{{target}} - {self.coordinate}_{{ref}}$ (marcsec)"
tmpString = f"${self.coordinate}_{{target}} - {self.coordinate}_{{ref}}$ (marcsec)"
self.produce.yAxisLabel = tmpString # type: ignore

def setDefaults(self):
super().setDefaults()
Expand Down Expand Up @@ -381,11 +382,12 @@ def _setActions(self) -> None:
else:
raise ValueError(f"Unsupported {self.context=}")
coordStr = self.coordinate.lower()
self.process.buildActions.zStars = AstromDiff(
self.process.buildActions.zStars = AstromDiff( # type: ignore
col1=f"coord_{coordStr}_target", col2=f"coord_{coordStr}_ref"
)
self.produce.plotName = f"astromDiffSky_{self.coordinate}"
self.produce.zAxisLabel = f"${self.coordinate}_{{target}} - {self.coordinate}_{{ref}}$ (marcsec)"
self.produce.plotName = f"astromDiffSky_{self.coordinate}" # type: ignore
tmpString = f"${self.coordinate}_{{target}} - {self.coordinate}_{{ref}}$ (marcsec)"
self.produce.zAxisLabel = tmpString # type: ignore

def setDefaults(self):
super().setDefaults()
Expand Down
8 changes: 5 additions & 3 deletions python/lsst/analysis/tools/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ class AnalysisMetric(AnalysisTool):
setDefaults) to be `BaseMetricAction`.
"""

produce = ConfigurableActionField[MetricAction](doc="Action which returns a calculated Metric")
produce = ConfigurableActionField[MetricAction](
doc="Action which returns a calculated Metric"
) # type: ignore

def setDefaults(self):
super().setDefaults()
Expand All @@ -378,7 +380,7 @@ class AnalysisPlot(AnalysisTool):
it expects to be assigned to a `PlotAction`.
"""

produce = ConfigurableActionField[PlotAction](doc="Action which returns a plot")
produce = ConfigurableActionField[PlotAction](doc="Action which returns a plot") # type: ignore

def getOutputNames(self) -> Iterable[str]:
"""Return the names of the plots produced by this action.
Expand All @@ -392,7 +394,7 @@ def getOutputNames(self) -> Iterable[str]:
result : `tuple` of `str`
Names for each plot produced by this action.
"""
outNames = tuple(self.produce.getOutputNames())
outNames = tuple(self.produce.getOutputNames()) # type: ignore
if outNames:
return (f"{self.identity or ''}_{name}" for name in outNames)
else:
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/analysis/tools/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def collectInputNames(self) -> Iterable[str]:
for name, action in self.config.plots.items():
for column, dataType in action.getFormattedInputSchema(band=band):
inputs.add(column)
for name, action in self.config.metrics.items():
for column, dataType in action.getFormattedInputSchema(band=band):
for name, metricAction in self.config.metrics.items():
for column, dataType in metricAction.getFormattedInputSchema(band=band):
inputs.add(column)
return inputs
15 changes: 8 additions & 7 deletions python/lsst/analysis/tools/tasks/catalogMatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
StarSelector,
VisitPlotFlagSelector,
)
from ..interfaces import VectorAction


class AstropyMatchConfig(pexConfig.Config):
Expand Down Expand Up @@ -160,19 +161,19 @@ class CatalogMatchConfig(pipeBase.PipelineTaskConfig, pipelineConnections=Catalo
doc="Band to use when selecting objects, primarily for extendedness", default="i"
)

selectorActions = ConfigurableActionStructField(
selectorActions = ConfigurableActionStructField[VectorAction](
doc="Which selectors to use to narrow down the data for QA plotting.",
default={"flagSelector": CoaddPlotFlagSelector},
default={"flagSelector": CoaddPlotFlagSelector()},
)

sourceSelectorActions = ConfigurableActionStructField(
sourceSelectorActions = ConfigurableActionStructField[VectorAction](
doc="What types of sources to use.",
default={"sourceSelector": StarSelector},
default={"sourceSelector": StarSelector()},
)

extraColumnSelectors = ConfigurableActionStructField(
extraColumnSelectors = ConfigurableActionStructField[VectorAction](
doc="Other selectors that are not used in this task, but whose columns" "may be needed downstream",
default={"selector1": SnSelector, "selector2": GalaxySelector},
default={"selector1": SnSelector(), "selector2": GalaxySelector()},
)

extraColumns = pexConfig.ListField[str](
Expand Down Expand Up @@ -350,7 +351,7 @@ class CatalogMatchVisitConnections(
class CatalogMatchVisitConfig(CatalogMatchConfig, pipelineConnections=CatalogMatchVisitConnections):
selectorActions = ConfigurableActionStructField(
doc="Which selectors to use to narrow down the data for QA plotting.",
default={"flagSelector": VisitPlotFlagSelector},
default={"flagSelector": VisitPlotFlagSelector()},
)

extraColumns = pexConfig.ListField[str](
Expand Down

0 comments on commit f4ba318

Please sign in to comment.