Skip to content

Commit 9b2d750

Browse files
committed
add client side schema validation
1 parent 7f3049e commit 9b2d750

13 files changed

Lines changed: 1463 additions & 256 deletions

File tree

client/api.go

Lines changed: 146 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/go-logr/logr"
1010
"github.com/ovn-org/libovsdb/cache"
11+
"github.com/ovn-org/libovsdb/mapper"
1112
"github.com/ovn-org/libovsdb/model"
1213
"github.com/ovn-org/libovsdb/ovsdb"
1314
)
@@ -280,98 +281,172 @@ func (a api) Get(ctx context.Context, m model.Model) error {
280281
// Create is a generic function capable of creating any row in the DB
281282
// A valid Model (pointer to object) must be provided.
282283
func (a api) Create(models ...model.Model) ([]ovsdb.Operation, error) {
284+
if len(models) == 0 {
285+
return nil, nil
286+
}
287+
288+
// Validate models before proceeding
289+
for _, m := range models {
290+
if err := validateModel(m); err != nil {
291+
// Include model details in the error if possible/useful
292+
modelName := reflect.TypeOf(m).String()
293+
return nil, fmt.Errorf("validation failed for model %s: %w", modelName, err)
294+
}
295+
}
296+
283297
var operations []ovsdb.Operation
298+
var tableName string
299+
var err error
284300

285-
for _, model := range models {
286-
var realUUID, namedUUID string
287-
var err error
301+
for _, m := range models {
302+
var namedUUID string // Store named UUID if found
288303

289-
tableName, err := a.getTableFromModel(model)
290-
if err != nil {
291-
return nil, err
304+
if tableName == "" {
305+
tableName, err = a.getTableFromModel(m)
306+
if err != nil {
307+
return nil, err
308+
}
309+
} else {
310+
currentTable, err := a.getTableFromModel(m)
311+
if err != nil {
312+
return nil, err
313+
}
314+
if currentTable != tableName {
315+
return nil, fmt.Errorf("models must belong to the same table for a single Create operation (%s != %s)", currentTable, tableName)
316+
}
292317
}
293318

294-
// Read _uuid field, and use it as named-uuid
295-
info, err := a.cache.DatabaseModel().NewModelInfo(model)
319+
// Use the DatabaseModel associated with the cache to get info
320+
info, err := a.cache.DatabaseModel().NewModelInfo(m)
296321
if err != nil {
297322
return nil, err
298323
}
299-
if uuid, err := info.FieldByColumn("_uuid"); err == nil {
300-
tmpUUID := uuid.(string)
301-
if ovsdb.IsNamedUUID(tmpUUID) {
302-
namedUUID = tmpUUID
303-
} else if ovsdb.IsValidUUID(tmpUUID) {
304-
realUUID = tmpUUID
324+
325+
// Check for named UUID in the _uuid field before creating the row
326+
if uuidField, err := info.FieldByColumn("_uuid"); err == nil {
327+
if uuidStr, ok := uuidField.(string); ok {
328+
if ovsdb.IsNamedUUID(uuidStr) {
329+
namedUUID = uuidStr
330+
}
305331
}
306-
} else {
307-
return nil, err
332+
} else if !errors.Is(err, &mapper.ErrColumnNotFound{}) {
333+
// Error other than column not found when accessing _uuid
334+
return nil, fmt.Errorf("error accessing _uuid field: %w", err)
308335
}
309336

337+
// Use the Mapper associated with the cache to create the row
310338
row, err := a.cache.Mapper().NewRow(info)
311339
if err != nil {
312340
return nil, err
313341
}
314-
// UUID is given in the operation, not the object
315-
delete(row, "_uuid")
316342

317-
operations = append(operations, ovsdb.Operation{
343+
// If a named UUID was found, remove the _uuid field from the row data
344+
// as it's provided via UUIDName in the operation.
345+
// If it was a regular UUID or empty, let it remain in the row.
346+
if namedUUID != "" {
347+
delete(row, "_uuid")
348+
}
349+
350+
op := ovsdb.Operation{
318351
Op: ovsdb.OperationInsert,
319352
Table: tableName,
320353
Row: row,
321-
UUID: realUUID,
322354
UUIDName: namedUUID,
323-
})
355+
}
356+
operations = append(operations, op)
324357
}
325358
return operations, nil
326359
}
327360

328361
// Mutate returns the operations needed to transform the one Model into another one
329362
func (a api) Mutate(model model.Model, mutationObjs ...model.Mutation) ([]ovsdb.Operation, error) {
330-
var mutations []ovsdb.Mutation
331-
var operations []ovsdb.Operation
332-
333363
if len(mutationObjs) < 1 {
334364
return nil, fmt.Errorf("at least one Mutation must be provided")
335365
}
336-
337-
tableName := a.cache.DatabaseModel().FindTable(reflect.ValueOf(model).Type())
338-
if tableName == "" {
339-
return nil, fmt.Errorf("table not found for object")
340-
}
341-
table := a.cache.Mapper().Schema.Table(tableName)
342-
if table == nil {
343-
return nil, fmt.Errorf("schema error: table not found in Database Model for type %s", reflect.TypeOf(model))
366+
if a.cond == nil {
367+
return nil, fmt.Errorf("Mutate requires a condition. Use Where() first")
344368
}
345369

346-
conditions, err := a.cond.Generate()
370+
tableName, err := a.getTableFromModel(model)
347371
if err != nil {
348372
return nil, err
349373
}
350-
374+
tableSchema := a.cache.DatabaseModel().Schema.Table(tableName)
375+
if tableSchema == nil {
376+
return nil, fmt.Errorf("schema not found for table %s", tableName)
377+
}
351378
info, err := a.cache.DatabaseModel().NewModelInfo(model)
352379
if err != nil {
353380
return nil, err
354381
}
355382

356-
for _, mobj := range mutationObjs {
357-
col, err := info.ColumnByPtr(mobj.Field)
383+
// Validate each mutation value against the field's constraints
384+
modelType := reflect.TypeOf(model).Elem() // Get the struct type
385+
for _, mutation := range mutationObjs {
386+
columnName, err := info.ColumnByPtr(mutation.Field)
358387
if err != nil {
359-
return nil, err
388+
return nil, fmt.Errorf("could not get column for mutation field: %w", err)
389+
}
390+
if !tableSchema.Columns[columnName].Mutable() {
391+
return nil, fmt.Errorf("unable to update field %s of table %s as it is not mutable", columnName, tableName)
392+
}
393+
// Find the struct field corresponding to the column name
394+
var structField reflect.StructField
395+
var found bool
396+
for i := 0; i < modelType.NumField(); i++ {
397+
if modelType.Field(i).Tag.Get("ovsdb") == columnName {
398+
structField = modelType.Field(i)
399+
found = true
400+
break
401+
}
402+
}
403+
if !found {
404+
// Should not happen if ColumnByPtr worked
405+
return nil, fmt.Errorf("could not find struct field for column %s", columnName)
360406
}
361407

362-
mutation, err := a.cache.Mapper().NewMutation(info, col, mobj.Mutator, mobj.Value)
408+
// Extract the validate tag
409+
validateTag := structField.Tag.Get("validate")
410+
411+
// Validate the mutation value if a tag exists
412+
if validateTag != "" {
413+
// We assume mutation.Value is the Go-typed value to be mutated
414+
err = validate.Var(mutation.Value, validateTag)
415+
if err != nil {
416+
return nil, fmt.Errorf("validation failed for column '%s': %w", columnName, err)
417+
}
418+
}
419+
}
420+
421+
// Convert model.Mutation to ovsdb.Mutation and store them
422+
var ovsMutations []ovsdb.Mutation
423+
for _, mutation := range mutationObjs {
424+
columnName, err := info.ColumnByPtr(mutation.Field)
363425
if err != nil {
364-
return nil, err
426+
// This error was already checked during validation, but double check
427+
return nil, fmt.Errorf("could not get column for mutation field: %w", err)
428+
}
429+
ovsMutation, err := a.cache.Mapper().NewMutation(info, columnName, mutation.Mutator, mutation.Value)
430+
if err != nil {
431+
return nil, fmt.Errorf("failed to create OVSDB mutation for column '%s': %w", columnName, err)
365432
}
366-
mutations = append(mutations, *mutation)
433+
ovsMutations = append(ovsMutations, *ovsMutation)
367434
}
435+
436+
// Get the conditions based on the conditional API context
437+
conditions, err := a.cond.Generate()
438+
if err != nil {
439+
return nil, err
440+
}
441+
442+
var operations []ovsdb.Operation
368443
for _, condition := range conditions {
369444
operations = append(operations,
370445
ovsdb.Operation{
371446
Op: ovsdb.OperationMutate,
372447
Table: tableName,
373-
Mutations: mutations,
374448
Where: condition,
449+
Mutations: ovsMutations, // Use the generated OVSDB mutations
375450
},
376451
)
377452
}
@@ -383,12 +458,20 @@ func (a api) Mutate(model model.Model, mutationObjs ...model.Mutation) ([]ovsdb.
383458
// Additional fields can be passed (variadic opts) to indicate fields to be updated
384459
// All immutable fields will be ignored
385460
func (a api) Update(model model.Model, fields ...interface{}) ([]ovsdb.Operation, error) {
386-
var operations []ovsdb.Operation
387-
table, err := a.getTableFromModel(model)
461+
if a.cond == nil {
462+
return nil, fmt.Errorf("Update requires a condition. Use Where() first")
463+
}
464+
465+
if err := validateModel(model); err != nil {
466+
return nil, fmt.Errorf("validation failed for model %s used in Update: %w", reflect.TypeOf(model).String(), err)
467+
}
468+
469+
tableName, err := a.getTableFromModel(model)
388470
if err != nil {
389471
return nil, err
390472
}
391-
tableSchema := a.cache.Mapper().Schema.Table(table)
473+
474+
tableSchema := a.cache.DatabaseModel().Schema.Table(tableName)
392475
info, err := a.cache.DatabaseModel().NewModelInfo(model)
393476
if err != nil {
394477
return nil, err
@@ -401,38 +484,48 @@ func (a api) Update(model model.Model, fields ...interface{}) ([]ovsdb.Operation
401484
return nil, err
402485
}
403486
if !tableSchema.Columns[colName].Mutable() {
404-
return nil, fmt.Errorf("unable to update field %s of table %s as it is not mutable", colName, table)
487+
return nil, fmt.Errorf("unable to update field %s of table %s as it is not mutable", colName, tableName)
405488
}
406489
}
407490
}
408491

409-
conditions, err := a.cond.Generate()
410-
if err != nil {
411-
return nil, err
412-
}
413-
492+
// Convert the model to a row, considering only specified fields if provided
414493
row, err := a.cache.Mapper().NewRow(info, fields...)
415494
if err != nil {
416495
return nil, err
417496
}
418497

498+
// Original behavior: Silently remove immutable fields from the row
419499
for colName, column := range tableSchema.Columns {
420500
if !column.Mutable() {
421-
a.logger.V(2).Info("removing immutable field", "name", colName)
422-
delete(row, colName)
501+
// Only delete if the key actually exists in the row map
502+
if _, exists := row[colName]; exists {
503+
a.logger.V(2).Info("removing immutable field from update row", "name", colName)
504+
delete(row, colName)
505+
}
423506
}
424507
}
508+
// Also remove _uuid explicitly if it exists (should generally not be included by NewRow unless named UUID was used incorrectly here)
425509
delete(row, "_uuid")
426510

511+
// Check if the row is empty after removing immutable fields
427512
if len(row) == 0 {
428513
return nil, fmt.Errorf("attempted to update using an empty row. please check that all fields you wish to update are mutable")
429514
}
430515

516+
// Original Update logic uses the same Condition object as Mutate.
517+
// Use Generate() here as well, consistent with Mutate.
518+
conditions, err := a.cond.Generate()
519+
if err != nil {
520+
return nil, err
521+
}
522+
523+
var operations []ovsdb.Operation
431524
for _, condition := range conditions {
432525
operations = append(operations,
433526
ovsdb.Operation{
434527
Op: ovsdb.OperationUpdate,
435-
Table: table,
528+
Table: tableName, // Use tableName obtained earlier
436529
Row: row,
437530
Where: condition,
438531
},

0 commit comments

Comments
 (0)