Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kpath heatmaps with PlotlyJS #24

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 138 additions & 61 deletions ext/BrillouinPlotlyJSExt/dispersion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,77 @@ const DEFAULT_PLOTLY_LAYOUT_DISPERSION = Layout(
plot_bgcolor=TRANSPARENT_COL[], paper_bgcolor=TRANSPARENT_COL[],
)


function plot(kpi::KPathInterpolant, traces::Vector{<:AbstractTrace},
layout::Layout = Layout();
ylims = nothing, ylabel = nothing, title = nothing,
config::PlotConfig = PlotConfig(responsive=true, displaylogo=false))

# merge (and implicitly copy) `layout` (copy ensures we can mutate `layout` without
# corrupting user input)
layout = merge(DEFAULT_PLOTLY_LAYOUT_DISPERSION, layout)

# set default y-limits in layout, if not already set
yaxis = get!(layout.fields, :yaxis, attr())
if isnothing(ylims)
if !haskey(yaxis, :range)
yaxis[:range] = ylims
else
ylims = yaxis[:range] # grab what was already in `layout`
end
else
# overwrite if ylims was provided, regardless of what it is in `layout`
yaxis[:range] = ylims
end
yaxis[:title] = ylabel

# add title, if requested
if !isnothing(title)
if title isa String
layout[:title] = attr(text=title)
else
layout[:title] = title
end
end

# prepare to plot band diagram
Npaths = length(kpi.kpaths)
local_xs = cumdists.(cartesianize(kpi).kpaths)
local_xs_lengths = last.(local_xs)
total_xs_lengths = sum(local_xs_lengths)
spacing = total_xs_lengths / 30
rel_xs_lengths = local_xs_lengths./(total_xs_lengths+spacing*(Npaths-1))
rel_spacing = spacing/(total_xs_lengths+spacing*(Npaths-1))

# plot k-lines/labels
xticks = [Vector{Float64}(undef, length(labels)) for labels in kpi.labels]
xlabs = [Vector{Symbol}(undef, length(labels)) for labels in kpi.labels]
domain_start = 0.0 # subplot domain "start" point
for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels))
# define xticks
for (lab_idx, (x_idx, lab)) in enumerate(labels)
xticks[path_idx][lab_idx] = local_x[x_idx]
xlabs[path_idx][lab_idx] = lab
end

# set subplot sizes and local xticks & xrange
sym_xaxis = Symbol("xaxis$path_idx") # subplot xaxis name

layout[sym_xaxis] = copy(get(layout, :xaxis, attr()))
layout[sym_xaxis][:range] = [extrema(local_x)...]
layout[sym_xaxis][:tickvals] = xticks[path_idx]
layout[sym_xaxis][:ticktext] = xlabs[path_idx]

domain_end = domain_start + rel_xs_lengths[path_idx]
layout[Symbol(sym_xaxis, "_domain")] = [domain_start, domain_end]
domain_start = domain_end + rel_spacing
end
delete!(layout.fields, :xaxis) # get rid of unused xaxis in layout; causes artifacts...

return plot(traces, layout; config=config)
end


"""
plot(kpi::KPathInterpolant, bands, [layout]; kwargs...)

Expand Down Expand Up @@ -63,73 +134,33 @@ Alternatively, some simple settings can be set directly via keyword arguments (s
"""
function plot(kpi::KPathInterpolant, bands,
layout::Layout = Layout();
ylims = nothing, ylabel = "Energy", title = nothing,
ylims = default_dispersion_ylims(bands), ylabel = "Energy",
band_highlights::Union{Dict, Nothing} = nothing,
annotations::Union{Dict, Nothing} = nothing,
config::PlotConfig = PlotConfig(responsive=true, displaylogo=false))
kwargs...)
# check input
N = length(kpi)
if !all(band -> length(band) == N, bands)
throw(DimensionMismatch("mismatched dimensions of `kpi` and `bands`"))
end
# merge (and implicitly copy) `layout` (copy ensures we can mutate `layout` without
# corrupting user input)
layout = merge(DEFAULT_PLOTLY_LAYOUT_DISPERSION, layout)

# set default y-limits in layout, if not already set
yaxis = get!(layout.fields, :yaxis, attr())
if isnothing(ylims)
if !haskey(yaxis, :range)
ylims = default_dispersion_ylims(bands)
yaxis[:range] = ylims
else
ylims = yaxis[:range] # grab what was already in `layout`
end
else
# overwrite if ylims was provided, regardless of what it is in `layout`
yaxis[:range] = ylims
end
yaxis[:title] = ylabel

# add title, if requested
if !isnothing(title)
if title isa String
layout[:title] = attr(text=title)
else
layout[:title] = title
end
end

# prepare to plot band diagram
Npaths = length(kpi.kpaths)
local_xs = cumdists.(cartesianize(kpi).kpaths)
local_xs_lengths = last.(local_xs)
total_xs_lengths = sum(local_xs_lengths)
spacing = total_xs_lengths / 30
rel_xs_lengths = local_xs_lengths./(total_xs_lengths+spacing*(Npaths-1))
rel_spacing = spacing/(total_xs_lengths+spacing*(Npaths-1))

# plot bands and k-lines/labels
tbands = Vector{GenericTrace{Dict{Symbol,Any}}}()
xticks = [Vector{Float64}(undef, length(labels)) for labels in kpi.labels]
xlabs = [Vector{Symbol}(undef, length(labels)) for labels in kpi.labels]
start_idx = 1
domain_start = 0.0 # subplot domain "start" point
for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels))
stop_idx = start_idx+length(local_x)-1
# plot bands
for (i, band) in enumerate(bands)
line = something(_get_value_if_in_ranges(band_highlights, i),
attr(color=BAND_COL[], width=3)) # default
line = something(_get_value_if_in_ranges(band_highlights, i),
attr(color=BAND_COL[], width=3)) # default
push!(tbands,
PlotlyJS.scatter(x=local_x, y=band[start_idx:stop_idx],
hoverinfo="y", mode="lines", line=line, xaxis="x$path_idx", yaxis="y"))
end
# define xticks
for (lab_idx, (x_idx, lab)) in enumerate(labels)
xticks[path_idx][lab_idx] = local_x[x_idx]
xlabs[path_idx][lab_idx] = lab
end

# place any high-symmetry point annotations
if annotations !== nothing
for (lab, positions_and_strs) in annotations
Expand Down Expand Up @@ -157,25 +188,12 @@ function plot(kpi::KPathInterpolant, bands,
end
end
end

# set subplot sizes and local xticks & xrange
sym_xaxis = Symbol("xaxis$path_idx") # subplot xaxis name

layout[sym_xaxis] = copy(get(layout, :xaxis, attr()))
layout[sym_xaxis][:range] = [extrema(local_x)...]
layout[sym_xaxis][:tickvals] = xticks[path_idx]
layout[sym_xaxis][:ticktext] = xlabs[path_idx]

domain_end = domain_start + rel_xs_lengths[path_idx]
layout[Symbol(sym_xaxis, "_domain")] = [domain_start, domain_end]
domain_start = domain_end + rel_spacing

# prepare for next iteration
start_idx = stop_idx + 1
end
delete!(layout.fields, :xaxis) # get rid of unused xaxis in layout; causes artifacts...

return plot(tbands, layout; config=config)
return plot(kpi, tbands, layout; ylims=ylims, ylabel=ylabel, kwargs...)
end
# `bands` can also be supplied as a matrix (w/ distinct bands in distinct columns)
function plot(kpi::KPathInterpolant, bands::AbstractMatrix{<:Real},
Expand All @@ -188,7 +206,7 @@ function plot(kpi::KPathInterpolant, bands::AbstractMatrix{<:Real},
end

function default_dispersion_ylims(bands)
ylims = [mapfoldl(minimum, min, bands, init=Inf),
ylims = [mapfoldl(minimum, min, bands, init=Inf),
mapfoldl(maximum, max, bands, init=-Inf)]
δ = (ylims[2]-ylims[1])/30
if isapprox(ylims[1], 0, atol=1e-6)
Expand All @@ -205,4 +223,63 @@ function _get_value_if_in_ranges(d::Dict, i::Integer)
end
return nothing
end
_get_value_if_in_ranges(::Nothing, ::Integer) = nothing
_get_value_if_in_ranges(::Nothing, ::Integer) = nothing

# ---------------------------------------------------------------------------------------- #

"""
plot(kpi::KPathInterpolant, ωs, fields, [layout]; kwargs...)

Plot a dispersion heatmap for provided `fields` and **k**-path interpolant `kpi`.

`fields` must be an iterable of real matrices (e.g., a `Vector{Matrix{Float64}}`),
with the first iteration running over distinct fields to overlay.
Note that the size of each iterant of `fields` must equal `(length(kpi), length(ωs))`.

## Keyword arguments `kwargs`

- `ylims`: y-axis limits (default: `extrema(ωs)`)

- `ylabel`: y-axis label (default: "Energy")

- `title`: plot title (default: `nothing`); can be a `String` or an `attr` dictionary of
PlotlyJS properties

- `opacity`: transparency of colormap (default: `1`); useful to set a value less than one
when overlaying multiple `fields`

- `colorscale`: an iteratable of PlotlyJS color scales (default: ["YlGnBu"])

- `reversescale`: boolean that reverses color scale (default: false)
"""
function plot(kpi::KPathInterpolant, ωs, fields,
layout::Layout = Layout();
ylims = extrema(ωs), ylabel = "Energy",
opacity=1, colorscale=["YlGnBu"], reversescale=false,
kwargs...)
# check input
N = length(kpi); M = length(ωs)
if !all(field -> size(field) == (N,M), fields)
throw(DimensionMismatch("mismatched dimensions of `kpi` with `ωs` and `fields`"))
end

# prepare to plot band diagram
local_xs = cumdists.(cartesianize(kpi).kpaths)

# plot bands and k-lines/labels
heatmaps = Vector{GenericTrace{Dict{Symbol,Any}}}()
start_idx = 1
for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels))
stop_idx = start_idx+length(local_x)-1
# plot fields
for (i, field) in enumerate(fields)
push!(heatmaps,
PlotlyJS.heatmap(x=local_x, y=ωs, z=transpose(field[start_idx:stop_idx,:]),
xaxis="x$path_idx", yaxis="y", opacity=opacity,
colorscale=colorscale[mod(i-1,length(colorscale))+1], reversescale=reversescale))
end
start_idx = stop_idx + 1
end

plot(kpi, heatmaps, layout; ylims=ylims, ylabel=ylabel, kwargs...)
end