Skip to content

Commit 83027c0

Browse files
CristianLarafacebook-github-bot
authored andcommitted
Sync convert_ipynb_to_mdx.py with Ax (#2795)
Summary: Pull Request resolved: #2795 This script started to drift from the implementation in Ax. We need to keep these in sync when making changes to either repo. Reviewed By: saitcakmak Differential Revision: D72236261 fbshipit-source-id: dc08c33137bcc66bfb375c91ef16e647696d245f
1 parent 4a1555b commit 83027c0

File tree

2 files changed

+152
-64
lines changed

2 files changed

+152
-64
lines changed

scripts/convert_ipynb_to_mdx.py

+145-42
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import subprocess
1313
import uuid
1414
from pathlib import Path
15-
from typing import Dict, List, Optional, Tuple, Union
15+
from typing import Union
1616

1717
import mdformat
1818
import nbformat
@@ -45,7 +45,7 @@
4545
]
4646

4747

48-
def load_nb_metadata() -> Dict[str, Dict[str, str]]:
48+
def load_nb_metadata() -> dict[str, dict[str, str]]:
4949
"""
5050
Load the metadata and list of notebooks that are to be converted to MDX.
5151
@@ -83,7 +83,7 @@ def load_notebook(path: Path) -> NotebookNode:
8383
return nb
8484

8585

86-
def create_folders(path: Path) -> Tuple[str, Path]:
86+
def create_folders(path: Path) -> tuple[str, Path]:
8787
"""
8888
Create asset folders for the tutorial.
8989
@@ -109,7 +109,7 @@ def create_folders(path: Path) -> Tuple[str, Path]:
109109
return filename, assets_folder
110110

111111

112-
def create_frontmatter(path: Path, nb_metadata: Dict[str, Dict[str, str]]) -> str:
112+
def create_frontmatter(path: Path, nb_metadata: dict[str, dict[str, str]]) -> str:
113113
"""
114114
Create frontmatter for the resulting MDX file.
115115
@@ -154,7 +154,7 @@ def create_imports() -> str:
154154
return f"{imports}\n"
155155

156156

157-
def get_current_git_tag() -> Optional[str]:
157+
def get_current_git_tag() -> str | None:
158158
"""
159159
Retrieve the current Git tag if the current commit is tagged.
160160
@@ -175,7 +175,7 @@ def get_current_git_tag() -> Optional[str]:
175175

176176

177177
def create_buttons(
178-
nb_metadata: Dict[str, Dict[str, str]],
178+
nb_metadata: dict[str, dict[str, str]],
179179
) -> str:
180180
"""
181181
Create buttons that link to Colab and GitHub for the tutorial.
@@ -196,10 +196,55 @@ def create_buttons(
196196
return f'<LinkButtons\n githubUrl="{github_url}"\n colabUrl="{colab_url}"\n/>\n\n'
197197

198198

199-
def handle_images_found_in_markdown(
199+
def handle_image_attachments(
200+
markdown: str,
201+
attachments: dict[str, dict[str, str]],
202+
) -> str:
203+
"""
204+
Image attachments are stored in the notebook cell's "attachments" field in base64
205+
format with their associated mime_type and referenced in the markdown via
206+
attachment name.
207+
208+
The pattern we search for in the Markdown is
209+
`![alt_text](attachment:attachment_name title)` with three groups:
210+
211+
- group 1 = alt_text (optional)
212+
- group 2 = attachment_name
213+
- group 3 = title (optional)
214+
215+
To represent this in MD we replace the attachment reference with the base64 encoded
216+
string as `![{alt_text}](data:{mime_type};base64,{img_as_base64})`
217+
218+
Args:
219+
markdown (str): The markdown content containing image attachments.
220+
attachments (Dict[str, Dict[str, str]]): A dictionary of attachments with their
221+
corresponding MIME types and base64 encoded data.
222+
223+
Returns:
224+
str: The markdown content with images converted to base64 format.
225+
"""
226+
markdown_image_pattern = re.compile(
227+
r"""!\[([^\]]*)\]\(attachment:(.*?)\s*?(\".*\")?\)"""
228+
)
229+
# go through searches in reverse order so that each replacement doesn't affect the
230+
# start/end indices for the next replacements
231+
searches = reversed(list(re.finditer(markdown_image_pattern, markdown)))
232+
for search in searches:
233+
alt_text, attachment_name, _ = search.groups()
234+
mime_type, base64 = next(iter(attachments[attachment_name].items()))
235+
start, end = search.span()
236+
markdown = (
237+
markdown[:start]
238+
+ generate_img_base64_md(base64, mime_type, alt_text)
239+
+ markdown[end:]
240+
)
241+
return markdown
242+
243+
244+
def handle_image_paths_found_in_markdown(
200245
markdown: str,
201246
new_img_dir: Path,
202-
lib_dir: Path,
247+
nb_path: Path,
203248
) -> str:
204249
"""
205250
Update image paths in the Markdown, and copy the image to the docs location.
@@ -210,6 +255,9 @@ def handle_images_found_in_markdown(
210255
- group 1 = path/to/image.png
211256
- group 2 = "title"
212257
258+
We explicitly exclude matching if the path starts with `attachment:` as this
259+
indicates that the image is embedded as a base64 attachment not a file path.
260+
213261
The first group (the path to the image from the original notebook) will be replaced
214262
with ``assets/img/{name}`` where the name is `image.png` from the example above. The
215263
original image will also be copied to the new location
@@ -219,12 +267,15 @@ def handle_images_found_in_markdown(
219267
markdown (str): Markdown where we look for Markdown flavored images.
220268
new_img_dir (Path): Path where images are copied to for display in the
221269
MDX file.
222-
lib_dir (Path): The location for the Bean Machine repo.
270+
lib_dir (Path): The location for the repo.
271+
nb_path (Path): The location for the notebook.
223272
224273
Returns:
225274
str: The original Markdown with new paths for images.
226275
"""
227-
markdown_image_pattern = re.compile(r"""!\[[^\]]*\]\((.*?)(?=\"|\))(\".*\")?\)""")
276+
markdown_image_pattern = re.compile(
277+
r"""!\[[^\]]*\]\((?!attachment:)(.*?)(?=\"|\))(\".*\")?\)"""
278+
)
228279
searches = list(re.finditer(markdown_image_pattern, markdown))
229280

230281
# Return the given Markdown if no images are found.
@@ -250,11 +301,11 @@ def handle_images_found_in_markdown(
250301

251302
# Copy the original image to the new location.
252303
if old_path.exists():
304+
# resolves if an absolute path is used
253305
old_img_path = old_path
254306
else:
255-
# Here we assume the original image exists in the same directory as the
256-
# notebook, which should be in the tutorials folder of the library.
257-
old_img_path = (lib_dir / "tutorials" / old_path).resolve()
307+
# fall back to path relative to the notebook
308+
old_img_path = (nb_path.parent / old_path).resolve()
258309
new_img_path = str(new_img_dir / name)
259310
shutil.copy(str(old_img_path), new_img_path)
260311

@@ -359,7 +410,7 @@ def get_source(cell: NotebookNode) -> str:
359410
def handle_markdown_cell(
360411
cell: NotebookNode,
361412
new_img_dir: Path,
362-
lib_dir: Path,
413+
nb_path: Path,
363414
) -> str:
364415
"""
365416
Handle the given Jupyter Markdown cell and convert it to MDX.
@@ -368,17 +419,17 @@ def handle_markdown_cell(
368419
cell (NotebookNode): Jupyter Markdown cell object.
369420
new_img_dir (Path): Path where images are copied to for display in the
370421
Markdown cell.
371-
lib_dir (Path): The location for the Bean Machine library.
422+
lib_dir (Path): The location for the library.
423+
nb_path (Path): The location for the notebook.
372424
373425
Returns:
374426
str: Transformed Markdown object suitable for inclusion in MDX.
375427
"""
376428
markdown = get_source(cell)
377429

378-
# Update image paths in the Markdown and copy them to the Markdown tutorials folder.
379-
# Skip - Our images are base64 encoded, so we don't need to copy them to the docs
380-
# folder.
381-
# markdown = handle_images_found_in_markdown(markdown, new_img_dir, lib_dir)
430+
# Handle the different ways images are included in the Markdown.
431+
markdown = handle_image_paths_found_in_markdown(markdown, new_img_dir, nb_path)
432+
markdown = handle_image_attachments(markdown, cell.get("attachments", {}))
382433

383434
markdown = sanitize_mdx(markdown)
384435
mdx = mdformat.text(markdown, options={"wrap": 88}, extensions={"myst"})
@@ -411,9 +462,29 @@ def handle_cell_input(cell: NotebookNode, language: str) -> str:
411462
return f"```{language}\n{cell_source}\n```\n\n"
412463

413464

465+
def generate_img_base64_md(
466+
img_as_base64: int | str | NotebookNode,
467+
mime_type: int | str | NotebookNode,
468+
alt_text: str = "",
469+
) -> str:
470+
"""
471+
Generate a markdown image tag from a base64 encoded image.
472+
473+
Args:
474+
img_as_base64 (int | str | NotebookNode): The base64 encoded image data.
475+
mime_type (int | str | NotebookNode): The MIME type of the image.
476+
alt_text (str, optional): The alternative text for the image. Defaults to an
477+
empty string.
478+
479+
Returns:
480+
str: A markdown formatted image tag.
481+
"""
482+
return f"![{alt_text}](data:{mime_type};base64,{img_as_base64})"
483+
484+
414485
def handle_image(
415-
values: List[Dict[str, Union[int, str, NotebookNode]]],
416-
) -> List[Tuple[int, str]]:
486+
values: list[dict[str, int | str | NotebookNode]],
487+
) -> list[tuple[int, str]]:
417488
"""
418489
Convert embedded images to string MDX can consume.
419490
@@ -431,13 +502,13 @@ def handle_image(
431502
index = value["index"]
432503
mime_type = value["mime_type"]
433504
img = value["data"]
434-
output.append((index, f"![](data:image/{mime_type};base64,{img})\n\n"))
505+
output.append((index, f"{generate_img_base64_md(img, mime_type)}\n\n"))
435506
return output
436507

437508

438509
def handle_markdown(
439-
values: List[Dict[str, Union[int, str, NotebookNode]]],
440-
) -> List[Tuple[int, str]]:
510+
values: list[dict[str, int | str | NotebookNode]],
511+
) -> list[tuple[int, str]]:
441512
"""
442513
Convert and format Markdown for MDX.
443514
@@ -460,8 +531,8 @@ def handle_markdown(
460531

461532

462533
def handle_pandas(
463-
values: List[Dict[str, Union[int, str, NotebookNode]]],
464-
) -> List[Tuple[int, str]]:
534+
values: list[dict[str, int | str | NotebookNode]],
535+
) -> list[tuple[int, str]]:
465536
"""
466537
Handle how to display pandas DataFrames.
467538
@@ -503,11 +574,14 @@ def handle_pandas(
503574

504575

505576
def handle_plain(
506-
values: List[Dict[str, Union[int, str, NotebookNode]]],
507-
) -> List[Tuple[int, str]]:
577+
values: list[dict[str, int | str | NotebookNode]],
578+
) -> list[tuple[int, str]]:
508579
"""
509580
Handle how to plain cell output should be displayed in MDX.
510581
582+
Stdout streams are chunked during execution, we merge adjacent streams here into
583+
single cell output blocks.
584+
511585
Args:
512586
values (List[Dict[str, Union[int, str, NotebookNode]]]): Bokeh tagged cell
513587
outputs.
@@ -518,6 +592,28 @@ def handle_plain(
518592
the tuple is the MDX formatted string.
519593
"""
520594
output = []
595+
adjacent_outputs = []
596+
previous_index = -1
597+
598+
def append_to_output() -> None:
599+
if not adjacent_outputs:
600+
return
601+
adjacent_outputs_str = "\n".join(adjacent_outputs)
602+
output.append(
603+
(
604+
previous_index,
605+
"\n".join(
606+
[
607+
"<CellOutput>",
608+
"{",
609+
f"`{adjacent_outputs_str}`",
610+
"}",
611+
"</CellOutput>\n\n",
612+
]
613+
),
614+
),
615+
)
616+
521617
for value in values:
522618
index = int(value["index"])
523619
data = str(value["data"])
@@ -527,16 +623,23 @@ def handle_plain(
527623
data = "\n".join([line for line in str(value["data"]).splitlines() if line])
528624
# Remove backticks to make the text MDX compatible.
529625
data = data.replace("`", "")
530-
output.append(
531-
(index, f"<CellOutput>\n{{\n `{data}`\n}}\n</CellOutput>\n\n"),
532-
)
626+
if previous_index in [-1, index - 1]:
627+
# store in cache until we reach nonconsecutive index
628+
adjacent_outputs.append(data)
629+
else:
630+
# flush cache to output and start a new one
631+
append_to_output()
632+
adjacent_outputs = [data]
633+
previous_index = index
634+
# flush the remaining cache to output
635+
append_to_output()
533636
return output
534637

535638

536639
def handle_plotly(
537-
values: List[Dict[str, Union[int, str, NotebookNode]]],
640+
values: list[dict[str, int | str | NotebookNode]],
538641
plot_data_folder: Path,
539-
) -> List[Tuple[int, str]]:
642+
) -> list[tuple[int, str]]:
540643
"""
541644
Convert Plotly outputs to MDX.
542645
@@ -567,8 +670,8 @@ def handle_plotly(
567670

568671

569672
def handle_tqdm(
570-
values: List[Dict[str, Union[int, str, NotebookNode]]],
571-
) -> List[Tuple[int, str]]:
673+
values: list[dict[str, int | str | NotebookNode]],
674+
) -> list[tuple[int, str]]:
572675
"""
573676
Handle the output of tqdm.
574677
@@ -590,9 +693,9 @@ def handle_tqdm(
590693
return [(index, f"<CellOutput>\n{{\n `{md}`\n}}\n</CellOutput>\n\n")]
591694

592695

593-
CELL_OUTPUTS_TO_PROCESS = Dict[
696+
CELL_OUTPUTS_TO_PROCESS = dict[
594697
str,
595-
List[Dict[str, Union[int, str, NotebookNode]]],
698+
list[dict[str, Union[int, str, NotebookNode]]],
596699
]
597700

598701

@@ -637,8 +740,8 @@ def aggregate_mdx(
637740

638741

639742
def prioritize_dtypes(
640-
cell_outputs: List[NotebookNode],
641-
) -> Tuple[List[List[str]], List[bool]]:
743+
cell_outputs: list[NotebookNode],
744+
) -> tuple[list[list[str]], list[bool]]:
642745
"""
643746
Prioritize cell output data types.
644747
@@ -679,7 +782,7 @@ def aggregate_images_and_plotly(
679782
prioritized_data_dtype: str,
680783
cell_output: NotebookNode,
681784
data: NotebookNode,
682-
plotly_flags: List[bool],
785+
plotly_flags: list[bool],
683786
cell_outputs_to_process: CELL_OUTPUTS_TO_PROCESS,
684787
i: int,
685788
) -> None:
@@ -742,7 +845,7 @@ def aggregate_plain_output(
742845
cell_outputs_to_process["plain"].append({"index": i, "data": data})
743846

744847

745-
def aggregate_output_types(cell_outputs: List[NotebookNode]) -> CELL_OUTPUTS_TO_PROCESS:
848+
def aggregate_output_types(cell_outputs: list[NotebookNode]) -> CELL_OUTPUTS_TO_PROCESS:
746849
"""
747850
Aggregate cell outputs into a dictionary for further processing.
748851
@@ -880,7 +983,7 @@ def transform_notebook(path: Path, nb_metadata: object) -> str:
880983

881984
# Handle a Markdown cell.
882985
if cell_type == "markdown":
883-
mdx += handle_markdown_cell(cell, img_folder, LIB_DIR)
986+
mdx += handle_markdown_cell(cell, img_folder, path)
884987

885988
# Handle a code cell.
886989
if cell_type == "code":

0 commit comments

Comments
 (0)