Skip to content

Commit f5555e1

Browse files
authored
Merge pull request #1639 from Anselmoo/feature/notebook-refactor
refactor: ♻️ Refactor DataFramePlot methods to improve clarity
2 parents 9b112a4 + b8402d1 commit f5555e1

File tree

1 file changed

+152
-97
lines changed

1 file changed

+152
-97
lines changed

spectrafit/plugins/notebook.py

+152-97
Original file line numberDiff line numberDiff line change
@@ -139,84 +139,137 @@ def markdown_display(df: pd.DataFrame) -> None:
139139

140140

141141
class DataFramePlot:
142-
"""Class to plot a dataframe."""
142+
"""Class to plot a data frame."""
143143

144144
def plot_2dataframes(
145145
self,
146146
args_plot: PlotAPI,
147147
df_1: pd.DataFrame,
148148
df_2: Optional[pd.DataFrame] = None,
149149
) -> None:
150-
"""Plot of two dataframes.
150+
"""Plot two data frames.
151151
152152
!!! info "About the plot"
153153
154-
The plot is a combination of two plots. The first plot is the
155-
can be the residual plot of a fit or the _modified_ data. The second
156-
plot can be the fit or the original data.
154+
The plot is a combination of two plots. The first plot can be the
155+
residual plot of a fit or the _modified_ data. The second plot can be the
156+
fit or the original data.
157157
158158
!!! missing "`line_dash_map`"
159159
160160
Currently, the `line_dash_map` is not working, and the dash is not
161-
plotted. Most likely, this is related to the fact that the columns
162-
are not labeled in the dataframe.
161+
plotted. This is likely due to the columns not being labeled in the
162+
data frame.
163163
164164
Args:
165165
args_plot (PlotAPI): PlotAPI object for the settings of the plot.
166-
df_1 (pd.DataFrame): First dataframe to plot, which will generate
167-
automatically a fit plot with residual plot. The ratio is 70% to 20%
168-
with 10% space in between.
169-
df_2 (Optional[pd.DataFrame], optional): Second optional dataframe to
170-
plot for comparsion. In this case, the ratio will between first
171-
and second plot will be same. Defaults to None.
166+
df_1 (pd.DataFrame): First data frame to plot, which will generate
167+
a fit plot with residual plot. The ratio is 70% to 20% with
168+
10% space in between.
169+
df_2 (Optional[pd.DataFrame], optional): Second optional data frame to
170+
plot for comparison. In this case, the ratio between the first
171+
and second plot will be the same. Defaults to None.
172172
"""
173173
if df_2 is None:
174-
_fig1 = px.line(
175-
df_1,
176-
x=ColumnNamesAPI().energy,
177-
y=ColumnNamesAPI().residual,
178-
color_discrete_sequence=[args_plot.color.residual],
179-
)
180-
_y = df_1.columns.drop([ColumnNamesAPI().energy, ColumnNamesAPI().residual])
181-
_fig2 = px.line(
182-
df_1,
183-
x=ColumnNamesAPI().energy,
184-
y=_y,
185-
color_discrete_map={
186-
ColumnNamesAPI().intensity: args_plot.color.intensity,
187-
ColumnNamesAPI().fit: args_plot.color.fit,
188-
**{
189-
key: args_plot.color.components
190-
for key in _y.drop(
191-
[ColumnNamesAPI().intensity, ColumnNamesAPI().fit]
192-
)
193-
},
194-
},
195-
line_dash_map={
196-
ColumnNamesAPI().intensity: "solid",
197-
ColumnNamesAPI().fit: "longdash",
198-
**{
199-
key: "dash"
200-
for key in _y.drop(
201-
[ColumnNamesAPI().intensity, ColumnNamesAPI().fit]
202-
)
203-
},
204-
},
205-
)
174+
fig = self._plot_single_dataframe(args_plot, df_1)
206175
else:
207-
_fig1 = px.line(df_1, x=args_plot.x, y=args_plot.y)
208-
_fig2 = px.line(df_2, x=args_plot.x, y=args_plot.y)
176+
fig = self._plot_two_dataframes(args_plot, df_1, df_2)
177+
178+
fig.show(
179+
config={
180+
"toImageButtonOptions": {
181+
"format": "png",
182+
"filename": "plot_of_2_dataframes",
183+
"scale": 4,
184+
}
185+
}
186+
)
187+
188+
def _plot_single_dataframe(self, args_plot: PlotAPI, df: pd.DataFrame) -> Figure:
189+
"""Plot a single data frame with residuals."""
190+
fig = make_subplots(
191+
rows=2, cols=1, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.05
192+
)
209193

194+
residual_fig = self._create_residual_plot(df, args_plot)
195+
fit_fig = self._create_fit_plot(df, args_plot)
196+
197+
for trace in residual_fig["data"]:
198+
fig.add_trace(trace, row=1, col=1)
199+
for trace in fit_fig["data"]:
200+
fig.add_trace(trace, row=2, col=1)
201+
202+
self._update_plot_layout(fig, args_plot, df_2_provided=False)
203+
return fig
204+
205+
def _plot_two_dataframes(
206+
self, args_plot: PlotAPI, df_1: pd.DataFrame, df_2: pd.DataFrame
207+
) -> Figure:
208+
"""Plot two data frames for comparison."""
210209
fig = make_subplots(
211210
rows=2, cols=1, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.05
212211
)
213212

214-
for _spec_1 in _fig1["data"]:
215-
fig.append_trace(_spec_1, row=1, col=1)
216-
for _spec_2 in _fig2["data"]:
217-
fig.append_trace(_spec_2, row=2, col=1)
213+
fig1 = px.line(df_1, x=args_plot.x, y=args_plot.y)
214+
fig2 = px.line(df_2, x=args_plot.x, y=args_plot.y)
215+
216+
for trace in fig1["data"]:
217+
fig.add_trace(trace, row=1, col=1)
218+
for trace in fig2["data"]:
219+
fig.add_trace(trace, row=2, col=1)
220+
221+
self._update_plot_layout(fig, args_plot, df_2_provided=True)
222+
return fig
223+
224+
def _create_residual_plot(self, df: pd.DataFrame, args_plot: PlotAPI) -> Figure:
225+
"""Create the residual plot."""
226+
return px.line(
227+
df,
228+
x=ColumnNamesAPI().energy,
229+
y=ColumnNamesAPI().residual,
230+
color_discrete_sequence=[args_plot.color.residual],
231+
)
232+
233+
def _create_fit_plot(self, df: pd.DataFrame, args_plot: PlotAPI) -> Figure:
234+
"""Create the fit plot."""
235+
y_columns = df.columns.drop(
236+
[ColumnNamesAPI().energy, ColumnNamesAPI().residual]
237+
)
238+
color_map = {
239+
ColumnNamesAPI().intensity: args_plot.color.intensity,
240+
ColumnNamesAPI().fit: args_plot.color.fit,
241+
**{
242+
key: args_plot.color.components
243+
for key in y_columns.drop(
244+
[ColumnNamesAPI().intensity, ColumnNamesAPI().fit]
245+
)
246+
},
247+
}
248+
line_dash_map = {
249+
ColumnNamesAPI().intensity: "solid",
250+
ColumnNamesAPI().fit: "longdash",
251+
**{
252+
key: "dash"
253+
for key in y_columns.drop(
254+
[ColumnNamesAPI().intensity, ColumnNamesAPI().fit]
255+
)
256+
},
257+
}
258+
return px.line(
259+
df,
260+
x=ColumnNamesAPI().energy,
261+
y=y_columns,
262+
color_discrete_map=color_map,
263+
line_dash_map=line_dash_map,
264+
)
265+
266+
def _update_plot_layout(
267+
self, fig: Figure, args_plot: PlotAPI, df_2_provided: bool
268+
) -> None:
269+
"""Update the plot layout."""
218270
height = args_plot.size[1][0]
219271
self.update_layout_axes(fig, args_plot, height)
272+
220273
xaxis_title = self.title_text(
221274
name=args_plot.xaxis_title.name, unit=args_plot.xaxis_title.unit
222275
)
@@ -226,7 +279,8 @@ def plot_2dataframes(
226279

227280
fig.update_xaxes(title_text=xaxis_title, row=1, col=1)
228281
fig.update_xaxes(title_text=xaxis_title, row=2, col=1)
229-
if df_2 is None:
282+
283+
if not df_2_provided:
230284
residual_title = self.title_text(
231285
name=args_plot.residual_title.name, unit=args_plot.residual_title.unit
232286
)
@@ -235,25 +289,20 @@ def plot_2dataframes(
235289
fig.update_yaxes(title_text=residual_title, row=1, col=1)
236290
else:
237291
fig.update_yaxes(title_text=yaxis_title, row=1, col=1)
292+
238293
fig.update_yaxes(title_text=yaxis_title, row=2, col=1)
239-
fig.show(
240-
config={
241-
"toImageButtonOptions": dict(
242-
format="png", filename="plot_of_2_dataframes", scale=4
243-
)
244-
}
245-
)
246294

247295
def plot_dataframe(self, args_plot: PlotAPI, df: pd.DataFrame) -> None:
248-
"""Plot the dataframe according to the PlotAPI arguments.
296+
"""Plot the data frame according to the PlotAPI arguments.
249297
250298
Args:
251299
args_plot (PlotAPI): PlotAPI object for the settings of the plot.
252-
df (pd.DataFrame): Dataframe to plot.
300+
df (pd.DataFrame): Data frame to plot.
253301
"""
254302
fig = px.line(df, x=args_plot.x, y=args_plot.y)
255303
height = args_plot.size[1][0]
256304
self.update_layout_axes(fig, args_plot, height)
305+
257306
fig.update_xaxes(
258307
title_text=self.title_text(
259308
name=args_plot.xaxis_title.name, unit=args_plot.xaxis_title.unit
@@ -266,34 +315,33 @@ def plot_dataframe(self, args_plot: PlotAPI, df: pd.DataFrame) -> None:
266315
)
267316
fig.show(
268317
config={
269-
"toImageButtonOptions": dict(
270-
format="png", filename="plot_dataframe", scale=4
271-
)
318+
"toImageButtonOptions": {
319+
"format": "png",
320+
"filename": "plot_dataframe",
321+
"scale": 4,
322+
}
272323
}
273324
)
274325

275326
def plot_global_fit(self, args_plot: PlotAPI, df: pd.DataFrame) -> None:
276-
"""Plot the global dataframe according to the PlotAPI arguments.
327+
"""Plot the global data frame according to the PlotAPI arguments.
277328
278329
Args:
279330
args_plot (PlotAPI): PlotAPI object for the settings of the plot.
280-
df (pd.DataFrame): Dataframe to plot.
331+
df (pd.DataFrame): Data frame to plot.
281332
"""
282-
for i in range(
283-
1,
284-
sum(bool(_col.startswith(ColumnNamesAPI().fit)) for _col in df.columns) + 1,
285-
):
286-
_col = [col for col in df.columns if col.endswith(str(i))]
287-
_col.append(ColumnNamesAPI().energy)
288-
_df = df[_col]
289-
_df = _df.rename(
333+
num_fits = df.columns.str.startswith(ColumnNamesAPI().fit).sum()
334+
for i in range(1, num_fits + 1):
335+
cols = [col for col in df.columns if col.endswith(f"_{i}")]
336+
cols.append(ColumnNamesAPI().energy)
337+
df_subset = df[cols].rename(
290338
columns={
291339
f"{ColumnNamesAPI().intensity}_{i}": ColumnNamesAPI().intensity,
292340
f"{ColumnNamesAPI().fit}_{i}": ColumnNamesAPI().fit,
293341
f"{ColumnNamesAPI().residual}_{i}": ColumnNamesAPI().residual,
294342
}
295343
)
296-
self.plot_2dataframes(args_plot, _df)
344+
self.plot_2dataframes(args_plot, df_subset)
297345

298346
def plot_metric(
299347
self,
@@ -306,28 +354,32 @@ def plot_metric(
306354
307355
Args:
308356
args_plot (PlotAPI): PlotAPI object for the settings of the plot.
309-
df_metric (pd.DataFrame): Metric dataframe to plot.
310-
bar_criteria (Union[str, List[str]]): String or list of criteria to plot as
311-
bars.
312-
line_criteria (Union[str, List[str]]): String or l of criteria to plot as
313-
lines.
357+
df_metric (pd.DataFrame): Metric data frame to plot.
358+
bar_criteria (Union[str, List[str]]): Criteria to plot as bars.
359+
line_criteria (Union[str, List[str]]): Criteria to plot as lines.
314360
"""
315361
fig = make_subplots(specs=[[{"secondary_y": True}]])
316-
_fig_bar = px.bar(
362+
fig_bar = px.bar(
317363
df_metric,
318364
y=bar_criteria,
319365
color_discrete_sequence=args_plot.color.bars,
320366
)
321-
_fig_line = px.line(
367+
fig_line = px.line(
322368
df_metric,
323369
y=line_criteria,
324370
color_discrete_sequence=args_plot.color.lines,
325371
)
326-
_fig_line.update_traces(mode="lines+markers", yaxis="y2")
327-
fig.add_traces(_fig_bar.data + _fig_line.data)
372+
fig_line.update_traces(mode="lines+markers", yaxis="y2")
373+
374+
for trace in fig_bar.data:
375+
fig.add_trace(trace)
376+
for trace in fig_line.data:
377+
fig.add_trace(trace)
378+
328379
fig.update_layout(xaxis_type="category")
329380
height = args_plot.size[1][1]
330381
self.update_layout_axes(fig, args_plot, height)
382+
331383
fig.update_xaxes(
332384
title_text=self.title_text(
333385
name=args_plot.run_title.name, unit=args_plot.run_title.unit
@@ -347,9 +399,11 @@ def plot_metric(
347399
)
348400
fig.show(
349401
config={
350-
"toImageButtonOptions": dict(
351-
format="png", filename="plot_metric", scale=4
352-
)
402+
"toImageButtonOptions": {
403+
"format": "png",
404+
"filename": "plot_metric",
405+
"scale": 4,
406+
}
353407
}
354408
)
355409

@@ -378,16 +432,17 @@ def update_layout_axes(
378432
plot_bgcolor=args_plot.color.plot,
379433
)
380434

435+
minor_ticks = self.get_minor(args_plot)
436+
381437
fig.update_xaxes(
382-
minor=self.get_minor(args_plot=args_plot),
438+
minor=minor_ticks,
383439
gridcolor=args_plot.color.grid,
384440
linecolor=args_plot.color.line,
385441
zerolinecolor=args_plot.color.zero_line,
386442
color=args_plot.color.color,
387443
)
388-
389444
fig.update_yaxes(
390-
minor=self.get_minor(args_plot=args_plot),
445+
minor=minor_ticks,
391446
gridcolor=args_plot.color.grid,
392447
linecolor=args_plot.color.line,
393448
zerolinecolor=args_plot.color.zero_line,
@@ -406,7 +461,7 @@ def title_text(name: str, unit: Optional[str] = None) -> str:
406461
Returns:
407462
str: Title text.
408463
"""
409-
return name if unit is None else f"{name} [{unit}]"
464+
return f"{name} [{unit}]" if unit else name
410465

411466
def get_minor(self, args_plot: PlotAPI) -> Dict[str, Union[str, bool]]:
412467
"""Get the minor axis arguments.
@@ -417,12 +472,12 @@ def get_minor(self, args_plot: PlotAPI) -> Dict[str, Union[str, bool]]:
417472
Returns:
418473
Dict[str, Union[str, bool]]: Dictionary with the minor axis arguments.
419474
"""
420-
return dict(
421-
tickcolor=args_plot.color.ticks,
422-
showgrid=args_plot.grid.show,
423-
ticks=args_plot.grid.ticks,
424-
griddash=args_plot.grid.dash,
425-
)
475+
return {
476+
"tickcolor": args_plot.color.ticks,
477+
"showgrid": args_plot.grid.show,
478+
"ticks": args_plot.grid.ticks,
479+
"griddash": args_plot.grid.dash,
480+
}
426481

427482

428483
class ExportResults:

0 commit comments

Comments
 (0)