Source code for defdap.plotting

# Copyright 2025 Mechanics of Microstructures Group
#    at The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from matplotlib.widgets import Button, TextBox
from matplotlib.collections import LineCollection
from matplotlib_scalebar.scalebar import ScaleBar
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import FuncFormatter

from skimage import morphology as mph

from defdap import defaults
from defdap import quat
from defdap.crystal_utils import project_to_orth, equavlent_indicies, idc_to_string

# TODO: add plot parameter to add to current figure


[docs]class Plot(object): """ Class used for creating and manipulating plots. """ def __init__(self, ax=None, ax_params={}, fig=None, make_interactive=False, title=None, **kwargs): self.interactive = make_interactive if make_interactive: if fig is not None and ax is not None: self.fig = fig self.ax = ax else: # self.fig, self.ax = plt.subplots(**kwargs) self.fig = plt.figure(**kwargs) self.ax = self.fig.add_subplot(111, **ax_params) self.btn_store = [] self.txt_store = [] self.txt_box_store = [] self.p1 = [] self.p2 = [] else: self.fig = fig # TODO: flag for new figure if ax is None: self.fig = plt.figure(**kwargs) self.ax = self.fig.add_subplot(111, **ax_params) else: self.ax = ax self.colour_bar = None self.arrow = None if title is not None: self.set_title(title)
[docs] def set_empty_state(self): pass
[docs] def check_interactive(self): """Checks if current plot is interactive. Raises ------- Exception If plot is not interactive """ if not self.interactive: raise Exception("Plot must be interactive")
[docs] def add_event_handler(self, eventName, eventHandler): self.check_interactive() self.fig.canvas.mpl_connect(eventName, lambda e: eventHandler(e, self))
[docs] def add_axes(self, loc, proj='2d'): """Add axis to current plot Parameters ---------- loc Location of axis. proj : str, {2d, 3d} 2D or 3D projection. Returns ------- matplotlib.Axes.axes """ if proj == '2d': return self.fig.add_axes(loc) if proj == '3d': return Axes3D(self.fig, rect=loc, proj_type='ortho', azim=270, elev=90)
[docs] def add_button(self, label, click_handler, loc=(0.8, 0.0, 0.1, 0.07), **kwargs): """Add a button to the plot. Parameters ---------- label : str Label for the button. click_handler Click handler to assign. loc : list(float), len 4 Left, bottom, width, height. kwargs All other arguments passed to :class:`matplotlib.widgets.Button`. """ self.check_interactive() btn_ax = self.fig.add_axes(loc) btn = Button(btn_ax, label, **kwargs) btn.on_clicked(lambda e: click_handler(e, self)) self.btn_store.append(btn)
[docs] def add_text_box(self, label, submit_handler=None, change_handler=None, loc=(0.8, 0.0, 0.1, 0.07), **kwargs): """Add a text box to the plot. Parameters ---------- label : str Label for the button. submit_handler Submit handler to assign. change_handler Change handler to assign. loc : list(float), len 4 Left, bottom, width, height. kwargs All other arguments passed to :class:`matplotlib.widgets.TextBox`. Returns ------- matplotlotlib.widgets.TextBox """ self.check_interactive() txt_box_ax = self.fig.add_axes(loc) txt_box = TextBox(txt_box_ax, label, **kwargs) if submit_handler != None: txt_box.on_submit(lambda e: submit_handler(e, self)) if change_handler != None: txt_box.on_text_change(lambda e: change_handler(e, self)) self.txt_box_store.append(txt_box) return txt_box
[docs] def add_text(self, ax, x, y, txt, **kwargs): """Add text to the plot. Parameters ---------- ax : matplotlib.axes.Axes Matplotlib axis to plot on. x : float x position. y : float y position. txt : str Text to write onto the plot. kwargs : All other arguments passed to :func:`matplotlib.pyplot.text`. """ txt = ax.text(x, y, txt, **kwargs) self.txt_store.append(txt)
[docs] def add_arrow(self, start_end, persistent=False, clear_previous=True, label=None): """Add arrow to grain plot. Parameters ---------- start_end: 4-tuple Starting (x, y), Ending (x, y). persistent : If persistent, do not clear arrow with clearPrev. clear_previous : Clear all non-persistent arrows. label Label to place near arrow. """ arrow_params = { 'xy': start_end[0:2], # Arrow start coordinates 'xycoords': 'data', 'xytext': start_end[2:4], # Arrow end coordinates 'textcoords': 'data', 'arrowprops': dict(arrowstyle="<-", connectionstyle="arc3", color='red', alpha=0.7, linewidth=2, shrinkA=0, shrinkB=0) } # If persisent, add the arrow onto the plot directly if persistent: self.ax.annotate("", **arrow_params) # If not persistent, save a reference so that it can be removed later if not persistent: if clear_previous and (self.arrow is not None): self.arrow.remove() if None not in start_end: self.arrow = self.ax.annotate("", **arrow_params) # Add a label if specified if label is not None: self.ax.annotate(label, xy=start_end[2:4], xycoords='data', xytext=(15, 15), textcoords='offset pixels', c='red', fontsize=14, fontweight='bold')
[docs] def set_size(self, size): """Set size of plot. Parameters ---------- size : float, float Width and height in inches. """ self.fig.set_size_inches(size[0], size[1], forward=True)
[docs] def set_title(self, txt): """Set title of plot. Parameters ---------- txt : str Title to set. """ if self.fig.canvas.manager is not None: self.fig.canvas.manager.set_window_title(txt)
[docs] def line_slice(self, event, plot, action=None): """ Catch click and drag then draw an arrow. Parameters ---------- event : Click event. plot : defdap.plotting.Plot Plot to capture clicks from. action Further action to perform. Examples ---------- To use, add a click and release event handler to your plot, pointing to this function: >>> plot.add_event_handler('button_press_event',lambda e, p: line_slice(e, p)) >>> plot.add_event_handler('button_release_event', lambda e, p: line_slice(e, p)) """ # check if click was on the map if event.inaxes is not self.ax: return if event.name == 'button_press_event': self.p1 = (event.xdata, event.ydata) # save 1st point elif event.name == 'button_release_event': self.p2 = (event.xdata, event.ydata) # save 2nd point self.add_arrow(start_end=(self.p1[0], self.p1[1], self.p2[0], self.p2[1])) self.fig.canvas.draw_idle() if action is not None: action(plot=self, start_end=(self.p1[0], self.p1[1], self.p2[0], self.p2[1]))
@property def exists(self): self.check_interactive() return plt.fignum_exists(self.fig.number)
[docs] def clear(self): """Clear plot. """ self.check_interactive() if self.colour_bar is not None: self.colour_bar.remove() self.colour_bar = None self.ax.clear() self.set_empty_state() self.draw()
[docs] def draw(self): """Draw plot """ self.fig.canvas.draw()
[docs]class MapPlot(Plot): """ Class for creating a map plot. """ def __init__(self, calling_map, fig=None, ax=None, ax_params={}, make_interactive=False, **kwargs): """Initialise a map plot. Parameters ---------- calling_map : Map DIC or EBSD map which called this plot. fig : matplotlib.figure.Figure Matplotlib figure to plot on ax : matplotlib.axes.Axes Matplotlib axis to plot on ax_params : Passed to defdap.plotting.Plot as ax_params. make_interactive : bool, optional If true, make interactive kwargs Other arguments passed to :class:`defdap.plotting.Plot`. """ super(MapPlot, self).__init__( ax, ax_params=ax_params, fig=fig, make_interactive=make_interactive, **kwargs ) self.calling_map = calling_map self.set_empty_state()
[docs] def set_empty_state(self): self.img_layers = [] self.highlights_layer_id = None self.points_layer_ids = [] self.ax.set_xticks([]) self.ax.set_yticks([])
[docs] def add_map(self, map_data, vmin=None, vmax=None, cmap='viridis', **kwargs): """Add a map to a plot. Parameters ---------- map_data : numpy.ndarray Map data to plot. vmin : float Minimum value for the colour scale. vmax : float Maximum value for the colour scale. cmap Colour map. kwargs Other arguments are passed to :func:`matplotlib.pyplot.imshow`. Returns ------- matplotlib.image.AxesImage """ img = self.ax.imshow(map_data, vmin=vmin, vmax=vmax, interpolation='None', cmap=cmap, **kwargs) self.draw() self.img_layers.append(img) return img
[docs] def add_colour_bar(self, label, layer=0, **kwargs): """Add a colour bar to plot. Parameters ---------- label : str Label for the colour bar. layer : int Layer ID. kwargs Other arguments are passed to :func:`matplotlib.pyplot.colorbar`. """ img = self.img_layers[layer] self.colour_bar = plt.colorbar(img, ax=self.ax, label=label, **kwargs)
[docs] def add_scale_bar(self, scale=None): """Add scale bar to plot. Parameters ---------- scale : float Size of a pixel in microns. """ if scale is None: scale = self.calling_map.scale self.ax.add_artist(ScaleBar(scale * 1e-6))
[docs] def add_grain_boundaries(self, kind="pixel", boundaries=None, colour=None, dilate=False, draw=True, **kwargs): """Add grain boundaries to the plot. Parameters ---------- kind : str, {"pixel", "line"} Type of boundaries to plot, either a boundary image or a collection of line segments. boundaries : various, optional Boundaries to plot, either a boundary image or a list of pairs of coordinates representing the start and end of each boundary segment. If not provided the boundaries are loaded from the calling map. boundaries : various, defdap.ebsd.BoundarySet Boundaries to plot. If not provided the boundaries are loaded from the calling map. colour : various One of: - Colour of all boundaries as a string (only option pixel kind) - Colour of all boundaries as RGBA tuple - List of values to represent colour of each line relative to a `norm` and `cmap` dilate : bool If true, dilate the grain boundaries. kwargs If line kind then other arguments are passed to :func:`matplotlib.collections.LineCollection`. Returns ------- Various : matplotlib.image.AxesImage if type is pixel """ if colour is None: colour = "white" if boundaries is None: boundaries = self.calling_map.data.grain_boundaries if kind == "line": if isinstance(colour, str): colour = mpl.colors.to_rgba(colour) if len(colour) == len(boundaries.lines): colour_array = colour colour_lc = None elif len(colour) == 4: colour_array = None colour_lc = colour else: ValueError('Issue with passed colour') lc = LineCollection(boundaries.lines, colors=colour_lc, **kwargs) lc.set_array(colour_array) img = self.ax.add_collection(lc) else: boundaries_image = boundaries.image.astype(int) if dilate: boundaries_image = mph.binary_dilation(boundaries_image) # create colourmap for boundaries going from transparent to # opaque of the given colour boundaries_cmap = mpl.colors.LinearSegmentedColormap.from_list( 'my_cmap', ['white', colour], 256 ) boundaries_cmap._init() boundaries_cmap._lut[:, -1] = np.linspace(0, 1, boundaries_cmap.N + 3) img = self.ax.imshow(boundaries_image, cmap=boundaries_cmap, interpolation='None', vmin=0, vmax=1) if draw: self.draw() self.img_layers.append(img) return img
[docs] def add_grain_highlights(self, grain_ids, grain_colours=None, alpha=None, new_layer=False): """Highlight grains in the plot. Parameters ---------- grain_ids : list List of grain IDs to highlight. grain_colours : Colour to use for grain highlight. alpha : float Alpha (transparency) to use for grain highlight. new_layer : bool If true, make a new layer in img_layers. Returns ------- matplotlib.image.AxesImage """ if grain_colours is None: grain_colours = ['white'] if alpha is None: alpha = self.calling_map.highlight_alpha outline = np.zeros(self.calling_map.shape, dtype=int) for i, grainId in enumerate(grain_ids, start=1): if i > len(grain_colours): i = len(grain_colours) # outline of highlighted grain grain = self.calling_map.grains[grainId] grainOutline = grain.grain_outline(bg=0, fg=i) x0, y0, xmax, ymax = grain.extreme_coords # add to highlight image outline[y0:ymax + 1, x0:xmax + 1] += grainOutline # Custom colour map where 0 is transparent white for bg and # then a patch for each grain colour grain_colours.insert(0, 'white') highlightsCmap = mpl.colors.ListedColormap(grain_colours) highlightsCmap._init() alphaMap = np.full(highlightsCmap.N + 3, alpha) alphaMap[0] = 0 highlightsCmap._lut[:, -1] = alphaMap if self.highlights_layer_id is None or new_layer: img = self.ax.imshow(outline, interpolation='none', cmap=highlightsCmap) if self.highlights_layer_id is None: self.highlights_layer_id = len(self.img_layers) self.img_layers.append(img) else: img = self.img_layers[self.highlights_layer_id] img.set_data(outline) img.set_cmap(highlightsCmap) img.autoscale() self.draw() return img
[docs] def add_grain_numbers(self, fontsize=10, **kwargs): """Add grain numbers to a map. Parameters ---------- fontsize : float Font size. kwargs Pass other arguments to :func:`matplotlib.pyplot.text`. """ for grain_id, grain in enumerate(self.calling_map): x_centre, y_centre = grain.centre_coords(centre_type="com", grain_coords=False) self.ax.text(x_centre, y_centre, grain_id, fontsize=fontsize, **kwargs) self.draw()
[docs] def add_legend(self, values, labels, layer=0, **kwargs): """Add a legend to a map. Parameters ---------- values : list Values to find colour patched for. labels : list Labels to assign to values. layer : int Image layer to generate legend for. kwargs Pass other arguments to :func:`matplotlib.pyplot.legend`. """ # Find colour values for given values img = self.img_layers[layer] colors = [img.cmap(img.norm(value)) for value in values] # Get colour patches for each phase and make legend patches = [mpl.patches.Patch( color=colors[i], label=labels[i] ) for i in range(len(values))] self.ax.legend(handles=patches, **kwargs)
[docs] def add_points(self, x, y, update_layer=None, **kwargs): """Add points to plot. Parameters ---------- x : list of float x coordinates y : list of float y coordinates update_layer : int, optional Layer to place points on kwargs Other arguments passed to :func:`matplotlib.pyplot.scatter`. """ x, y = np.array(x), np.array(y) if len(self.points_layer_ids) == 0 or update_layer is None: points = self.ax.scatter(x, y, **kwargs) self.points_layer_ids.append(len(self.img_layers)) self.img_layers.append(points) else: points = self.img_layers[self.points_layer_ids[update_layer]] points.set_offsets(np.hstack((x[:, np.newaxis], y[:, np.newaxis]))) self.draw() return points
[docs] @classmethod def create( cls, calling_map, map_data, fig=None, fig_params={}, ax=None, ax_params={}, plot=None, make_interactive=False, plot_colour_bar=False, vmin=None, vmax=None, cmap=None, clabel="", plot_gbs=False, dilate_boundaries=False, boundary_colour=None, plot_scale_bar=False, scale=None, highlight_grains=None, highlight_colours=None, highlight_alpha=None, **kwargs ): """Create a plot for a map. Parameters ---------- calling_map : base.Map DIC or EBSD map which called this plot. map_data : numpy.ndarray Data to be plotted. fig : matplotlib.figure.Figure Matplotlib figure to plot on. fig_params : Passed to defdap.plotting.Plot. ax : matplotlib.axes.Axes Matplotlib axis to plot on. ax_params : Passed to defdap.plotting.Plot as ax_params. plot : defdap.plotting.Plot If none, use current plot. make_interactive : If true, make plot interactive plot_colour_bar : bool If true, plot a colour bar next to the map. vmin : float, optional Minimum value for the colour scale. vmax : float, optional Maximum value for the colour scale. cmap : str Colour map. clabel : str Label for the colour bar. plot_gbs : bool If true, plot the grain boundaries on the map. dilate_boundaries : bool If true, dilate the grain boundaries. boundary_colour : str Colour to use for the grain boundaries. plot_scale_bar : bool If true, plot a scale bar in the map. scale : float Size of pixel in microns. highlight_grains : list(int) List of grain IDs to highlight. highlight_colours : str Colour to highlight grains. highlight_alpha : float Alpha (transparency) by which to highlight grains. kwargs : All other arguments passed to :func:`defdap.plotting.MapPlot.add_map` Returns ------- defdap.plotting.MapPlot """ if plot is None: plot = cls(calling_map, fig=fig, ax=ax, ax_params=ax_params, make_interactive=make_interactive, **fig_params) if map_data is not None: plot.add_map(map_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) if plot_colour_bar: plot.add_colour_bar(clabel) if plot_gbs: plot.add_grain_boundaries( colour=boundary_colour, dilate=dilate_boundaries, kind=plot_gbs ) if highlight_grains is not None: plot.add_grain_highlights( highlight_grains, grain_colours=highlight_colours, alpha=highlight_alpha ) if plot_scale_bar: plot.add_scale_bar(scale=scale) return plot
[docs]class GrainPlot(Plot): """ Class for creating a map for a grain. """ def __init__(self, calling_grain, fig=None, ax=None, ax_params={}, make_interactive=False, **kwargs): super(GrainPlot, self).__init__( ax, ax_params=ax_params, fig=fig, make_interactive=make_interactive, **kwargs ) self.calling_grain = calling_grain self.set_empty_state()
[docs] def set_empty_state(self): self.img_layers = [] self.ax.set_xticks([]) self.ax.set_yticks([])
[docs] def addMap(self, map_data, vmin=None, vmax=None, cmap='viridis', **kwargs): """Add a map to a grain plot. Parameters ---------- map_data : numpy.ndarray Grain data to plot vmin : float Minimum value for the colour scale. vmax : float Maximum value for the colour scale. cmap Colour map to use. kwargs Other arguments are passed to :func:`matplotlib.pyplot.imshow`. Returns ------- matplotlib.image.AxesImage """ img = self.ax.imshow(map_data, vmin=vmin, vmax=vmax, interpolation='None', cmap=cmap, **kwargs) self.draw() self.img_layers.append(img) return img
[docs] def add_colour_bar(self, label, layer=0, **kwargs): """Add colour bar to grain plot. Parameters ---------- label : str Label to add to colour bar. layer : int Layer on which to add colourbar. kwargs Other arguments passed to :func:`matplotlib.pyplot.colorbar`. """ img = self.img_layers[layer] self.colour_bar = plt.colorbar(img, ax=self.ax, label=label, **kwargs)
[docs] def add_scale_bar(self, scale=None): """Add scale bar to grain plot. Parameters ---------- scale : float Size of pixel in micron. """ if scale is None: scale = self.calling_grain.owner_map.scale self.ax.add_artist(ScaleBar(scale * 1e-6))
[docs] def add_traces(self, angles, colours, top_only=False, pos=None, **kwargs): """Add slip trace angles to grain plot. Illustrated by lines crossing through central pivot point to create a circle. Parameters ---------- angles : list Angles of slip traces. colours : list Colours to plot. top_only : bool, optional If true, plot only a semicircle instead of a circle. pos : tuple Position of slip traces. kwargs Other arguments are passed to :func:`matplotlib.pyplot.quiver` """ if pos is None: pos = self.calling_grain.centre_coords() traces = np.array((-np.sin(angles), np.cos(angles))) # When plotting top half only, move all 'traces' to +ve y # and set the pivot to be in the tail instead of centre if top_only: pivot = 'tail' for idx, (x, y) in enumerate(zip(traces[0], traces[1])): if x < 0 and y < 0: traces[0][idx] *= -1 traces[1][idx] *= -1 self.ax.set_ylim(pos[1] - 0.001, pos[1] + 0.1) self.ax.set_xlim(pos[0] - 0.1, pos[0] + 0.1) else: pivot = 'middle' for i, trace in enumerate(traces.T): colour = colours[len(colours) - 1] if i >= len(colours) else colours[i] self.ax.quiver( pos[0], pos[1], trace[0], trace[1], scale=1, pivot=pivot, color=colour, headwidth=1, headlength=0, **kwargs ) self.draw()
[docs] def add_slip_traces(self, colours=None, **kwargs): """Add slip traces to plot, based on the calling grain's slip systems. Parameters ---------- colours : list Colours of each trace. top_only : bool, optional If true, plot only a semicircle instead of a circle. kwargs Other arguments are passed to :func:`defdap.plotting.GrainPlot.add_traces` """ if colours is None: colours = self.calling_grain.phase.slip_trace_colours slip_trace_angles = self.calling_grain.slip_traces self.add_traces(slip_trace_angles, colours, **kwargs)
[docs] def add_slip_bands(self, grain_map_data, colour=None, thres=None, min_dist=None, **kwargs): """Add lines representing slip bands detected by Radon transform in :func:`~defdap.hrdic.grain.calc_slip_bands`. Parameters ---------- grain_map_data : Map data to pass to :func:`~defdap.hrdic.Grain.calc_slip_bands`. colour : str Colour of traces. thres : float Threshold to use in :func:`~defdap.hrdic.Grain.calc_slip_bands`. min_dist : Minimum angle between bands in :func:`~defdap.hrdic.Grain.calc_slip_bands`. kwargs Other arguments are passed to :func:`defdap.plotting.GrainPlot.add_traces`. """ if colour is None: colour = "black" slip_band_angles = self.calling_grain.calc_slip_bands( grain_map_data, thres=thres, min_dist=min_dist ) self.add_traces(slip_band_angles, [colour], **kwargs)
[docs] @classmethod def create( cls, calling_grain, map_data, fig=None, fig_params={}, ax=None, ax_params={}, plot=None, make_interactive=False, plot_colour_bar=False, vmin=None, vmax=None, cmap=None, clabel="", plot_scale_bar=False, scale=None, plot_slip_traces=False, plot_slip_bands=False, **kwargs ): """Create grain plot. Parameters ---------- calling_grain : base.Grain DIC or EBSD grain which called this plot. map_data : Data to be plotted. fig : matplotlib.figure.Figure Matplotlib figure to plot on. fig_params : Passed to defdap.plotting.Plot. ax : matplotlib.axes.Axes Matplotlib axis to plot on. ax_params : Passed to defdap.plotting.Plot as ax_params. plot : defdap.plotting.Plot If none, use current plot. make_interactive : If true, make plot interactive plot_colour_bar : bool If true, plot a colour bar next to the map. vmin : float Minimum value for the colour scale. vmax : float Maximum value for the colour scale. cmap : Colour map. clabel : str Label for the colour bar. plot_scale_bar : bool If true, plot a scale bar in the map. scale : float Size of pizel in microns. plot_slip_traces : bool If true, plot slip traces with :func:`~defdap.plotting.GrainPlot.add_slip_traces` plot_slip_bands : bool If true, plot slip traces with :func:`~defdap.plotting.GrainPlot.add_slip_bands` kwargs : All other arguments passed to :func:`defdap.plotting.GrainPlot.add_map` Returns ------- defdap.plotting.GrainPlot """ if plot is None: plot = cls(calling_grain, fig=fig, ax=ax, ax_params=ax_params, make_interactive=make_interactive, **fig_params) plot.addMap(map_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) if plot_colour_bar: plot.add_colour_bar(clabel) if plot_scale_bar: plot.add_scale_bar(scale=scale) if plot_slip_traces: plot.add_slip_traces() if plot_slip_bands: plot.add_slip_bands(map_data) return plot
[docs]class PolePlot(Plot): """ Class for creating an inverse pole figure plot. """ def __init__(self, plot_type, crystal_sym, projection=None, fig=None, ax=None, ax_params={}, make_interactive=False, **kwargs): super(PolePlot, self).__init__( ax, ax_params=ax_params, fig=fig, make_interactive=make_interactive, **kwargs) self.plot_type = plot_type self.crystal_sym = crystal_sym self.projection = self._validateProjection(projection) self.img_layers = [] self.add_axis()
[docs] def add_axis(self): """Draw axes on the IPF based on crystal symmetry. Raises ------- NotImplementedError If a crystal type other than 'cubic' or 'hexagonal' are selected. """ if self.plot_type == "IPF" and self.crystal_sym == "cubic": lines = [ ((0, 0, 1), (0, 1, 1)), ((0, 0, 1), (-1, 1, 1)), ((0, 1, 1), (-1, 1, 1)), ] labels = [ ((0, 0, 1), -0.005, 'top'), ((0, 1, 1), -0.005, 'top'), ((-1, 1, 1), 0.005, 'bottom'), ] elif self.plot_type == "IPF" and self.crystal_sym == "hexagonal": if defaults['ipf_triangle_convention'] == 'down': lines = [ ((0, 0, 0, 1), (-1, 2, -1, 0)), ((0, 0, 0, 1), (0, 1, -1, 0)), ((-1, 2, -1, 0), (0, 1, -1, 0)), ] labels = [ ((0, 0, 0, 1), 0.012, 'bottom'), ((-1, 2, -1, 0), 0.012, 'bottom'), ((0, 1, -1, 0), -0.012, 'top'), ] else: lines = [ ((0, 0, 0, 1), (-1, 2, -1, 0)), ((0, 0, 0, 1), (-1, 1, 0, 0)), ((-1, 2, -1, 0), (-1, 1, 0, 0)), ] labels = [ ((0, 0, 0, 1), -0.012, 'top'), ((-1, 2, -1, 0), -0.012, 'top'), ((-1, 1, 0, 0), 0.012, 'bottom'), ] else: raise NotImplementedError("Only works for cubic and hexagonal.") for line in lines: self.add_line(*line, c='k', lw=2) for label in labels: self.label_point( label[0], pad_y=label[1], va=label[2], ha='center', fontsize=12 ) self.ax.axis('equal') self.ax.axis('off')
[docs] def add_line(self, start_point, end_point, plot_syms=False, res=100, **kwargs): """Draw lines on the IPF plot. Parameters ---------- start_point : tuple Start point in crystal coordinates (i.e. [0,0,1]). end_point : tuple End point in crystal coordinates, (i.e. [1,0,0]). plot_syms : bool, optional If true, plot all symmetrically equivelant points. res : int Number of points within each line to plot. kwargs All other arguments are passed to :func:`matplotlib.pyplot.plot`. """ if self.crystal_sym == 'hexagonal': start_point = project_to_orth(0.8165, dir=start_point, in_type='mb') end_point = project_to_orth(0.8165, dir=end_point, in_type='mb') lines = [(start_point, end_point)] if plot_syms: for symm in quat.Quat.sym_eqv(self.crystal_sym)[1:]: start_point_symm = symm.transform_vector(start_point) end_point_symm = symm.transform_vector(end_point) if start_point_symm[2] < 0: start_point_symm *= -1 if end_point_symm[2] < 0: end_point_symm *= -1 lines.append((start_point_symm, end_point_symm)) line_points = np.zeros((3, res), dtype=float) for line in lines: for i in range(3): if line[0][i] == line[1][i]: line_points[i] = np.full(res, line[0][i]) else: line_points[i] = np.linspace(line[0][i], line[1][i], res) xp, yp = self.projection(line_points[0], line_points[1], line_points[2]) self.ax.plot(xp, yp, **kwargs)
[docs] def label_point(self, point, label=None, plot_syms=False, pad_x=0, pad_y=0, **kwargs): """Place a label near a coordinate in the pole plot. Parameters ---------- point : tuple (x, y) coordinate to place text. label : str, optional Text to use in label. pad_x : int, optional Pad added to x coordinate. pad_y : int, optional Pad added to y coordinate. kwargs Other arguments are passed to :func:`matplotlib.axes.Axes.text`. """ labels = [idc_to_string(point, str_type='tex')] if label is None else [label] point_idc = point if self.crystal_sym == 'hexagonal': point = project_to_orth(0.8165, dir=point, in_type='mb') points = [point] if plot_syms: for symm in quat.Quat.sym_eqv(self.crystal_sym)[1:]: point_symm = symm.transform_vector(point) if point_symm[2] < 0: point_symm *= -1 points.append(point_symm) if label is None: labels = map( partial(idc_to_string, str_type='tex'), equavlent_indicies( self.crystal_sym, quat.Quat.sym_eqv(self.crystal_sym), dir=point_idc, c_over_a=0.8165 ) ) else: labels *= len(quat.Quat.sym_eqv(self.crystal_sym)) for point, label in zip(points, labels): xp, yp = self.projection(*point) self.ax.text(xp + pad_x, yp + pad_y, label, **kwargs)
[docs] def add_points(self, alpha_ang, beta_ang, marker_colour=None, marker_size=None, **kwargs): """Add a point to the pole plot. Parameters ---------- alpha_ang Inclination angle to plot. beta_ang Azimuthal angle (around z axis from x in anticlockwise as per ISO) to plot. marker_colour : str or list(str), optional Colour of marker. If two specified, then the point will have two semicircles of different colour. marker_size : float Size of marker. kwargs Other arguments are passed to :func:`matplotlib.axes.Axes.scatter`. Raises ------- Exception If more than two colours are specified """ # project onto equatorial plane xp, yp = self.projection(alpha_ang, beta_ang) # plot poles # plot markers with 'half-and-half' colour if type(marker_colour) is str: marker_colour = [marker_colour] if marker_colour is None: points = self.ax.scatter(xp, yp, **kwargs) self.img_layers.append(points) elif len(marker_colour) == 2: pos = (xp, yp) r1 = 0.5 r2 = r1 + 0.5 marker_size = np.sqrt(marker_size) x = [0] + np.cos(np.linspace(0, 2 * np.pi * r1, 10)).tolist() y = [0] + np.sin(np.linspace(0, 2 * np.pi * r1, 10)).tolist() xy1 = list(zip(x, y)) x = [0] + np.cos(np.linspace(2 * np.pi * r1, 2 * np.pi * r2, 10)).tolist() y = [0] + np.sin(np.linspace(2 * np.pi * r1, 2 * np.pi * r2, 10)).tolist() xy2 = list(zip(x, y)) points = self.ax.scatter( pos[0], pos[1], marker=(xy1, 0), s=marker_size, c=marker_colour[0], **kwargs ) self.img_layers.append(points) points = self.ax.scatter( pos[0], pos[1], marker=(xy2, 0), s=marker_size, c=marker_colour[1], **kwargs ) self.img_layers.append(points) else: raise Exception("specify one colour for solid markers or list two for 'half and half'")
[docs] def add_colour_bar(self, label, layer=0, **kwargs): """Add a colour bar to the pole plot. Parameters ---------- label : str Label to place next to colour bar. layer : int Layer number to add the colour bar to. kwargs Other argument are passed to :func:`matplotlib.pyplot.colorbar`. """ img = self.img_layers[layer] self.colour_bar = plt.colorbar(img, ax=self.ax, label=label, **kwargs)
[docs] def add_legend( self, label='Grain area (μm$^2$)', number=6, layer=0, scaling=1, **kwargs ): """Add a marker size legend to the pole plot. Parameters ---------- label : str Label to place next to legend. number : Number of markers to plot in legend. layer : int Layer number to add the colour bar to. scaling : float Scaling applied to the data. kwargs Other argument are passed to :func:`matplotlib.pyplot.legend`. """ img = self.img_layers[layer] self.legend = plt.legend( *img.legend_elements("sizes", num=number, func=lambda s: s / scaling), title=label, **kwargs )
@staticmethod def _validateProjection(projection_in, validate_default=False): if validate_default: default_projection = None else: default_projection = PolePlot._validateProjection( defaults['pole_projection'], validate_default=True ) if projection_in is None: projection = default_projection elif type(projection_in) is str: projection_name = projection_in.replace(" ", "").lower() if projection_name in ["lambert", "equalarea"]: projection = PolePlot.lambert_project elif projection_name in ["stereographic", "stereo", "equalangle"]: projection = PolePlot.stereo_project else: print("Unknown projection name, using default") projection = default_projection elif callable(projection_in): projection = projection_in else: print("Unknown projection, using default") projection = default_projection if projection is None: raise ValueError("Problem with default projection.") return projection
[docs] @staticmethod def stereo_project(*args): """Stereographic projection of pole direction or pair of polar angles. Parameters ---------- args : numpy.ndarray, len 2 or 3 2 arguments for polar angles or 3 arguments for pole directions. Returns ------- float, float x coordinate, y coordinate Raises ------- Exception If input array has incorrect length """ if len(args) == 3: alpha, beta = quat.Quat.polar_angles(args[0], args[1], args[2]) elif len(args) == 2: alpha, beta = args else: raise Exception("3 arguments for pole directions and 2 for polar angles.") alpha_comp = np.tan(alpha / 2) xp = alpha_comp * np.cos(beta - np.pi/2) yp = alpha_comp * np.sin(beta - np.pi/2) return xp, yp
[docs] @staticmethod def lambert_project(*args): """Lambert Projection of pole direction or pair of polar angles. Parameters ---------- args : numpy.ndarray, len 2 or 3 2 arguments for polar angles or 3 arguments for pole directions. Returns ------- float, float x coordinate, y coordinate Raises ------- Exception If input array has incorrect length """ if len(args) == 3: alpha, beta = quat.Quat.polar_angles(args[0], args[1], args[2]) elif len(args) == 2: alpha, beta = args else: raise Exception("3 arguments for pole directions and 2 for polar angles.") alpha_comp = np.sqrt(2 * (1 - np.cos(alpha))) xp = alpha_comp * np.cos(beta - np.pi/2) yp = alpha_comp * np.sin(beta - np.pi/2) return xp, yp
[docs]class HistPlot(Plot): """ Class for creating a histogram. """ def __init__(self, plot_type="scatter", axes_type="linear", density=True, fig=None, ax=None, ax_params={}, make_interactive=False, **kwargs): """Initialise a histogram plot Parameters ---------- plot_type: str, {'scatter', 'bar', 'step'} Type of plot to use axes_type : str, {'linear', 'logx', 'logy', 'loglog', 'None'}, optional If 'log' is specified, logarithmic scale is used. density : If true, histogram is normalised such that the integral sums to 1. fig : matplotlib.figure.Figure Matplotlib figure to plot on. ax : matplotlib.axes.Axes Matplotlib axis to plot on. ax_params : Passed to defdap.plotting.Plot as ax_params. make_interactive : bool If true, make the plot interactive. kwargs Other arguments are passed to :class:`defdap.plotting.Plot` """ super(HistPlot, self).__init__( ax, ax_params=ax_params, fig=fig, make_interactive=make_interactive, **kwargs ) axes_type = axes_type.lower() if axes_type in ["linear", "logy", "logx", "loglog"]: self.axes_type = axes_type else: raise ValueError("plot_type must be linear or log.") if plot_type in ['scatter', 'bar', 'step']: self.plot_type = plot_type else: raise ValueError("plot_type must be scatter, bar or step.") self.density = bool(density) # set y-axis label yLabel = "Normalised frequency" if self.density else "Frequency" self.ax.set_ylabel(yLabel) # set axes to linear or log as appropriate and set to be numbers as opposed to scientific notation if self.axes_type == 'logx' or self.axes_type == 'loglog': self.ax.set_xscale("log") self.ax.xaxis.set_major_formatter(FuncFormatter(lambda y, _: '{:.5g}'.format(y))) if self.axes_type == 'logy' or self.axes_type == 'loglog': self.ax.set_yscale("log") self.ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: '{:.5g}'.format(y)))
[docs] def add_hist(self, hist_data, bins=100, range=None, line='o', label=None, **kwargs): """Add a histogram to the current plot Parameters ---------- hist_data : numpy.ndarray Data to be used in the histogram. bins : int Number of bins to use for histogram. range : tuple or None, optional The lower and upper range of the bins line : str, optional Marker or line type to be used. label : str, optional Label to use for data (used for legend). kwargs Other arguments are passed to :func:`numpy.histogram` """ # Generate the x bins with appropriate spaceing for linear or log if self.axes_type == 'logx' or self.axes_type == 'loglog': bin_list = np.logspace(np.log10(range[0]), np.log10(range[1]), bins) else: bin_list = np.linspace(range[0], range[1], bins) if self.plot_type == 'scatter': # Generate the histogram data and plot as a scatter plot hist = np.histogram(hist_data.flatten(), bins=bin_list, density=self.density) y_vals = hist[0] x_vals = 0.5 * (hist[1][1:] + hist[1][:-1]) self.ax.plot(x_vals, y_vals, line, label=label, **kwargs) else: # Plot as a matplotlib histogram self.ax.hist(hist_data.flatten(), bins=bin_list, histtype=self.plot_type, density=self.density, label=label, **kwargs)
[docs] def add_legend(self, **kwargs): """Add legend to histogram. Parameters ---------- kwargs All arguments passed to :func:`matplotlib.axes.Axes.legend`. """ self.ax.legend(**kwargs)
[docs] @classmethod def create( cls, hist_data, fig=None, fig_params={}, ax=None, ax_params={}, plot=None, make_interactive=False, plot_type="scatter", axes_type="linear", density=True, bins=10, range=None, line='o', label=None, **kwargs ): """Create a histogram plot. Parameters ---------- hist_data : numpy.ndarray Data to be used in the histogram. fig : matplotlib.figure.Figure Matplotlib figure to plot on. fig_params : Passed to defdap.plotting.Plot. ax : matplotlib.axes.Axes Matplotlib axis to plot on. ax_params : Passed to defdap.plotting.Plot as ax_params. plot : defdap.plotting.HistPlot Plot where histgram is created. If none, a new plot is created. make_interactive : bool, optional If true, make plot interactive. plot_type: str, {'scatter', 'bar', 'barfilled', 'step'} Type of plot to use axes_type : str, {'linear', 'logx', 'logy', 'loglog', 'None'}, optional If 'log' is specified, logarithmic scale is used. density : If true, histogram is normalised such that the integral sums to 1. bins : int Number of bins to use for histogram. range : tuple or None, optional The lower and upper range of the bins line : str, optional Marker or line type to be used. label : str, optional Label to use for data (is used for legend). kwargs Other arguments are passed to :func:`defdap.plotting.HistPlot.add_hist` Returns ------- defdap.plotting.HistPlot """ if plot is None: plot = cls(axes_type=axes_type, plot_type=plot_type, density=density, fig=fig, ax=ax, ax_params=ax_params, make_interactive=make_interactive, **fig_params) plot.add_hist(hist_data, bins=bins, range=range, line=line, label=label, **kwargs) return plot
[docs]class CrystalPlot(Plot): """ Class for creating a 3D plot for plotting unit cells. """ def __init__(self, fig=None, ax=None, ax_params={}, make_interactive=False, **kwargs): """Initialise a 3D plot. Parameters ---------- fig : matplotlib.pyplot.Figure Figure to plot to. ax : matplotlib.pyplot.Axis Axis to plot to. ax_params Passed to defdap.plotting.Plot as ax_params. make_interactive : bool, optional If true, make plot interactive. kwargs Other arguments are passed to :class:`defdap.plotting.Plot`. """ # Set default plot parameters then update with input fig_params = { 'figsize': (6, 6) } fig_params.update(kwargs) ax_params_default = { 'projection': '3d', 'proj_type': 'ortho' } ax_params_default.update(ax_params) ax_params = ax_params_default super(CrystalPlot, self).__init__( ax, ax_params=ax_params, fig=fig, make_interactive=make_interactive, **fig_params )
[docs] def add_verts(self, verts, **kwargs): """Plots planes, defined by the vertices provided. Parameters ---------- verts : list List of vertices. kwargs Other arguments are passed to :class:`matplotlib.collections.PolyCollection`. """ # Set default plot parameters then update with any input plot_params = { 'alpha': 0.6, 'facecolor': '0.8', 'linewidths': 3, 'edgecolor': 'k' } plot_params.update(kwargs) # Add list of planes defined by given vertices to the 3D plot pc = Poly3DCollection(verts, **plot_params) self.ax.add_collection3d(pc)