import os
import sys
from typing import Type
sys.path.append('.')
sys.path.append('..')
sys.path.append('../..')
import warnings
import copy
import matplotlib.cm as cm
import matplotlib.gridspec
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Slider
from MJOLNIR import _tools
from MJOLNIR._interactiveSettings import Viewer3DSettings, States, cut1DHolder, cancel
import functools
pythonVersion = sys.version_info[0]
pythonSubVersion = sys.version_info[1]
[docs]class Viewer3D(object):
[docs] @_tools.KwargChecker(include=[_tools.MPLKwargs])
def __init__(self,Data,bins,axis=2, log=False ,ax = None, grid = False, adjustable=True, outputFunction=print,
cmap=None, CurratAxeBraggList=None, plotCurratAxe=False,Ei=None,EfLimits=None, dataset = None, cut1DFunctionRectangle=None,\
cut1DFunctionCircle = None, cut1DFunctionRectanglePerp=None,cut1DFunctionRectangleHorizontal=None,cut1DFunctionRectangleVertical=None,
backgroundSubtraction=False, **kwargs):#pragma: no cover
"""3 dimensional viewing object generating interactive Matplotlib figure.
Keeps track of all the different plotting functions and variables in order to allow the user to change between different slicing modes and to scroll through the data in an interactive way.
Args:
- Data (3D array): Intensity array in three dimensions. Assumed to have Qx, Qy, and E along the first, second, and third directions respectively.
- bins (List of 1D arrays): Coordinates of the three directions as returned by the BinData3D functionality of DataSet.
Kwargs:
- axis (int): Axis along which the interactive plot slices the data (default 2).
- log (bool): If true, the log 10 of the intensity is plotted (default False).
- ax (matplotlib axis): Matplotlib axis into which one plots data (Default None).
- grid (bool/int): If int, grid will be plotted with zorder=int, if True, grid is plotted at zorder=-10 (Default False).
- adjustable (bool): If set true, 2 sliders will be present allowing to fine tune the c-axis (Default True)
- outputFunction (function): Function called on output string (default print)
- CurratAxeBraggList (list): List of Bragg peaks for which Currat-Axe spurions are to be calculated, [[H1,K1,L1],[H2,K2,L2],..] (default None)
- plotCurratAxe (bool): Flag to determine whether or not to plot the Currat-Axe spurions
- Ei (list): List of incoming energies (default None)
- EfLimits (float,float): Minimal and maximal Ef for current instrument (default None)
- dataset (DataSet): Reference to data set used - needed for interactive cutting with (default None)
- cut1DFunctionRectangle (function): Function to be called when performing an interactive rectangle (default None)
- cut1DFunctionCircle (function): Function to be called when performing an interactive circle (default None)
- backgroundSubtraction (bool): Signify if a subtraction has been performed such that interactive cuts work (default False)
For an example, see the `quick plotting tutorial <../Tutorials/Quick/QuickView3D.html>`_ under scripting tutorials.
"""
if cmap is None:
cmap = 'viridis'
self.Ei = Ei
self.EfLimits = EfLimits
self.outputFunction = outputFunction
self.currentData = None
self.ds = dataset
self.plotCurratAxe = plotCurratAxe # Set to false but will change to True when correct list is created
self._CurratAxeBraggList = None
if len(Data)==4: # If data is provided as I, norm, mon, normcount
with warnings.catch_warnings() as w:
warnings.simplefilter("ignore")
self.Data = np.divide(Data[0]*Data[3],Data[1]*Data[2])
self.Counts,self.Monitor,self.Normalization,self.NormCounts = Data
self.allData = True
elif len(Data) == 2: # From bin3D with no normalization or monitor
with warnings.catch_warnings() as w:
warnings.simplefilter("ignore")
self.Data = np.divide(Data[0],Data[1])
self.allData = False
else:
self.Data = Data
self.allData = False
if log:
self.Data = np.log10(self.Data+1e-20)
self.bins = bins
self.dataLimits = [np.nanmin(Data),np.nanmax(Data)]
gs = matplotlib.gridspec.GridSpec(1, 2, width_ratios=[4, 1])
if not grid == False:
gridArg = grid
if isinstance(gridArg,(bool)):
self.grid = True
self.gridZOrder=-10
elif isinstance(gridArg,(int,float)):
self.grid = True
self.gridZOrder=gridArg
else:
self.grid = False
self.gridZOrder = 0
if ax is None: # TODO: REDO with actual correct axes!
self.figure = plt.figure()
self.ax = plt.subplot(gs[0])#self.figure.add_subplot(111)
self.ax.isActive = lambda : True
self.ax.drawState = States.INACTIVE
self.xlabel = r'Qx [$A^{-1}$]'
self.ylabel = r'Qy [$A^{-1}$]'
self.zlabel = r'E [meV]'
self.rlu = False
self._axes = [self.ax]
else:
if isinstance(ax,plt.Axes): # Assuming only RLU - energy plot is provided
self.axRLU = ax
self.figure = ax.get_figure() # Get the correct figure
self.axNorm,ax2 = self.figure.subplots(1,2,gridspec_kw={'width_ratios':[4, 1]}) # Create figure on top of the other
ax2.remove() # Remove the excess figure
self.axRLU.set_position(self.axNorm.get_position()) # Update RLU to correct position
self._axes = [self.axNorm,self.axNorm,self.axRLU]
self._axes[0].set_xlabel(r'Qx [$A^{-1}$]')
self._axes[0].set_ylabel(r'E [meV]')
self._axes[1].set_xlabel(r'Qy [$A^{-1}$]')
self._axes[1].set_ylabel(r'E [meV]')
self.ax = self.axNorm
self.xlabel = r'Qx [$A^{-1}$]'
self.ylabel = r'Qy [$A^{-1}$]'
self.zlabel = 'E [meV]'
self.rlu = True
elif len(ax)==3: # All axes provided in order QxE,QyE,QxQy
self.axQxE = ax[0]
self.axQyE = ax[1]
self.axRLU = ax[2]
self.figure = self.axQyE.get_figure() # Get the correct figure
self.axNorm,ax2 = self.figure.subplots(1,2,gridspec_kw={'width_ratios':[4, 1]}) # Create figure on top of the other
ax2.remove() # Remove the excess figure
self.axQxE.set_position(self.axNorm.get_position()) # Update axQxE to correct position
self.axQyE.set_position(self.axNorm.get_position()) # Update axQyE to correct position
self.axRLU.set_position(self.axNorm.get_position()) # Update axRLU to correct position
self._axes = [self.axQxE,self.axQyE,self.axRLU]
self.ax = self.axNorm
hkl = ['H','K','L']
xlabelSplit = self.axQyE.get_xlabel().replace(' [RLU]','').split(',')
ylabelSplit = self.axQxE.get_xlabel().replace(' [RLU]','').split(',')
self.xlabel = '\n'.join(['{}: '.format(hkl[i])+'{:+.3f}'.format(float(xlabelSplit[i])) for i in range(len(xlabelSplit))])
self.ylabel = '\n'.join(['{}: '.format(hkl[i])+'{:+.3f}'.format(float(ylabelSplit[i])) for i in range(len(ylabelSplit))])
self.zlabel = 'E [meV]'
self.rlu = True
self.EnergySliderTransform=[self.axQxE._length,self.axQyE._length,1.0] # Factor to divide the Energy slider value with (only applicable for QE axes)
else:
raise AttributeError('Number of provided axes is {} but only 1 or 3 is accepted.'.format(len(ax)))
for ax in self._axes:
ax.backgroundSubtraction = backgroundSubtraction
self.figure.set_size_inches(11,7)
self.value = 0
self.figure.subplots_adjust(bottom=0.25)
self.cmap = cmap # Update to accommodate deprecation warning
self.value = 0
self.setAxis(2)
# Set up interactive generation of 1DCuts
viewAxis = axis
axis_color='white'
self.figure.canvas.mpl_connect('key_press_event',lambda event: onkeypress(event, self) )
self.figure.canvas.mpl_connect('scroll_event',lambda event: onscroll(event, self))
zeroPoint = np.argmin(np.abs(0.5*(self.Z[0,0][1:]+self.Z[0,0][:-1])))
self.Energy_slider_ax = self.figure.add_axes([0.15, 0.1, 0.65, 0.03])
self.Energy_slider = Slider(self.Energy_slider_ax, label=self.label, valmin=self.lowerLim, valmax=self.upperLim, valinit=zeroPoint)
self.Energy_slider.valtext.set_visible(False)
self.Energy_slider.on_changed(lambda val: sliders_on_changed(self,val))
if self.rlu:
self.units = ['','',' meV']
else:
self.units = [' 1/AA',' 1/AA',' meV']
textposition = [self.Energy_slider_ax.get_position().p1[0]+0.005,self.Energy_slider_ax.get_position().p0[1]+0.005]
self.text = self.figure.text(textposition[0], textposition[1],s=self.stringValue())
self.shading = 'flat'
#self.imcbaxes = self.figure.add_axes([0.0, 0.2, 0.2, 0.7])
#self.im = self.ax.imshow(self.masked_array[:,:,self.value].T,cmap=self.cmap,extent=[self.X[0],self.X[-1],self.Y[0],self.Y[-1]],origin='lower')
if self.shading=='flat':
self.ax.grid(False)
self.im = self.ax.pcolormesh(self.X[:,:,0].T,self.Y[:,:,0].T,self.masked_array[:,:,self.value].T,zorder=10,shading=self.shading,cmap=cmap)
elif self.shading=='gouraud': # pragma: no cover
XX = 0.5*(self.X[:-1,:-1,self.value]+self.X[1:,1:,self.value]).T
YY = 0.5*(self.Y[:-1,:-1,self.value]+self.Y[1:,1:,self.value]).T
self.ax.grid(False)
self.im = self.ax.pcolormesh(XX,YY,self.masked_array[:,:,self.value].T,zorder=10,shading=self.shading,cmap=cmap) # ,vmin=1e-6,vmax=6e-6
else:
raise AttributeError('Did not understand shading {}.'.format(self.shading))
if self.grid:
self.ax.grid(self.grid,zorder=self.gridZOrder)
else:
self.ax.grid(self.grid)
self._caxis = self.im.get_clim()
self.figpos = [0.125,0.25,0.63,0.63]#self.ax.get_position()
self.cbaxes = self.figure.add_axes([0.8, 0.2, 0.03, 0.7])
self.colorbar = self.figure.colorbar(self.im,cax = self.cbaxes)
warnings.simplefilter("ignore")
#self.figure.tight_layout(rect=[0,0.1,0.9,0.9])
warnings.simplefilter("once")
self.text.set_text(self.stringValue())
xlim = self.ax.get_xlim()
ylim = self.ax.get_ylim()
#if self.axis == 2:
# self.ax.set_xlim(np.min([xlim[0],ylim[0]]),np.max([xlim[1],ylim[1]]))
dQx = np.diff(self.bins[0][:2,0,0])[0]
dQy = np.diff(self.bins[1][0,:2,0])[0]
ddE = np.diff(self.bins[2][0,0,:2])[0]
self.dQE = [dQx,dQy,ddE]
self.resolution = [0.05,0.05]
self.Energy_slider.set_val(self.value)
if self.rlu:
self.axRLU.onClick = lambda event: eventdecorator(onclick,self,self.axRLU,event,outputFunction=outputFunction)#,extra=' axRLU')
self.axRLU._button_press_event = self.figure.canvas.mpl_connect('button_press_event', self.axRLU.onClick)
def updateLimitsAxRLU(self,lower,upper):
self.EMin = lower
self.EMax = upper
self.axRLU.updateLimits = lambda lower,upper: updateLimitsAxRLU(self.axRLU,lower,upper)
self.axQxE.onClick = lambda event: eventdecorator(onclick,self,self.axQxE,event,outputFunction=outputFunction)#,extra=' axQxE')
self.axQxE._button_press_event = self.figure.canvas.mpl_connect('button_press_event', self.axQxE.onClick)
def updateLimitsAxQ(self,lower,upper):
self.QPoints = [self.sample.calculateHKLToQxQy(*x) for x in np.array([self.calculateRLU(-1,0)[0],self.calculateRLU(1,0)[0]])]
self.axQxE.updateLimits = lambda lower,upper: updateLimitsAxQ(self.axQxE,lower,upper)
self.axQyE.updateLimits = lambda lower,upper: updateLimitsAxQ(self.axQyE,lower,upper)
self.axQyE.onClick = lambda event: eventdecorator(onclick,self,self.axQyE,event,outputFunction=outputFunction)#,extra=' axQyE')
self.axQyE._button_press_event = self.figure.canvas.mpl_connect('button_press_event', self.axQyE.onClick)
self.axRLU.rlu = self.axQxE.rlu = self.axQyE.rlu = True
# Set up initial active axes
self.axRLU.isActive = True
self.axQxE.isActive = False
self.axQyE.isActive = False
if not cut1DFunctionRectangle is None:
self.axRLU.cut1DFunctionRectangle = lambda dr: cut1DFunctionRectangle(viewer=self,dr=dr)
if not cut1DFunctionCircle is None:
self.axRLU.cut1DFunctionCircle = lambda dr: cut1DFunctionCircle(viewer=self,dr=dr)
if not cut1DFunctionRectanglePerp is None:
self.axQxE.cut1DFunctionRectanglePerpendicular = lambda dr: cut1DFunctionRectanglePerp(viewer=self,dr=dr)
self.axQyE.cut1DFunctionRectanglePerpendicular = lambda dr: cut1DFunctionRectanglePerp(viewer=self,dr=dr)
if not cut1DFunctionRectangleHorizontal is None:
self.axQxE.cut1DFunctionRectangleHorizontal = lambda dr: cut1DFunctionRectangleHorizontal(viewer=self,dr=dr)
self.axQyE.cut1DFunctionRectangleHorizontal = lambda dr: cut1DFunctionRectangleHorizontal(viewer=self,dr=dr)
if not cut1DFunctionRectangleVertical is None:
self.axQxE.cut1DFunctionRectangleVertical = lambda dr: cut1DFunctionRectangleVertical(viewer=self,dr=dr)
self.axQyE.cut1DFunctionRectangleVertical = lambda dr: cut1DFunctionRectangleVertical(viewer=self,dr=dr)
else:
self.ax._button_press_event = self.figure.canvas.mpl_connect('key_press_event',lambda event: onkeypress(event, self) )
self.onClick = lambda event: eventdecorator(onclick,self,self.ax,event,outputFunction=outputFunction)
self.button_press_event = self.figure.canvas.mpl_connect('button_press_event', self.onClick)
self.ax.rlu = False
try:
maxVal = np.nanmax(self.masked_array[np.isfinite(self.masked_array)])
except ValueError:
maxVal = 1
self.caxis = [np.nanmin(self.masked_array[np.isfinite(self.masked_array)]),maxVal]
if self.grid:
self.ax.grid(self.grid,zorder=self.gridZOrder)
else:
self.ax.grid(self.grid)
self.setAxis(viewAxis) # Set view plane to correct
## Hack for this to look nice as just changing direction does not render correctly
self.setPlane(1)
self.setPlane(0)
if not self.EfLimits is None:
self.Ef = np.arange(*self.EfLimits,self.dQE[2])
self.CurratAxeBraggList = CurratAxeBraggList
if adjustable and pythonVersion>2 and pythonSubVersion>5:
ax_cmin = plt.axes([0.87, 0.1, 0.05, 0.7])
ax_cmax = plt.axes([0.93, 0.1, 0.05, 0.7])
addColorbarSliders(self,c_min=self.caxis[0],c_max=self.caxis[1],c_minval=self.caxis[0],\
c_maxval=self.caxis[1],ax_cmin=ax_cmin,ax_cmax=ax_cmax,log=False)
@property
def caxis(self):
return self._caxis
@caxis.getter
def caxis(self):
return self._caxis
@caxis.setter
def caxis(self,caxis):
ErrMsg = 'Provided caxis is not of correct format. Expected 2 values but recieved "{}" of type {}'
if not isinstance(caxis,(list,np.ndarray,tuple)):
raise AttributeError(ErrMsg.format(caxis,type(caxis)))
if len(list(caxis))!=2:
raise AttributeError(ErrMsg.format(caxis,type(caxis)))
self._caxis = caxis
self.im.set_clim(caxis)
cmin,cmax = caxis
fig = self.ax.get_figure()
#if not _internal:
if hasattr(fig,'s_cmin'): # The method addColorbarSliders has been called
for s in [fig.s_cmin,fig.s_cmax]:
if s._log:
s.valmin = np.log10(cmin)
s.valmax = np.log10(cmax)
else:
s.valmin = cmin
s.valmax = cmax
s.ax.set_ylim(s.valmin,s.valmax)
fig.s_cmin.set_val(cmin)
fig.s_cmax.set_val(cmax)
self.colorbar.update_normal(self.im)
def set_clim(self,cmin,cmax=None):
if cmax is None:
self.caxis = cmin
else:
self.caxis = (cmin,cmax)
def setAxis(self,axis):
if hasattr(self,'im'): # this function is also called before any plot has been performed
self.im.set_array(self.emptyData) # Set data in current im to a fully masked data set
if axis==2:
if self.rlu:
if hasattr(self.ax,'_button_press_event'):
self.ax.get_figure().canvas.mpl_disconnect(self.ax._button_press_event)
self.ax.isActive = False
self.figure.delaxes(self.ax)
self.ax = self.figure.add_axes(self._axes[axis])
self.ax.isActive = True
if hasattr(self.ax,'onClick'):
self.ax._button_press_event = self.ax.get_figure().canvas.mpl_connect('button_press_event', self.ax.onClick)
else:
self.ax.set_xlabel(self.xlabel)
self.ax.set_ylabel(self.ylabel)
axes = (0,1,2)
label = self.zlabel#self.ax.get_ylabel
elif axis==1: # pragma: no cover
if self.rlu:
if hasattr(self.ax,'_button_press_event'):
self.ax.get_figure().canvas.mpl_disconnect(self.ax._button_press_event)
self.ax.isActive = False
self.figure.delaxes(self.ax)
self.ax = self.figure.add_axes(self._axes[axis])
self.ax.isActive = True
if hasattr(self.ax,'onClick'):
self.ax._button_press_event = self.ax.get_figure().canvas.mpl_connect('button_press_event', self.ax.onClick)
else:
self.ax.set_xlabel(self.xlabel)
self.ax.set_ylabel(self.zlabel)
axes = (0,2,1)
label = self.ylabel#self.ax.get_ylabel
elif axis==0: # pragma: no cover
if self.rlu:
if hasattr(self.ax,'_button_press_event'):
self.ax.get_figure().canvas.mpl_disconnect(self.ax._button_press_event)
self.ax.isActive = False
self.figure.delaxes(self.ax)
self.ax = self.figure.add_axes(self._axes[axis])
self.ax.isActive = True
if hasattr(self.ax,'onClick'):
self.ax._button_press_event = self.ax.get_figure().canvas.mpl_connect('button_press_event', self.ax.onClick)
else:
self.ax.set_xlabel(self.ylabel)
self.ax.set_ylabel(self.zlabel)
axes = (1,2,0)
label = self.xlabel#self.ax.get_xlabel()
else:
raise AttributeError('Axis provided not recognized. Should be 0, 1, or 2 but got {}'.format(axis))
#self.ax.format_coord = self._axes[axis].format_coord
if hasattr(self.ax,'_step'):
self.ax._step=self.calculateValue()
X=self.bins[axes[0]].transpose(axes)
Y=self.bins[axes[1]].transpose(axes)
Z=self.bins[axes[2]].transpose(axes)
masked_array = np.ma.array(self.Data, mask=np.isnan(self.Data)).transpose(axes)
self.emptyData = masked_array[:,:,0].T.flatten().copy()
self.emptyData.mask = np.ones_like(self.emptyData,dtype=bool)
self._axesChanged = True
upperLim = self.Data.shape[axis]-1
self.label = label
self.X = X
self.Y = Y
self.Z = Z
self.masked_array = masked_array
self.axes = axes
self.upperLim = upperLim
self.lowerLim = 0
self.axis = axis
def calculateValue(self,index=None):
if index is None:
index = int(self.value)
try:
val = 0.5*(self.Z[0,0,index+1]+self.Z[0,0,index])
except:
val = 0.5*(2*self.Z[0,0,index]-self.Z[0,0,index-1])
if hasattr(self,'EnergySliderTransform'):
val/=self.EnergySliderTransform[self.axis]
return val
def stringValue(self):
unit = self.units[self.axis]
val = self.calculateValue()
return str(np.round(val,2))+unit
[docs] def setProjection(self,value):
"""Change projection between Qx,Qy, and E, or along principal, orthogonal Q direction, or E if plotting in RLU."""
self.figure.canvas.key_press_event(str(value))
[docs] def setPlane(self,value):
"""Change plotting plane to new along same axis"""
self.Energy_slider.set_val(value)
@property
def CurratAxeBraggList(self):
return self._CurratAxeBraggList
@CurratAxeBraggList.setter
def CurratAxeBraggList(self,peaks):
self._CurratAxeBraggList = peaks
if peaks is None:
self.plotCurratAxe = False
elif np.array(self.plotCurratAxe).size ==0:
self.plotCurratAxe = False
elif self.rlu :
self.monoPoints = self.axRLU.sample.CurratAxe(Ei=self.Ei,Ef=self.Ef,Bragg=self.CurratAxeBraggList,spurionType='Monochromator',HKL=False)[:,:,:,:2]
self.anaPoints = self.axRLU.sample.CurratAxe(Ei=self.Ei,Ef=self.Ef,Bragg=self.CurratAxeBraggList,spurionType='Analyser',HKL=False)[:,:,:,:2]
self.dE = np.repeat((self.Ei[:,np.newaxis]-self.Ef[np.newaxis])[np.newaxis],len(self.CurratAxeBraggList),axis=0)
dEIndex = np.round((self.dE-self.bins[2][0,0,0])/self.dQE[2])
dQxIndexMono = np.round((self.monoPoints[:,:,:,0]-self.bins[0][0,0,0])/self.dQE[0])
dQyIndexMono = np.round((self.monoPoints[:,:,:,1]-self.bins[1][0,0,0])/self.dQE[1])
dQxIndexAna = np.round((self.anaPoints[:,:,:,0]-self.bins[0][0,0,0])/self.dQE[0])
dQyIndexAna = np.round((self.anaPoints[:,:,:,1]-self.bins[1][0,0,0])/self.dQE[1])
self.CurratAxeIndicesMono = np.array([dQxIndexMono,dQyIndexMono,dEIndex])
self.CurratAxeIndicesAna = np.array([dQxIndexAna,dQyIndexAna,dEIndex])
try:
self.plot()
self.plotCurratAxe = True
except:
pass
else:
self.plotCurratAxe = False
@CurratAxeBraggList.getter
def CurratAxeBraggList(self):
return self._CurratAxeBraggList
@property
def plotCurratAxe(self):
return self._plotCurratAxe
@plotCurratAxe.setter
def plotCurratAxe(self,value):
self._plotCurratAxe = value
if hasattr(self,'_axes'):
# for ax in self._axes:
# for plotter in ['BraggScatterMono','BraggScatterAna']:
# if hasattr(ax,plotter):
# try:
# getattr(ax,plotter).remove()
# except ValueError:
# pass
if self._plotCurratAxe:
self.plot()
def plot(self):
self.text.set_text(self.stringValue())
try:
#self.im.set_array(self.emptyData)
pass
except TypeError:
pass
if self._axesChanged:
tempData = np.ma.array(self.im.get_array().T)
tempData.mask = np.ones_like(tempData,dtype=bool)
self.im.set_array(tempData)
self._axesChanged = False
else:
self.im.set_array(self.masked_array[:,:,int(self.value)].T.flatten())
if not self.CurratAxeBraggList is None and not self.Ei is None and self.plotCurratAxe is True and self.rlu is True:
insidePointsMono = self.monoPoints[np.isclose(self.CurratAxeIndicesMono[self.axis],self.Energy_slider.val)]
insidePointsAna = self.anaPoints[np.isclose(self.CurratAxeIndicesAna[self.axis],self.Energy_slider.val)]
if not hasattr(self.ax,'BraggScatterMono'):
self.ax.BraggScatterMono = self.ax.plot([],[],zorder=20, linestyle="", marker="o", mfc='none', color='r',label='Currat-Axe Monochromator')[0]
if not hasattr(self.ax,'BraggScatterAna'):
self.ax.BraggScatterAna = self.ax.plot([],[],zorder=20, linestyle="", marker="o", mfc='none', color='k',label='Currat-Axe Analyzer')[0]
self.ax.BraggScatterMono.set_data([],[])
self.ax.BraggScatterAna.set_data([],[])
if len(insidePointsMono) > 0: # Only if there are any points to plot, do it
insidePointsMono = insidePointsMono.T
if self.axis == 2:
insidePointsAna = insidePointsAna.T
self.ax.BraggScatterMono.set_data(*insidePointsMono)
self.ax.BraggScatterAna.set_data(*insidePointsAna)
else:
energy = self.dE[np.isclose(self.CurratAxeIndicesMono[self.axis],self.Energy_slider.val)]
plotPointsMoni = np.array([insidePointsMono[1-self.axis],energy])
self.ax.BraggScatterMono.set_data(*plotPointsMoni)
if len(insidePointsAna) > 0 and self.axis != 2: # Only if there are any points to plot, do it
insidePointsAna = insidePointsAna.T
energy = self.dE[np.isclose(self.CurratAxeIndicesAna[self.axis],self.Energy_slider.val)]
plotPointsAna = np.array([insidePointsAna[1-self.axis],energy])
self.ax.BraggScatterAna.set_data(*plotPointsAna)
#self.ax.
self.im.set_clim(self.caxis)
self.ax.set_position(self.figpos)
xlim = self.ax.get_xlim()
ylim = self.ax.get_ylim()
if self.axis == 2:
pass
if self.grid:
self.ax.grid(self.grid,zorder=self.gridZOrder)
else:
self.ax.grid(self.grid)
def set_title(self,title):
self.ax.set_title(title)
[docs] def saveToFile(self,folder,extension,gui=None):
"""Save all planes to files
Args:
- folder (path): Folder in which to save files
- extension (str): name of wanted extension. Must be accepted by matplotlib
Kwargs:
- gui (PyQt5 MainApplication): Used for MJOLNIR Gui
"""
XLength,YLength,ZLength = self.Data.shape
startProjection = self.axis
startPosition = self.Energy_slider.val
totalFigures = XLength+YLength+ZLength
if not gui is None:
gui.setProgressBarMaximum(totalFigures)
progress = 0
gui.setProgressBarValue(progress)
self.setProjection(2)
for zidx in range(ZLength):
self.setPlane(zidx)
Energy = self.bins[2][0,0,zidx]
self.figure.savefig(os.path.join(folder,'E{}.{}'.format('{:.2f}'.format(Energy).replace('.','p').replace('-','m'),extension)))
if not gui is None:
progress += 1
gui.setProgressBarValue(progress)
self.setProjection(0)
for xidx in range(XLength):
self.setPlane(xidx)
position = self.bins[0][xidx,0,0]
self.figure.savefig(os.path.join(folder,'X{}.{}'.format('{:.2f}'.format(position).replace('.','p').replace('-','m'),extension)))
if not gui is None:
progress += 1
gui.setProgressBarValue(progress)
self.setProjection(1)
for yidx in range(YLength):
self.setPlane(yidx)
position = self.bins[1][0,yidx,0]
self.figure.savefig(os.path.join(folder,'Y{}.{}'.format('{:.2f}'.format(position).replace('.','p').replace('-','m'),extension)))
if not gui is None:
progress += 1
gui.setProgressBarValue(progress)
self.setProjection(startProjection)
self.setPlane(startPosition)
def eventdecorator(function,self,ax,event,*args,**kwargs):# pragma: no cover
if event.xdata is not None and ax.in_axes(event) and ax.isActive:
try:
C = ax.get_figure().canvas.cursor().shape() # Only works for pyQt5 backend
except:
pass
else:
if C != 0 or ax.drawState != States.INACTIVE:
return
return function(self,event.xdata,event.ydata,*args,**kwargs)
def onclick(self,x,y,returnText=False, outputFunction=print,extra=None): # pragma: no cover
idz = self.value
axis = self.axis
XX,YY = self.X[:,:,idz],self.Y[:,:,idz]
XX = 0.25*(XX[:-1,:-1]+XX[1:,:-1]+XX[:-1,1:]+XX[1:,1:])
YY = 0.25*(YY[:-1,:-1]+YY[1:,:-1]+YY[:-1,1:]+YY[1:,1:])
idx = np.unravel_index(np.argmin(np.abs(XX-x)),XX.shape)[0]
idy = np.unravel_index(np.argmin(np.abs(YY-y)),YY.shape)[1]
I = self.masked_array[idx,idy,idz]
masked = np.ma.is_masked(I)
printString = ''
printString+=self.ax.format_coord(x, y)+', '
if masked:
I = np.nan
printString+='I = {:.4E}'.format(I)
if self.allData is True and not masked:
if self.axis == 0:
flipper = [2,0,1]
elif self.axis == 1:
flipper = [0,2,1]
else:
flipper = [0,1,2]
ID = np.array([idx,idy,idz])[flipper]
cts = self.Counts[ID[0],ID[1],ID[2]]
Norm = self.Normalization[ID[0],ID[1],ID[2]]
Mon = self.Monitor[ID[0],ID[1],ID[2]]
NC = self.NormCounts[ID[0],ID[1],ID[2]]
printString+=', Cts = {:}, Norm = {:.3f}, Mon = {:d}, NormCount = {:d}'.format(cts,Norm,int(Mon),int(NC))
if not extra is None:
printString+=extra
if returnText:
return printString
else:
if hasattr(self.ax,'suppressPrint'):
if not self.ax.suppressPrint:
outputFunction(printString)
else:
outputFunction(printString)
def onkeypress(event,self): # pragma: no cover
if event.key in Viewer3DSettings['upwards']:
increaseAxis(event,self)
elif event.key in Viewer3DSettings['downwards']:
decreaseAxis(event,self)
elif event.key in Viewer3DSettings['home']:
self.Energy_slider.set_val(self.Energy_slider.valmin)
elif event.key in Viewer3DSettings['end']:
self.Energy_slider.set_val(self.Energy_slider.valmax)
elif event.key in Viewer3DSettings['QxE']:
if self.axis!=0:
reloadslider(self,0)
del self.im
#self.currentData = None
self.ax.grid(False)
if self.shading=='flat':
self.im = self.ax.pcolormesh(self.X[:,:,0].T,self.Y[:,:,0].T,self.masked_array[:,:,self.value].T,zorder=10,shading=self.shading,cmap=self.cmap)
elif self.shading=='gouraud':
self.im = self.ax.pcolormesh(0.5*(self.X[:-1,:-1,0]+self.X[1:,1:,0]).T,0.5*(self.Y[:-1,:-1,0]+self.Y[1:,1:,0]).T,self.masked_array[:,:,self.value].T,zorder=10,shading=self.shading,cmap=self.cmap) # ,vmin=1e-6,vmax=6e-6
else:
raise AttributeError('Did not understand shading {}.'.format(self.shading))
if self.grid:
self.ax.grid(self.grid,zorder=self.gridZOrder)
else:
self.ax.grid(self.grid)
self.im.set_clim(self.caxis)
self.Energy_slider.set_val(0)
self.plot()
self.ax.set_xlim([np.min(self.X),np.max(self.X)])
self.ax.set_ylim([np.min(self.Y),np.max(self.Y)])
elif event.key in Viewer3DSettings['QyE']:
if self.axis!=1:
reloadslider(self,1)
del self.im
#self.currentData = None
self.ax.grid(False)
if self.shading=='flat':
self.im = self.ax.pcolormesh(self.X[:,:,0].T,self.Y[:,:,0].T,self.masked_array[:,:,self.value].T,zorder=10,shading=self.shading,cmap=self.cmap)
elif self.shading=='gouraud':
self.im = self.ax.pcolormesh(0.5*(self.X[:-1,:-1]+self.X[1:,:1:]).T,0.5*(self.Y[:-1,-1]+self.Y[1:,1:]).T,self.masked_array[:,:,self.value].T,zorder=10,shading=self.shading,cmap=self.cmap) # ,vmin=1e-6,vmax=6e-6
else:
raise AttributeError('Did not understand shading {}.'.format(self.shading))
if self.grid:
self.ax.grid(self.grid,zorder=self.gridZOrder)
else:
self.ax.grid(self.grid)
self.im.set_clim(self.caxis)
self.Energy_slider.set_val(0)
self.plot()
self.ax.set_xlim([np.min(self.X),np.max(self.X)])
self.ax.set_ylim([np.min(self.Y),np.max(self.Y)])
elif event.key in Viewer3DSettings['QxQy']:
if self.axis!=2:
reloadslider(self,2)
del self.im
#self.currentData = None
self.ax.grid(False)
if self.shading=='flat':
self.im = self.ax.pcolormesh(self.X[:,:,0].T,self.Y[:,:,0].T,self.masked_array[:,:,self.value].T,zorder=10,shading=self.shading,cmap=self.cmap)
elif self.shading=='gouraud':
XX = 0.5*(self.X[:-1,:-1,self.value]+self.X[1:,1:,self.value]).T
YY = 0.5*(self.Y[:-1,:-1,self.value]+self.Y[1:,1:,self.value]).T
self.im = self.ax.pcolormesh(XX,YY,self.masked_array[:,:,self.value].T,zorder=10,shading=self.shading,cmap=self.cmap) # ,vmin=1e-6,vmax=6e-6
else:
raise AttributeError('Did not understand shading {}.'.format(self.shading))
if self.grid:
self.ax.grid(self.grid,zorder=self.gridZOrder)
else:
self.ax.grid(self.grid)
self.im.set_clim(self.caxis)
self.Energy_slider.set_val(0)
self.plot()
self.ax.set_xlim([np.min(self.X),np.max(self.X)])
self.ax.set_ylim([np.min(self.Y),np.max(self.Y)])
def reloadslider(self,axis):
try:
cancel(self.ax,axis)
except:
pass#print('Nope')
self.setAxis(axis)
self.Energy_slider.set_val(0)
self.Energy_slider.label.remove()
self.Energy_slider.disconnect_events()#(self.Energy_slider.cids[0])
self.Energy_slider.vline.set_visible(False)
del self.Energy_slider
zeroPoint = np.argmin(np.abs(0.5*(self.Z[0,0][1:]+self.Z[0,0][:-1])))
self.Energy_slider = Slider(self.Energy_slider_ax, label=self.label, valmin=self.lowerLim, valmax=self.upperLim, valinit=zeroPoint)
self.Energy_slider.valtext.set_visible(False)
self.Energy_slider.on_changed(lambda val: sliders_on_changed(self,val))
self.value=0
self.im.remove()
def onscroll(event,self): # pragma: no cover
if(event.button in Viewer3DSettings['upwardsScroll']):
increaseAxis(event,self)
elif event.button in Viewer3DSettings['downwardsScroll']:
decreaseAxis(event,self)
def increaseAxis(event,self): # pragma: no cover
self.Energy_slider.set_val(self.Energy_slider.val+1)
def decreaseAxis(event,self): # pragma: no cover
self.Energy_slider.set_val(self.Energy_slider.val-1)
def sliders_on_changed(self,val): # pragma: no cover
value = int(np.round(val))
if value>self.Energy_slider.valmax:
self.Energy_slider.set_val(self.Energy_slider.valmax)
elif value<self.Energy_slider.valmin:
self.Energy_slider.set_val(self.Energy_slider.valmin)
if value<=self.Energy_slider.valmax and value>=self.Energy_slider.valmin:
if hasattr(self.ax,'updateLimits'):
self.ax.updateLimits(self.Z[0,0,value],self.Z[0,0,value+1])
if value!=val:
self.value = val
self.Energy_slider.set_val(value)
self.plot()
else:
self.value = val
#self.Energy_slider.set_val(value)
self.plot()
if hasattr(self.ax,'_step'):
self.ax._step=self.calculateValue()
def addColorbarSliders(self,c_min,c_max,c_minval,c_maxval,ax_cmin,ax_cmax,log=True):
"""Add two colorbars controling the colour axis
args:
self (Viewer3D object): Current object
c_min (float): Minimal color value
c_max (float): Maximal color value
c_minval (float): Starting value of lower bound
c_maxval (float): Starting value of upper bound
ax_cmin (mpl axes): Axis in which lower bound slider is to be shown
ax_cmax (mpl axes): Axis in which upper bound slider is to be shown
Kwargs:
log (bool): If true, sliders are logarithmic
"""
fig = self.ax.get_figure()
if log==True:
fig.s_cmin = Slider(ax_cmin, 'min', np.log10(c_min+1e-20), np.log10(c_max), valinit=np.log10(c_min+1e-20),orientation='vertical',valfmt='%2.1f')
fig.s_cmax = Slider(ax_cmax, 'max', np.log10(c_min+1e-20), np.log10(c_max), valinit=np.log10(c_max),orientation='vertical',valfmt='%2.1f')
else:
fig.s_cmin = Slider(ax_cmin, 'min', c_min, c_max, valinit=c_min,orientation='vertical',valfmt='%2.1e')
fig.s_cmax = Slider(ax_cmax, 'max', c_min, c_max, valinit=c_max,orientation='vertical',valfmt='%2.1e')
fig.s_cmin._log = log
fig.s_cmax._log = log
for s in [fig.s_cmin,fig.s_cmax]:
s.label.set_fontsize(12)
s.valtext.set_fontsize(12)
def update(fig,val, bar=None, s=None,log=log):
_cmin =fig.s_cmin.val# np.log10(fig.s_cmin.val)
_cmax =fig.s_cmax.val# np.log10(fig.s_cmax.val)
if _cmin>_cmax:
if bar == 'min': # If lower bar is change push max upwards
_cmax = _cmin
if log:
fig.s_cmax.set_val(np.power(10,_cmax))
else:
fig.s_cmax.set_val(_cmax)
else:
_cmin = _cmax
if log:
fig.s_cmin.set_val(np.power(10,_cmin))
else:
fig.s_cmin.set_val(_cmin)
if log:
#self.caxis = (np.power(10,_cmin), np.power(10,_cmax))
self.im.set_clim([np.power(10,_cmin), np.power(10,_cmax)])
fig.s_cmax.valtext.set_text(_cmin)
fig.s_cmax.valtext.set_text(_cmax)
self._caxis = (np.power(10,_cmin), np.power(10,_cmax))
else:
#self.caxis = (_cmin,_cmax)
self.im.set_clim([_cmin,_cmax])
self._caxis = (_cmin,_cmax)
plt.draw()
self.colorbar.update_normal(self.im)
fig.s_cmin.on_changed(lambda val,*arg,**kwargs: update(fig,val,*arg,bar='min',**kwargs))
fig.s_cmax.on_changed(lambda val,*arg,**kwargs: update(fig,val,*arg,bar='max',**kwargs))
fig._savefig = fig.savefig
@functools.wraps(fig.savefig)
def savefig(fname,hide = [ax_cmin,ax_cmax],fig=fig,**kwargs):
unhide = []
for obj in hide:
if obj.get_visible():
obj.set_visible(False)
unhide.append(obj)
fig._savefig(fname,**kwargs)
for obj in unhide:
obj.set_visible(True)
fig.canvas.draw()
fig.savefig = savefig
def cut1DFunctionDefault(self,dr):# pragma: no cover
global cut1DHolder
parameters = extractCut1DProperties(dr.rect,self.ax.sample)
step = self.dQE[self.axis]
pos = np.zeros(3,dtype=int)
pos[self.axis]=self.Energy_slider.val
EMin = self.bins[self.axis][pos[0],pos[1],pos[2]]
EMax = EMin+step
cut1DHolder.append([self.ds.plotCut1D(**parameters,Emin=EMin,Emax=EMax)])