1- from behave import given , register_type , when
1+ import json
2+
3+ import pandas as pd
4+ from behave import given , register_type , then , when
25from behave .runner import Context
36from pcse .base import ParameterProvider
47
58from wofostat import (
69 end_of_season_sensitivity_func ,
710 get_parameter_spec ,
11+ objective_func ,
12+ run_optimisation ,
813 run_sensitivity_analysis ,
914 snake_case_string ,
1015)
@@ -23,6 +28,22 @@ def specify_calibration(context: Context) -> None:
2328 context .calibration_spec = {row ["name" ]: row ["value" ] for row in context .table }
2429
2530
31+ @given ('we are using "{distance_metric}" as our error metric' )
32+ def set_distance_metric (context : Context , distance_metric : str ) -> None :
33+ context .distance_metric = distance_metric
34+
35+
36+ @given ('we are using observed data from the "{fpath}" file' )
37+ def get_observed_data (context : Context , fpath : str ) -> None :
38+ context .observed_data = pd .read_csv (fpath )
39+
40+
41+ @given ('we are using ground truth data from the "{fpath}" file' )
42+ def get_ground_truth (context : Context , fpath : str ) -> None :
43+ with open (fpath ) as f :
44+ context .ground_truth = json .load (f )
45+
46+
2647def _get_params (context : Context ) -> ParameterProvider :
2748 params = WOFOST .get_params (
2849 cropd = context .cropd , sited = context .sited , soild = context .soild
@@ -52,3 +73,56 @@ def execute_sensitivity(
5273 engine = engine ,
5374 ** context .calibration_spec ,
5475 )
76+
77+
78+ @then (
79+ 'the "{position}" highest "{order}" sensitivity index for "{state_var}" '
80+ 'should be "{param_name}"'
81+ )
82+ def check_sensitivity_index (
83+ context : Context , position : str , order : str , state_var : str , param_name : str
84+ ) -> None :
85+ index = "" .join (c for c in position if c .isdigit ())
86+ index = int (index ) - 1
87+
88+ if order == "total order" :
89+ order = "ST"
90+ else :
91+ order = "S1"
92+
93+ sensitivity_param = context .sp_df [state_var ][order ].iloc [index ].name
94+ if sensitivity_param != param_name :
95+ raise RuntimeWarning (f"Parameter is { sensitivity_param } " )
96+
97+ assert sensitivity_param == param_name
98+
99+
100+ @when (
101+ 'we execute an optimisation procedure using the "{method:SnakeCaseString}" method '
102+ 'and the "{engine:SnakeCaseString}" library with "{n_iterations:d}" iterations'
103+ )
104+ def execute_optimisation (
105+ context : Context , method : str , engine : str , n_iterations : int
106+ ) -> None :
107+ params = _get_params (context )
108+
109+ (
110+ context .calibrator ,
111+ context .param_importances ,
112+ context .trials_df ,
113+ context .parameter_estimates ,
114+ ) = run_optimisation (
115+ parameter_spec = context .parameter_spec ,
116+ n_iterations = n_iterations ,
117+ wdp = context .wdp ,
118+ agro = context .agro ,
119+ state_vars = context .state_vars ,
120+ calibration_func = objective_func ,
121+ params = params ,
122+ method = method ,
123+ engine = engine ,
124+ ground_truth = context .ground_truth ,
125+ observed_data = context .observed_data ,
126+ distance_metric = context .distance_metric ,
127+ ** context .calibration_spec ,
128+ )
0 commit comments