diff --git a/pybop/_result.py b/pybop/_result.py index 0c0d8156e..9b8073a07 100644 --- a/pybop/_result.py +++ b/pybop/_result.py @@ -1,4 +1,5 @@ import pickle +import types import numpy as np @@ -38,7 +39,7 @@ def __init__( self._problem = problem self._minimising = problem.minimising self.method_name = method_name - self.n_runs = 0 + self.n_runs = 1 self._best_run = None self._x = [logger.x_model_best] self._x_model = [logger.x_model] @@ -322,18 +323,27 @@ def plot_contour(self, **kwargs): def save(self, filename) -> None: """Save the whole result using pickle""" - with open(filename, "wb") as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) + @staticmethod + def load(filename: str) -> "Result": + """Load a saved Result.""" + with open(filename, "rb") as f: + result = pickle.load(f) + return result + def data_dict(self) -> dict: """return result data as dictionary for saving to file""" return { + "minimising": self._minimising, "method_name": self.method_name, "n_runs": self.n_runs, + "best_run": self._best_run, "x": self._x, "x_model": self._x_model, + "x0": self._x0, "best_cost": self._best_cost, "cost": self._cost, "initial_cost": self._initial_cost, @@ -341,9 +351,7 @@ def data_dict(self) -> dict: "iteration_number": self._iteration_number, "n_evaluations": self._n_evaluations, "message": self._message, - "scipy_result": self._scipy_result - if self._scipy_result[0] is not None - else [], + "scipy_result": [0 if x is None else x for x in self._scipy_result], "time": self._time, } @@ -379,7 +387,7 @@ def save_data( return save_data_dict(data, filename=filename, to_format=to_format) @staticmethod - def load_data_dict(filename: str, file_format: str = "pickle") -> dict: + def load_data(filename: str, file_format: str = "pickle") -> dict: """ Load results data as dictionary from a given file. Restores data saved with save_data. @@ -403,76 +411,56 @@ def load_data_dict(filename: str, file_format: str = "pickle") -> dict: data_dict : python dictionary containing the data in the file. """ - return load_data_dict( + data = load_data_dict( filename, file_format=file_format, - data_keys_0d=["n_runs"], + data_keys_0d=["_minimising", "n_runs", "best_run"], data_keys_1d=[ "method_name", - "message", - "n_evaluations", "best_cost", "initial_cost", "n_iterations", "n_evaluations", + "message", + "scipy_result", "time", ], ) - @staticmethod - def load_result( - problem: Problem, filename: str, file_format: str = "pickle" - ) -> "Result": - """ - Reconstructs result object based on the underlying problem and - the result data stored in a file - - Parameters - ---------- - problem: The underlying problem used to obtain the result before saving. - filename : str - The name of the file containing the data. - file_format : str, optional - The format the data was save to. Options are: - - 'pickle' (default) - - 'matlab' - - 'csv' - - 'json' - - Returns - ------- - result : - result object containing the data from the given file. - """ - - # read data file - data = Result.load_data_dict(filename, file_format) - - # dummy logger for initialising result - logger = Logger(minimising=problem.minimising) - logger.extend_log( - x_search=[np.asarray([1e-3])], x_model=[np.asarray([1e-3])], cost=[0.1] - ) - method_name = data["method_name"] if "method_name" in data.keys() else None - message = data["message"] if "message" in data.keys() else None - - # create result instance - result = Result( - problem, - logger, - time=0.0, - method_name=method_name, - message=message, - ) - - # set result data - if "n_runs" in data.keys(): - result.n_runs = data["n_runs"] - del data["n_runs"] - for key, value in data.items(): - setattr(result, f"_{key}", list(value)) - result._x0 = [x_model[0] for x_model in result._x_model] - if len(result._scipy_result) == 0: - result._scipy_result = [None for _ in range(max(1, result.n_runs))] + # Create a dummy problem + problem = types.SimpleNamespace() + problem.minimising = data["minimising"] + + # Create one logging result for each run + n_runs = data["n_runs"] + list_of_results = [] + for i in range(n_runs): + # Create a dummy logger + logger = types.SimpleNamespace() + for logger_key, result_key in [ + ("x_model_best", "x"), + ("x_model", "x_model"), + ("x0", "x0"), + ("cost_best", "best_cost"), + ("cost_convergence", "cost"), + ("iteration", "n_iterations"), + ("iteration_number", "iteration_number"), + ("evaluations", "n_evaluations"), + ]: + setattr(logger, logger_key, data[result_key][i]) + logger.cost = [data["initial_cost"][i]] + + list_of_results.append( + Result( + problem=problem, + logger=logger, + time=data["time"][i], + method_name=data["method_name"], + message=data["message"][i], + scipy_result=data["scipy_result"][i] + if data["scipy_result"][i] != 0 + else None, + ) + ) - return result + return Result.combine(results=list_of_results) diff --git a/pybop/plot/__init__.py b/pybop/plot/__init__.py index 47f5c8f06..f58db2016 100644 --- a/pybop/plot/__init__.py +++ b/pybop/plot/__init__.py @@ -10,3 +10,4 @@ from .problem import problem from .nyquist import nyquist from .voronoi import surface +from .samples import trace, chains, posterior, summary_table diff --git a/pybop/plot/samples.py b/pybop/plot/samples.py new file mode 100644 index 000000000..1de41ea42 --- /dev/null +++ b/pybop/plot/samples.py @@ -0,0 +1,131 @@ +from typing import TYPE_CHECKING + +from pybop.plot import PlotlyManager + +if TYPE_CHECKING: + from pybop.samplers.base_pints_sampler import SamplingResult + + +def trace(result: "SamplingResult", **kwargs): + """ + Plot trace plots for the posterior samples. + """ + # Import plotly only when needed + go = PlotlyManager().go + + for i in range(result.num_parameters): + fig = go.Figure() + + for j, chain in enumerate(result.chains): + fig.add_trace(go.Scatter(y=chain[:, i], mode="lines", name=f"Chain {j}")) + + fig.update_layout( + title=f"Parameter {i} Trace Plot", + xaxis_title="Sample Index", + yaxis_title="Value", + ) + fig.update_layout(**kwargs) + fig.show() + + +def chains(result: "SamplingResult", **kwargs): + """ + Plot posterior distributions for each chain. + """ + # Import plotly only when needed + go = PlotlyManager().go + + fig = go.Figure() + + for i, chain in enumerate(result.chains): + for j in range(chain.shape[1]): + fig.add_trace( + go.Histogram( + x=chain[:, j], + name=f"Chain {i} - Parameter {j}", + opacity=0.75, + ) + ) + + fig.add_shape( + type="line", + x0=result.mean[j], + y0=0, + x1=result.mean[j], + y1=result.max[j], + name=f"Mean - Parameter {j}", + line=dict(color="Black", width=1.5, dash="dash"), + ) + + fig.update_layout( + barmode="overlay", + title="Posterior Distribution", + xaxis_title="Value", + yaxis_title="Density", + ) + fig.update_layout(**kwargs) + fig.show() + + +def posterior(result: "SamplingResult", **kwargs): + """ + Plot the summed posterior distribution across chains. + """ + # Import plotly only when needed + go = PlotlyManager().go + + fig = go.Figure() + + for j in range(result.all_samples.shape[1]): + histogram = go.Histogram( + x=result.all_samples[:, j], + name=f"Parameter {j}", + opacity=0.75, + ) + fig.add_trace(histogram) + fig.add_vline( + x=result.mean[j], line_width=3, line_dash="dash", line_color="black" + ) + + fig.update_layout( + barmode="overlay", + title="Posterior Distribution", + xaxis_title="Value", + yaxis_title="Density", + ) + fig.update_layout(**kwargs) + fig.show() + return fig + + +def summary_table(result: "SamplingResult"): + """ + Display summary statistics in a table. + """ + # Import plotly only when needed + go = PlotlyManager().go + + summary_stats = result.get_summary_statistics() + + header = ["Statistic", "Value"] + values = [ + ["Mean", summary_stats["mean"]], + ["Median", summary_stats["median"]], + ["Standard Deviation", summary_stats["std"]], + ["95% CI Lower", summary_stats["ci_lower"]], + ["95% CI Upper", summary_stats["ci_upper"]], + ] + + fig = go.Figure( + data=[ + go.Table( + header=dict(values=header), + cells=dict( + values=[[row[0] for row in values], [row[1] for row in values]] + ), + ) + ] + ) + + fig.update_layout(title="Summary Statistics") + fig.show() diff --git a/pybop/samplers/base_sampler.py b/pybop/samplers/base_sampler.py index 26e9c5f9a..5bbcfdbc1 100644 --- a/pybop/samplers/base_sampler.py +++ b/pybop/samplers/base_sampler.py @@ -4,9 +4,9 @@ import pints import scipy +from pybop import plot from pybop._logging import Logger from pybop._result import Result -from pybop.plot import PlotlyManager from pybop.problems.problem import Problem @@ -163,22 +163,6 @@ def __init__( self.chains = chains self.all_samples = np.concatenate(chains, axis=0) self.num_parameters = self.chains.shape[2] - self.go = PlotlyManager().go - - def __getstate__(self): - # Copy the object's state from self.__dict__ which contains - # all our instance attributes. Use the dict.copy() - # method to avoid modifying the original state. - state = self.__dict__.copy() - # Remove the unpicklable entries. - del state["go"] - return state - - def __setstate__(self, state): - # Restore instance attributes . - self.__dict__.update(state) - # Restore unpickalable attributes - self.go = PlotlyManager().go def signif(self, x, p: int): """ @@ -235,113 +219,25 @@ def plot_trace(self, **kwargs): """ Plot trace plots for the posterior samples. """ - - for i in range(self.num_parameters): - fig = self.go.Figure() - - for j, chain in enumerate(self.chains): - fig.add_trace( - self.go.Scatter(y=chain[:, i], mode="lines", name=f"Chain {j}") - ) - - fig.update_layout( - title=f"Parameter {i} Trace Plot", - xaxis_title="Sample Index", - yaxis_title="Value", - ) - fig.update_layout(**kwargs) - fig.show() + return plot.trace(result=self, **kwargs) def plot_chains(self, **kwargs): """ Plot posterior distributions for each chain. """ - fig = self.go.Figure() - - for i, chain in enumerate(self.chains): - for j in range(chain.shape[1]): - fig.add_trace( - self.go.Histogram( - x=chain[:, j], - name=f"Chain {i} - Parameter {j}", - opacity=0.75, - ) - ) - - fig.add_shape( - type="line", - x0=self.mean[j], - y0=0, - x1=self.mean[j], - y1=self.max[j], - name=f"Mean - Parameter {j}", - line=dict(color="Black", width=1.5, dash="dash"), - ) - - fig.update_layout( - barmode="overlay", - title="Posterior Distribution", - xaxis_title="Value", - yaxis_title="Density", - ) - fig.update_layout(**kwargs) - fig.show() + return plot.chains(result=self, **kwargs) def plot_posterior(self, **kwargs): """ Plot the summed posterior distribution across chains. """ - fig = self.go.Figure() - - for j in range(self.all_samples.shape[1]): - histogram = self.go.Histogram( - x=self.all_samples[:, j], - name=f"Parameter {j}", - opacity=0.75, - ) - fig.add_trace(histogram) - fig.add_vline( - x=self.mean[j], line_width=3, line_dash="dash", line_color="black" - ) - - fig.update_layout( - barmode="overlay", - title="Posterior Distribution", - xaxis_title="Value", - yaxis_title="Density", - ) - fig.update_layout(**kwargs) - fig.show() - return fig + return plot.posterior(result=self, **kwargs) - def summary_table(self): + def summary_table(self, **kwargs): """ Display summary statistics in a table. """ - summary_stats = self.get_summary_statistics() - - header = ["Statistic", "Value"] - values = [ - ["Mean", summary_stats["mean"]], - ["Median", summary_stats["median"]], - ["Standard Deviation", summary_stats["std"]], - ["95% CI Lower", summary_stats["ci_lower"]], - ["95% CI Upper", summary_stats["ci_upper"]], - ] - - fig = self.go.Figure( - data=[ - self.go.Table( - header=dict(values=header), - cells=dict( - values=[[row[0] for row in values], [row[1] for row in values]] - ), - ) - ] - ) - - fig.update_layout(title="Summary Statistics") - fig.show() + return plot.summary_table(result=self, **kwargs) def autocorrelation(self, x: np.ndarray) -> np.ndarray: """ @@ -409,82 +305,3 @@ def rhat(self): stationary chains R-hat will be close to one, otherwise it is higher. """ return pints.rhat(self.chains) - - def data_dict(self) -> dict: - """return result data as dictionary for saving to file""" - - data_dict = super().data_dict() - data_dict["chains"] = self.chains - - return data_dict - - @staticmethod - def load_result( - sampler: BaseSampler, filename: str, file_format: str = "pickle" - ) -> "SamplingResult": - """ - Reconstructs result object based on the underlying sampler and - the result data stored in a file - - Parameters - ---------- - sampler: The underlying problem used to obtain the result before saving. - filename : str - The name of the file containing the data. - file_format : str, optional - The format the data was save to. Options are: - - 'pickle' (default) - - 'matlab' - - 'csv' - - 'json' - - Returns - ------- - result : - result object containing the data from the given file. - """ - - # read data file - data = Result.load_data_dict(filename, file_format) - - # dummy logger for initialising result - logger = sampler.logger - dummy_logger = Logger(sampler.log_pdf.minimising) - dummy_logger.extend_log( - x_search=[np.asarray([1e-3])], x_model=[np.asarray([1e-3])], cost=[0.1] - ) - - sampler._logger = dummy_logger # noqa: SLF001 - - # initialise result - time = 0.0 - chains = None - if "chains" in data.keys(): - chains = np.asarray(data["chains"].copy()) - del data["chains"] - - method_name = data["method_name"] if "method_name" in data.keys() else None - message = data["message"] if "message" in data.keys() else None - - result = SamplingResult( - sampler, - time, - chains, - method_name=method_name, - message=message, - ) - - # set result data - if "n_runs" in data.keys(): - result.n_runs = data["n_runs"] - del data["n_runs"] - for key, value in data.items(): - setattr(result, f"_{key}", list(value)) - result._x0 = [x_model[0] for x_model in result._x_model] - if len(result._scipy_result) == 0: - result._scipy_result = [None for _ in range(max(1, result.n_runs))] - - # restore orginal logger in sampler - sampler._logger = logger # noqa: SLF001 - - return result diff --git a/tests/unit/test_optimisation.py b/tests/unit/test_optimisation.py index f100ac059..6cc65e97d 100644 --- a/tests/unit/test_optimisation.py +++ b/tests/unit/test_optimisation.py @@ -1,5 +1,4 @@ import io -import pickle import re import sys @@ -795,25 +794,22 @@ def test_save_result_data(self, result, problem, to_format, tmp_path): # Test save result result.save_data(filename, to_format=to_format) - result_load = OptimisationResult.load_result( - problem, filename, file_format=to_format - ) + result_load = OptimisationResult.load_data(filename, file_format=to_format) self.compare_result_data(result, result_load) # Test save combined result result_combined = OptimisationResult.combine([result, result]) result_combined.save_data(filename, to_format=to_format) - result_load = OptimisationResult.load_result( - problem, filename, file_format=to_format - ) + result_load = OptimisationResult.load_data(filename, file_format=to_format) self.compare_result_data(result_combined, result_load) def test_save_result(self, result, tmp_path): test_stub = tmp_path / "test" # test save whole result - result.save(f"{test_stub}.pickle") - with open(f"{test_stub}.pickle", "rb") as f: - result_load = pickle.load(f) + filename = f"{test_stub}.pickle" + result.save(filename) + result_load = OptimisationResult.load(filename) self.compare_result_data(result, result_load) + assert result.problem.parameters.names == result_load.problem.parameters.names diff --git a/tests/unit/test_sampling.py b/tests/unit/test_sampling.py index 01b795b9a..b3596338f 100644 --- a/tests/unit/test_sampling.py +++ b/tests/unit/test_sampling.py @@ -1,5 +1,4 @@ import logging -import pickle from unittest.mock import call, patch import numpy as np @@ -326,9 +325,6 @@ def compare_result_data(self, result1, result2): assert result1._message == result2._message np.testing.assert_array_equal(result1._scipy_result, result2._scipy_result) np.testing.assert_array_equal(result1._time, result2._time) - np.testing.assert_array_equal(result1.chains, result2.chains) - np.testing.assert_array_equal(result1.all_samples, result2.all_samples) - assert result1.num_parameters == result2.num_parameters def test_save(self, posterior_problem, n_chains, MCMC, tmp_path): test_stub = tmp_path / "test" @@ -359,14 +355,14 @@ def test_save(self, posterior_problem, n_chains, MCMC, tmp_path): result.save_data(filename, to_format=to_format) # load result - result_load = SamplingResult.load_result( - sampler2, filename, file_format=to_format - ) + result_load = SamplingResult.load_data(filename, file_format=to_format) self.compare_result_data(result, result_load) assert sampler2.logger is None # test save whole result - result.save(f"{test_stub}.pickle") - with open(f"{test_stub}.pickle", "rb") as f: - result_load = pickle.load(f) + filename = f"{test_stub}.pickle" + result.save(filename) + result_load = SamplingResult.load(filename) self.compare_result_data(result, result_load) + assert result.problem.parameters.names == result_load.problem.parameters.names + np.testing.assert_array_equal(result.chains, result_load.chains)