diff --git a/ampyc/plotting/plot_u.py b/ampyc/plotting/plot_u.py index 2af695b..63b0f7e 100644 --- a/ampyc/plotting/plot_u.py +++ b/ampyc/plotting/plot_u.py @@ -12,6 +12,7 @@ from ampyc.typing import Params from ampyc.utils import Polytope +from .plot_x import plot_constraints def plot_u(fig_number: int, u: np.ndarray, @@ -49,8 +50,7 @@ def plot_u(fig_number: int, ax.plot(u, color=params.color, alpha=params.alpha, linewidth=params.linewidth, label=label) if U is not None: - ax.axline((-1, U.vertices.max()), slope=0, color='k', linewidth=2) - ax.axline((-1, U.vertices.min()), slope=0, color='k', linewidth=2) + plot_constraints(fig, U) ax.set_xlabel('time') ax.set_ylabel(axes_labels[0]) ax.set_xlim([0, num_steps]) diff --git a/ampyc/plotting/plot_x.py b/ampyc/plotting/plot_x.py index d5e60f9..27625e0 100644 --- a/ampyc/plotting/plot_x.py +++ b/ampyc/plotting/plot_x.py @@ -9,10 +9,24 @@ import numpy as np import matplotlib.pyplot as plt +from matplotlib.figure import Figure from ampyc.typing import Params from ampyc.utils import Polytope + +def plot_constraints( + fig: Figure, + X: Polytope, +): + mins, maxes = X.bounding_box + for i, ax in enumerate(fig.axes): + if np.isfinite(mins[i, 0]): + ax.axline((-1, mins[i, 0]), slope=0, color='k', linewidth=2) + if np.isfinite(maxes[i, 0]): + ax.axline((-1, maxes[i, 0]), slope=0, color='k', linewidth=2) + + def plot_x_state_time( fig_number: int, x: np.ndarray, @@ -21,15 +35,15 @@ def plot_x_state_time( label: str | None = None, legend_loc: str = 'upper right', title: str | None = None, - axes_labels: list[str] = ['x_1', 'x_2'], + axes_labels: list[str] | None = None, ) -> None: ''' Plots the state trajectory x over time, including the state constraint set X. - This function assumes a 2D state space (n=2) and plots the two state variables against time. + This function plots the state variables against time. Args: fig_number (int): The figure number to use for the plot. This allows multiple plots in the same figure. - x (np.ndarray): The state trajectory, shape (N, n=2, T), where N is the number of time steps, + x (np.ndarray): The state trajectory, shape (N, n, T), where N is the number of time steps, n is the state dimension, and T is the number of trajectories. X (Polytope | None): The state constraint set. params (Params): Parameters for plotting, e.g., color, alpha, and linewidth. @@ -39,23 +53,34 @@ def plot_x_state_time( axes_labels (list[str]): Labels for the x and y axes. ''' + n = x.shape[1] + if axes_labels is None: + axes_labels = [f'x_{{{i+1}}}' for i in range(n)] + # check if the figure number is already open if plt.fignum_exists(fig_number): fig = plt.figure(fig_number) - ax1 = fig.axes[0] - ax2 = fig.axes[1] + axes = fig.axes else: - fig, (ax1, ax2) = plt.subplots(2,1, num=fig_number, sharex=True) + fig, axes = plt.subplots(n,1, num=fig_number, sharex=True) num_steps = x.shape[0] - ax1.plot(x[:,0], color=params.color, alpha=params.alpha, linewidth=params.linewidth, label=label) + for i, ax in enumerate(axes): + kwargs = {} + if i == 0: + kwargs['label'] = label + ax.plot(x[:,i], color=params.color, alpha=params.alpha, linewidth=params.linewidth, **kwargs) + if i == n - 1: + ax.set_xlabel('time') + ax.set_ylabel(axes_labels[i]) + ax.set_xlim([0, num_steps]) + ax.grid(visible=True) + if X is not None: - ax1.axline((-1, X.vertices[:,0].max()), slope=0, color='k', linewidth=2) - ax1.axline((-1, X.vertices[:,0].min()), slope=0, color='k', linewidth=2) - ax1.set_ylabel(axes_labels[0]) - ax1.set_xlim([0, num_steps]) - ax1.grid(visible=True) + plot_constraints(fig, X) + + ax1 = axes[0] if title is not None: ax1.set_title(title) @@ -64,15 +89,6 @@ def plot_x_state_time( handles, labels = ax1.get_legend_handles_labels() by_label = dict(zip(labels, handles)) ax1.legend(by_label.values(), by_label.keys(), loc=legend_loc) - - ax2.plot(x[:,1], color=params.color, alpha=params.alpha, linewidth=params.linewidth) - if X is not None: - ax2.axline((-1, X.vertices[:,1].max()), slope=0, color='k', linewidth=2) - ax2.axline((-1, X.vertices[:,1].min()), slope=0, color='k', linewidth=2) - ax2.set_xlabel('time') - ax2.set_ylabel(axes_labels[1]) - ax2.set_xlim([0, num_steps]) - ax2.grid(visible=True) fig.tight_layout() diff --git a/notebooks/tutorial.ipynb b/notebooks/tutorial.ipynb index 4d06d41..216a009 100644 --- a/notebooks/tutorial.ipynb +++ b/notebooks/tutorial.ipynb @@ -1518,7 +1518,7 @@ "source": [ "For convenience, the AMPyC package also contains basic plotting functions. Here we use a state-state plot and an input-time plot to visualize the trajectories produced by the MPC.\n", "\n", - "**Note:** All plotting functions shipped with the AMPyC package assume a 2D state and a 1D input." + "**Note:** Most plotting functions shipped with the AMPyC package assume a 2D state and a 1D input." ] }, {