Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added up code to detect AWS bad workers #53

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MARS_Developer
Submodule MARS_Developer added at 08d8ef
4 changes: 2 additions & 2 deletions MARS_behavior_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "mars",
"display_name": "Python [conda env:mars_dev] *",
"language": "python",
"name": "mars"
"name": "conda-env-mars_dev-py"
},
"language_info": {
"codemirror_mode": {
Expand Down
152 changes: 86 additions & 66 deletions MARS_pose_tutorial.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion create_new_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create_new_project(location, name, download_MARS_checkpoints=True, download_

if download_MARS_checkpoints:
ckpts_name = 'MARS_v1_8_models'
ckpts_id = '1NyAuwI6iQdMgRB2w4zX44yFAgEkux4op'
ckpts_id = '1NyAuwI6iQdMgRB2w4zX44yFAgEkux4op&confirm=t&uuid=3e6e6742-0b74-4de4-9058-f2f94a8aa7e5&at=ALgDtszLJO_HlgzepoZYqu-Tfdc6:1674927765818'
# names of the models we want to unpack:
search_keys = ['detect*black*', 'detect*white*', 'detect*resnet*', 'pose*']
# where we're unpacking them to:
Expand Down
4 changes: 2 additions & 2 deletions pose_annotation_tools/Submit_Labeling_Job.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "dlc",
"display_name": "Python [conda env:mars_dev] *",
"language": "python",
"name": "dlc"
"name": "conda-env-mars_dev-py"
},
"language_info": {
"codemirror_mode": {
Expand Down
8 changes: 4 additions & 4 deletions pose_annotation_tools/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def compute_human_PCK(project, animal_names=None, xlim=None, pixel_units=False):
Y = np.array(frame['ann_' + animal]['Y']) * D[0]['height']
trial_dists = []
for i, [pX, pY] in enumerate(zip(X, Y)):
mX = np.median(np.delete(X, i, axis=0), axis=0)
mY = np.median(np.delete(Y, i, axis=0), axis=0)
mX = np.nanmedian(np.delete(X, i, axis=0), axis=0)
mY = np.nanmedian(np.delete(Y, i, axis=0), axis=0)
trial_dists.append(np.sqrt(np.square(mX - pX) + np.square(mY - pY)))
trial_dists = np.array(trial_dists)

dMean[:, fr] = np.mean(trial_dists, axis=0)
dMedian[:, fr] = np.median(trial_dists, axis=0)
dMean[:, fr] = np.nanmean(trial_dists, axis=0)
dMedian[:, fr] = np.nanmedian(trial_dists, axis=0)
dMin[:, fr] = np.min(trial_dists, axis=0)
dMax[:, fr] = np.max(trial_dists, axis=0)

Expand Down
119 changes: 108 additions & 11 deletions pose_annotation_tools/json_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,102 @@ def apply_flip_correction(frame, meds, keypoints, pair):
if (d1[0] > d1[1]) and (d2[1] > d2[0]):
frame[[i1, i2], :, w] = frame[[i2, i1], :, w]

# re-compute the medians:
meds = np.median(frame, axis=2)
# re-compute the medians:
meds = np.nanmedian(frame, axis=2)

return frame, meds

def create_mask_bad_wks(data, workers2del, nKpts, nWorkers):
null_mask_arr = np.ones((2, nWorkers))

workers_list = data['annotatedResult']['annotationsFromAllWorkers']
valid_idx = [i for i, x in enumerate([w['workerId'] for w in workers_list])
if x not in workers2del]

for i in valid_idx:
null_mask_arr[:,i] = np.zeros((null_mask_arr.shape[0],1)).ravel()

null_mask_arr = np.repeat(null_mask_arr[np.newaxis, :, :], nKpts, axis=0)

return null_mask_arr

def hold_one_out_dist(fr_df, frame_num):
cols = ['x', 'y']
all_wkrs_dist = pd.DataFrame()
for wk in fr_df['worker ID'].unique():
df_worker = fr_df[fr_df['worker ID'].eq(wk)] \
.loc[:,'x':'color'].set_index(['color']).sort_index()

dist_df = fr_df[~fr_df['worker ID'].eq(wk)] \
.loc[:,'x':'color'] \
.groupby(['color']) \
.median().sort_index() \
.assign(
distance = lambda df_:
np.linalg.norm(df_[cols] - df_worker[cols], axis=1)) \
.loc[:,'distance'].reset_index() \
.assign(
worker = [wk for i in range(2)],
frame = [frame_num for i in range(2)])
all_wkrs_dist = pd.concat([dist_df, all_wkrs_dist], ignore_index=True)

return all_wkrs_dist

def detect_bad_workers(project, manifest_file, perc_dist = 0.75, per_del = 0.15):

f = open(os.path.join(project, 'annotation_data', manifest_file), 'r')
st = f.read()
st = "[\n" + st + "\n]"
st = st.replace("\\","").replace("\"{","{") \
.replace("}\"}","}}").replace("{\"source",",{\"source") \
.replace(',', '', 1)
while True:
try:
output_manifest = json.loads(st)
break
except Exception as e:
print(str(e))

f.close()

frames_dist_df = pd.DataFrame()
for fr_id, data in enumerate(output_manifest):

frame = data['annotatedResult']['annotationsFromAllWorkers']
fr_df = pd.DataFrame()

for wk_id, worker_content in enumerate(frame):
worker = worker_content['annotationData']['content']['annotatedResult']
k_pts = worker['keypoints']
pos_blk_nose = [(k_pts[i]['x'],k_pts[i]['y']) for i, x in enumerate (k_pts) if x['label'] == 'black mouse nose']
pos_wht_nose = [(k_pts[i]['x'],k_pts[i]['y']) for i, x in enumerate (k_pts) if x['label'] == 'white mouse nose']

wht_info = pd.DataFrame([{'x':pos_wht_nose[0][0],
'y':pos_wht_nose[0][1],
'color':'white',
'worker ID': frame[wk_id]['workerId']}])
blk_info = pd.DataFrame([{'x':pos_blk_nose[0][0],
'y':pos_blk_nose[0][1],
'color':'black',
'worker ID': frame[wk_id]['workerId']}])

fr_df = pd.concat([fr_df, wht_info], ignore_index=True)
fr_df = pd.concat([fr_df, blk_info], ignore_index=True)

frames_dist_df = pd.concat([frames_dist_df, hold_one_out_dist(fr_df, fr_id)], ignore_index=True)


workers_total_annotations = frames_dist_df.worker.value_counts()
thr_dist = frames_dist_df['distance'].quantile(perc_dist)
bad_workers_list = frames_dist_df[frames_dist_df['distance']>thr_dist].groupby(['worker']) \
.size() \
.div(workers_total_annotations) \
.loc[lambda df: df > per_del] \
.index.tolist()

return bad_workers_list



def manifest_to_dict(project):
"""
Expand Down Expand Up @@ -70,6 +161,10 @@ def manifest_to_dict(project):
if verbose:
print('Processing manifest file...')
print(len(data))

print('Detecting bad workers')
bad_workers_list = detect_bad_workers(project, manifest_file, thr_dist = 300, per_del = 0.15)

for f, sample in enumerate(data):
if f and not f % 1000 and verbose:
print(' frame '+str(f))
Expand Down Expand Up @@ -107,11 +202,13 @@ def manifest_to_dict(project):

rawPts[animal][part, 0, w] = pt['x']/im.shape[1]
rawPts[animal][part, 1, w] = pt['y']/im.shape[0]


mask_bad_wks = create_mask_bad_wks(sample, bad_workers_list, nKpts, nWorkers[f])
for animal in animal_names:

rawPts[animal] = np.ma.array(rawPts[animal], mask = mask_bad_wks).filled(np.nan)

if check_pairs: # adjust L/R assignments to try to find better median estimates.
meds = np.median(rawPts[animal], axis=2)
meds = np.nanmedian(rawPts[animal], axis=2)
for pair in check_pairs:
rawPts[animal], meds = apply_flip_correction(rawPts[animal], meds, keypoint_names, pair)

Expand Down Expand Up @@ -220,14 +317,14 @@ def make_animal_dict(pts, im_shape):
X = pts[:, 0, :].T
Y = pts[:, 1, :].T

mX = np.median(X, axis=0)
mY = np.median(Y, axis=0)
mX = np.nanmedian(X, axis=0)
mY = np.nanmedian(Y, axis=0)

muX = np.mean(X, axis=0)
muY = np.mean(Y, axis=0)
muX = np.nanmean(X, axis=0)
muY = np.nanmean(Y, axis=0)

stdX = np.std(Y, axis=0)
stdY = np.std(Y, axis=0)
stdX = np.nanstd(Y, axis=0)
stdY = np.nanstd(Y, axis=0)

# bounding box
Bxmin = min(mX)
Expand Down
2 changes: 2 additions & 0 deletions pose_annotation_tools/restore_images_from_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

def restore(tfrecords_filenames, output_path):
# this is a script to extract images from tfrecord files, for sanity-checking.
if not isinstance(tfrecords_filenames, list):
tfrecords_filenames = [tfrecords_filenames]
f = tfrecords_filenames
totalFiles = 0

Expand Down