# coding: utf-8
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.
"""
This module implements plotter for DOS and band structure.
"""
import logging
from collections import OrderedDict, namedtuple
import numpy as np
import scipy.constants as const
from monty.json import jsanitize
from pymatgen.electronic_structure.plotter import plot_brillouin_zone
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
from pymatgen.util.plotting import pretty_plot, add_fig_kwargs, get_ax_fig_plt
logger = logging.getLogger(__name__)
FreqUnits = namedtuple("FreqUnits", ["factor", "label"])
[docs]def freq_units(units):
"""
Args:
units: str, accepted values: thz, ev, mev, ha, cm-1, cm^-1
Returns:
Returns conversion factor from THz to the requred units and the label in the form of a namedtuple
"""
d = {"thz": FreqUnits(1, "THz"),
"ev": FreqUnits(const.value("hertz-electron volt relationship") * const.tera, "eV"),
"mev": FreqUnits(const.value("hertz-electron volt relationship") * const.tera / const.milli, "meV"),
"ha": FreqUnits(const.value("hertz-hartree relationship") * const.tera, "Ha"),
"cm-1": FreqUnits(const.value("hertz-inverse meter relationship") * const.tera * const.centi, "cm^{-1}"),
'cm^-1': FreqUnits(const.value("hertz-inverse meter relationship") * const.tera * const.centi, "cm^{-1}")
}
try:
return d[units.lower().strip()]
except KeyError:
raise KeyError('Value for units `{}` unknown\nPossible values are:\n {}'.format(units, list(d.keys())))
[docs]class PhononDosPlotter:
"""
Class for plotting phonon DOSs. Note that the interface is extremely flexible
given that there are many different ways in which people want to view
DOS. The typical usage is::
# Initializes plotter with some optional args. Defaults are usually
# fine,
plotter = PhononDosPlotter()
# Adds a DOS with a label.
plotter.add_dos("Total DOS", dos)
# Alternatively, you can add a dict of DOSs. This is the typical
# form returned by CompletePhononDos.get_element_dos().
"""
def __init__(self, stack=False, sigma=None):
"""
Args:
stack: Whether to plot the DOS as a stacked area graph
sigma: A float specifying a standard deviation for Gaussian smearing
the DOS for nicer looking plots. Defaults to None for no
smearing.
"""
self.stack = stack
self.sigma = sigma
self._doses = OrderedDict()
[docs] def add_dos(self, label, dos):
"""
Adds a dos for plotting.
Args:
label:
label for the DOS. Must be unique.
dos:
PhononDos object
"""
densities = dos.get_smeared_densities(self.sigma) if self.sigma \
else dos.densities
self._doses[label] = {'frequencies': dos.frequencies, 'densities': densities}
[docs] def add_dos_dict(self, dos_dict, key_sort_func=None):
"""
Add a dictionary of doses, with an optional sorting function for the
keys.
Args:
dos_dict: dict of {label: Dos}
key_sort_func: function used to sort the dos_dict keys.
"""
if key_sort_func:
keys = sorted(dos_dict.keys(), key=key_sort_func)
else:
keys = dos_dict.keys()
for label in keys:
self.add_dos(label, dos_dict[label])
[docs] def get_dos_dict(self):
"""
Returns the added doses as a json-serializable dict. Note that if you
have specified smearing for the DOS plot, the densities returned will
be the smeared densities, not the original densities.
Returns:
Dict of dos data. Generally of the form, {label: {'frequencies':..,
'densities': ...}}
"""
return jsanitize(self._doses)
[docs] def get_plot(self, xlim=None, ylim=None, units="thz"):
"""
Get a matplotlib plot showing the DOS.
Args:
xlim: Specifies the x-axis limits. Set to None for automatic
determination.
ylim: Specifies the y-axis limits.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
"""
u = freq_units(units)
ncolors = max(3, len(self._doses))
ncolors = min(9, ncolors)
import palettable
colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors
y = None
alldensities = []
allfrequencies = []
plt = pretty_plot(12, 8)
# Note that this complicated processing of frequencies is to allow for
# stacked plots in matplotlib.
for key, dos in self._doses.items():
frequencies = dos['frequencies'] * u.factor
densities = dos['densities']
if y is None:
y = np.zeros(frequencies.shape)
if self.stack:
y += densities
newdens = y.copy()
else:
newdens = densities
allfrequencies.append(frequencies)
alldensities.append(newdens)
keys = list(self._doses.keys())
keys.reverse()
alldensities.reverse()
allfrequencies.reverse()
allpts = []
for i, (key, frequencies, densities) in enumerate(zip(keys, allfrequencies, alldensities)):
allpts.extend(list(zip(frequencies, densities)))
if self.stack:
plt.fill(frequencies, densities, color=colors[i % ncolors],
label=str(key))
else:
plt.plot(frequencies, densities, color=colors[i % ncolors],
label=str(key), linewidth=3)
if xlim:
plt.xlim(xlim)
if ylim:
plt.ylim(ylim)
else:
xlim = plt.xlim()
relevanty = [p[1] for p in allpts
if xlim[0] < p[0] < xlim[1]]
plt.ylim((min(relevanty), max(relevanty)))
ylim = plt.ylim()
plt.plot([0, 0], ylim, 'k--', linewidth=2)
plt.xlabel(r'$\mathrm{{Frequencies\ ({})}}$'.format(u.label))
plt.ylabel(r'$\mathrm{Density\ of\ states}$')
plt.legend()
leg = plt.gca().get_legend()
ltext = leg.get_texts() # all the text.Text instance in the legend
plt.setp(ltext, fontsize=30)
plt.tight_layout()
return plt
[docs] def save_plot(self, filename, img_format="eps", xlim=None, ylim=None, units="thz"):
"""
Save matplotlib plot to a file.
Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
xlim: Specifies the x-axis limits. Set to None for automatic
determination.
ylim: Specifies the y-axis limits.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1
"""
plt = self.get_plot(xlim, ylim, units=units)
plt.savefig(filename, format=img_format)
plt.close()
[docs] def show(self, xlim=None, ylim=None, units="thz"):
"""
Show the plot using matplotlib.
Args:
xlim: Specifies the x-axis limits. Set to None for automatic
determination.
ylim: Specifies the y-axis limits.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
"""
plt = self.get_plot(xlim, ylim, units=units)
plt.show()
[docs]class PhononBSPlotter:
"""
Class to plot or get data to facilitate the plot of band structure objects.
"""
def __init__(self, bs):
"""
Args:
bs: A PhononBandStructureSymmLine object.
"""
if not isinstance(bs, PhononBandStructureSymmLine):
raise ValueError(
"PhononBSPlotter only works with PhononBandStructureSymmLine objects. "
"A PhononBandStructure object (on a uniform grid for instance and "
"not along symmetry lines won't work)")
self._bs = bs
self._nb_bands = self._bs.nb_bands
def _maketicks(self, plt):
"""
utility private method to add ticks to a band structure
"""
ticks = self.get_ticks()
# Sanitize only plot the uniq values
uniq_d = []
uniq_l = []
temp_ticks = list(zip(ticks['distance'], ticks['label']))
for i in range(len(temp_ticks)):
if i == 0:
uniq_d.append(temp_ticks[i][0])
uniq_l.append(temp_ticks[i][1])
logger.debug("Adding label {l} at {d}".format(
l=temp_ticks[i][0], d=temp_ticks[i][1]))
else:
if temp_ticks[i][1] == temp_ticks[i - 1][1]:
logger.debug("Skipping label {i}".format(
i=temp_ticks[i][1]))
else:
logger.debug("Adding label {l} at {d}".format(
l=temp_ticks[i][0], d=temp_ticks[i][1]))
uniq_d.append(temp_ticks[i][0])
uniq_l.append(temp_ticks[i][1])
logger.debug("Unique labels are %s" % list(zip(uniq_d, uniq_l)))
plt.gca().set_xticks(uniq_d)
plt.gca().set_xticklabels(uniq_l)
for i in range(len(ticks['label'])):
if ticks['label'][i] is not None:
# don't print the same label twice
if i != 0:
if ticks['label'][i] == ticks['label'][i - 1]:
logger.debug("already print label... "
"skipping label {i}".format(i=ticks['label'][i]))
else:
logger.debug("Adding a line at {d} for label {l}".format(
d=ticks['distance'][i], l=ticks['label'][i]))
plt.axvline(ticks['distance'][i], color='k')
else:
logger.debug("Adding a line at {d} for label {l}".format(
d=ticks['distance'][i], l=ticks['label'][i]))
plt.axvline(ticks['distance'][i], color='k')
return plt
[docs] def bs_plot_data(self):
"""
Get the data nicely formatted for a plot
Returns:
A dict of the following format:
ticks: A dict with the 'distances' at which there is a qpoint (the
x axis) and the labels (None if no label)
frequencies: A list (one element for each branch) of frequencies for
each qpoint: [branch][qpoint][mode]. The data is
stored by branch to facilitate the plotting
lattice: The reciprocal lattice.
"""
distance = []
frequency = []
ticks = self.get_ticks()
for b in self._bs.branches:
frequency.append([])
distance.append([self._bs.distance[j]
for j in range(b['start_index'],
b['end_index'] + 1)])
for i in range(self._nb_bands):
frequency[-1].append(
[self._bs.bands[i][j]
for j in range(b['start_index'], b['end_index'] + 1)])
return {'ticks': ticks, 'distances': distance, 'frequency': frequency,
'lattice': self._bs.lattice_rec.as_dict()}
[docs] def get_plot(self, ylim=None, units="thz"):
"""
Get a matplotlib object for the bandstructure plot.
Args:
ylim: Specify the y-axis (frequency) limits; by default None let
the code choose.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
"""
u = freq_units(units)
plt = pretty_plot(12, 8)
band_linewidth = 1
data = self.bs_plot_data()
for d in range(len(data['distances'])):
for i in range(self._nb_bands):
plt.plot(data['distances'][d],
[data['frequency'][d][i][j] * u.factor
for j in range(len(data['distances'][d]))], 'b-',
linewidth=band_linewidth)
self._maketicks(plt)
# plot y=0 line
plt.axhline(0, linewidth=1, color='k')
# Main X and Y Labels
plt.xlabel(r'$\mathrm{Wave\ Vector}$', fontsize=30)
ylabel = r'$\mathrm{{Frequencies\ ({})}}$'.format(u.label)
plt.ylabel(ylabel, fontsize=30)
# X range (K)
# last distance point
x_max = data['distances'][-1][-1]
plt.xlim(0, x_max)
if ylim is not None:
plt.ylim(ylim)
plt.tight_layout()
return plt
[docs] def show(self, ylim=None, units="thz"):
"""
Show the plot using matplotlib.
Args:
ylim: Specify the y-axis (frequency) limits; by default None let
the code choose.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
"""
plt = self.get_plot(ylim, units=units)
plt.show()
[docs] def save_plot(self, filename, img_format="eps", ylim=None, units="thz"):
"""
Save matplotlib plot to a file.
Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
ylim: Specifies the y-axis limits.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
"""
plt = self.get_plot(ylim=ylim, units=units)
plt.savefig(filename, format=img_format)
plt.close()
[docs] def get_ticks(self):
"""
Get all ticks and labels for a band structure plot.
Returns:
A dict with 'distance': a list of distance at which ticks should
be set and 'label': a list of label for each of those ticks.
"""
tick_distance = []
tick_labels = []
previous_label = self._bs.qpoints[0].label
previous_branch = self._bs.branches[0]['name']
for i, c in enumerate(self._bs.qpoints):
if c.label is not None:
tick_distance.append(self._bs.distance[i])
this_branch = None
for b in self._bs.branches:
if b['start_index'] <= i <= b['end_index']:
this_branch = b['name']
break
if c.label != previous_label \
and previous_branch != this_branch:
label1 = c.label
if label1.startswith("\\") or label1.find("_") != -1:
label1 = "$" + label1 + "$"
label0 = previous_label
if label0.startswith("\\") or label0.find("_") != -1:
label0 = "$" + label0 + "$"
tick_labels.pop()
tick_distance.pop()
tick_labels.append(label0 + "$\\mid$" + label1)
else:
if c.label.startswith("\\") or c.label.find("_") != -1:
tick_labels.append("$" + c.label + "$")
else:
tick_labels.append(c.label)
previous_label = c.label
previous_branch = this_branch
return {'distance': tick_distance, 'label': tick_labels}
[docs] def plot_compare(self, other_plotter, units="thz"):
"""
plot two band structure for comparison. One is in red the other in blue.
The two band structures need to be defined on the same symmetry lines!
and the distance between symmetry lines is
the one of the band structure used to build the PhononBSPlotter
Args:
another PhononBSPlotter object defined along the same symmetry lines
Returns:
a matplotlib object with both band structures
"""
u = freq_units(units)
data_orig = self.bs_plot_data()
data = other_plotter.bs_plot_data()
if len(data_orig['distances']) != len(data['distances']):
raise ValueError('The two objects are not compatible.')
plt = self.get_plot(units=units)
band_linewidth = 1
for i in range(other_plotter._nb_bands):
for d in range(len(data_orig['distances'])):
plt.plot(data_orig['distances'][d],
[data['frequency'][d][i][j] * u.factor
for j in range(len(data_orig['distances'][d]))],
'r-', linewidth=band_linewidth)
return plt
[docs] def plot_brillouin(self):
"""
plot the Brillouin zone
"""
# get labels and lines
labels = {}
for q in self._bs.qpoints:
if q.label:
labels[q.label] = q.frac_coords
lines = []
for b in self._bs.branches:
lines.append([self._bs.qpoints[b['start_index']].frac_coords,
self._bs.qpoints[b['end_index']].frac_coords])
plot_brillouin_zone(self._bs.lattice_rec, lines=lines, labels=labels)
[docs]class ThermoPlotter:
"""
Plotter for thermodynamic properties obtained from phonon DOS.
If the structure corresponding to the DOS, it will be used to extract the forumla unit and provide
the plots in units of mol instead of mole-cell
"""
def __init__(self, dos, structure=None):
"""
Args:
dos: A PhononDos object.
structure: A Structure object corresponding to the structure used for the calculation.
"""
self.dos = dos
self.structure = structure
def _plot_thermo(self, func, temperatures, factor=1, ax=None, ylabel=None, label=None, ylim=None, **kwargs):
"""
Plots a thermodynamic property for a generic function from a PhononDos instance.
Args:
func: the thermodynamic function to be used to calculate the property
temperatures: a list of temperatures
factor: a multiplicative factor applied to the thermodynamic property calculated. Used to change
the units.
ax: matplotlib :class:`Axes` or None if a new figure should be created.
ylabel: label for the y axis
label: label of the plot
ylim: tuple specifying the y-axis limits.
kwargs: kwargs passed to the matplotlib function 'plot'.
Returns:
matplotlib figure
"""
ax, fig, plt = get_ax_fig_plt(ax)
values = []
for t in temperatures:
values.append(func(t, structure=self.structure) * factor)
ax.plot(temperatures, values, label=label, **kwargs)
if ylim:
ax.set_ylim(ylim)
ax.set_xlim((np.min(temperatures), np.max(temperatures)))
ylim = plt.ylim()
if ylim[0] < 0 < ylim[1]:
plt.plot(plt.xlim(), [0, 0], 'k-', linewidth=1)
ax.set_xlabel(r"$T$ (K)")
if ylabel:
ax.set_ylabel(ylabel)
return fig
[docs] @add_fig_kwargs
def plot_cv(self, tmin, tmax, ntemp, ylim=None, **kwargs):
"""
Plots the constant volume specific heat C_v in a temperature range.
Args:
tmin: minimum temperature
tmax: maximum temperature
ntemp: number of steps
ylim: tuple specifying the y-axis limits.
kwargs: kwargs passed to the matplotlib function 'plot'.
Returns:
matplotlib figure
"""
temperatures = np.linspace(tmin, tmax, ntemp)
if self.structure:
ylabel = r"$C_v$ (J/K/mol)"
else:
ylabel = r"$C_v$ (J/K/mol-c)"
fig = self._plot_thermo(self.dos.cv, temperatures, ylabel=ylabel, ylim=ylim, **kwargs)
return fig
[docs] @add_fig_kwargs
def plot_entropy(self, tmin, tmax, ntemp, ylim=None, **kwargs):
"""
Plots the vibrational entrpy in a temperature range.
Args:
tmin: minimum temperature
tmax: maximum temperature
ntemp: number of steps
ylim: tuple specifying the y-axis limits.
kwargs: kwargs passed to the matplotlib function 'plot'.
Returns:
matplotlib figure
"""
temperatures = np.linspace(tmin, tmax, ntemp)
if self.structure:
ylabel = r"$S$ (J/K/mol)"
else:
ylabel = r"$S$ (J/K/mol-c)"
fig = self._plot_thermo(self.dos.entropy, temperatures, ylabel=ylabel, ylim=ylim, **kwargs)
return fig
[docs] @add_fig_kwargs
def plot_internal_energy(self, tmin, tmax, ntemp, ylim=None, **kwargs):
"""
Plots the vibrational internal energy in a temperature range.
Args:
tmin: minimum temperature
tmax: maximum temperature
ntemp: number of steps
ylim: tuple specifying the y-axis limits.
kwargs: kwargs passed to the matplotlib function 'plot'.
Returns:
matplotlib figure
"""
temperatures = np.linspace(tmin, tmax, ntemp)
if self.structure:
ylabel = r"$\Delta E$ (kJ/mol)"
else:
ylabel = r"$\Delta E$ (kJ/mol-c)"
fig = self._plot_thermo(self.dos.internal_energy, temperatures, ylabel=ylabel, ylim=ylim,
factor=1e-3, **kwargs)
return fig
[docs] @add_fig_kwargs
def plot_helmholtz_free_energy(self, tmin, tmax, ntemp, ylim=None, **kwargs):
"""
Plots the vibrational contribution to the Helmoltz free energy in a temperature range.
Args:
tmin: minimum temperature
tmax: maximum temperature
ntemp: number of steps
ylim: tuple specifying the y-axis limits.
kwargs: kwargs passed to the matplotlib function 'plot'.
Returns:
matplotlib figure
"""
temperatures = np.linspace(tmin, tmax, ntemp)
if self.structure:
ylabel = r"$\Delta F$ (kJ/mol)"
else:
ylabel = r"$\Delta F$ (kJ/mol-c)"
fig = self._plot_thermo(self.dos.helmholtz_free_energy, temperatures, ylabel=ylabel, ylim=ylim,
factor=1e-3, **kwargs)
return fig
[docs] @add_fig_kwargs
def plot_thermodynamic_properties(self, tmin, tmax, ntemp, ylim=None, **kwargs):
"""
Plots all the thermodynamic properties in a temperature range.
Args:
tmin: minimum temperature
tmax: maximum temperature
ntemp: number of steps
ylim: tuple specifying the y-axis limits.
kwargs: kwargs passed to the matplotlib function 'plot'.
Returns:
matplotlib figure
"""
temperatures = np.linspace(tmin, tmax, ntemp)
mol = "" if self.structure else "-c"
fig = self._plot_thermo(self.dos.cv, temperatures, ylabel="Thermodynamic properties", ylim=ylim,
label=r"$C_v$ (J/K/mol{})".format(mol), **kwargs)
self._plot_thermo(self.dos.entropy, temperatures, ylim=ylim, ax=fig.axes[0],
label=r"$S$ (J/K/mol{})".format(mol), **kwargs)
self._plot_thermo(self.dos.internal_energy, temperatures, ylim=ylim, ax=fig.axes[0], factor=1e-3,
label=r"$\Delta E$ (kJ/mol{})".format(mol), **kwargs)
self._plot_thermo(self.dos.helmholtz_free_energy, temperatures, ylim=ylim, ax=fig.axes[0], factor=1e-3,
label=r"$\Delta F$ (kJ/mol{})".format(mol), **kwargs)
fig.axes[0].legend(loc="best")
return fig