diff --git a/controllers/domain.go b/controllers/domain.go index e6fd7f0..bd89b9d 100644 --- a/controllers/domain.go +++ b/controllers/domain.go @@ -10,7 +10,7 @@ import ( ) type domainsDAO struct { - database.BaseDAO[models.Domain] + database.BaseDAO[models.IDomain] } func CreateDomain(d *models.Domain) (*models.Domain, error) { @@ -20,7 +20,7 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) { } tx := database.Client.Begin() - if _, err := (domainsDAO{}).Create(tx, *d); err != nil { + if _, err := (domainsDAO{}).Create(tx, d); err != nil { tx.Rollback() return nil, err } @@ -54,21 +54,30 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) { } } - tx.Commit() - return d, err + return d, tx.Commit().Error } func GetDomains(domain string) ([]models.Domain, error) { if domain != "" { - return (domainsDAO{}).GetAll(database.Client, models.Domain{DomainName: domain}) + r, err := (domainsDAO{}).GetAll(database.Client, models.Domain{DomainName: domain}) + n := make([]models.Domain, 0) + for _, e := range r { + n = append(n, e.(models.Domain)) + } + return n, err } else { - return (domainsDAO{}).GetAll(database.Client, models.Domain{}) + r, err := (domainsDAO{}).GetAll(database.Client, models.Domain{}) + n := make([]models.Domain, 0) + for _, e := range r { + n = append(n, e.(models.Domain)) + } + return n, err } } func UpdateDomain(d *models.Domain) error { tx := database.Client.Begin() - if _, err := (domainsDAO{}).Update(tx, *d); err != nil { + if _, err := (domainsDAO{}).Update(tx, d); err != nil { tx.Rollback() return err } @@ -98,9 +107,7 @@ func UpdateDomain(d *models.Domain) error { return err } - tx.Commit() - return nil - + return tx.Commit().Error } func DeleteDomain(id string) error { @@ -126,8 +133,7 @@ func DeleteDomain(id string) error { return err } - tx.Commit() - return nil + return tx.Commit().Error } // for metrics diff --git a/controllers/migrate.go b/controllers/migrate.go index eb13b53..a1d32ee 100644 --- a/controllers/migrate.go +++ b/controllers/migrate.go @@ -6,16 +6,15 @@ import ( ) func Migrate() error { - if err := (domainsDAO{}).Migrate(database.Client, models.Domain{}); err != nil { + if err := (domainsDAO{}).Migrate(database.Client, &models.Domain{}); err != nil { return err } - var recordDefiniation models.IRecord = &models.Record[models.RecordContentDefault]{} - if err := (recordsDAO{}).Migrate(database.Client, recordDefiniation); err != nil { + if err := (recordsDAO{}).Migrate(database.Client, &models.Record[models.RecordContentDefault]{Content: make(models.RecordContentDefault)}); err != nil { return err } - if err := (settingsDAO{}).Migrate(database.Client, models.Settings{}); err != nil { + if err := (settingsDAO{}).Migrate(database.Client, &models.Settings{}); err != nil { return err } diff --git a/controllers/record.go b/controllers/record.go index d67ae02..07b3725 100644 --- a/controllers/record.go +++ b/controllers/record.go @@ -42,8 +42,8 @@ func CreateRecords(rs []models.IRecord) error { return err } } - tx.Commit() - return nil + + return tx.Commit().Error } func GetRecords(cond models.IRecord) ([]models.IRecord, error) { @@ -84,8 +84,7 @@ func DeleteRecord(domain, id string) error { return err } - tx.Commit() - return nil + return tx.Commit().Error } // for metrics diff --git a/controllers/settings.go b/controllers/settings.go index d04be81..e04ab65 100644 --- a/controllers/settings.go +++ b/controllers/settings.go @@ -4,18 +4,20 @@ import ( "reCoreD-UI/database" "reCoreD-UI/models" "strings" + + "github.com/sirupsen/logrus" ) const dnsSep = "," type settingsDAO struct { - database.BaseDAO[models.Settings] + database.BaseDAO[models.ISettings] } func SetupDNS(dns ...string) error { settings := models.Settings{Key: models.SettingsKeyDNSServer, Value: strings.Join(dns, dnsSep)} - if _, err := (settingsDAO{}).UpdateOrCreate(database.Client, settings); err != nil { + if _, err := (settingsDAO{}).UpdateOrCreate(database.Client, &settings, models.Settings{Key: models.SettingsKeyDNSServer}); err != nil { return err } @@ -28,10 +30,12 @@ func GetDNS() ([]string, error) { return nil, err } - return strings.Split(settings.Value, dnsSep), nil + return strings.Split(settings.(models.Settings).Value, dnsSep), nil } func SetupAdmin(username, password string) error { + logrus.Debugf("got %s:%s", username, password) + settingUsername := models.Settings{ Key: models.SettingsKeyAdminUsername, Value: username, @@ -42,18 +46,17 @@ func SetupAdmin(username, password string) error { } tx := database.Client.Begin() - if _, err := (settingsDAO{}).UpdateOrCreate(tx, settingUsername); err != nil { + if _, err := (settingsDAO{}).UpdateOrCreate(tx, &settingUsername, models.Settings{Key: models.SettingsKeyAdminUsername}); err != nil { tx.Rollback() return err } - if _, err := (settingsDAO{}).UpdateOrCreate(tx, settingPassword); err != nil { + if _, err := (settingsDAO{}).UpdateOrCreate(tx, &settingPassword, models.Settings{Key: models.SettingsKeyAdminPassword}); err != nil { tx.Rollback() return err } - tx.Commit() - return nil + return tx.Commit().Error } func GetAdmin() (string, string, error) { @@ -61,13 +64,13 @@ func GetAdmin() (string, string, error) { if err != nil { return "", "", err } - username := settings.Value + username := settings.(models.Settings).Value settings, err = (settingsDAO{}).GetOne(database.Client, models.Settings{Key: models.SettingsKeyAdminPassword}) if err != nil { return "", "", err } - password := settings.Value + password := settings.(models.Settings).Value return username, password, nil } diff --git a/database/basedao.go b/database/basedao.go index 6a16a5a..63d908d 100644 --- a/database/basedao.go +++ b/database/basedao.go @@ -3,6 +3,7 @@ package database import ( "errors" + "github.com/sirupsen/logrus" "gorm.io/gorm" ) @@ -12,32 +13,47 @@ func (b BaseDAO[T]) Migrate(db *gorm.DB, e T) error { return db.Set("gorm:table_options", "CHARSET=utf8mb4").AutoMigrate(e) } -func (BaseDAO[T]) GetAll(db *gorm.DB, e T) ([]T, error) { +func (BaseDAO[T]) GetAll(db *gorm.DB, e T, cond ...T) ([]T, error) { var r []T - if err := db.Find(&r, e).Error; err != nil { + tx := db + for _, c := range cond { + tx = tx.Where(c) + } + + if err := tx.Find(&r, e).Error; err != nil { return nil, err } return r, nil } -func (BaseDAO[T]) GetOne(db *gorm.DB, e T) (T, error) { +func (BaseDAO[T]) GetOne(db *gorm.DB, e T, cond ...T) (T, error) { var r T - if err := db.First(&r, e).Error; err != nil { + tx := db + for _, c := range cond { + tx = tx.Where(c) + } + + if err := tx.First(&r, e).Error; err != nil { return r, err } return r, nil } -func (BaseDAO[T]) GetSome(db *gorm.DB, e T, limit, offset int) ([]T, error) { +func (BaseDAO[T]) GetSome(db *gorm.DB, e T, limit, offset int, cond ...T) ([]T, error) { var r []T - if err := db.Find(&r, e).Limit(limit).Offset(offset).Error; err != nil { + tx := db + for _, c := range cond { + tx = tx.Where(c) + } + + if err := tx.Find(&r, e).Limit(limit).Offset(offset).Error; err != nil { return nil, err } return r, nil } func (BaseDAO[T]) Create(db *gorm.DB, e T) (T, error) { - if err := db.Create(&e).Error; err != nil { + if err := db.Create(e).Error; err != nil { return e, err } return e, nil @@ -50,18 +66,26 @@ func (BaseDAO[T]) FirstOrCreate(db *gorm.DB, e T) (T, error) { return e, nil } -func (BaseDAO[T]) Update(db *gorm.DB, e T) (T, error) { - if err := db.Updates(&e).Error; err != nil { +func (BaseDAO[T]) Update(db *gorm.DB, e T, cond ...T) (T, error) { + tx := db.Model(e) + for _, c := range cond { + tx = tx.Where(c) + } + + if err := tx.Updates(&e).Error; err != nil { return e, err } return e, nil } -func (b BaseDAO[T]) UpdateOrCreate(db *gorm.DB, e T) (T, error) { - e, err := b.Update(db, e) +func (b BaseDAO[T]) UpdateOrCreate(db *gorm.DB, e T, cond ...T) (T, error) { + logrus.Debugf("got %v %v %v", db, e, cond) + e, err := b.Update(db, e, cond...) if errors.Is(err, gorm.ErrRecordNotFound) { + logrus.Debug("will create it") return b.Create(db, e) } + logrus.Debugf("return %v %v", e, err) return e, err } diff --git a/database/database.go b/database/database.go index c50d8d9..91bac4d 100644 --- a/database/database.go +++ b/database/database.go @@ -9,7 +9,9 @@ var Client *gorm.DB func Connect(DSN string) error { var err error - Client, err = gorm.Open(mysql.Open(DSN), &gorm.Config{}) + Client, err = gorm.Open(mysql.Open(DSN), &gorm.Config{ + SkipDefaultTransaction: true, + }) if err != nil { return err } diff --git a/main.go b/main.go index 8d726c1..f97bfc4 100644 --- a/main.go +++ b/main.go @@ -43,7 +43,6 @@ func main() { app := &cli.App{ Name: "reCoreD-UI", Usage: "Web UI for CoreDNS", - UseShortOptionHandling: true, Before: altsrc.InitInputSourceWithContext( flags, altsrc.NewYamlSourceFromFlagFunc("config"), ), diff --git a/models/domain.go b/models/domain.go index 64b3b0d..05aec83 100644 --- a/models/domain.go +++ b/models/domain.go @@ -12,13 +12,13 @@ type Domain struct { DomainName string `gorm:"unique,not null,size:255" json:"domain_name"` //SOA Info - MainDNS string `gorm:"not null,size:255" json:"main_dns"` - AdminEmail string `gorm:"not null,size:255" json:"admin_email"` - SerialNumber int64 `gorm:"not null,default:1" json:"serial_number"` - RefreshInterval uint32 `gorm:"not null,size:255,default:\"86400\"" json:"refresh_interval"` - RetryInterval uint32 `gorm:"not null,size:255,default:\"7200\"" json:"retry_interval"` - ExpiryPeriod uint32 `gorm:"not null,size:255,default:\"3600000\"" json:"expiry_period"` - NegativeTtl uint32 `gorm:"not null,size:255,default:\"86400\"" json:"negative_ttl"` + MainDNS string `gorm:"not null;size:255" json:"main_dns"` + AdminEmail string `gorm:"not null;size:255" json:"admin_email"` + SerialNumber int64 `gorm:"not null;default:1" json:"serial_number"` + RefreshInterval uint32 `gorm:"not null;size:255,default:\"86400\"" json:"refresh_interval"` + RetryInterval uint32 `gorm:"not null;size:255,default:\"7200\"" json:"retry_interval"` + ExpiryPeriod uint32 `gorm:"not null;size:255,default:\"3600000\"" json:"expiry_period"` + NegativeTtl uint32 `gorm:"not null;size:255,default:\"86400\"" json:"negative_ttl"` } func (d Domain) EmailSOAForamt() string { @@ -38,7 +38,7 @@ func (d Domain) WithDotEnd() string { } } -func (d *Domain) GenerateSOA() dns.SOARecord { +func (d Domain) GenerateSOA() dns.SOARecord { var ns string if !strings.HasSuffix(d.MainDNS, ".") { ns = fmt.Sprintf("%s.", d.MainDNS) @@ -54,3 +54,9 @@ func (d *Domain) GenerateSOA() dns.SOARecord { MinTtl: d.NegativeTtl, } } + +type IDomain interface { + EmailSOAForamt() string + WithDotEnd() string + GenerateSOA() dns.SOARecord +} diff --git a/models/record.go b/models/record.go index f78b27c..b972e59 100644 --- a/models/record.go +++ b/models/record.go @@ -22,7 +22,7 @@ const ( RecordTypeSRV = "SRV" ) -type RecordContentDefault any +type RecordContentDefault map[string]any type recordContentTypes interface { dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord | RecordContentDefault @@ -30,11 +30,11 @@ type recordContentTypes interface { type Record[T recordContentTypes] struct { ID int `gorm:"primaryKey" json:"id"` - Zone string `gorm:"not null,size:255" json:"zone"` - Name string `gorm:"not null,size:255" json:"name"` + Zone string `gorm:"not null;size:255" json:"zone"` + Name string `gorm:"not null;size:255" json:"name"` Ttl int `json:"ttl"` - Content T `gorm:"serializer:json,type:\"text\"" json:"content"` - RecordType string `gorm:"not null,size:255" json:"record_type"` + Content T `gorm:"serializer:json;type:text" json:"content"` + RecordType string `gorm:"not null;size:255" json:"record_type"` } func (Record[T]) TableName() string { diff --git a/models/settings.go b/models/settings.go index 240a44d..b7cd999 100644 --- a/models/settings.go +++ b/models/settings.go @@ -1,6 +1,10 @@ package models -import "gorm.io/gorm" +import ( + "fmt" + + "gorm.io/gorm" +) const ( SettingsKeyAdminUsername = "admin.username" @@ -10,6 +14,14 @@ const ( type Settings struct { gorm.Model - Key string `gorm:"unique,not null,size:255"` - Value string `gorm:"not null,size:255"` + Key string `gorm:"unique;not null;size:255"` + Value string `gorm:"not null;size:255"` +} + +func (s Settings) String() string { + return fmt.Sprintf("%s: %s", s.Key, s.Value) +} + +type ISettings interface { + String() string }