@@ -4,6 +4,7 @@ use serde_json::{json, Value};
44
55pub const LAB_VERSION : & str = "gatl-ad-lab/v0.1" ;
66pub const REPORT_FORMAT : & str = "gatl-ad-lab/v0.1-ad-check-report" ;
7+ pub const DERIVATIVE_GOLDEN_FORMAT : & str = "gatl-ad-lab/v0.1-derivative-golden-fixture" ;
78
89#[ derive( Clone , Debug , PartialEq ) ]
910pub struct LabError {
@@ -194,6 +195,46 @@ pub fn gradient_check_report(
194195 } ) )
195196}
196197
198+ pub fn derivative_golden_fixture (
199+ root : & Value ,
200+ options : & EvaluationOptions ,
201+ ) -> Result < Value , Vec < LabError > > {
202+ let model = select_model ( root, options) ?;
203+ let result = gradient_check_model ( & model, options) ?;
204+ let input_cases = options
205+ . inputs
206+ . iter ( )
207+ . map ( |( name, value) | {
208+ json ! ( {
209+ "name" : name,
210+ "value" : value
211+ } )
212+ } )
213+ . collect :: < Vec < _ > > ( ) ;
214+ Ok ( json ! ( {
215+ "apiVersion" : "gatl.dev/v0.1" ,
216+ "kind" : "GATLDerivativeGoldenFixture" ,
217+ "version" : LAB_VERSION ,
218+ "format" : DERIVATIVE_GOLDEN_FORMAT ,
219+ "modelId" : result. model_id,
220+ "assumptions" : assumptions_json( options) ,
221+ "inputCases" : input_cases,
222+ "expectedValue" : result. value,
223+ "expectedReverseGradients" : result. reverse_gradients,
224+ "comparisonSource" : {
225+ "finiteDifferenceEpsilon" : options. epsilon,
226+ "tolerance" : options. tolerance,
227+ "maxAbsError" : result. max_abs_error,
228+ "passed" : result. passed
229+ } ,
230+ "proofBackendBoundary" : {
231+ "status" : "external-proof-backend-required" ,
232+ "consumes" : "typedExpression + derivativeProgram + expectedReverseGradients" ,
233+ "executesProofBackend" : false
234+ }
235+ } ) )
236+ }
237+
197238fn gradient_check_model (
198239 model : & RiskModel ,
199240 options : & EvaluationOptions ,
@@ -885,6 +926,19 @@ mod tests {
885926 assert ! ( ( result. reverse_gradients[ "y" ] - 2.0 ) . abs( ) < 1e-9 ) ;
886927 }
887928
929+ #[ test]
930+ fn exports_derivative_golden_fixture ( ) {
931+ let fixture = derivative_golden_fixture ( & report ( ) , & options ( ) ) . expect ( "golden" ) ;
932+ assert_eq ! (
933+ fixture. pointer( "/format" ) ,
934+ Some ( & json!( DERIVATIVE_GOLDEN_FORMAT ) )
935+ ) ;
936+ assert_eq ! (
937+ fixture. pointer( "/expectedReverseGradients/x" ) ,
938+ Some ( & json!( 3.0 ) )
939+ ) ;
940+ }
941+
888942 #[ test]
889943 fn rejects_missing_input ( ) {
890944 let mut options = options ( ) ;
0 commit comments