Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate notebooks once per fetch or save #724

Merged
merged 3 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions jupyter_server/services/contents/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,13 @@ def _get_os_path(self, path):
raise HTTPError(404, "%s is outside root contents directory" % path)
return os_path

def _read_notebook(self, os_path, as_version=4):
def _read_notebook(self, os_path, as_version=4, capture_validation_error=None):
"""Read a notebook from an os path."""
with self.open(os_path, "r", encoding="utf-8") as f:
try:
return nbformat.read(f, as_version=as_version)
return nbformat.read(
f, as_version=as_version, capture_validation_error=capture_validation_error
)
except Exception as e:
e_orig = e

Expand All @@ -284,12 +286,19 @@ def _read_notebook(self, os_path, as_version=4):
invalid_file = path_to_invalid(os_path)
replace_file(os_path, invalid_file)
replace_file(tmp_path, os_path)
return self._read_notebook(os_path, as_version)
return self._read_notebook(
os_path, as_version, capture_validation_error=capture_validation_error
)

def _save_notebook(self, os_path, nb):
def _save_notebook(self, os_path, nb, capture_validation_error=None):
"""Save a notebook to an os_path."""
with self.atomic_writing(os_path, encoding="utf-8") as f:
nbformat.write(nb, f, version=nbformat.NO_CONVERT)
nbformat.write(
nb,
f,
version=nbformat.NO_CONVERT,
capture_validation_error=capture_validation_error,
)

def _read_file(self, os_path, format):
"""Read a non-notebook file.
Expand Down Expand Up @@ -352,11 +361,18 @@ async def _copy(self, src, dest):
"""
await async_copy2_safe(src, dest, log=self.log)

async def _read_notebook(self, os_path, as_version=4):
async def _read_notebook(self, os_path, as_version=4, capture_validation_error=None):
"""Read a notebook from an os path."""
with self.open(os_path, "r", encoding="utf-8") as f:
try:
return await run_sync(partial(nbformat.read, as_version=as_version), f)
return await run_sync(
partial(
nbformat.read,
as_version=as_version,
capture_validation_error=capture_validation_error,
),
f,
)
except Exception as e:
e_orig = e

Expand All @@ -375,12 +391,22 @@ async def _read_notebook(self, os_path, as_version=4):
invalid_file = path_to_invalid(os_path)
await async_replace_file(os_path, invalid_file)
await async_replace_file(tmp_path, os_path)
return await self._read_notebook(os_path, as_version)
return await self._read_notebook(
os_path, as_version, capture_validation_error=capture_validation_error
)

async def _save_notebook(self, os_path, nb):
async def _save_notebook(self, os_path, nb, capture_validation_error=None):
"""Save a notebook to an os_path."""
with self.atomic_writing(os_path, encoding="utf-8") as f:
await run_sync(partial(nbformat.write, version=nbformat.NO_CONVERT), nb, f)
await run_sync(
partial(
nbformat.write,
version=nbformat.NO_CONVERT,
capture_validation_error=capture_validation_error,
),
nb,
f,
)

async def _read_file(self, os_path, format):
"""Read a non-notebook file.
Expand Down
24 changes: 16 additions & 8 deletions jupyter_server/services/contents/filemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,14 @@ def _notebook_model(self, path, content=True):
os_path = self._get_os_path(path)

if content:
nb = self._read_notebook(os_path, as_version=4)
validation_error = {}
nb = self._read_notebook(
os_path, as_version=4, capture_validation_error=validation_error
)
self.mark_trusted_cells(nb, path)
model["content"] = nb
model["format"] = "json"
self.validate_notebook_model(model)
self.validate_notebook_model(model, validation_error)

return model

Expand Down Expand Up @@ -461,11 +464,12 @@ def save(self, model, path=""):
os_path = self._get_os_path(path)
self.log.debug("Saving %s", os_path)

validation_error = {}
try:
if model["type"] == "notebook":
nb = nbformat.from_dict(model["content"])
self.check_and_sign(nb, path)
self._save_notebook(os_path, nb)
self._save_notebook(os_path, nb, capture_validation_error=validation_error)
# One checkpoint should always exist for notebooks.
if not self.checkpoints.list_checkpoints(path):
self.create_checkpoint(path)
Expand All @@ -484,7 +488,7 @@ def save(self, model, path=""):

validation_message = None
if model["type"] == "notebook":
self.validate_notebook_model(model)
self.validate_notebook_model(model, validation_error=validation_error)
validation_message = model.get("message", None)

model = self.get(path, content=False)
Expand Down Expand Up @@ -707,11 +711,14 @@ async def _notebook_model(self, path, content=True):
os_path = self._get_os_path(path)

if content:
nb = await self._read_notebook(os_path, as_version=4)
validation_error = {}
nb = await self._read_notebook(
os_path, as_version=4, capture_validation_error=validation_error
)
self.mark_trusted_cells(nb, path)
model["content"] = nb
model["format"] = "json"
self.validate_notebook_model(model)
self.validate_notebook_model(model, validation_error)

return model

Expand Down Expand Up @@ -785,11 +792,12 @@ async def save(self, model, path=""):
os_path = self._get_os_path(path)
self.log.debug("Saving %s", os_path)

validation_error = {}
try:
if model["type"] == "notebook":
nb = nbformat.from_dict(model["content"])
self.check_and_sign(nb, path)
await self._save_notebook(os_path, nb)
await self._save_notebook(os_path, nb, capture_validation_error=validation_error)
# One checkpoint should always exist for notebooks.
if not (await self.checkpoints.list_checkpoints(path)):
await self.create_checkpoint(path)
Expand All @@ -808,7 +816,7 @@ async def save(self, model, path=""):

validation_message = None
if model["type"] == "notebook":
self.validate_notebook_model(model)
self.validate_notebook_model(model, validation_error=validation_error)
validation_message = model.get("message", None)

model = await self.get(path, content=False)
Expand Down
14 changes: 12 additions & 2 deletions jupyter_server/services/contents/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,20 @@ def increment_filename(self, filename, path="", insert=""):
break
return name

def validate_notebook_model(self, model):
def validate_notebook_model(self, model, validation_error=None):
"""Add failed-validation message to model"""
try:
validate_nb(model["content"])
# If we're given a validation_error dictionary, extract the exception
# from it and raise the exception, else call nbformat's validate method
# to determine if the notebook is valid. This 'else' condition may
# pertain to server extension not using the server's notebook read/write
# functions.
if validation_error is not None:
e = validation_error.get("ValidationError")
if isinstance(e, ValidationError):
raise e
else:
validate_nb(model["content"])
except ValidationError as e:
model["message"] = "Notebook validation failed: {}:\n{}".format(
e.message,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ install_requires =
traitlets>=5
jupyter_core>=4.6.0
jupyter_client>=6.1.1
nbformat
nbformat>=5.2.0
nbconvert
Send2Trash
terminado>=0.8.3
Expand Down
109 changes: 107 additions & 2 deletions tests/services/contents/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
import sys
import time
from itertools import combinations
from typing import Dict
from typing import Optional
from typing import Tuple
from unittest.mock import patch

import pytest
from nbformat import v4 as nbformat
from nbformat import ValidationError
from tornado.web import HTTPError
from traitlets import TraitError

Expand Down Expand Up @@ -63,7 +68,16 @@ def add_code_cell(notebook):
notebook.cells.append(cell)


async def new_notebook(jp_contents_manager):
def add_invalid_cell(notebook):
output = nbformat.new_output("display_data", {"application/javascript": "alert('hi');"})
cell = nbformat.new_code_cell("print('hi')", outputs=[output])
cell.pop("source") # Remove source to invaliate
notebook.cells.append(cell)


async def prepare_notebook(
jp_contents_manager, make_invalid: Optional[bool] = False
) -> Tuple[Dict, str]:
cm = jp_contents_manager
model = await ensure_async(cm.new_untitled(type="notebook"))
name = model["name"]
Expand All @@ -72,8 +86,19 @@ async def new_notebook(jp_contents_manager):
full_model = await ensure_async(cm.get(path))
nb = full_model["content"]
nb["metadata"]["counter"] = int(1e6 * time.time())
add_code_cell(nb)
if make_invalid:
add_invalid_cell(nb)
else:
add_code_cell(nb)
return full_model, path


async def new_notebook(jp_contents_manager):
full_model, path = await prepare_notebook(jp_contents_manager)
cm = jp_contents_manager
name = full_model["name"]
path = full_model["path"]
nb = full_model["content"]
await ensure_async(cm.save(full_model, path))
return nb, name, path

Expand Down Expand Up @@ -667,3 +692,83 @@ async def test_check_and_sign(jp_contents_manager):
cm.mark_trusted_cells(nb, path)
cm.check_and_sign(nb, path)
assert cm.notary.check_signature(nb)


async def test_nb_validation(jp_contents_manager):
# Test that validation is performed once when a notebook is read or written

model, path = await prepare_notebook(jp_contents_manager, make_invalid=False)
cm = jp_contents_manager

# We'll use a patch to capture the call count on "nbformat.validate" for the
# successful methods and ensure that calls to the aliased "validate_nb" are
# zero. Note that since patching side-effects the validation error case, we'll
# skip call-count assertions for that portion of the test.
with patch("nbformat.validate") as mock_validate, patch(
"jupyter_server.services.contents.manager.validate_nb"
) as mock_validate_nb:
# Valid notebook, save, then get
model = await ensure_async(cm.save(model, path))
assert "message" not in model
assert mock_validate.call_count == 1
assert mock_validate_nb.call_count == 0
mock_validate.reset_mock()
mock_validate_nb.reset_mock()

# Get the notebook and ensure there are no messages
model = await ensure_async(cm.get(path))
assert "message" not in model
assert mock_validate.call_count == 1
assert mock_validate_nb.call_count == 0
mock_validate.reset_mock()
mock_validate_nb.reset_mock()

# Add invalid cell, save, then get
add_invalid_cell(model["content"])

model = await ensure_async(cm.save(model, path))
assert "message" in model
assert "Notebook validation failed:" in model["message"]

model = await ensure_async(cm.get(path))
assert "message" in model
assert "Notebook validation failed:" in model["message"]


async def test_validate_notebook_model(jp_contents_manager):
# Test the validation_notebook_model method to ensure that validation is not
# performed when a validation_error dictionary is provided and is performed
# when that parameter is None.

model, path = await prepare_notebook(jp_contents_manager, make_invalid=False)
cm = jp_contents_manager

with patch("jupyter_server.services.contents.manager.validate_nb") as mock_validate_nb:
# Valid notebook and a non-None dictionary, no validate call expected

validation_error = {}
cm.validate_notebook_model(model, validation_error)
assert mock_validate_nb.call_count == 0
mock_validate_nb.reset_mock()

# And without the extra parameter, validate call expected
cm.validate_notebook_model(model)
assert mock_validate_nb.call_count == 1
mock_validate_nb.reset_mock()

# Now do the same with an invalid model
# invalidate the model...
add_invalid_cell(model["content"])

validation_error["ValidationError"] = ValidationError("not a real validation error")
cm.validate_notebook_model(model, validation_error)
assert "Notebook validation failed" in model["message"]
assert mock_validate_nb.call_count == 0
mock_validate_nb.reset_mock()
model.pop("message")

# And without the extra parameter, validate call expected. Since patch side-effects
# the patched method, we won't attempt to access the message field.
cm.validate_notebook_model(model)
assert mock_validate_nb.call_count == 1
mock_validate_nb.reset_mock()