diff --git a/ext/BrillouinPlotlyJSExt/dispersion.jl b/ext/BrillouinPlotlyJSExt/dispersion.jl index dbb1e0a..400a810 100644 --- a/ext/BrillouinPlotlyJSExt/dispersion.jl +++ b/ext/BrillouinPlotlyJSExt/dispersion.jl @@ -18,6 +18,77 @@ const DEFAULT_PLOTLY_LAYOUT_DISPERSION = Layout( plot_bgcolor=TRANSPARENT_COL[], paper_bgcolor=TRANSPARENT_COL[], ) + +function plot(kpi::KPathInterpolant, traces::Vector{<:AbstractTrace}, + layout::Layout = Layout(); + ylims = nothing, ylabel = nothing, title = nothing, + config::PlotConfig = PlotConfig(responsive=true, displaylogo=false)) + + # merge (and implicitly copy) `layout` (copy ensures we can mutate `layout` without + # corrupting user input) + layout = merge(DEFAULT_PLOTLY_LAYOUT_DISPERSION, layout) + + # set default y-limits in layout, if not already set + yaxis = get!(layout.fields, :yaxis, attr()) + if isnothing(ylims) + if !haskey(yaxis, :range) + yaxis[:range] = ylims + else + ylims = yaxis[:range] # grab what was already in `layout` + end + else + # overwrite if ylims was provided, regardless of what it is in `layout` + yaxis[:range] = ylims + end + yaxis[:title] = ylabel + + # add title, if requested + if !isnothing(title) + if title isa String + layout[:title] = attr(text=title) + else + layout[:title] = title + end + end + + # prepare to plot band diagram + Npaths = length(kpi.kpaths) + local_xs = cumdists.(cartesianize(kpi).kpaths) + local_xs_lengths = last.(local_xs) + total_xs_lengths = sum(local_xs_lengths) + spacing = total_xs_lengths / 30 + rel_xs_lengths = local_xs_lengths./(total_xs_lengths+spacing*(Npaths-1)) + rel_spacing = spacing/(total_xs_lengths+spacing*(Npaths-1)) + + # plot k-lines/labels + xticks = [Vector{Float64}(undef, length(labels)) for labels in kpi.labels] + xlabs = [Vector{Symbol}(undef, length(labels)) for labels in kpi.labels] + domain_start = 0.0 # subplot domain "start" point + for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels)) + # define xticks + for (lab_idx, (x_idx, lab)) in enumerate(labels) + xticks[path_idx][lab_idx] = local_x[x_idx] + xlabs[path_idx][lab_idx] = lab + end + + # set subplot sizes and local xticks & xrange + sym_xaxis = Symbol("xaxis$path_idx") # subplot xaxis name + + layout[sym_xaxis] = copy(get(layout, :xaxis, attr())) + layout[sym_xaxis][:range] = [extrema(local_x)...] + layout[sym_xaxis][:tickvals] = xticks[path_idx] + layout[sym_xaxis][:ticktext] = xlabs[path_idx] + + domain_end = domain_start + rel_xs_lengths[path_idx] + layout[Symbol(sym_xaxis, "_domain")] = [domain_start, domain_end] + domain_start = domain_end + rel_spacing + end + delete!(layout.fields, :xaxis) # get rid of unused xaxis in layout; causes artifacts... + + return plot(traces, layout; config=config) +end + + """ plot(kpi::KPathInterpolant, bands, [layout]; kwargs...) @@ -63,73 +134,33 @@ Alternatively, some simple settings can be set directly via keyword arguments (s """ function plot(kpi::KPathInterpolant, bands, layout::Layout = Layout(); - ylims = nothing, ylabel = "Energy", title = nothing, + ylims = default_dispersion_ylims(bands), ylabel = "Energy", band_highlights::Union{Dict, Nothing} = nothing, annotations::Union{Dict, Nothing} = nothing, - config::PlotConfig = PlotConfig(responsive=true, displaylogo=false)) + kwargs...) # check input N = length(kpi) if !all(band -> length(band) == N, bands) throw(DimensionMismatch("mismatched dimensions of `kpi` and `bands`")) end - # merge (and implicitly copy) `layout` (copy ensures we can mutate `layout` without - # corrupting user input) - layout = merge(DEFAULT_PLOTLY_LAYOUT_DISPERSION, layout) - - # set default y-limits in layout, if not already set - yaxis = get!(layout.fields, :yaxis, attr()) - if isnothing(ylims) - if !haskey(yaxis, :range) - ylims = default_dispersion_ylims(bands) - yaxis[:range] = ylims - else - ylims = yaxis[:range] # grab what was already in `layout` - end - else - # overwrite if ylims was provided, regardless of what it is in `layout` - yaxis[:range] = ylims - end - yaxis[:title] = ylabel - - # add title, if requested - if !isnothing(title) - if title isa String - layout[:title] = attr(text=title) - else - layout[:title] = title - end - end # prepare to plot band diagram - Npaths = length(kpi.kpaths) local_xs = cumdists.(cartesianize(kpi).kpaths) - local_xs_lengths = last.(local_xs) - total_xs_lengths = sum(local_xs_lengths) - spacing = total_xs_lengths / 30 - rel_xs_lengths = local_xs_lengths./(total_xs_lengths+spacing*(Npaths-1)) - rel_spacing = spacing/(total_xs_lengths+spacing*(Npaths-1)) # plot bands and k-lines/labels tbands = Vector{GenericTrace{Dict{Symbol,Any}}}() - xticks = [Vector{Float64}(undef, length(labels)) for labels in kpi.labels] - xlabs = [Vector{Symbol}(undef, length(labels)) for labels in kpi.labels] start_idx = 1 - domain_start = 0.0 # subplot domain "start" point for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels)) stop_idx = start_idx+length(local_x)-1 # plot bands for (i, band) in enumerate(bands) - line = something(_get_value_if_in_ranges(band_highlights, i), - attr(color=BAND_COL[], width=3)) # default + line = something(_get_value_if_in_ranges(band_highlights, i), + attr(color=BAND_COL[], width=3)) # default push!(tbands, PlotlyJS.scatter(x=local_x, y=band[start_idx:stop_idx], hoverinfo="y", mode="lines", line=line, xaxis="x$path_idx", yaxis="y")) end - # define xticks - for (lab_idx, (x_idx, lab)) in enumerate(labels) - xticks[path_idx][lab_idx] = local_x[x_idx] - xlabs[path_idx][lab_idx] = lab - end + # place any high-symmetry point annotations if annotations !== nothing for (lab, positions_and_strs) in annotations @@ -157,25 +188,12 @@ function plot(kpi::KPathInterpolant, bands, end end end - - # set subplot sizes and local xticks & xrange - sym_xaxis = Symbol("xaxis$path_idx") # subplot xaxis name - - layout[sym_xaxis] = copy(get(layout, :xaxis, attr())) - layout[sym_xaxis][:range] = [extrema(local_x)...] - layout[sym_xaxis][:tickvals] = xticks[path_idx] - layout[sym_xaxis][:ticktext] = xlabs[path_idx] - - domain_end = domain_start + rel_xs_lengths[path_idx] - layout[Symbol(sym_xaxis, "_domain")] = [domain_start, domain_end] - domain_start = domain_end + rel_spacing # prepare for next iteration start_idx = stop_idx + 1 end - delete!(layout.fields, :xaxis) # get rid of unused xaxis in layout; causes artifacts... - return plot(tbands, layout; config=config) + return plot(kpi, tbands, layout; ylims=ylims, ylabel=ylabel, kwargs...) end # `bands` can also be supplied as a matrix (w/ distinct bands in distinct columns) function plot(kpi::KPathInterpolant, bands::AbstractMatrix{<:Real}, @@ -188,7 +206,7 @@ function plot(kpi::KPathInterpolant, bands::AbstractMatrix{<:Real}, end function default_dispersion_ylims(bands) - ylims = [mapfoldl(minimum, min, bands, init=Inf), + ylims = [mapfoldl(minimum, min, bands, init=Inf), mapfoldl(maximum, max, bands, init=-Inf)] δ = (ylims[2]-ylims[1])/30 if isapprox(ylims[1], 0, atol=1e-6) @@ -205,4 +223,63 @@ function _get_value_if_in_ranges(d::Dict, i::Integer) end return nothing end -_get_value_if_in_ranges(::Nothing, ::Integer) = nothing \ No newline at end of file +_get_value_if_in_ranges(::Nothing, ::Integer) = nothing + +# ---------------------------------------------------------------------------------------- # + +""" + plot(kpi::KPathInterpolant, ωs, fields, [layout]; kwargs...) + +Plot a dispersion heatmap for provided `fields` and **k**-path interpolant `kpi`. + +`fields` must be an iterable of real matrices (e.g., a `Vector{Matrix{Float64}}`), +with the first iteration running over distinct fields to overlay. +Note that the size of each iterant of `fields` must equal `(length(kpi), length(ωs))`. + +## Keyword arguments `kwargs` + +- `ylims`: y-axis limits (default: `extrema(ωs)`) + +- `ylabel`: y-axis label (default: "Energy") + +- `title`: plot title (default: `nothing`); can be a `String` or an `attr` dictionary of + PlotlyJS properties + +- `opacity`: transparency of colormap (default: `1`); useful to set a value less than one + when overlaying multiple `fields` + +- `colorscale`: an iteratable of PlotlyJS color scales (default: ["YlGnBu"]) + +- `reversescale`: boolean that reverses color scale (default: false) +""" +function plot(kpi::KPathInterpolant, ωs, fields, + layout::Layout = Layout(); + ylims = extrema(ωs), ylabel = "Energy", + opacity=1, colorscale=["YlGnBu"], reversescale=false, + kwargs...) + # check input + N = length(kpi); M = length(ωs) + if !all(field -> size(field) == (N,M), fields) + throw(DimensionMismatch("mismatched dimensions of `kpi` with `ωs` and `fields`")) + end + + # prepare to plot band diagram + local_xs = cumdists.(cartesianize(kpi).kpaths) + + # plot bands and k-lines/labels + heatmaps = Vector{GenericTrace{Dict{Symbol,Any}}}() + start_idx = 1 + for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels)) + stop_idx = start_idx+length(local_x)-1 + # plot fields + for (i, field) in enumerate(fields) + push!(heatmaps, + PlotlyJS.heatmap(x=local_x, y=ωs, z=transpose(field[start_idx:stop_idx,:]), + xaxis="x$path_idx", yaxis="y", opacity=opacity, + colorscale=colorscale[mod(i-1,length(colorscale))+1], reversescale=reversescale)) + end + start_idx = stop_idx + 1 + end + + plot(kpi, heatmaps, layout; ylims=ylims, ylabel=ylabel, kwargs...) +end