diff --git a/controllers/domain.go b/controllers/domain.go index 82c7580..e6fd7f0 100644 --- a/controllers/domain.go +++ b/controllers/domain.go @@ -25,34 +25,30 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) { return nil, err } - r := &models.RecordWithType[dns.SOARecord]{} + r := &models.Record[dns.SOARecord]{} r.Zone = d.WithDotEnd() r.Name = "@" r.RecordType = models.RecordTypeSOA - r.Content.Ns = d.MainDNS - r.Content.MBox = d.EmailSOAForamt() - r.Content.Refresh = d.RefreshInterval - r.Content.Retry = d.RetryInterval - r.Content.Expire = d.ExpiryPeriod - r.Content.MinTtl = d.NegativeTtl + r.Content = d.GenerateSOA() if err := r.CheckZone(); err != nil { tx.Rollback() return nil, err } - if _, err := (recordsDAO{}).Create(tx, *r.ToRecord()); err != nil { + if _, err := (recordsDAO{}).Create(tx, r); err != nil { tx.Rollback() return nil, err } for i, ns := range nss { - record := &models.RecordWithType[dns.NSRecord]{} - record.Zone = d.WithDotEnd() - record.RecordType = models.RecordTypeNS + record := &models.Record[dns.NSRecord]{ + Zone: d.WithDotEnd(), + RecordType: models.RecordTypeSOA, + Name: fmt.Sprintf("ns%d", i+1), + } record.Content.Host = ns - record.Name = fmt.Sprintf("ns%d", i+1) - if _, err := (recordsDAO{}).Create(tx, *record.ToRecord()); err != nil { + if _, err := (recordsDAO{}).Create(tx, record); err != nil { tx.Rollback() return nil, err } @@ -77,7 +73,7 @@ func UpdateDomain(d *models.Domain) error { return err } - soa, err := (recordsDAO{}).GetOne(tx, models.Record{ + soa, err := (recordsDAO{}).GetOne(tx, &models.Record[models.RecordContentDefault]{ RecordType: models.RecordTypeSOA, Zone: d.WithDotEnd(), }) if err != nil { @@ -85,24 +81,19 @@ func UpdateDomain(d *models.Domain) error { return err } - r := &models.RecordWithType[dns.SOARecord]{} - if err := r.FromRecord(&soa); err != nil { + r := &models.Record[dns.SOARecord]{} + if err := r.FromEntity(soa); err != nil { tx.Rollback() return err } - r.Content.Ns = d.MainDNS - r.Content.MBox = d.EmailSOAForamt() - r.Content.Refresh = d.RefreshInterval - r.Content.Retry = d.RetryInterval - r.Content.Expire = d.ExpiryPeriod - r.Content.MinTtl = d.NegativeTtl + r.Content = d.GenerateSOA() if err := r.CheckZone(); err != nil { tx.Rollback() return err } - if _, err := (recordsDAO{}).Update(tx, *r.ToRecord()); err != nil { + if _, err := (recordsDAO{}).Update(tx, r); err != nil { tx.Rollback() return err } @@ -130,7 +121,7 @@ func DeleteDomain(id string) error { return err } - if err := (recordsDAO{}).Delete(tx, models.Record{Zone: domain.WithDotEnd()}); err != nil { + if err := (recordsDAO{}).Delete(tx, &models.Record[models.RecordContentDefault]{Zone: domain.WithDotEnd()}); err != nil { tx.Rollback() return err } diff --git a/controllers/metrics.go b/controllers/metrics.go index e8b6d64..198e4ed 100644 --- a/controllers/metrics.go +++ b/controllers/metrics.go @@ -41,7 +41,7 @@ func RegisterMetrics() { } } -func RefreshMetrics() error { +func RefreshMetrics() error { domainCounts, err := getDomainCounts() if err != nil { return err diff --git a/controllers/migrate.go b/controllers/migrate.go index 648efc9..eb13b53 100644 --- a/controllers/migrate.go +++ b/controllers/migrate.go @@ -10,7 +10,8 @@ func Migrate() error { return err } - if err := (recordsDAO{}).Migrate(database.Client, models.Record{}); err != nil { + var recordDefiniation models.IRecord = &models.Record[models.RecordContentDefault]{} + if err := (recordsDAO{}).Migrate(database.Client, recordDefiniation); err != nil { return err } diff --git a/controllers/record.go b/controllers/record.go index fa69d2b..d67ae02 100644 --- a/controllers/record.go +++ b/controllers/record.go @@ -10,11 +10,11 @@ import ( ) type recordsDAO struct { - database.BaseDAO[models.Record] + database.BaseDAO[models.IRecord] } -func CreateRecord(r *models.Record) (*models.Record, error) { - if r.RecordType != models.RecordTypeSOA { +func CreateRecord(r models.IRecord) (models.IRecord, error) { + if r.GetType() != models.RecordTypeSOA { _, err := GetDomains(r.WithOutDotTail()) if err != nil { return nil, err @@ -25,11 +25,11 @@ func CreateRecord(r *models.Record) (*models.Record, error) { return nil, err } - res, err := (recordsDAO{}).Create(database.Client, *r) - return &res, err + res, err := (recordsDAO{}).Create(database.Client, r) + return res, err } -func CreateRecords(rs []*models.Record) error { +func CreateRecords(rs []models.IRecord) error { tx := database.Client.Begin() for _, r := range rs { if err := r.CheckZone(); err != nil { @@ -37,7 +37,7 @@ func CreateRecords(rs []*models.Record) error { return err } - if _, err := (recordsDAO{}).Create(tx, *r); err != nil { + if _, err := (recordsDAO{}).Create(tx, r); err != nil { tx.Rollback() return err } @@ -46,16 +46,16 @@ func CreateRecords(rs []*models.Record) error { return nil } -func GetRecords(cond models.Record) ([]models.Record, error) { +func GetRecords(cond models.IRecord) ([]models.IRecord, error) { return (recordsDAO{}).GetAll(database.Client, cond) } -func UpdateRecord(r *models.Record) error { +func UpdateRecord(r models.IRecord) error { if err := r.CheckZone(); err != nil { return err } - if _, err := (recordsDAO{}).Update(database.Client, *r); err != nil { + if _, err := (recordsDAO{}).Update(database.Client, r); err != nil { return err } return nil @@ -68,13 +68,13 @@ func DeleteRecord(domain, id string) error { } tx := database.Client.Begin() - record, err := (recordsDAO{}).GetOne(tx, models.Record{ID: ID, Zone: fmt.Sprintf("%s.", domain)}) + record, err := (recordsDAO{}).GetOne(tx, &models.Record[models.RecordContentDefault]{ID: ID, Zone: fmt.Sprintf("%s.", domain)}) if err != nil { tx.Rollback() return err } - if record.RecordType == models.RecordTypeSOA { + if record.GetType() == models.RecordTypeSOA { tx.Rollback() return gorm.ErrRecordNotFound } @@ -90,14 +90,18 @@ func DeleteRecord(domain, id string) error { // for metrics func getRecordCounts() (map[string]float64, error) { - rows, err := (recordsDAO{}).GetAll(database.Client, models.Record{}) + rows, err := (recordsDAO{}).GetAll(database.Client, &models.Record[models.RecordContentDefault]{}) if err != nil { return nil, err } result := make(map[string]float64) for _, row := range rows { - result[row.Zone] += 1 + record := &models.Record[models.RecordContentDefault]{} + if err := record.FromEntity(row); err != nil { + return nil, err + } + result[record.Zone] += 1 } return result, nil } diff --git a/main.go b/main.go index 1d8ffdd..f1d4463 100644 --- a/main.go +++ b/main.go @@ -26,7 +26,7 @@ func main() { Name: "mysql-dsn", Usage: "mysql dsn", Required: true, - EnvVars: []string{"RECORED_MYSQL_DSN"}, + EnvVars: []string{"RECORED_MYSQL_DSN"}, }, &cli.BoolFlag{ Name: "debug", diff --git a/models/domain.go b/models/domain.go index f110428..87b4b49 100644 --- a/models/domain.go +++ b/models/domain.go @@ -3,6 +3,8 @@ package models import ( "fmt" "strings" + + dns "github.com/cloud66-oss/coredns_mysql" ) type Domain struct { @@ -32,3 +34,14 @@ func (d *Domain) WithDotEnd() string { return fmt.Sprintf("%s.", d.DomainName) } } + +func (d *Domain) GenerateSOA() dns.SOARecord { + return dns.SOARecord{ + Ns: d.MainDNS, + MBox: d.EmailSOAForamt(), + Refresh: d.RefreshInterval, + Retry: d.RetryInterval, + Expire: d.ExpiryPeriod, + MinTtl: d.NegativeTtl, + } +} diff --git a/models/record.go b/models/record.go index 8ea5a8f..d55bec8 100644 --- a/models/record.go +++ b/models/record.go @@ -1,6 +1,7 @@ package models import ( + "encoding/json" "fmt" "strings" @@ -19,50 +20,58 @@ const ( RecordTypeSRV = "SRV" ) -type Record struct { +type RecordContentDefault any + +type recordContentTypes interface { + dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord | RecordContentDefault +} + +type Record[T recordContentTypes] struct { ID int `gorm:"primaryKey" json:"id"` Zone string `gorm:"not null,size:255" json:"zone"` Name string `gorm:"not null,size:255" json:"name"` Ttl int `json:"ttl"` - Content any `gorm:"serializer:json,type:\"text\"" json:"content"` + Content T `gorm:"serializer:json,type:\"text\"" json:"content"` RecordType string `gorm:"not null,size:255" json:"record_type"` } -func (Record) TableName() string { +func (Record[T]) TableName() string { return "coredns_record" } -func (r Record) CheckZone() error { +func (r Record[T]) CheckZone() error { if strings.HasSuffix(r.Zone, ".") { return fmt.Errorf("zone should end with '.'") } return nil } -func (r Record) WithOutDotTail() string { +func (r Record[T]) WithOutDotTail() string { return strings.TrimRight(r.Zone, ".") } -type RecordContentTypes interface { - dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord +func (r Record[T]) ToEntity() IRecord { + return &r } -type RecordWithType[T RecordContentTypes] struct { - Record - Content T `json:"content"` -} - -func (r *RecordWithType[T]) ToRecord() *Record { - r.Record.Content = r.Content - return &r.Record -} - -func (r *RecordWithType[T]) FromRecord(record *Record) error { - r.Record = *record - - var ok bool - if r.Content, ok = record.Content.(T); !ok { - return fmt.Errorf("cannot convert record type") +func (r *Record[T]) FromEntity(entity any) error { + b, err := json.Marshal(entity) + if err != nil { + return err } - return nil + + return json.Unmarshal(b, r) +} + +func (r Record[T]) GetType() string { + return r.RecordType +} + +type IRecord interface { + TableName() string + CheckZone() error + WithOutDotTail() string + ToEntity() IRecord + FromEntity(any) error + GetType() string } diff --git a/server/handlers_domains.go b/server/handlers_domains.go index 394cf89..f91ccf2 100644 --- a/server/handlers_domains.go +++ b/server/handlers_domains.go @@ -8,7 +8,7 @@ import ( "github.com/gin-gonic/gin" ) -func (s *Server) getDomains(c *gin.Context) { +func getDomains(c *gin.Context) { domains, err := controllers.GetDomains("") if err != nil { errorHandler(c, err) @@ -21,7 +21,7 @@ func (s *Server) getDomains(c *gin.Context) { }) } -func (s *Server) createDomain(c *gin.Context) { +func createDomain(c *gin.Context) { domain := &models.Domain{} if err := c.BindJSON(domain); err != nil { @@ -44,7 +44,7 @@ func (s *Server) createDomain(c *gin.Context) { }) } -func (s *Server) updateDomain(c *gin.Context) { +func updateDomain(c *gin.Context) { domain := &models.Domain{} if err := c.BindJSON(domain); err != nil { @@ -65,7 +65,7 @@ func (s *Server) updateDomain(c *gin.Context) { }) } -func (s *Server) deleteDomain(c *gin.Context) { +func deleteDomain(c *gin.Context) { id := c.Param("id") if err := controllers.DeleteDomain(id); err != nil { errorHandler(c, err) diff --git a/server/handlers_records.go b/server/handlers_records.go index d73fdbf..29a43c7 100644 --- a/server/handlers_records.go +++ b/server/handlers_records.go @@ -9,8 +9,8 @@ import ( "github.com/gin-gonic/gin" ) -func (s *Server) getRecords(c *gin.Context) { - query := models.Record{} +func getRecords(c *gin.Context) { + query := models.Record[models.RecordContentDefault]{} if err := c.BindQuery(&query); err != nil { c.JSON(http.StatusBadRequest, Response{ Succeed: false, @@ -21,7 +21,7 @@ func (s *Server) getRecords(c *gin.Context) { domain := c.Param("domain") query.Zone = fmt.Sprintf("%s.", domain) - records, err := controllers.GetRecords(query) + records, err := controllers.GetRecords(&query) if err != nil { errorHandler(c, err) return @@ -33,8 +33,8 @@ func (s *Server) getRecords(c *gin.Context) { }) } -func (s *Server) createRecord(c *gin.Context) { - record := &models.Record{} +func createRecord(c *gin.Context) { + record := &models.Record[models.RecordContentDefault]{} if err := c.BindJSON(record); err != nil { c.JSON(http.StatusBadRequest, Response{ Succeed: false, @@ -44,7 +44,7 @@ func (s *Server) createRecord(c *gin.Context) { } domain := c.Param("domain") - if domain != record.Zone { + if domain != record.WithOutDotTail() { c.JSON(http.StatusBadRequest, Response{ Succeed: false, Message: "request body doesn't match URI", @@ -52,7 +52,7 @@ func (s *Server) createRecord(c *gin.Context) { return } - record, err := controllers.CreateRecord(record) + irecord, err := controllers.CreateRecord(record) if err != nil { errorHandler(c, err) return @@ -60,12 +60,12 @@ func (s *Server) createRecord(c *gin.Context) { c.JSON(http.StatusCreated, Response{ Succeed: true, - Data: record, + Data: irecord, }) } -func (s *Server) createRecords(c *gin.Context) { - var records []*models.Record +func createRecords(c *gin.Context) { + var records []models.IRecord if err := c.BindJSON(&records); err != nil { c.JSON(http.StatusBadRequest, Response{ Succeed: false, @@ -84,8 +84,8 @@ func (s *Server) createRecords(c *gin.Context) { }) } -func (s *Server) updateRecord(c *gin.Context) { - record := &models.Record{} +func updateRecord(c *gin.Context) { + record := &models.Record[models.RecordContentDefault]{} if err := c.BindJSON(record); err != nil { c.JSON(http.StatusBadRequest, Response{ Succeed: false, @@ -113,7 +113,7 @@ func (s *Server) updateRecord(c *gin.Context) { }) } -func (s *Server) deleteRecord(c *gin.Context) { +func deleteRecord(c *gin.Context) { domain := c.Param("domain") id := c.Param("id") diff --git a/server/route.go b/server/route.go index ef59a62..722ff15 100644 --- a/server/route.go +++ b/server/route.go @@ -37,18 +37,18 @@ func (s *Server) setupRoute() { domains := groupV1.Group("/domains") domains. - GET("/", s.getDomains). - POST("/", s.createDomain). - PUT("/", s.updateDomain). - DELETE("/:id", s.deleteDomain) + GET("/", getDomains). + POST("/", createDomain). + PUT("/", updateDomain). + DELETE("/:id", deleteDomain) records := groupV1.Group("/records") records. - GET("/:domain", s.getRecords). - POST("/:domain", s.createRecord). - POST("/:domain/bulk", s.createRecords). - PUT("/:domain", s.updateRecord). - DELETE("/:domain/:id", s.deleteRecord) + GET("/:domain", getRecords). + POST("/:domain", createRecord). + POST("/:domain/bulk", createRecords). + PUT("/:domain", updateRecord). + DELETE("/:domain/:id", deleteRecord) server := s.webServer.Group(s.prefix) server.Use(func(ctx *gin.Context) {