Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6 implement vlbi fringe fit #19

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
188 changes: 188 additions & 0 deletions src/astroviper/calibration/apply_fringe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from xradio.vis.read_processing_set import read_processing_set

import dask
import numpy as np
import xarray as xa
import pandas as pd
import datetime

# I am surprised this is not some kind of standard function, but the
# pandas version scolds me for using it on a string.
#
# This implementation is O(n^2) but if that is ever an issue we are
# doing something else very wrong.
def unique(s):
"Get the unique characters in a string in the order they occur in the original"
u = []
for c in s:
if c not in u:
u.append(c)
return u

def nanCount(j):
return np.sum(np.isnan(j))

def numberCount(j):
return np.sum(~np.isnan(j))


def makeCalTable(xds):
"An attempt to make a calibration table out of coordinates"
pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)]))
coords = xa.Coordinates(coords={'time' : xa.Coordinates(coords = {'time' : 0.5*(xds.time[0]+xds.time[-1])}),
'antenna_name' : xds.antenna_xds.antenna_name,
'polarization' : pols_ant,
'parameter' : ['one', 'two', 'three', 'ah-ha']
})
cds = xa.Dataset(data_vars = dict(cals=(coords.sizes.keys(), np.zeros(tuple(coords.sizes.values()), complex))),
coords=coords)
return cds

#############################################################################
# How to make a (time, frequency) grid:
#
# np.expand_dims(dt, 1) + np.expand_dims(df, 0) => array of shape (nt, nf)
#############################################################################

class GridJonesCalculator(object):
def __init__(self, xds):
"""
"""
# We scrape all the metadata from an xds. Maybe this is wise, maybe not.
self.xds = xds
self.frequency = xds.frequency
self.time = xds.time
self.baseline_id = xds.baseline_id
self.VISIBILITY = xds.VISIBILITY
self.n_ants = xds.antenna_xds.antenna_name.size
# We expand out a copy of the baseline_antenna1_id array to have shape
# (1, n_baselines, 1, 1, 1)
self.ant1_mask = np.expand_dims(xds.baseline_antenna1_name.values, (0, 2, 3, 4))
self.ant2_mask = np.expand_dims(xds.baseline_antenna2_name.values, (0, 2, 3, 4))
self.makeAccumulatedJoneses()

def makeAccumulatedJoneses(self):
vcs = self.VISIBILITY.shape
# We are going to use 2x2 matrices for our Jones matrices because that's how they multiply
assert vcs[-1] == 4 # We'll figure other cases out later
new_shape = vcs[:-1] + (2, 2)
self.j_a1_composed = np.zeros(new_shape, complex)
self.j_a1_composed = np.identity(2)
self.j_a2_composed = np.ones(new_shape, complex)
self.j_a2_composed = np.identity(2)

def insertBaselineDimension(self, j, nbaselines):
"""We add a baseline dimension, but we also broadcast to it"""
# Insert a baseline dimension to a calibration matrix.
j_shape = j.shape
new_shape = j_shape[:1] + (nbaselines,) + j_shape[1:]
j = np.expand_dims(j, 1)
j = np.broadcast_to(j, new_shape)
return j

def calcGridJonesAnt(self, fp, df, dt):
"""Return Jones matrices for all grid points of an xds for a single set of fringefit parameters (which come in pairs one for each polarization)"""
# We now assume phi0, tau and r are 2-vectors
phi0, tau, r = fp
# We upscale the dimensions so that things broadcast nicely:
df_shaped = np.expand_dims(df, (0, 2))
dt_shaped = np.expand_dims(dt, (1, 2))
phi_shaped = np.expand_dims(phi0, (0, 1))
# Calculate phases:
phi = (phi_shaped +
2*np.pi*tau.values*df_shaped +
2*np.pi*r.values*dt_shaped)
# And then phasors.
# (I spent a long time trying to express this in a neat numpy way.
# I did not succeed, so I do it the ugly stupid way for now.)
many_jones_diags = np.exp(1J*phi, dtype=complex)
many_jones = np.zeros(many_jones_diags.shape + (2,), complex)
many_jones[:, :, 0, 0] = many_jones_diags[:, :, 0]
many_jones[:, :, 1, 1] = many_jones_diags[:, :, 1]
return many_jones

def calcGridJones(self, cal_quantum):
dt = (self.time - cal_quantum.t_ref).values
df = (self.frequency - cal_quantum.f_ref).values
# This needs fixed too
for iant, ant in enumerate(cal_quantum.coords['antenna_name'].values):
fp = cal_quantum.sel(antenna_name=ant)
params = np.sum(~np.isnan(fp.values))
print(f"{ant=} {params=}")
if params == 0:
continue
print(f"{fp.values}")
j = self.calcGridJonesAnt(fp, df, dt)
count = np.sum(~np.isnan(j))
print(f"{count=}")
print(f"{np.max(np.abs(j.flatten()))=}")
j = self.insertBaselineDimension(j, self.xds.baseline_id.size)
# Then we can make a version of our baseline jones matrices that only affects a specific ant1:
j_1_mask = np.where(self.ant1_mask != ant, j, 1)
j_2_mask = np.where(self.ant2_mask != ant, j, 1)
# The second array needs to be hermitianized on the last two axes, which we have to do by hand
j_a2 = j.transpose(0, 1, 2, 4, 3).conj()
# Then we can apply those entries to our corrected data array by multiplication:
# First antenna corrected by multiplication from the left:
print(f"{np.max(np.abs(self.j_a1_composed.flatten()))=}")
self.j_a1_composed = np.matmul(j, self.j_a1_composed)
# Second antenna in baseline corrected (by Hermitian matrix) from the right
self.j_a2_composed = np.matmul(self.j_a2_composed, j)
if False:
print(f"{numberCount(self.j_a1_composed)=}")
print(f"{nanCount(self.j_a1_composed)=}")
print(f"{numberCount(self.j_a2_composed)=}")
print(f"{nanCount(self.j_a2_composed)=}")
print(f"{np.max(np.abs(self.j_a1_composed.flatten()))=}")


# We need to consult the data for polarizations now.
ps = read_processing_set('n14c3.zarr')
ps.keys()

# Current version of this ps is split by SPW and not by field
xds = ps['n14c3_099']

# In fact, all xdses have the same polarization setup here, but whomst can say if that is always true?
# Actually, I think maybe we could?
pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)]))
quantumCoords = xa.Coordinates(coords={'antenna_name' : xds.antenna_xds.antenna_name,
'parameter' : range(3),
'polarization' : pols_ant
})
q = xa.DataArray(coords=quantumCoords)

q.attrs['f_ref'] = xds.frequency[0]
q.attrs['t_ref'] = xds.time[0]

# Note that the f_ref attr copies over a lot of metadata.
# Which I think is a good thing?
q.attrs['f_ref'].attrs['spectral_window_name']


# You can't assign to DataArrays by name, only by integer index.
q[0] = [[0.0, 0.0], [1.0e-9,-1.0e-9], [0, 0]]

gjc = GridJonesCalculator(xds)
gjc.calcGridJones(q)

def squareUpLastDimension(v):
s = v.shape[:-1] + (2,2)
v2 = np.reshape(v, s)
return v2




v = squareUpLastDimension(xds.VISIBILITY.values)
# And this is the payoff I guess
v2 = gjc.j_a1_composed @ v @ gjc.j_a2_composed

# This works, although we should do better along the polarization axis.
xds.assign({'FROBBED' : xa.DataArray( coords=(xds.time, xds.baseline_id, xds.frequency, pols_ant, pols_ant), data=v2)})

# Isn't there meant to be a nice way to get spw now?
#>>> xds.partition_info['spectral_window_name']

# We can also now do this at ps level:
ps2 = ps.sel(spw_name='spw_0')
29 changes: 29 additions & 0 deletions src/astroviper/calibration/exercise_fringe_sbd2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from xradio.vis.read_processing_set import read_processing_set
from graphviper.graph_tools.coordinate_utils import (interpolate_data_coords_onto_parallel_coords,
make_parallel_coord)
from graphviper.graph_tools.generate_dask_workflow import generate_dask_workflow
from graphviper.graph_tools.coordinate_utils import make_time_coord
from graphviper.graph_tools.coordinate_utils import make_frequency_coord

from astroviper.calibration.fringefit import fringefit_single
import dask
import xarray as xa

ps = read_processing_set('n14c3.zarr')
ps.keys()

xds = ps['n14c3_000']

#
meas = make_time_coord(time_start='2014-10-22 13:18:00', time_delta=120, n_samples=2)
parallel_coords = {}
parallel_coords['baseline_id'] = make_parallel_coord(
coord=xds.baseline_id, n_chunks=1)
parallel_coords['time'] = make_parallel_coord(meas, n_chunks=1)
node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(parallel_coords,
ps, ps_partition=['spectral_window_name'])
subsel = {'polarization': 'LL'}
res = fringefit_single(ps, node_task_data_mapping, subsel)

# print(res)

144 changes: 144 additions & 0 deletions src/astroviper/calibration/fringefit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import numpy as np
import dask
import xarray as xr
from graphviper.graph_tools.map import map
from graphviper.graph_tools.reduce import reduce
from graphviper.graph_tools.generate_dask_workflow import generate_dask_workflow
from typing import Dict, Union

from xradio.vis.read_processing_set import read_processing_set

import dask
import numpy as np
import xarray as xa
import pandas as pd
import datetime

## I should figure out where this belongs at some point
def unique(s):
"Get the unique characters in a string in the order they occur in the original"
u = []
for c in s:
if c not in u:
u.append(c)
return u




def getFourierSpacings(xds):
f = xds.frequency.values
df = (f[-1] - f[0])/(len(f)-1)
dF = len(f)*df
ddelay = 1/dF
#
t = xds.time.values
dt = (t[-1] - t[0])/(len(t)-1)
dT = len(t) * dt
drate = 1/dT
return (ddelay, drate)

def makeCalArray(xds, ref_ant):
pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)]))
quantumCoords = xa.Coordinates(coords={'antenna_name' : xds.antenna_xds.antenna_name,
'polarization' : pols_ant,
'parameter' : range(3)
})
q = xa.DataArray(coords=quantumCoords)
ref_freq = xds.frequency.reference_frequency['data']
# Should we choose this reference time?
ref_time = xds.time[0]
q.attrs['reference_frequency'] = ref_freq
q.attrs['reference_time'] = ref_time
q.attrs['reference_antenna'] = ref_ant
return q

def _fringe_node_task(input_params: Dict):
ps = input_params['ps']
data_selection = input_params['data_selection']
ref_ant = input_params['ref_ant']
# FIXME: for now we do single band
if len(data_selection.keys()) > 1:
print(f'{data_selection.keys()=}')
raise RuntimeError("We only do single xdses so far")
name = list(data_selection.keys())[0]
xds = ps[name]
q = makeCalArray(xds, ref_ant)
data_sub_selection = input_params['data_sub_selection']
pols = data_sub_selection['polarization']
# FIXME!
pol = pols[0]
xds2 = xds.isel(**data_selection[name])
xds2 = xds2.sel(polarization=pols)
ddelay, drate = getFourierSpacings(xds2)
vis = xds2.VISIBILITY
ang = np.angle(vis)
nvis = np.exp(1J*ang)
# Zero the NaNs
nvis = np.where(np.isnan(vis), 0, nvis)
fftvis = np.fft.fftshift(
np.fft.fft2(
nvis,
axes=(0,2)
),
axes=(0,2)
)
bl_slice = data_selection[name]["baseline_id"]
baselines = xds2.baseline_id[bl_slice].values
ant1s = xds2.baseline_antenna1_name.values
ant2s = xds2.baseline_antenna2_name.values
try:
for i, (bl, ant1, ant2) in enumerate(zip(baselines, ant1s, ant2s)):
if ref_ant not in [ant1, ant2]:
# print(f"Skipping {ant1}-{ant2}")
continue
if ref_ant == ant1 and ref_ant==ant2:
print("Skipping autos")
# print(f"{ant1}-{ant2}")
ant = ant1 if (ant2 == ref_ant) else ant2
spw = xds.partition_info['spectral_window_name']
t = xds.time[0].values
print(f"{ant} {spw} {t}")
a = np.abs(fftvis[:, i, :])
ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
# breakpoint()
ix, iy = ind
phi0 = np.angle(a[ind])
delay = ix*ddelay
ref_freq = xds.frequency.reference_frequency['data']
rate = iy*drate/ref_freq
q.loc[dict(antenna_name=ant, polarization=pol)] = [phi0, delay, rate]
except IndexError as e:
print(f'{xds2.baseline_antenna1_name.values}\n{baselines=}')
raise e
return q

def _fringefit_reduce(graph_inputs: xr.Dataset, input_params: Dict):
merged = {}
for e in graph_inputs:
[t] = e.keys()
rhs = e[t]
if t in merged:
merged[t].update(rhs)
else:
merged[t] = rhs
return merged


def fringefit_single(ps, node_task_data_mapping: Dict, sub_selection: Dict, ref_ant: int):
"""
TODO!
"""
input_params = {}
input_params['data_sub_selection'] = sub_selection
input_params['ps'] = ps
input_params['ref_ant'] = ref_ant
graph = map(
input_data = ps,
node_task_data_mapping = node_task_data_mapping,
node_task = _fringe_node_task,
input_params = input_params,
in_memory_compute=False)
dask_graph = generate_dask_workflow(graph)
res = dask.compute(dask_graph)
return res
Loading