diff --git a/controllers/domain.go b/controllers/domain.go index 8a76147..b45ed2f 100644 --- a/controllers/domain.go +++ b/controllers/domain.go @@ -5,8 +5,6 @@ import ( "reCoreD-UI/database" "reCoreD-UI/models" "strconv" - - dns "github.com/cloud66-oss/coredns_mysql" ) type domainsDAO struct { @@ -25,7 +23,7 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) { return nil, err } - r := &models.Record[dns.SOARecord]{} + r := &models.Record[models.SOARecord]{} r.Zone = d.WithDotEnd() r.Name = "@" r.RecordType = models.RecordTypeSOA @@ -41,7 +39,7 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) { } for i, ns := range nss { - record := &models.Record[dns.NSRecord]{ + record := &models.Record[models.NSRecord]{ Zone: d.WithDotEnd(), RecordType: models.RecordTypeSOA, Name: fmt.Sprintf("ns%d", i+1), @@ -98,7 +96,7 @@ func UpdateDomain(d *models.Domain) error { return err } - r := &models.Record[dns.SOARecord]{} + r := &models.Record[models.SOARecord]{} if err := r.FromEntity(soa); err != nil { tx.Rollback() return err diff --git a/models/domain.go b/models/domain.go index 78a5784..01e6104 100644 --- a/models/domain.go +++ b/models/domain.go @@ -3,8 +3,6 @@ package models import ( "fmt" "strings" - - dns "github.com/cloud66-oss/coredns_mysql" ) type Domain struct { @@ -38,25 +36,25 @@ func (d *Domain) WithDotEnd() string { } } -func (d *Domain) GenerateSOA() dns.SOARecord { +func (d *Domain) GenerateSOA() SOARecord { var ns string if !strings.HasSuffix(d.MainDNS, ".") { ns = fmt.Sprintf("%s.", d.MainDNS) } else { ns = d.MainDNS } - return dns.SOARecord{ - Ns: ns, - MBox: d.EmailSOAForamt(), - Refresh: d.RefreshInterval, - Retry: d.RetryInterval, - Expire: d.ExpiryPeriod, - MinTtl: d.NegativeTtl, - } + r := SOARecord{} + r.Ns = ns + r.MBox = d.EmailSOAForamt() + r.Refresh = d.RefreshInterval + r.Retry = d.RetryInterval + r.Expire = d.ExpiryPeriod + r.MinTtl = d.NegativeTtl + return r } type IDomain interface { EmailSOAForamt() string WithDotEnd() string - GenerateSOA() dns.SOARecord + GenerateSOA() SOARecord } diff --git a/models/record.go b/models/record.go index df4c4de..f677a5e 100644 --- a/models/record.go +++ b/models/record.go @@ -4,8 +4,6 @@ import ( "encoding/json" "errors" "strings" - - dns "github.com/cloud66-oss/coredns_mysql" ) var ErrorZoneNotEndWithDot = errors.New("zone should end with '.'") @@ -22,10 +20,8 @@ const ( RecordTypeSRV = "SRV" ) -type RecordContentDefault map[string]any - type recordContentTypes interface { - dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord | RecordContentDefault + ARecord | AAAARecord | CNAMERecord | CAARecord | NSRecord | MXRecord | SOARecord | SRVRecord | TXTRecord | RecordContentDefault } type Record[T recordContentTypes] struct { diff --git a/models/record_types.go b/models/record_types.go new file mode 100644 index 0000000..0a5a0f5 --- /dev/null +++ b/models/record_types.go @@ -0,0 +1,146 @@ +package models + +import ( + "errors" + "regexp" + "strings" + + dns "github.com/cloud66-oss/coredns_mysql" +) + +var ( + ErrInvalidIPv4 = errors.New("not a valid ipv4 address") + ErrInvalidIPv6 = errors.New("not a valid ipv6 address") + ErrEmptyTXT = errors.New("txt record should not empty") + ErrNoDotSuffix = errors.New("should end with dot") + ErrBadEmailFormat = errors.New("email here should have no '@'") + ErrBadCAATag = errors.New("caa tag should not empty") + ErrBadCAAValue = errors.New("caa value should not empty") + ErrInvalidType = errors.New("invalid type") +) + +type ARecord struct { + dns.ARecord +} + +func (r ARecord) Validate() error { + ok := regexp.MustCompile("^((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])[.]){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])$").MatchString(r.Ip.String()) + if !ok { + return ErrInvalidIPv4 + } + return nil +} + +type AAAARecord struct { + dns.AAAARecord +} + +func (r AAAARecord) Validate() error { + ok := regexp.MustCompile("^(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])[.]){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])[.]){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))$").MatchString(r.Ip.String()) + if !ok { + return ErrInvalidIPv6 + } + return nil +} + +type TXTRecord struct { + dns.TXTRecord +} + +func (r TXTRecord) Validate() error { + if r.Text == "" { + return ErrEmptyTXT + } + return nil +} + +type CNAMERecord struct { + dns.CNAMERecord +} + +func (r CNAMERecord) Validate() error { + if strings.HasSuffix(r.Host, ".") { + return ErrNoDotSuffix + } + return nil +} + +type NSRecord struct { + dns.NSRecord +} + +func (r NSRecord) Validate() error { + if strings.HasSuffix(r.Host, ".") { + return ErrNoDotSuffix + } + return nil +} + +type MXRecord struct { + dns.MXRecord +} + +func (r MXRecord) Validate() error { + if strings.HasSuffix(r.Host, ".") { + return ErrNoDotSuffix + } + return nil +} + +type SRVRecord struct { + dns.SRVRecord +} + +func (r SRVRecord) Validate() error { + if strings.HasPrefix(r.Target, ".") { + return ErrNoDotSuffix + } + + return nil +} + +type SOARecord struct { + dns.SOARecord +} + +func (r SOARecord) Validate() error { + if strings.HasPrefix(r.MBox, ".") { + return ErrNoDotSuffix + } + + if strings.HasSuffix(r.Ns, ".") { + return ErrNoDotSuffix + } + + if strings.Contains(r.MBox, "@") { + return ErrBadEmailFormat + } + + return nil +} + +type CAARecord struct { + dns.CAARecord +} + +func (r CAARecord) Validate() error { + if r.Tag == "" { + return ErrBadCAATag + } + + if r.Value == "" { + return ErrBadCAAValue + } + + return nil +} + +type RecordContentDefault map[string]any + +func (r RecordContentDefault) Validate() error { + return ErrInvalidType +} + +type IRecordType interface { + Validate() error +} diff --git a/server/handlers_records.go b/server/handlers_records.go index 347e9da..0f42b3d 100644 --- a/server/handlers_records.go +++ b/server/handlers_records.go @@ -9,6 +9,68 @@ import ( "github.com/gin-gonic/gin" ) +func validateRecord(r models.IRecord) error { + + switch r.GetType() { + case models.RecordTypeA: + record := &models.Record[models.ARecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + case models.RecordTypeAAAA: + record := &models.Record[models.AAAARecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + case models.RecordTypeCNAME: + record := &models.Record[models.CNAMERecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + case models.RecordTypeCAA: + record := &models.Record[models.CAARecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + case models.RecordTypeMX: + record := &models.Record[models.MXRecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + case models.RecordTypeNS: + record := &models.Record[models.NSRecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + case models.RecordTypeSOA: + record := &models.Record[models.SOARecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + case models.RecordTypeSRV: + record := &models.Record[models.SRVRecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + case models.RecordTypeTXT: + record := &models.Record[models.TXTRecord]{} + if err := record.FromEntity(r); err != nil { + return err + } + return record.Content.Validate() + default: + return models.ErrInvalidType + } +} + func getRecords(c *gin.Context) { query := &models.Record[models.RecordContentDefault]{Content: make(models.RecordContentDefault)} if err := c.BindQuery(query); err != nil { @@ -52,6 +114,14 @@ func createRecord(c *gin.Context) { return } + if err := validateRecord(record); err != nil { + c.JSON(http.StatusBadRequest, Response{ + Succeed: false, + Message: err.Error(), + }) + return + } + irecord, err := controllers.CreateRecord(record) if err != nil { errorHandler(c, err) @@ -99,6 +169,14 @@ func updateRecord(c *gin.Context) { return } + if err := validateRecord(record); err != nil { + c.JSON(http.StatusBadRequest, Response{ + Succeed: false, + Message: err.Error(), + }) + return + } + domain := c.Param("domain") if domain != record.Zone { c.JSON(http.StatusBadRequest, Response{