Source code for pymor.discretizers.builtin.gui.matplotlib

# This file is part of the pyMOR project (http://www.pymor.org).
# Copyright 2013-2020 pyMOR developers and contributors. All rights reserved.
# License: BSD 2-Clause License (http://opensource.org/licenses/BSD-2-Clause)

""" This module provides widgets for displaying plots of
scalar data assigned to one- and two-dimensional grids using
:mod:`matplotlib`. These widgets are not intended to be used directly.
"""

import numpy as np
from IPython.core.display import display, HTML
from matplotlib import animation, pyplot
from pymor.core.base import abstractmethod

from pymor.core.config import config
from pymor.discretizers.builtin.grids.constructions import flatten_grid
from pymor.discretizers.builtin.grids.referenceelements import triangle, square


[docs]class MatplotlibAxesBase: def __init__(self, figure, sync_timer, grid, U=None, vmin=None, vmax=None, codim=2, separate_axes=False, columns=2): self.vmin = vmin self.vmax = vmax self.codim = codim self.grid = grid if separate_axes: if len(U) == 1: columns = 1 # otherwise we get a sep axes object with 0 data rows = int(np.ceil(len(U) / columns)) self.ax = figure.subplots(rows, columns, squeeze=False).flatten() else: self.ax = (figure.gca(),) self.figure = figure self.codim = codim self.grid = grid self.separate_axes = separate_axes self.count = len(U) if separate_axes or isinstance(U, tuple) else 1 self._plot_init() # assignment delayed to ensure _plot_init works w/o data self.U = U # Rest is only needed with animation if not separate_axes and self.count == 1: assert len(self.ax) == 1 delay_between_frames = 200 # ms self.anim = animation.FuncAnimation(figure, self.animate, frames=U, interval=delay_between_frames, blit=True, event_source=sync_timer) # generating the HTML instance outside this class causes the plot display to fail self.html = HTML(self.anim.to_jshtml()) else: self.set(self.U) @abstractmethod def _plot_init(self): """Setup MPL figure display with empty data.""" pass
[docs] @abstractmethod def set(self, U): """Load new data into existing plot objects.""" pass
[docs] @abstractmethod def animate(self, u): """Load new data into existing plot objects.""" pass
[docs]class MatplotlibPatchAxes(MatplotlibAxesBase): def __init__(self, figure, grid, bounding_box=None, U=None, vmin=None, vmax=None, codim=2, columns=2, colorbar=True, sync_timer=None): assert grid.reference_element in (triangle, square) assert grid.dim == 2 assert codim in (0, 2) subentities, coordinates, entity_map = flatten_grid(grid) self.subentities = subentities if grid.reference_element is triangle \ else np.vstack((subentities[:, 0:3], subentities[:, [2, 3, 0]])) self.coordinates = coordinates self.entity_map = entity_map self.reference_element = grid.reference_element self.colorbar = colorbar self.animate = self.set super().__init__(U=U, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim, columns=columns, sync_timer=sync_timer) def _plot_init(self): if self.codim == 2: self.p = self.ax[0].tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities, np.zeros(len(self.coordinates)), vmin=self.vmin, vmax=self.vmax, shading='gouraud') else: self.p = self.ax[0].tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities, facecolors=np.zeros(len(self.subentities)), vmin=self.vmin, vmax=self.vmax, shading='flat') if self.colorbar: self.figure.colorbar(self.p, ax=self.ax[0])
[docs] def set(self, U, vmin=None, vmax=None): self.vmin = self.vmin if vmin is None else vmin self.vmax = self.vmax if vmax is None else vmax if self.codim == 2: self.p.set_array(U) elif self.reference_element is triangle: self.p.set_array(U) else: self.p.set_array(np.tile(U, 2)) self.p.set_clim(self.vmin, self.vmax) return (self.p,)
[docs]class Matplotlib1DAxes(MatplotlibAxesBase): def __init__(self, U, figure, grid, vmin=None, vmax=None, codim=1, separate_axes=False, sync_timer=None, columns=2): assert isinstance(grid, OnedGrid) assert codim in (0, 1) super().__init__(U=U, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim, columns=columns, sync_timer=sync_timer, separate_axes=separate_axes) def _plot_init(self): centers = self.grid.centers(1) if self.grid.identify_left_right: centers = np.concatenate((centers, [[self.grid.domain[1]]]), axis=0) self.periodic = True else: self.periodic = False if self.codim == 1: xs = centers else: xs = np.repeat(centers, 2)[1:-1] if self.separate_axes: self.lines = [ax.plot(xs, np.zeros_like(xs))[0] for ax in self.ax] else: self.lines = [self.ax[0].plot(xs, np.zeros_like(xs))[0] for _ in range(self.count)] pad = (self.vmax - self.vmin) * 0.1 for ax in self.ax: ax.set_ylim(self.vmin - pad, self.vmax + pad) def _set(self, u, i): if self.codim == 1: if self.periodic: self.lines[i].set_ydata(np.concatenate((u, [self.U[0]]))) else: self.lines[i].set_ydata(u) else: self.lines[i].set_ydata(np.repeat(u, 2))
[docs] def animate(self, u): for i in range(len(self.ax)): self._set(u, i) return self.lines
[docs] def set(self, U, vmin=None, vmax=None): self.vmin = self.vmin if vmin is None else vmin self.vmax = self.vmax if vmax is None else vmax if isinstance(U, tuple): for i, u in enumerate(U): self._set(u, i) else: for i, (u, _) in enumerate(zip(U, self.ax)): self._set(u, i) pad = (self.vmax - self.vmin) * 0.1 for ax in self.ax: ax.set_ylim(self.vmin - pad, self.vmax + pad)
if config.HAVE_QT and config.HAVE_MATPLOTLIB: from Qt.QtWidgets import QSizePolicy import Qt if Qt.__qt_version__[0] == '4': from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas elif Qt.__qt_version__[0] == '5': from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas else: raise NotImplementedError from matplotlib.figure import Figure from pymor.discretizers.builtin.grids.oned import OnedGrid # noinspection PyShadowingNames class Matplotlib1DWidget(FigureCanvas): def __init__(self, U, parent, grid, count, vmin=None, vmax=None, legend=None, codim=1, separate_plots=False, dpi=100): assert isinstance(grid, OnedGrid) assert codim in (0, 1) figure = Figure(dpi=dpi) if not separate_plots: axes = figure.gca() self.codim = codim lines = () centers = grid.centers(1) if grid.identify_left_right: centers = np.concatenate((centers, [[grid.domain[1]]]), axis=0) self.periodic = True else: self.periodic = False if codim == 1: xs = centers else: xs = np.repeat(centers, 2)[1:-1] for i in range(count): if separate_plots: figure.add_subplot(int(count / 2) + count % 2, 2, i + 1) axes = figure.gca() pad = (vmax[i] - vmin[i]) * 0.1 axes.set_ylim(vmin[i] - pad, vmax[i] + pad) l, = axes.plot(xs, np.zeros_like(xs)) lines = lines + (l,) if legend and separate_plots: axes.legend([legend[i]]) if not separate_plots: pad = (max(vmax) - min(vmin)) * 0.1 axes.set_ylim(min(vmin) - pad, max(vmax) + pad) if legend: axes.legend(legend) self.lines = lines super().__init__(figure) self.setParent(parent) self.setMinimumSize(300, 300) self.setSizePolicy(QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)) def set(self, U, ind): for l, u in zip(self.lines, U): if self.codim == 1: if self.periodic: l.set_ydata(np.concatenate((u[ind], [u[ind][0]]))) else: l.set_ydata(u[ind]) else: l.set_ydata(np.repeat(u[ind], 2)) self.draw() class MatplotlibPatchWidget(FigureCanvas): def __init__(self, parent, grid, bounding_box=None, vmin=None, vmax=None, codim=2, dpi=100): assert grid.reference_element in (triangle, square) assert grid.dim == 2 assert codim in (0, 2) self.figure = Figure(dpi=dpi) super().__init__(self.figure) self.setParent(parent) self.setMinimumSize(300, 300) self.setSizePolicy(QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)) self.patch_axes = MatplotlibPatchAxes(figure=self.figure, grid=grid, bounding_box=bounding_box, vmin=vmin, vmax=vmax, codim=codim) def set(self, U, vmin=None, vmax=None): self.patch_axes.set(U, vmin, vmax) self.draw() else:
[docs] class Matplotlib1DWidget: pass
[docs] class MatplotlibPatchWidget: pass