|
38 | 38 |
|
39 | 39 | def plot_forecast( |
40 | 40 | df: pd.DataFrame, |
41 | | - label_df: pd.DataFrame | pd.Series, |
| 41 | + label_df: pd.DataFrame | pd.Series | None = None, |
42 | 42 | title: str | None = None, |
43 | 43 | fig_width: float = 12, |
44 | 44 | fig_height: float = 3, |
@@ -96,7 +96,12 @@ def plot_forecast( |
96 | 96 | } |
97 | 97 | ) |
98 | 98 |
|
99 | | - if label_df is not None: |
| 99 | + # Ensure df have pd.DatetimeIndex |
| 100 | + if not isinstance(df.index, pd.DatetimeIndex): |
| 101 | + if label_df is None or len(label_df) == 0: |
| 102 | + raise ValueError( |
| 103 | + "label_df is needed since concensus dataframe does not have pd.DatetimeIndex." |
| 104 | + ) |
100 | 105 | df = set_datetime_index(label_df, df) |
101 | 106 |
|
102 | 107 | # Maintain backward compatibility |
@@ -253,7 +258,7 @@ def plot_forecast( |
253 | 258 |
|
254 | 259 | def plot_forecast_from_file( |
255 | 260 | consensus_file: str, |
256 | | - label_file: str, |
| 261 | + label_file: str | None = None, |
257 | 262 | title: str | None = None, |
258 | 263 | fig_width: float = 12, |
259 | 264 | fig_height: float = 3, |
@@ -285,10 +290,12 @@ def plot_forecast_from_file( |
285 | 290 | Returns: |
286 | 291 | plt.Figure: Matplotlib figure object with three vertically stacked subplots. |
287 | 292 | """ |
288 | | - df = pd.read_csv(consensus_file, index_col="id") |
289 | | - label_df = pd.read_csv(label_file, index_col=0, parse_dates=True) |
290 | | - |
291 | | - df = set_datetime_index(label_df, df) |
| 293 | + df = pd.read_csv(consensus_file, index_col=0, parse_dates=True) |
| 294 | + label_df = ( |
| 295 | + pd.read_csv(label_file, index_col=0, parse_dates=True) |
| 296 | + if label_file |
| 297 | + else pd.DataFrame() |
| 298 | + ) |
292 | 299 |
|
293 | 300 | return plot_forecast( |
294 | 301 | df, |
|
0 commit comments