model with generics done

This commit is contained in:
Sense T 2024-04-09 21:16:19 +08:00
parent 7dd3af3707
commit 9752e7d9ae
10 changed files with 109 additions and 91 deletions

View File

@ -25,34 +25,30 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) {
return nil, err return nil, err
} }
r := &models.RecordWithType[dns.SOARecord]{} r := &models.Record[dns.SOARecord]{}
r.Zone = d.WithDotEnd() r.Zone = d.WithDotEnd()
r.Name = "@" r.Name = "@"
r.RecordType = models.RecordTypeSOA r.RecordType = models.RecordTypeSOA
r.Content.Ns = d.MainDNS r.Content = d.GenerateSOA()
r.Content.MBox = d.EmailSOAForamt()
r.Content.Refresh = d.RefreshInterval
r.Content.Retry = d.RetryInterval
r.Content.Expire = d.ExpiryPeriod
r.Content.MinTtl = d.NegativeTtl
if err := r.CheckZone(); err != nil { if err := r.CheckZone(); err != nil {
tx.Rollback() tx.Rollback()
return nil, err return nil, err
} }
if _, err := (recordsDAO{}).Create(tx, *r.ToRecord()); err != nil { if _, err := (recordsDAO{}).Create(tx, r); err != nil {
tx.Rollback() tx.Rollback()
return nil, err return nil, err
} }
for i, ns := range nss { for i, ns := range nss {
record := &models.RecordWithType[dns.NSRecord]{} record := &models.Record[dns.NSRecord]{
record.Zone = d.WithDotEnd() Zone: d.WithDotEnd(),
record.RecordType = models.RecordTypeNS RecordType: models.RecordTypeSOA,
Name: fmt.Sprintf("ns%d", i+1),
}
record.Content.Host = ns record.Content.Host = ns
record.Name = fmt.Sprintf("ns%d", i+1)
if _, err := (recordsDAO{}).Create(tx, *record.ToRecord()); err != nil { if _, err := (recordsDAO{}).Create(tx, record); err != nil {
tx.Rollback() tx.Rollback()
return nil, err return nil, err
} }
@ -77,7 +73,7 @@ func UpdateDomain(d *models.Domain) error {
return err return err
} }
soa, err := (recordsDAO{}).GetOne(tx, models.Record{ soa, err := (recordsDAO{}).GetOne(tx, &models.Record[models.RecordContentDefault]{
RecordType: models.RecordTypeSOA, Zone: d.WithDotEnd(), RecordType: models.RecordTypeSOA, Zone: d.WithDotEnd(),
}) })
if err != nil { if err != nil {
@ -85,24 +81,19 @@ func UpdateDomain(d *models.Domain) error {
return err return err
} }
r := &models.RecordWithType[dns.SOARecord]{} r := &models.Record[dns.SOARecord]{}
if err := r.FromRecord(&soa); err != nil { if err := r.FromEntity(soa); err != nil {
tx.Rollback() tx.Rollback()
return err return err
} }
r.Content.Ns = d.MainDNS r.Content = d.GenerateSOA()
r.Content.MBox = d.EmailSOAForamt()
r.Content.Refresh = d.RefreshInterval
r.Content.Retry = d.RetryInterval
r.Content.Expire = d.ExpiryPeriod
r.Content.MinTtl = d.NegativeTtl
if err := r.CheckZone(); err != nil { if err := r.CheckZone(); err != nil {
tx.Rollback() tx.Rollback()
return err return err
} }
if _, err := (recordsDAO{}).Update(tx, *r.ToRecord()); err != nil { if _, err := (recordsDAO{}).Update(tx, r); err != nil {
tx.Rollback() tx.Rollback()
return err return err
} }
@ -130,7 +121,7 @@ func DeleteDomain(id string) error {
return err return err
} }
if err := (recordsDAO{}).Delete(tx, models.Record{Zone: domain.WithDotEnd()}); err != nil { if err := (recordsDAO{}).Delete(tx, &models.Record[models.RecordContentDefault]{Zone: domain.WithDotEnd()}); err != nil {
tx.Rollback() tx.Rollback()
return err return err
} }

View File

@ -41,7 +41,7 @@ func RegisterMetrics() {
} }
} }
func RefreshMetrics() error { func RefreshMetrics() error {
domainCounts, err := getDomainCounts() domainCounts, err := getDomainCounts()
if err != nil { if err != nil {
return err return err

View File

@ -10,7 +10,8 @@ func Migrate() error {
return err return err
} }
if err := (recordsDAO{}).Migrate(database.Client, models.Record{}); err != nil { var recordDefiniation models.IRecord = &models.Record[models.RecordContentDefault]{}
if err := (recordsDAO{}).Migrate(database.Client, recordDefiniation); err != nil {
return err return err
} }

View File

@ -10,11 +10,11 @@ import (
) )
type recordsDAO struct { type recordsDAO struct {
database.BaseDAO[models.Record] database.BaseDAO[models.IRecord]
} }
func CreateRecord(r *models.Record) (*models.Record, error) { func CreateRecord(r models.IRecord) (models.IRecord, error) {
if r.RecordType != models.RecordTypeSOA { if r.GetType() != models.RecordTypeSOA {
_, err := GetDomains(r.WithOutDotTail()) _, err := GetDomains(r.WithOutDotTail())
if err != nil { if err != nil {
return nil, err return nil, err
@ -25,11 +25,11 @@ func CreateRecord(r *models.Record) (*models.Record, error) {
return nil, err return nil, err
} }
res, err := (recordsDAO{}).Create(database.Client, *r) res, err := (recordsDAO{}).Create(database.Client, r)
return &res, err return res, err
} }
func CreateRecords(rs []*models.Record) error { func CreateRecords(rs []models.IRecord) error {
tx := database.Client.Begin() tx := database.Client.Begin()
for _, r := range rs { for _, r := range rs {
if err := r.CheckZone(); err != nil { if err := r.CheckZone(); err != nil {
@ -37,7 +37,7 @@ func CreateRecords(rs []*models.Record) error {
return err return err
} }
if _, err := (recordsDAO{}).Create(tx, *r); err != nil { if _, err := (recordsDAO{}).Create(tx, r); err != nil {
tx.Rollback() tx.Rollback()
return err return err
} }
@ -46,16 +46,16 @@ func CreateRecords(rs []*models.Record) error {
return nil return nil
} }
func GetRecords(cond models.Record) ([]models.Record, error) { func GetRecords(cond models.IRecord) ([]models.IRecord, error) {
return (recordsDAO{}).GetAll(database.Client, cond) return (recordsDAO{}).GetAll(database.Client, cond)
} }
func UpdateRecord(r *models.Record) error { func UpdateRecord(r models.IRecord) error {
if err := r.CheckZone(); err != nil { if err := r.CheckZone(); err != nil {
return err return err
} }
if _, err := (recordsDAO{}).Update(database.Client, *r); err != nil { if _, err := (recordsDAO{}).Update(database.Client, r); err != nil {
return err return err
} }
return nil return nil
@ -68,13 +68,13 @@ func DeleteRecord(domain, id string) error {
} }
tx := database.Client.Begin() tx := database.Client.Begin()
record, err := (recordsDAO{}).GetOne(tx, models.Record{ID: ID, Zone: fmt.Sprintf("%s.", domain)}) record, err := (recordsDAO{}).GetOne(tx, &models.Record[models.RecordContentDefault]{ID: ID, Zone: fmt.Sprintf("%s.", domain)})
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return err return err
} }
if record.RecordType == models.RecordTypeSOA { if record.GetType() == models.RecordTypeSOA {
tx.Rollback() tx.Rollback()
return gorm.ErrRecordNotFound return gorm.ErrRecordNotFound
} }
@ -90,14 +90,18 @@ func DeleteRecord(domain, id string) error {
// for metrics // for metrics
func getRecordCounts() (map[string]float64, error) { func getRecordCounts() (map[string]float64, error) {
rows, err := (recordsDAO{}).GetAll(database.Client, models.Record{}) rows, err := (recordsDAO{}).GetAll(database.Client, &models.Record[models.RecordContentDefault]{})
if err != nil { if err != nil {
return nil, err return nil, err
} }
result := make(map[string]float64) result := make(map[string]float64)
for _, row := range rows { for _, row := range rows {
result[row.Zone] += 1 record := &models.Record[models.RecordContentDefault]{}
if err := record.FromEntity(row); err != nil {
return nil, err
}
result[record.Zone] += 1
} }
return result, nil return result, nil
} }

View File

@ -26,7 +26,7 @@ func main() {
Name: "mysql-dsn", Name: "mysql-dsn",
Usage: "mysql dsn", Usage: "mysql dsn",
Required: true, Required: true,
EnvVars: []string{"RECORED_MYSQL_DSN"}, EnvVars: []string{"RECORED_MYSQL_DSN"},
}, },
&cli.BoolFlag{ &cli.BoolFlag{
Name: "debug", Name: "debug",

View File

@ -3,6 +3,8 @@ package models
import ( import (
"fmt" "fmt"
"strings" "strings"
dns "github.com/cloud66-oss/coredns_mysql"
) )
type Domain struct { type Domain struct {
@ -32,3 +34,14 @@ func (d *Domain) WithDotEnd() string {
return fmt.Sprintf("%s.", d.DomainName) return fmt.Sprintf("%s.", d.DomainName)
} }
} }
func (d *Domain) GenerateSOA() dns.SOARecord {
return dns.SOARecord{
Ns: d.MainDNS,
MBox: d.EmailSOAForamt(),
Refresh: d.RefreshInterval,
Retry: d.RetryInterval,
Expire: d.ExpiryPeriod,
MinTtl: d.NegativeTtl,
}
}

View File

@ -1,6 +1,7 @@
package models package models
import ( import (
"encoding/json"
"fmt" "fmt"
"strings" "strings"
@ -19,50 +20,58 @@ const (
RecordTypeSRV = "SRV" RecordTypeSRV = "SRV"
) )
type Record struct { type RecordContentDefault any
type recordContentTypes interface {
dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord | RecordContentDefault
}
type Record[T recordContentTypes] struct {
ID int `gorm:"primaryKey" json:"id"` ID int `gorm:"primaryKey" json:"id"`
Zone string `gorm:"not null,size:255" json:"zone"` Zone string `gorm:"not null,size:255" json:"zone"`
Name string `gorm:"not null,size:255" json:"name"` Name string `gorm:"not null,size:255" json:"name"`
Ttl int `json:"ttl"` Ttl int `json:"ttl"`
Content any `gorm:"serializer:json,type:\"text\"" json:"content"` Content T `gorm:"serializer:json,type:\"text\"" json:"content"`
RecordType string `gorm:"not null,size:255" json:"record_type"` RecordType string `gorm:"not null,size:255" json:"record_type"`
} }
func (Record) TableName() string { func (Record[T]) TableName() string {
return "coredns_record" return "coredns_record"
} }
func (r Record) CheckZone() error { func (r Record[T]) CheckZone() error {
if strings.HasSuffix(r.Zone, ".") { if strings.HasSuffix(r.Zone, ".") {
return fmt.Errorf("zone should end with '.'") return fmt.Errorf("zone should end with '.'")
} }
return nil return nil
} }
func (r Record) WithOutDotTail() string { func (r Record[T]) WithOutDotTail() string {
return strings.TrimRight(r.Zone, ".") return strings.TrimRight(r.Zone, ".")
} }
type RecordContentTypes interface { func (r Record[T]) ToEntity() IRecord {
dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord return &r
} }
type RecordWithType[T RecordContentTypes] struct { func (r *Record[T]) FromEntity(entity any) error {
Record b, err := json.Marshal(entity)
Content T `json:"content"` if err != nil {
} return err
func (r *RecordWithType[T]) ToRecord() *Record {
r.Record.Content = r.Content
return &r.Record
}
func (r *RecordWithType[T]) FromRecord(record *Record) error {
r.Record = *record
var ok bool
if r.Content, ok = record.Content.(T); !ok {
return fmt.Errorf("cannot convert record type")
} }
return nil
return json.Unmarshal(b, r)
}
func (r Record[T]) GetType() string {
return r.RecordType
}
type IRecord interface {
TableName() string
CheckZone() error
WithOutDotTail() string
ToEntity() IRecord
FromEntity(any) error
GetType() string
} }

View File

@ -8,7 +8,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func (s *Server) getDomains(c *gin.Context) { func getDomains(c *gin.Context) {
domains, err := controllers.GetDomains("") domains, err := controllers.GetDomains("")
if err != nil { if err != nil {
errorHandler(c, err) errorHandler(c, err)
@ -21,7 +21,7 @@ func (s *Server) getDomains(c *gin.Context) {
}) })
} }
func (s *Server) createDomain(c *gin.Context) { func createDomain(c *gin.Context) {
domain := &models.Domain{} domain := &models.Domain{}
if err := c.BindJSON(domain); err != nil { if err := c.BindJSON(domain); err != nil {
@ -44,7 +44,7 @@ func (s *Server) createDomain(c *gin.Context) {
}) })
} }
func (s *Server) updateDomain(c *gin.Context) { func updateDomain(c *gin.Context) {
domain := &models.Domain{} domain := &models.Domain{}
if err := c.BindJSON(domain); err != nil { if err := c.BindJSON(domain); err != nil {
@ -65,7 +65,7 @@ func (s *Server) updateDomain(c *gin.Context) {
}) })
} }
func (s *Server) deleteDomain(c *gin.Context) { func deleteDomain(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
if err := controllers.DeleteDomain(id); err != nil { if err := controllers.DeleteDomain(id); err != nil {
errorHandler(c, err) errorHandler(c, err)

View File

@ -9,8 +9,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func (s *Server) getRecords(c *gin.Context) { func getRecords(c *gin.Context) {
query := models.Record{} query := models.Record[models.RecordContentDefault]{}
if err := c.BindQuery(&query); err != nil { if err := c.BindQuery(&query); err != nil {
c.JSON(http.StatusBadRequest, Response{ c.JSON(http.StatusBadRequest, Response{
Succeed: false, Succeed: false,
@ -21,7 +21,7 @@ func (s *Server) getRecords(c *gin.Context) {
domain := c.Param("domain") domain := c.Param("domain")
query.Zone = fmt.Sprintf("%s.", domain) query.Zone = fmt.Sprintf("%s.", domain)
records, err := controllers.GetRecords(query) records, err := controllers.GetRecords(&query)
if err != nil { if err != nil {
errorHandler(c, err) errorHandler(c, err)
return return
@ -33,8 +33,8 @@ func (s *Server) getRecords(c *gin.Context) {
}) })
} }
func (s *Server) createRecord(c *gin.Context) { func createRecord(c *gin.Context) {
record := &models.Record{} record := &models.Record[models.RecordContentDefault]{}
if err := c.BindJSON(record); err != nil { if err := c.BindJSON(record); err != nil {
c.JSON(http.StatusBadRequest, Response{ c.JSON(http.StatusBadRequest, Response{
Succeed: false, Succeed: false,
@ -44,7 +44,7 @@ func (s *Server) createRecord(c *gin.Context) {
} }
domain := c.Param("domain") domain := c.Param("domain")
if domain != record.Zone { if domain != record.WithOutDotTail() {
c.JSON(http.StatusBadRequest, Response{ c.JSON(http.StatusBadRequest, Response{
Succeed: false, Succeed: false,
Message: "request body doesn't match URI", Message: "request body doesn't match URI",
@ -52,7 +52,7 @@ func (s *Server) createRecord(c *gin.Context) {
return return
} }
record, err := controllers.CreateRecord(record) irecord, err := controllers.CreateRecord(record)
if err != nil { if err != nil {
errorHandler(c, err) errorHandler(c, err)
return return
@ -60,12 +60,12 @@ func (s *Server) createRecord(c *gin.Context) {
c.JSON(http.StatusCreated, Response{ c.JSON(http.StatusCreated, Response{
Succeed: true, Succeed: true,
Data: record, Data: irecord,
}) })
} }
func (s *Server) createRecords(c *gin.Context) { func createRecords(c *gin.Context) {
var records []*models.Record var records []models.IRecord
if err := c.BindJSON(&records); err != nil { if err := c.BindJSON(&records); err != nil {
c.JSON(http.StatusBadRequest, Response{ c.JSON(http.StatusBadRequest, Response{
Succeed: false, Succeed: false,
@ -84,8 +84,8 @@ func (s *Server) createRecords(c *gin.Context) {
}) })
} }
func (s *Server) updateRecord(c *gin.Context) { func updateRecord(c *gin.Context) {
record := &models.Record{} record := &models.Record[models.RecordContentDefault]{}
if err := c.BindJSON(record); err != nil { if err := c.BindJSON(record); err != nil {
c.JSON(http.StatusBadRequest, Response{ c.JSON(http.StatusBadRequest, Response{
Succeed: false, Succeed: false,
@ -113,7 +113,7 @@ func (s *Server) updateRecord(c *gin.Context) {
}) })
} }
func (s *Server) deleteRecord(c *gin.Context) { func deleteRecord(c *gin.Context) {
domain := c.Param("domain") domain := c.Param("domain")
id := c.Param("id") id := c.Param("id")

View File

@ -37,18 +37,18 @@ func (s *Server) setupRoute() {
domains := groupV1.Group("/domains") domains := groupV1.Group("/domains")
domains. domains.
GET("/", s.getDomains). GET("/", getDomains).
POST("/", s.createDomain). POST("/", createDomain).
PUT("/", s.updateDomain). PUT("/", updateDomain).
DELETE("/:id", s.deleteDomain) DELETE("/:id", deleteDomain)
records := groupV1.Group("/records") records := groupV1.Group("/records")
records. records.
GET("/:domain", s.getRecords). GET("/:domain", getRecords).
POST("/:domain", s.createRecord). POST("/:domain", createRecord).
POST("/:domain/bulk", s.createRecords). POST("/:domain/bulk", createRecords).
PUT("/:domain", s.updateRecord). PUT("/:domain", updateRecord).
DELETE("/:domain/:id", s.deleteRecord) DELETE("/:domain/:id", deleteRecord)
server := s.webServer.Group(s.prefix) server := s.webServer.Group(s.prefix)
server.Use(func(ctx *gin.Context) { server.Use(func(ctx *gin.Context) {