Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ampyc/plotting/plot_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ampyc.typing import Params
from ampyc.utils import Polytope
from .plot_x import plot_constraints

def plot_u(fig_number: int,
u: np.ndarray,
Expand Down Expand Up @@ -49,8 +50,7 @@ def plot_u(fig_number: int,

ax.plot(u, color=params.color, alpha=params.alpha, linewidth=params.linewidth, label=label)
if U is not None:
ax.axline((-1, U.vertices.max()), slope=0, color='k', linewidth=2)
ax.axline((-1, U.vertices.min()), slope=0, color='k', linewidth=2)
plot_constraints(fig, U)
ax.set_xlabel('time')
ax.set_ylabel(axes_labels[0])
ax.set_xlim([0, num_steps])
Expand Down
58 changes: 37 additions & 21 deletions ampyc/plotting/plot_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,24 @@

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from ampyc.typing import Params
from ampyc.utils import Polytope


def plot_constraints(
fig: Figure,
X: Polytope,
):
mins, maxes = X.bounding_box
for i, ax in enumerate(fig.axes):
if np.isfinite(mins[i, 0]):
ax.axline((-1, mins[i, 0]), slope=0, color='k', linewidth=2)
if np.isfinite(maxes[i, 0]):
ax.axline((-1, maxes[i, 0]), slope=0, color='k', linewidth=2)


def plot_x_state_time(
fig_number: int,
x: np.ndarray,
Expand All @@ -21,15 +35,15 @@ def plot_x_state_time(
label: str | None = None,
legend_loc: str = 'upper right',
title: str | None = None,
axes_labels: list[str] = ['x_1', 'x_2'],
axes_labels: list[str] | None = None,
) -> None:
'''
Plots the state trajectory x over time, including the state constraint set X.
This function assumes a 2D state space (n=2) and plots the two state variables against time.
This function plots the state variables against time.

Args:
fig_number (int): The figure number to use for the plot. This allows multiple plots in the same figure.
x (np.ndarray): The state trajectory, shape (N, n=2, T), where N is the number of time steps,
x (np.ndarray): The state trajectory, shape (N, n, T), where N is the number of time steps,
n is the state dimension, and T is the number of trajectories.
X (Polytope | None): The state constraint set.
params (Params): Parameters for plotting, e.g., color, alpha, and linewidth.
Expand All @@ -39,23 +53,34 @@ def plot_x_state_time(
axes_labels (list[str]): Labels for the x and y axes.
'''

n = x.shape[1]
if axes_labels is None:
axes_labels = [f'x_{{{i+1}}}' for i in range(n)]

# check if the figure number is already open
if plt.fignum_exists(fig_number):
fig = plt.figure(fig_number)
ax1 = fig.axes[0]
ax2 = fig.axes[1]
axes = fig.axes
else:
fig, (ax1, ax2) = plt.subplots(2,1, num=fig_number, sharex=True)
fig, axes = plt.subplots(n,1, num=fig_number, sharex=True)

num_steps = x.shape[0]

ax1.plot(x[:,0], color=params.color, alpha=params.alpha, linewidth=params.linewidth, label=label)
for i, ax in enumerate(axes):
kwargs = {}
if i == 0:
kwargs['label'] = label
ax.plot(x[:,i], color=params.color, alpha=params.alpha, linewidth=params.linewidth, **kwargs)
if i == n - 1:
ax.set_xlabel('time')
ax.set_ylabel(axes_labels[i])
ax.set_xlim([0, num_steps])
ax.grid(visible=True)

if X is not None:
ax1.axline((-1, X.vertices[:,0].max()), slope=0, color='k', linewidth=2)
ax1.axline((-1, X.vertices[:,0].min()), slope=0, color='k', linewidth=2)
ax1.set_ylabel(axes_labels[0])
ax1.set_xlim([0, num_steps])
ax1.grid(visible=True)
plot_constraints(fig, X)

ax1 = axes[0]
if title is not None:
ax1.set_title(title)

Expand All @@ -64,15 +89,6 @@ def plot_x_state_time(
handles, labels = ax1.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
ax1.legend(by_label.values(), by_label.keys(), loc=legend_loc)

ax2.plot(x[:,1], color=params.color, alpha=params.alpha, linewidth=params.linewidth)
if X is not None:
ax2.axline((-1, X.vertices[:,1].max()), slope=0, color='k', linewidth=2)
ax2.axline((-1, X.vertices[:,1].min()), slope=0, color='k', linewidth=2)
ax2.set_xlabel('time')
ax2.set_ylabel(axes_labels[1])
ax2.set_xlim([0, num_steps])
ax2.grid(visible=True)

fig.tight_layout()

Expand Down
2 changes: 1 addition & 1 deletion notebooks/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1518,7 +1518,7 @@
"source": [
"For convenience, the AMPyC package also contains basic plotting functions. Here we use a state-state plot and an input-time plot to visualize the trajectories produced by the MPC.\n",
"\n",
"**Note:** All plotting functions shipped with the AMPyC package assume a 2D state and a 1D input."
"**Note:** Most plotting functions shipped with the AMPyC package assume a 2D state and a 1D input."
]
},
{
Expand Down