Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions partinet/process_utils/guided_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def transform(image: np.ndarray) -> np.ndarray:
"""
i_min = image.min()
i_max = image.max()
if i_max == i_min:
# avoid division by zero; return a zero array when input is constant
return np.zeros_like(image, dtype=np.uint8)
image = ((image - i_min) / (i_max - i_min)) * 255
return image.astype(np.uint8)

Expand All @@ -28,12 +31,20 @@ def standard_scaler(image: np.ndarray) -> np.ndarray:
"""
Apply Gaussian blur and standardize the image to have zero mean and unit variance.

The input is cast to ``float32`` before any OpenCV operations to avoid the
``CV_16F`` kernel-type error that occurs when processing 16‑bit micrographs
(see issue #41). After blurring and normalization we transform the result to
eight‑bit for downstream filters.

Args:
image (np.ndarray): Input image array.

Returns:
np.ndarray: Scaled and transformed image.
"""
# convert to a supported floating point type for OpenCV kernels
image = image.astype(np.float32)

kernel_size = 9
image = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
mu = np.mean(image)
Expand Down Expand Up @@ -150,6 +161,12 @@ def denoise(image_path: str) -> np.ndarray:
"""
kernel = gaussian_kernel(kernel_size=9)
image = mrcfile.read(image_path)

# some MRCs are stored as 16‑bit integers; ensure we work in float32 so that
# subsequent OpenCV calls (GaussianBlur, etc.) don't raise the ktype error
# described in https://github.com/WEHI-ResearchComputing/PartiNet/issues/41
image = image.astype(np.float32)

image = image.T
image = np.rot90(image)
normalized_image = standard_scaler(np.array(image))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"tensorboard",
"scikit-learn",
"mrcfile",
"pytest",
]

[project.scripts]
Expand Down
51 changes: 51 additions & 0 deletions tests/test_guided_denoiser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import tempfile

import numpy as np
import mrcfile
import pytest

from partinet.process_utils.guided_denoiser import denoise, transform, standard_scaler


def _write_temp_mrc(array: np.ndarray, dtype: np.dtype) -> str:
"""Write ``array`` to a temporary .mrc file and return its path."""
fd, path = tempfile.mkstemp(suffix=".mrc")
os.close(fd)
with mrcfile.new(path, overwrite=True) as mrc:
mrc.set_data(array.astype(dtype))
return path


@ pytest.mark.parametrize("dtype", [np.uint16, np.int16, np.float32])
def test_denoise_handles_various_dtypes(dtype):
"""The ``denoise`` pipeline should accept 16-bit and 32-bit inputs without
throwing OpenCV kernel-type errors (issue #41).
"""
data = (np.random.rand(32, 32) * 255).astype(dtype)
path = _write_temp_mrc(data, dtype)

try:
out = denoise(path)
assert isinstance(out, np.ndarray)
# our pipeline always returns 8-bit data
assert out.dtype == np.uint8
assert out.shape == data.shape
finally:
os.unlink(path)


def test_transform_stable_when_constant():
"""``transform`` should not divide by zero if image has no contrast."""
arr = np.full((4, 4), 100, dtype=np.float32)
out = transform(arr)
assert out.dtype == np.uint8
assert np.all(out == 0)


def test_standard_scaler_normalises():
arr = np.arange(25, dtype=np.float32).reshape(5, 5)
scaled = standard_scaler(arr)
assert scaled.dtype == np.uint8
# scaled values should not all be equal
assert scaled.std() > 0