model with generics done

This commit is contained in:
Sense T 2024-04-09 21:16:19 +08:00
parent 7dd3af3707
commit 9752e7d9ae
10 changed files with 109 additions and 91 deletions

View File

@ -25,34 +25,30 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) {
return nil, err
}
r := &models.RecordWithType[dns.SOARecord]{}
r := &models.Record[dns.SOARecord]{}
r.Zone = d.WithDotEnd()
r.Name = "@"
r.RecordType = models.RecordTypeSOA
r.Content.Ns = d.MainDNS
r.Content.MBox = d.EmailSOAForamt()
r.Content.Refresh = d.RefreshInterval
r.Content.Retry = d.RetryInterval
r.Content.Expire = d.ExpiryPeriod
r.Content.MinTtl = d.NegativeTtl
r.Content = d.GenerateSOA()
if err := r.CheckZone(); err != nil {
tx.Rollback()
return nil, err
}
if _, err := (recordsDAO{}).Create(tx, *r.ToRecord()); err != nil {
if _, err := (recordsDAO{}).Create(tx, r); err != nil {
tx.Rollback()
return nil, err
}
for i, ns := range nss {
record := &models.RecordWithType[dns.NSRecord]{}
record.Zone = d.WithDotEnd()
record.RecordType = models.RecordTypeNS
record := &models.Record[dns.NSRecord]{
Zone: d.WithDotEnd(),
RecordType: models.RecordTypeSOA,
Name: fmt.Sprintf("ns%d", i+1),
}
record.Content.Host = ns
record.Name = fmt.Sprintf("ns%d", i+1)
if _, err := (recordsDAO{}).Create(tx, *record.ToRecord()); err != nil {
if _, err := (recordsDAO{}).Create(tx, record); err != nil {
tx.Rollback()
return nil, err
}
@ -77,7 +73,7 @@ func UpdateDomain(d *models.Domain) error {
return err
}
soa, err := (recordsDAO{}).GetOne(tx, models.Record{
soa, err := (recordsDAO{}).GetOne(tx, &models.Record[models.RecordContentDefault]{
RecordType: models.RecordTypeSOA, Zone: d.WithDotEnd(),
})
if err != nil {
@ -85,24 +81,19 @@ func UpdateDomain(d *models.Domain) error {
return err
}
r := &models.RecordWithType[dns.SOARecord]{}
if err := r.FromRecord(&soa); err != nil {
r := &models.Record[dns.SOARecord]{}
if err := r.FromEntity(soa); err != nil {
tx.Rollback()
return err
}
r.Content.Ns = d.MainDNS
r.Content.MBox = d.EmailSOAForamt()
r.Content.Refresh = d.RefreshInterval
r.Content.Retry = d.RetryInterval
r.Content.Expire = d.ExpiryPeriod
r.Content.MinTtl = d.NegativeTtl
r.Content = d.GenerateSOA()
if err := r.CheckZone(); err != nil {
tx.Rollback()
return err
}
if _, err := (recordsDAO{}).Update(tx, *r.ToRecord()); err != nil {
if _, err := (recordsDAO{}).Update(tx, r); err != nil {
tx.Rollback()
return err
}
@ -130,7 +121,7 @@ func DeleteDomain(id string) error {
return err
}
if err := (recordsDAO{}).Delete(tx, models.Record{Zone: domain.WithDotEnd()}); err != nil {
if err := (recordsDAO{}).Delete(tx, &models.Record[models.RecordContentDefault]{Zone: domain.WithDotEnd()}); err != nil {
tx.Rollback()
return err
}

View File

@ -41,7 +41,7 @@ func RegisterMetrics() {
}
}
func RefreshMetrics() error {
func RefreshMetrics() error {
domainCounts, err := getDomainCounts()
if err != nil {
return err

View File

@ -10,7 +10,8 @@ func Migrate() error {
return err
}
if err := (recordsDAO{}).Migrate(database.Client, models.Record{}); err != nil {
var recordDefiniation models.IRecord = &models.Record[models.RecordContentDefault]{}
if err := (recordsDAO{}).Migrate(database.Client, recordDefiniation); err != nil {
return err
}

View File

@ -10,11 +10,11 @@ import (
)
type recordsDAO struct {
database.BaseDAO[models.Record]
database.BaseDAO[models.IRecord]
}
func CreateRecord(r *models.Record) (*models.Record, error) {
if r.RecordType != models.RecordTypeSOA {
func CreateRecord(r models.IRecord) (models.IRecord, error) {
if r.GetType() != models.RecordTypeSOA {
_, err := GetDomains(r.WithOutDotTail())
if err != nil {
return nil, err
@ -25,11 +25,11 @@ func CreateRecord(r *models.Record) (*models.Record, error) {
return nil, err
}
res, err := (recordsDAO{}).Create(database.Client, *r)
return &res, err
res, err := (recordsDAO{}).Create(database.Client, r)
return res, err
}
func CreateRecords(rs []*models.Record) error {
func CreateRecords(rs []models.IRecord) error {
tx := database.Client.Begin()
for _, r := range rs {
if err := r.CheckZone(); err != nil {
@ -37,7 +37,7 @@ func CreateRecords(rs []*models.Record) error {
return err
}
if _, err := (recordsDAO{}).Create(tx, *r); err != nil {
if _, err := (recordsDAO{}).Create(tx, r); err != nil {
tx.Rollback()
return err
}
@ -46,16 +46,16 @@ func CreateRecords(rs []*models.Record) error {
return nil
}
func GetRecords(cond models.Record) ([]models.Record, error) {
func GetRecords(cond models.IRecord) ([]models.IRecord, error) {
return (recordsDAO{}).GetAll(database.Client, cond)
}
func UpdateRecord(r *models.Record) error {
func UpdateRecord(r models.IRecord) error {
if err := r.CheckZone(); err != nil {
return err
}
if _, err := (recordsDAO{}).Update(database.Client, *r); err != nil {
if _, err := (recordsDAO{}).Update(database.Client, r); err != nil {
return err
}
return nil
@ -68,13 +68,13 @@ func DeleteRecord(domain, id string) error {
}
tx := database.Client.Begin()
record, err := (recordsDAO{}).GetOne(tx, models.Record{ID: ID, Zone: fmt.Sprintf("%s.", domain)})
record, err := (recordsDAO{}).GetOne(tx, &models.Record[models.RecordContentDefault]{ID: ID, Zone: fmt.Sprintf("%s.", domain)})
if err != nil {
tx.Rollback()
return err
}
if record.RecordType == models.RecordTypeSOA {
if record.GetType() == models.RecordTypeSOA {
tx.Rollback()
return gorm.ErrRecordNotFound
}
@ -90,14 +90,18 @@ func DeleteRecord(domain, id string) error {
// for metrics
func getRecordCounts() (map[string]float64, error) {
rows, err := (recordsDAO{}).GetAll(database.Client, models.Record{})
rows, err := (recordsDAO{}).GetAll(database.Client, &models.Record[models.RecordContentDefault]{})
if err != nil {
return nil, err
}
result := make(map[string]float64)
for _, row := range rows {
result[row.Zone] += 1
record := &models.Record[models.RecordContentDefault]{}
if err := record.FromEntity(row); err != nil {
return nil, err
}
result[record.Zone] += 1
}
return result, nil
}

View File

@ -26,7 +26,7 @@ func main() {
Name: "mysql-dsn",
Usage: "mysql dsn",
Required: true,
EnvVars: []string{"RECORED_MYSQL_DSN"},
EnvVars: []string{"RECORED_MYSQL_DSN"},
},
&cli.BoolFlag{
Name: "debug",

View File

@ -3,6 +3,8 @@ package models
import (
"fmt"
"strings"
dns "github.com/cloud66-oss/coredns_mysql"
)
type Domain struct {
@ -32,3 +34,14 @@ func (d *Domain) WithDotEnd() string {
return fmt.Sprintf("%s.", d.DomainName)
}
}
func (d *Domain) GenerateSOA() dns.SOARecord {
return dns.SOARecord{
Ns: d.MainDNS,
MBox: d.EmailSOAForamt(),
Refresh: d.RefreshInterval,
Retry: d.RetryInterval,
Expire: d.ExpiryPeriod,
MinTtl: d.NegativeTtl,
}
}

View File

@ -1,6 +1,7 @@
package models
import (
"encoding/json"
"fmt"
"strings"
@ -19,50 +20,58 @@ const (
RecordTypeSRV = "SRV"
)
type Record struct {
type RecordContentDefault any
type recordContentTypes interface {
dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord | RecordContentDefault
}
type Record[T recordContentTypes] struct {
ID int `gorm:"primaryKey" json:"id"`
Zone string `gorm:"not null,size:255" json:"zone"`
Name string `gorm:"not null,size:255" json:"name"`
Ttl int `json:"ttl"`
Content any `gorm:"serializer:json,type:\"text\"" json:"content"`
Content T `gorm:"serializer:json,type:\"text\"" json:"content"`
RecordType string `gorm:"not null,size:255" json:"record_type"`
}
func (Record) TableName() string {
func (Record[T]) TableName() string {
return "coredns_record"
}
func (r Record) CheckZone() error {
func (r Record[T]) CheckZone() error {
if strings.HasSuffix(r.Zone, ".") {
return fmt.Errorf("zone should end with '.'")
}
return nil
}
func (r Record) WithOutDotTail() string {
func (r Record[T]) WithOutDotTail() string {
return strings.TrimRight(r.Zone, ".")
}
type RecordContentTypes interface {
dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord
func (r Record[T]) ToEntity() IRecord {
return &r
}
type RecordWithType[T RecordContentTypes] struct {
Record
Content T `json:"content"`
}
func (r *RecordWithType[T]) ToRecord() *Record {
r.Record.Content = r.Content
return &r.Record
}
func (r *RecordWithType[T]) FromRecord(record *Record) error {
r.Record = *record
var ok bool
if r.Content, ok = record.Content.(T); !ok {
return fmt.Errorf("cannot convert record type")
func (r *Record[T]) FromEntity(entity any) error {
b, err := json.Marshal(entity)
if err != nil {
return err
}
return nil
return json.Unmarshal(b, r)
}
func (r Record[T]) GetType() string {
return r.RecordType
}
type IRecord interface {
TableName() string
CheckZone() error
WithOutDotTail() string
ToEntity() IRecord
FromEntity(any) error
GetType() string
}

View File

@ -8,7 +8,7 @@ import (
"github.com/gin-gonic/gin"
)
func (s *Server) getDomains(c *gin.Context) {
func getDomains(c *gin.Context) {
domains, err := controllers.GetDomains("")
if err != nil {
errorHandler(c, err)
@ -21,7 +21,7 @@ func (s *Server) getDomains(c *gin.Context) {
})
}
func (s *Server) createDomain(c *gin.Context) {
func createDomain(c *gin.Context) {
domain := &models.Domain{}
if err := c.BindJSON(domain); err != nil {
@ -44,7 +44,7 @@ func (s *Server) createDomain(c *gin.Context) {
})
}
func (s *Server) updateDomain(c *gin.Context) {
func updateDomain(c *gin.Context) {
domain := &models.Domain{}
if err := c.BindJSON(domain); err != nil {
@ -65,7 +65,7 @@ func (s *Server) updateDomain(c *gin.Context) {
})
}
func (s *Server) deleteDomain(c *gin.Context) {
func deleteDomain(c *gin.Context) {
id := c.Param("id")
if err := controllers.DeleteDomain(id); err != nil {
errorHandler(c, err)

View File

@ -9,8 +9,8 @@ import (
"github.com/gin-gonic/gin"
)
func (s *Server) getRecords(c *gin.Context) {
query := models.Record{}
func getRecords(c *gin.Context) {
query := models.Record[models.RecordContentDefault]{}
if err := c.BindQuery(&query); err != nil {
c.JSON(http.StatusBadRequest, Response{
Succeed: false,
@ -21,7 +21,7 @@ func (s *Server) getRecords(c *gin.Context) {
domain := c.Param("domain")
query.Zone = fmt.Sprintf("%s.", domain)
records, err := controllers.GetRecords(query)
records, err := controllers.GetRecords(&query)
if err != nil {
errorHandler(c, err)
return
@ -33,8 +33,8 @@ func (s *Server) getRecords(c *gin.Context) {
})
}
func (s *Server) createRecord(c *gin.Context) {
record := &models.Record{}
func createRecord(c *gin.Context) {
record := &models.Record[models.RecordContentDefault]{}
if err := c.BindJSON(record); err != nil {
c.JSON(http.StatusBadRequest, Response{
Succeed: false,
@ -44,7 +44,7 @@ func (s *Server) createRecord(c *gin.Context) {
}
domain := c.Param("domain")
if domain != record.Zone {
if domain != record.WithOutDotTail() {
c.JSON(http.StatusBadRequest, Response{
Succeed: false,
Message: "request body doesn't match URI",
@ -52,7 +52,7 @@ func (s *Server) createRecord(c *gin.Context) {
return
}
record, err := controllers.CreateRecord(record)
irecord, err := controllers.CreateRecord(record)
if err != nil {
errorHandler(c, err)
return
@ -60,12 +60,12 @@ func (s *Server) createRecord(c *gin.Context) {
c.JSON(http.StatusCreated, Response{
Succeed: true,
Data: record,
Data: irecord,
})
}
func (s *Server) createRecords(c *gin.Context) {
var records []*models.Record
func createRecords(c *gin.Context) {
var records []models.IRecord
if err := c.BindJSON(&records); err != nil {
c.JSON(http.StatusBadRequest, Response{
Succeed: false,
@ -84,8 +84,8 @@ func (s *Server) createRecords(c *gin.Context) {
})
}
func (s *Server) updateRecord(c *gin.Context) {
record := &models.Record{}
func updateRecord(c *gin.Context) {
record := &models.Record[models.RecordContentDefault]{}
if err := c.BindJSON(record); err != nil {
c.JSON(http.StatusBadRequest, Response{
Succeed: false,
@ -113,7 +113,7 @@ func (s *Server) updateRecord(c *gin.Context) {
})
}
func (s *Server) deleteRecord(c *gin.Context) {
func deleteRecord(c *gin.Context) {
domain := c.Param("domain")
id := c.Param("id")

View File

@ -37,18 +37,18 @@ func (s *Server) setupRoute() {
domains := groupV1.Group("/domains")
domains.
GET("/", s.getDomains).
POST("/", s.createDomain).
PUT("/", s.updateDomain).
DELETE("/:id", s.deleteDomain)
GET("/", getDomains).
POST("/", createDomain).
PUT("/", updateDomain).
DELETE("/:id", deleteDomain)
records := groupV1.Group("/records")
records.
GET("/:domain", s.getRecords).
POST("/:domain", s.createRecord).
POST("/:domain/bulk", s.createRecords).
PUT("/:domain", s.updateRecord).
DELETE("/:domain/:id", s.deleteRecord)
GET("/:domain", getRecords).
POST("/:domain", createRecord).
POST("/:domain/bulk", createRecords).
PUT("/:domain", updateRecord).
DELETE("/:domain/:id", deleteRecord)
server := s.webServer.Group(s.prefix)
server.Use(func(ctx *gin.Context) {