Skip to content

Commit

Permalink
Results, removing y-ticks
Browse files Browse the repository at this point in the history
  • Loading branch information
Anirudh Lakra committed Mar 31, 2024
1 parent 18f85e3 commit cdd9540
Show file tree
Hide file tree
Showing 5 changed files with 2,673 additions and 17 deletions.
2,615 changes: 2,615 additions & 0 deletions kymata/ippm/Results.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion kymata/ippm/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def build_graph(self,
# While doing this, if we encounter empty pairings for f or any parent, we skip them.

if f in inputs:
# input node default size is 10.
# input node default size is 100.
hier.pop(f)
graph[f] = Node(100, (0, 1 - n_partitions * part_idx), [])

Expand Down
38 changes: 34 additions & 4 deletions kymata/ippm/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def stem_plot(

plt.show()

def causality_violation_score(denoised_hexels: Dict[str, IPPMHexel], hierarchy: Dict[str, List[str]], hemi: str):
def causality_violation_score(denoised_hexels: Dict[str, IPPMHexel], hierarchy: Dict[str, List[str]], hemi: str, inputs: List[str]) -> Tuple[float, int, int]:
"""
Assumption: hexels are denoised. Otherwise, it doesn't really make sense to check the min/max latency of noisy hexels.
Expand Down Expand Up @@ -249,14 +249,44 @@ def get_latency(func_hexels: IPPMHexel, mini: bool):
for func, inc_edges in hierarchy.items():
# essentially: if max(parent_spikes_latency) > min(child_spikes_latency), there will be a backwards arrow in time.
# arrows go from latest inc_edge spike to the earliest func spike
child_latency = get_latency(denoised_hexels[func], mini=True)

if func in inputs:
continue

if hemi == 'leftHemisphere':
if len(denoised_hexels[func].left_best_pairings) == 0:
continue
else:
if len(denoised_hexels[func].right_best_pairings) == 0:
continue

child_latency = get_latency(denoised_hexels[func], mini=True)[0]
for inc_edge in inc_edges:
parent_latency = get_latency(denoised_hexels[inc_edge], mini=False)
if inc_edge in inputs:
# input node, so parent latency is 0
parent_latency = 0
if child_latency < parent_latency:
causality_violations += 1
total_arrows += 1
continue

# We need to ensure the function has significant spikes
if hemi == 'leftHemisphere':
if len(denoised_hexels[inc_edge].left_best_pairings) == 0:
continue
else:
if len(denoised_hexels[inc_edge].right_best_pairings) == 0:
continue

parent_latency = get_latency(denoised_hexels[inc_edge], mini=False)[0]
if child_latency < parent_latency:
causality_violations += 1
total_arrows += 1

return causality_violations / total_arrows if total_arrows != 0 else 0
return (
causality_violations / total_arrows if total_arrows != 0 else 0,
causality_violations,
total_arrows)

def function_recall(noisy_hexels: Dict[str, IPPMHexel], funcs: List[str], ippm_dict: Dict[str, Node], hemi: str) -> Tuple[float]:
"""
Expand Down
7 changes: 5 additions & 2 deletions kymata/ippm/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from sklearn.cluster import DBSCAN as DBSCAN_, MeanShift as MeanShift_
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import normalize
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning

#import multiprocessing

Expand Down Expand Up @@ -679,7 +681,8 @@ def __init__(self, **kwargs):

if invalid:
raise ValueError


@ignore_warnings(category=ConvergenceWarning)
def cluster(
self, hexels: Dict[str, IPPMHexel], hemi: str, normalise: bool = False,
cluster_latency: bool = False, posterior_pooling: bool = False
Expand Down Expand Up @@ -708,7 +711,7 @@ def cluster(

if len(df) == 1:
# no point clustering, just return the single data point.
ret = [(df.iloc[0, 'Latency'], df.iloc[0, 'Mag'])]
ret = [(df.iloc[0, 0], df.iloc[0, 1])]
hexels = super()._update_pairings(hexels, func, ret, hemi)
continue

Expand Down
28 changes: 18 additions & 10 deletions kymata/ippm/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def draw(self,
width
"""
# first lets aggregate all of the information.
# TODO: refactor to generate BSplines in the first loop, so we dont have to loop again.
hexel_x = [_ for _ in range(len(graph.keys()))] # x coordinates for nodes e.g., (x, y) = (hexel_x[i], hexel_y[i])
hexel_y = [_ for _ in range(len(graph.keys()))] # y coordinates for nodes
node_colors = [_ for _ in range(len(graph.keys()))] # color for nodes
Expand Down Expand Up @@ -61,25 +60,34 @@ def draw(self,
bsplines += self._make_bspline_paths(pairs)

fig, ax = plt.subplots()
fig.set_figheight(figheight)
fig.set_figwidth(figwidth)
plt.axis('on')

for path, color in zip(bsplines, edge_colors):
ax.plot(path[0], path[1], color=color, linewidth='3', zorder=-1)
ax.scatter(x=hexel_x, y=hexel_y, c=node_colors, s=node_sizes, zorder=1)
ax.tick_params(bottom=True, labelbottom=True, left=False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.set_xlabel('Latency (ms)')

ax.scatter(x=hexel_x, y=hexel_y, c=node_colors, s=node_sizes, zorder=1)

legend = []
for f in colors.keys():
legend.append(Line2D([0], [0], marker='o', color='w', label=f, markerfacecolor=colors[f], markersize=15))

plt.legend(handles=legend, loc='upper left')
plt.title(title)

# for some unknown reason, I can set the x-axis ticks to be off but if I try with
# the y-axis it ignores the changes :D
ax.set_ylim(min(hexel_y) - 0.1, max(hexel_y) + 0.1)
ax.set_yticklabels([])
ax.yaxis.set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.set_xlabel('Latency (ms)')

fig.set_figheight(figheight)
fig.set_figwidth(figwidth)

plt.show()

def _make_bspline_paths(self, hexel_coordinate_pairs: List[List[Tuple[float, float]]]) -> List[List[np.array]]:
"""
Given a list of hexel positions pairs, return a list of
Expand Down

0 comments on commit cdd9540

Please sign in to comment.