Skip to content

Commit 0f13f50

Browse files
authored
Merge pull request #471 from booxter/fix-typed-nil-predicate-panic
fix: don't panic on typed nil predicate
2 parents 2ffc49d + 2c668e1 commit 0f13f50

2 files changed

Lines changed: 29 additions & 5 deletions

File tree

client/api.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,14 @@ func (a api) List(ctx context.Context, result any) error {
158158
return err
159159
}
160160

161-
if a.cond != nil && a.cond.Table() != table {
162-
return &ErrWrongType{resultPtr.Type(),
163-
fmt.Sprintf("Table derived from input type (%s) does not match Table from Condition (%s)", table, a.cond.Table())}
161+
if a.cond != nil {
162+
if errCond, ok := a.cond.(*errorConditional); ok {
163+
return errCond.err
164+
}
165+
if a.cond.Table() != table {
166+
return &ErrWrongType{resultPtr.Type(),
167+
fmt.Sprintf("Table derived from input type (%s) does not match Table from Condition (%s)", table, a.cond.Table())}
168+
}
164169
}
165170

166171
tableCache := a.cache.Table(table)
@@ -636,6 +641,9 @@ func (a api) getTableFromFunc(predicate any) (string, error) {
636641
if predType == nil || predType.Kind() != reflect.Func {
637642
return "", &ErrWrongType{predType, "Expected function"}
638643
}
644+
if reflect.ValueOf(predicate).IsNil() {
645+
return "", &ErrWrongType{predType, "Expected non-nil function"}
646+
}
639647
if predType.NumIn() != 1 || predType.NumOut() != 1 || predType.Out(0).Kind() != reflect.Bool {
640648
return "", &ErrWrongType{predType, "Expected func(Model) bool"}
641649
}

client/api_test.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ func TestAPIListPredicate(t *testing.T) {
223223
predicate any
224224
content []model.Model
225225
err bool
226+
errText string
226227
}{
227228
{
228229
name: "none",
@@ -241,8 +242,9 @@ func TestAPIListPredicate(t *testing.T) {
241242
err: false,
242243
},
243244
{
244-
name: "nil function must fail",
245-
err: true,
245+
name: "nil interface must fail",
246+
err: true,
247+
errText: "Expected function",
246248
},
247249
{
248250
name: "arbitrary condition",
@@ -269,13 +271,27 @@ func TestAPIListPredicate(t *testing.T) {
269271
err := cond.List(context.Background(), &result)
270272
if tt.err {
271273
require.Error(t, err)
274+
if tt.errText != "" {
275+
require.ErrorContains(t, err, tt.errText)
276+
}
272277
} else {
273278
require.NoError(t, err)
274279
assert.ElementsMatchf(t, tt.content, result, "Content should match")
275280
}
276281

277282
})
278283
}
284+
285+
t.Run("ApiListPredicate: typed nil function must fail", func(t *testing.T) {
286+
var predicate func(*testLogicalSwitch) bool
287+
var result []*testLogicalSwitch
288+
api := newAPI(tcache, &discardLogger, false)
289+
cond := api.WhereCache(predicate)
290+
291+
err := cond.List(context.Background(), &result)
292+
require.Error(t, err)
293+
require.ErrorContains(t, err, "Expected non-nil function")
294+
})
279295
}
280296

281297
func TestAPIListWhereConditions(t *testing.T) {

0 commit comments

Comments
 (0)