Source code for structuregraph_helpers.plotting

"""Plotting helpers."""
import plotly.graph_objs as go

__all__ = ("plotly_plot_structure_graph",)


[docs]def plotly_plot_structure_graph(structure_graph, show_edges: bool = True, show_nodes: bool = True): """Plot a StructureGraph using Plotly.""" node_x = [] node_y = [] node_z = [] edge_x = [] edge_y = [] edge_z = [] atom_number = [] coords = structure_graph.structure.frac_coords for i, _ in enumerate(coords): c = coords[i] node_x.append(c[0]) node_y.append(c[1]) node_z.append(c[2]) atom_number.append(structure_graph.structure[i].specie.number) for start, end, data in structure_graph.graph.edges(data=True): start_c = coords[start] end_c = coords[end] + data["to_jimage"] edge_x += [start_c[0], end_c[0], None] edge_y += [start_c[1], end_c[1], None] edge_z += [start_c[2], end_c[2], None] trace1 = go.Scatter3d( x=edge_x, y=edge_y, z=edge_z, mode="lines", hoverinfo="none", line=dict(color="black", width=2), ) trace2 = go.Scatter3d( x=node_x, y=node_y, z=node_z, mode="markers", hoverinfo="none", marker=dict( symbol="circle", size=6, color=atom_number, colorscale="Viridis", line=dict(color="rgb(50,50,50)", width=0.5), ), ) axis = dict( showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title="", ) layout = go.Layout( showlegend=False, scene=dict( xaxis=dict(axis), yaxis=dict(axis), zaxis=dict(axis), ), margin=dict(t=100), hovermode="closest", ) data = [] if show_nodes: data.append(trace2) if show_edges: data.append(trace1) fig = go.Figure(data=data, layout=layout) return fig