Skip to content

Commit a24c5fc

Browse files
JimA-cyborg“Trumanahellegitclaude
authored
fixed CheckTrainingStatus to match python and js (#43)
* update the thresholdl for quickflow auto train * new dataset * update test * test(quick_flow): tolerance on exhaustive recall instead of bit-exact 1.0 n_probes==n_lists exhaustive search should recover the ground truth, but JSON->float32 rounding of query vectors differs slightly across SDKs and can flip a few tie-broken neighbors at the top_k boundary. Assert recall within 0.01 of 1.0 rather than exact equality so the check is portable. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * renamed CheckTrainingStatus --------- Co-authored-by: “Truman <“trumanngny@yahoo.com”> Co-authored-by: ahellegit <ah.secured@gmail.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent 84e35f7 commit a24c5fc

4 files changed

Lines changed: 11 additions & 12 deletions

File tree

encrypted_index.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,14 @@ func (e *EncryptedIndex) IsTrained(ctx context.Context) (bool, error) {
170170
return resp.GetIsTrained(), nil
171171
}
172172

173-
// CheckTrainingStatus queries the server to check whether this index is
174-
// currently being trained (e.g. by an auto-training trigger). It reflects live
175-
// server state on every call and holds no cached status, matching the Python
176-
// SDK's is_training.
173+
// IsTraining queries the server to check whether this index is currently being
174+
// trained (e.g. by an auto-training trigger). It reflects live server state on
175+
// every call and holds no cached status, matching the Python SDK's is_training.
177176
//
178177
// Returns:
179178
// - bool: true if the index is currently being trained, false otherwise
180179
// - error: Any error encountered during the status check
181-
func (e *EncryptedIndex) CheckTrainingStatus(ctx context.Context) (bool, error) {
180+
func (e *EncryptedIndex) IsTraining(ctx context.Context) (bool, error) {
182181
result, _, err := e.client.APIClient.DefaultAPI.GetTrainingStatusV1IndexesTrainingStatusGet(ctx).Execute()
183182
if err != nil {
184183
return false, fmt.Errorf("failed to get training status: %w", err)
@@ -607,7 +606,7 @@ func (e *EncryptedIndex) Delete(ctx context.Context, ids []string) error {
607606
// be suboptimal until training completes.
608607
//
609608
// All parameters are optional with sensible defaults. Use IsTrained or
610-
// CheckTrainingStatus to observe training state, which is read live from the
609+
// IsTraining to observe training state, which is read live from the
611610
// server.
612611
//
613612
// Parameters:

test/api_contract_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,11 @@ func TestClientIsTraining(t *testing.T) {
510510
defer cancel()
511511

512512
t.Run("ReturnTrainingStatusWithCorrectSchema", func(t *testing.T) {
513-
// Note: In Go SDK, this is CheckTrainingStatus on the index, not client
513+
// Note: In Go SDK, this is IsTraining on the index, not client
514514
// But we can verify the behavior exists
515-
isTraining, err := testIndex.CheckTrainingStatus(ctx)
515+
isTraining, err := testIndex.IsTraining(ctx)
516516
if err != nil {
517-
t.Fatalf("CheckTrainingStatus failed: %v", err)
517+
t.Fatalf("IsTraining failed: %v", err)
518518
}
519519

520520
if reflect.TypeOf(isTraining).Kind() != reflect.Bool {

test/kms_byok_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ func runKMSRoundTrip(t *testing.T, cfg kmsBYOKConfig, kmsName string) {
210210
if _, err := loaded.IsTrained(ctx); err != nil {
211211
t.Errorf("IsTrained: %v", err)
212212
}
213-
if _, err := loaded.CheckTrainingStatus(ctx); err != nil {
214-
t.Errorf("CheckTrainingStatus: %v", err)
213+
if _, err := loaded.IsTraining(ctx); err != nil {
214+
t.Errorf("IsTraining: %v", err)
215215
}
216216

217217
if err := loaded.Delete(ctx, []string{"0"}); err != nil {

test/quick_flow_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ func TestUnitFlow(t *testing.T) {
651651
for attempt := 0; attempt < numRetries; attempt++ {
652652
time.Sleep(2 * time.Second)
653653

654-
isTraining, checkErr := index.CheckTrainingStatus(ctx)
654+
isTraining, checkErr := index.IsTraining(ctx)
655655
if checkErr != nil {
656656
fmt.Printf("Error checking training status: %v, retrying... (%d/%d)\n", checkErr, attempt+1, numRetries)
657657
continue

0 commit comments

Comments
 (0)