Skip to content

Commit

Permalink
add option to plot probabilities with stacked bars
Browse files Browse the repository at this point in the history
  • Loading branch information
bjarthur committed Sep 30, 2024
1 parent 4edeef7 commit 00d23d6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
1 change: 1 addition & 0 deletions configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
gui_spectrogram_low_hz=0
gui_spectrogram_high_hz=1250
gui_spectrogram_clip=[1,99]
gui_probability_style="lines" # either "lines" or "bars"

# neural network architecture to use
architecture_plugin="convolutional"
Expand Down
3 changes: 3 additions & 0 deletions src/gui/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def init(_bokeh_document, _configuration_file, _use_aitch):
global time_scale, freq_scale, context_time_scale, context_freq_scale
global context_waveform_low, context_waveform_high, label_colors
global spectrogram_colormap, spectrogram_clip, spectrogram_window, spectrogram_length_sec, spectrogram_overlap, spectrogram_low_hz, spectrogram_high_hz
global probability_style
global overlapped_prefix
global deterministic
global context_width_sec0, context_offset_sec0
Expand Down Expand Up @@ -376,6 +377,8 @@ def is_local_server_or_cluster(varname, varvalue):
context_freq_units = gui_context_freq_units
context_freq_scale = gui_context_freq_scale

probability_style=gui_probability_style

spectrogram_clip=gui_spectrogram_clip
spectrogram_colormap=gui_spectrogram_colormap
spectrogram_window=gui_spectrogram_window
Expand Down
58 changes: 46 additions & 12 deletions src/gui/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import sys
from bokeh.models.widgets import RadioButtonGroup, TextInput, Button, Div, DateFormatter, TextAreaInput, Select, NumberFormatter, Slider, Toggle, ColorPicker, MultiSelect, Paragraph
from bokeh.models.formatters import FuncTickFormatter
from bokeh.models import ColumnDataSource, TableColumn, DataTable, LayoutDOM, Span
from bokeh.models import ColumnDataSource, TableColumn, DataTable, LayoutDOM, Span, HoverTool
from bokeh.plotting import figure
from bokeh.transform import linear_cmap
from bokeh.transform import linear_cmap, stack
from bokeh.events import Tap, DoubleTap, PanStart, Pan, PanEnd, ButtonClick, MouseWheel
from bokeh.models.callbacks import CustomJS
from bokeh.models.markers import Circle
Expand Down Expand Up @@ -1130,9 +1130,19 @@ def context_update():
if xwav and not np.isnan(xwav[0][-1]):
spectrogram_range_source.data.update(x=[xwav[0][-1]])

probability_source.data.update(xs=xprob, ys=yprob,
colors=[M.label_colors[x] for x in M.used_labels],
labels=list(M.used_labels))
if M.probability_style=="lines":
probability_source.data.update(xs=xprob, ys=yprob,
colors=[M.label_colors[x] for x in M.used_labels],
labels=list(M.used_labels))
else:
if len(xprob)>1:
probability_source.data.update(**{'x'+str(i):(xprob[i] if i<len(xprob) else xprob[0])
for i in range(M.nlabels)},
**{'y'+str(i):(yprob[i] if i<len(yprob) else [0]*len(yprob[0]))
for i in range(M.nlabels)})
if len(xprob[0])>1:
for g in probability_glyphs:
g.glyph.width = xprob[0][1] - xprob[0][0]

if M.context_waveform:
waveform_quad_grey_used.data.update(left=left_used,
Expand Down Expand Up @@ -1248,6 +1258,9 @@ def recordings_update():
recordings.options = []

M.used_labels = set([x['label'] for x in M.used_sounds]) if M.used_sounds else []
if M.probability_style=="bars":
for ilabel,label in enumerate(M.used_labels):
probability_glyphs[ilabel].name = label
M.label_colors = { l:c for l,c in zip(M.used_labels, cycle(label_palette)) }
M.isnippet = -1
M.context_sound = None
Expand Down Expand Up @@ -1519,7 +1532,7 @@ def init(_bokeh_document):
global p_snippets, snippet_palette, snippets_dy, snippets_both, snippets_label_sources_clustered, snippets_label_sources_annotated, snippets_wave_sources, snippets_wave_glyphs, snippets_gram_sources, snippets_gram_glyphs, snippets_quad_grey, snippets_quad_fuchsia
global p_waveform, waveform_span_red, waveform_quad_grey_used, waveform_quad_grey_annotated, waveform_quad_grey_pan, waveform_quad_fuchsia, waveform_source, waveform_glyph, waveform_label_source_used, waveform_label_source_annotated
global p_spectrogram, spectrogram_span_red, spectrogram_quad_grey_used, spectrogram_quad_grey_annotated, spectrogram_quad_grey_pan, spectrogram_quad_fuchsia, spectrogram_source, spectrogram_glyph, spectrogram_label_source_used, spectrogram_label_source_annotated, spectrogram_range_source, spectrogram_length
global p_probability, probability_span_red, probability_source, probability_glyph
global p_probability, probability_span_red, probability_source
global which_layer, which_species, which_word, which_nohyphen, which_kind
global color_picker
global zoom_width, zoom_offset, zoomin, zoomout, reset, panleft, panright, allleft, allout, allright, firstlabel, nextlabel, prevlabel, lastlabel
Expand Down Expand Up @@ -1750,9 +1763,12 @@ def init(_bokeh_document):
text_font_size='6pt', text_align='center', text_baseline='bottom',
text_line_height=0.8, level='underlay', text_color='white')

TOOLTIPS = """
<div><div><span style="color:@colors;">@labels</span></div></div>
"""
if M.probability_style=="lines":
TOOLTIPS = """
<div><div><span style="color:@colors;">@labels,$y</span></div></div>
"""
else:
TOOLTIPS = ""

p_probability = figure(plot_width=M.gui_width_pix, tooltips=TOOLTIPS,
plot_height=M.context_probability_height_pix,
Expand All @@ -1765,9 +1781,27 @@ def init(_bokeh_document):
p_probability.y_range.end = 1
p_probability.xaxis.visible = False

probability_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[], labels=[]))
probability_glyph = p_probability.multi_line(xs='xs', ys='ys',
source=probability_source, color='colors')
if M.probability_style=="lines":
probability_source = ColumnDataSource(data=dict(xs=[], ys=[], colors=[], labels=[]))
global probability_glyph
probability_glyph = p_probability.multi_line(xs='xs', ys='ys',
source=probability_source, color='colors')
else:
xs = {'x'+str(i):[] for i in range(M.nlabels)}
ys = {'y'+str(i):[] for i in range(M.nlabels)}
probability_source = ColumnDataSource(data=xs|ys)
global probability_glyphs
probability_glyphs = []
for i in range(M.nlabels):
probability_glyphs.append(p_probability.vbar(
x='x'+str(i),
bottom=stack(*['y'+str(j) for j in range(i)]),
top=stack(*['y'+str(j) for j in range(i+1)]),
line_width = 0,
fill_color=label_palette[i],
source=probability_source))
p_probability.add_tools(HoverTool(renderers=[probability_glyphs[i]],
tooltips=[("", "$name, @y"+str(i))]))

probability_span_red = Span(location=0, dimension='height', line_color='red')
p_probability.add_layout(probability_span_red)
Expand Down

0 comments on commit 00d23d6

Please sign in to comment.