Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 54 additions & 66 deletions pybop/_result.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
import types

import numpy as np

Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(
self._problem = problem
self._minimising = problem.minimising
self.method_name = method_name
self.n_runs = 0
self.n_runs = 1
self._best_run = None
self._x = [logger.x_model_best]
self._x_model = [logger.x_model]
Expand Down Expand Up @@ -322,28 +323,35 @@ def plot_contour(self, **kwargs):

def save(self, filename) -> None:
"""Save the whole result using pickle"""

with open(filename, "wb") as f:
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)

@staticmethod
def load(filename: str) -> "Result":
"""Load a saved Result."""
with open(filename, "rb") as f:
result = pickle.load(f)
return result

def data_dict(self) -> dict:
"""return result data as dictionary for saving to file"""

return {
"minimising": self._minimising,
"method_name": self.method_name,
"n_runs": self.n_runs,
"best_run": self._best_run,
"x": self._x,
"x_model": self._x_model,
"x0": self._x0,
"best_cost": self._best_cost,
"cost": self._cost,
"initial_cost": self._initial_cost,
"n_iterations": self._n_iterations,
"iteration_number": self._iteration_number,
"n_evaluations": self._n_evaluations,
"message": self._message,
"scipy_result": self._scipy_result
if self._scipy_result[0] is not None
else [],
"scipy_result": [0 if x is None else x for x in self._scipy_result],
"time": self._time,
}

Expand Down Expand Up @@ -379,7 +387,7 @@ def save_data(
return save_data_dict(data, filename=filename, to_format=to_format)

@staticmethod
def load_data_dict(filename: str, file_format: str = "pickle") -> dict:
def load_data(filename: str, file_format: str = "pickle") -> dict:
"""
Load results data as dictionary from a given file. Restores data saved with
save_data.
Expand All @@ -403,76 +411,56 @@ def load_data_dict(filename: str, file_format: str = "pickle") -> dict:
data_dict :
python dictionary containing the data in the file.
"""
return load_data_dict(
data = load_data_dict(
filename,
file_format=file_format,
data_keys_0d=["n_runs"],
data_keys_0d=["_minimising", "n_runs", "best_run"],
data_keys_1d=[
"method_name",
"message",
"n_evaluations",
"best_cost",
"initial_cost",
"n_iterations",
"n_evaluations",
"message",
"scipy_result",
"time",
],
)

@staticmethod
def load_result(
problem: Problem, filename: str, file_format: str = "pickle"
) -> "Result":
"""
Reconstructs result object based on the underlying problem and
the result data stored in a file

Parameters
----------
problem: The underlying problem used to obtain the result before saving.
filename : str
The name of the file containing the data.
file_format : str, optional
The format the data was save to. Options are:
- 'pickle' (default)
- 'matlab'
- 'csv'
- 'json'

Returns
-------
result :
result object containing the data from the given file.
"""

# read data file
data = Result.load_data_dict(filename, file_format)

# dummy logger for initialising result
logger = Logger(minimising=problem.minimising)
logger.extend_log(
x_search=[np.asarray([1e-3])], x_model=[np.asarray([1e-3])], cost=[0.1]
)
method_name = data["method_name"] if "method_name" in data.keys() else None
message = data["message"] if "message" in data.keys() else None

# create result instance
result = Result(
problem,
logger,
time=0.0,
method_name=method_name,
message=message,
)

# set result data
if "n_runs" in data.keys():
result.n_runs = data["n_runs"]
del data["n_runs"]
for key, value in data.items():
setattr(result, f"_{key}", list(value))
result._x0 = [x_model[0] for x_model in result._x_model]
if len(result._scipy_result) == 0:
result._scipy_result = [None for _ in range(max(1, result.n_runs))]
# Create a dummy problem
problem = types.SimpleNamespace()
problem.minimising = data["minimising"]

# Create one logging result for each run
n_runs = data["n_runs"]
list_of_results = []
for i in range(n_runs):
# Create a dummy logger
logger = types.SimpleNamespace()
for logger_key, result_key in [
("x_model_best", "x"),
("x_model", "x_model"),
("x0", "x0"),
("cost_best", "best_cost"),
("cost_convergence", "cost"),
("iteration", "n_iterations"),
("iteration_number", "iteration_number"),
("evaluations", "n_evaluations"),
]:
setattr(logger, logger_key, data[result_key][i])
logger.cost = [data["initial_cost"][i]]

list_of_results.append(
Result(
problem=problem,
logger=logger,
time=data["time"][i],
method_name=data["method_name"],
message=data["message"][i],
scipy_result=data["scipy_result"][i]
if data["scipy_result"][i] != 0
else None,
)
)

return result
return Result.combine(results=list_of_results)
1 change: 1 addition & 0 deletions pybop/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .problem import problem
from .nyquist import nyquist
from .voronoi import surface
from .samples import trace, chains, posterior, summary_table
131 changes: 131 additions & 0 deletions pybop/plot/samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import TYPE_CHECKING

from pybop.plot import PlotlyManager

if TYPE_CHECKING:
from pybop.samplers.base_pints_sampler import SamplingResult


def trace(result: "SamplingResult", **kwargs):
"""
Plot trace plots for the posterior samples.
"""
# Import plotly only when needed
go = PlotlyManager().go

for i in range(result.num_parameters):
fig = go.Figure()

for j, chain in enumerate(result.chains):
fig.add_trace(go.Scatter(y=chain[:, i], mode="lines", name=f"Chain {j}"))

fig.update_layout(
title=f"Parameter {i} Trace Plot",
xaxis_title="Sample Index",
yaxis_title="Value",
)
fig.update_layout(**kwargs)
fig.show()


def chains(result: "SamplingResult", **kwargs):
"""
Plot posterior distributions for each chain.
"""
# Import plotly only when needed
go = PlotlyManager().go

fig = go.Figure()

for i, chain in enumerate(result.chains):
for j in range(chain.shape[1]):
fig.add_trace(
go.Histogram(
x=chain[:, j],
name=f"Chain {i} - Parameter {j}",
opacity=0.75,
)
)

fig.add_shape(
type="line",
x0=result.mean[j],
y0=0,
x1=result.mean[j],
y1=result.max[j],
name=f"Mean - Parameter {j}",
line=dict(color="Black", width=1.5, dash="dash"),
)

fig.update_layout(
barmode="overlay",
title="Posterior Distribution",
xaxis_title="Value",
yaxis_title="Density",
)
fig.update_layout(**kwargs)
fig.show()


def posterior(result: "SamplingResult", **kwargs):
"""
Plot the summed posterior distribution across chains.
"""
# Import plotly only when needed
go = PlotlyManager().go

fig = go.Figure()

for j in range(result.all_samples.shape[1]):
histogram = go.Histogram(
x=result.all_samples[:, j],
name=f"Parameter {j}",
opacity=0.75,
)
fig.add_trace(histogram)
fig.add_vline(
x=result.mean[j], line_width=3, line_dash="dash", line_color="black"
)

fig.update_layout(
barmode="overlay",
title="Posterior Distribution",
xaxis_title="Value",
yaxis_title="Density",
)
fig.update_layout(**kwargs)
fig.show()
return fig


def summary_table(result: "SamplingResult"):
"""
Display summary statistics in a table.
"""
# Import plotly only when needed
go = PlotlyManager().go

summary_stats = result.get_summary_statistics()

header = ["Statistic", "Value"]
values = [
["Mean", summary_stats["mean"]],
["Median", summary_stats["median"]],
["Standard Deviation", summary_stats["std"]],
["95% CI Lower", summary_stats["ci_lower"]],
["95% CI Upper", summary_stats["ci_upper"]],
]

fig = go.Figure(
data=[
go.Table(
header=dict(values=header),
cells=dict(
values=[[row[0] for row in values], [row[1] for row in values]]
),
)
]
)

fig.update_layout(title="Summary Statistics")
fig.show()
Loading