validate 'em !

This commit is contained in:
Sense T 2024-04-13 10:14:45 +08:00
parent 7b529ad8f6
commit 2c754e7eec
5 changed files with 238 additions and 22 deletions

View File

@ -5,8 +5,6 @@ import (
"reCoreD-UI/database" "reCoreD-UI/database"
"reCoreD-UI/models" "reCoreD-UI/models"
"strconv" "strconv"
dns "github.com/cloud66-oss/coredns_mysql"
) )
type domainsDAO struct { type domainsDAO struct {
@ -25,7 +23,7 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) {
return nil, err return nil, err
} }
r := &models.Record[dns.SOARecord]{} r := &models.Record[models.SOARecord]{}
r.Zone = d.WithDotEnd() r.Zone = d.WithDotEnd()
r.Name = "@" r.Name = "@"
r.RecordType = models.RecordTypeSOA r.RecordType = models.RecordTypeSOA
@ -41,7 +39,7 @@ func CreateDomain(d *models.Domain) (*models.Domain, error) {
} }
for i, ns := range nss { for i, ns := range nss {
record := &models.Record[dns.NSRecord]{ record := &models.Record[models.NSRecord]{
Zone: d.WithDotEnd(), Zone: d.WithDotEnd(),
RecordType: models.RecordTypeSOA, RecordType: models.RecordTypeSOA,
Name: fmt.Sprintf("ns%d", i+1), Name: fmt.Sprintf("ns%d", i+1),
@ -98,7 +96,7 @@ func UpdateDomain(d *models.Domain) error {
return err return err
} }
r := &models.Record[dns.SOARecord]{} r := &models.Record[models.SOARecord]{}
if err := r.FromEntity(soa); err != nil { if err := r.FromEntity(soa); err != nil {
tx.Rollback() tx.Rollback()
return err return err

View File

@ -3,8 +3,6 @@ package models
import ( import (
"fmt" "fmt"
"strings" "strings"
dns "github.com/cloud66-oss/coredns_mysql"
) )
type Domain struct { type Domain struct {
@ -38,25 +36,25 @@ func (d *Domain) WithDotEnd() string {
} }
} }
func (d *Domain) GenerateSOA() dns.SOARecord { func (d *Domain) GenerateSOA() SOARecord {
var ns string var ns string
if !strings.HasSuffix(d.MainDNS, ".") { if !strings.HasSuffix(d.MainDNS, ".") {
ns = fmt.Sprintf("%s.", d.MainDNS) ns = fmt.Sprintf("%s.", d.MainDNS)
} else { } else {
ns = d.MainDNS ns = d.MainDNS
} }
return dns.SOARecord{ r := SOARecord{}
Ns: ns, r.Ns = ns
MBox: d.EmailSOAForamt(), r.MBox = d.EmailSOAForamt()
Refresh: d.RefreshInterval, r.Refresh = d.RefreshInterval
Retry: d.RetryInterval, r.Retry = d.RetryInterval
Expire: d.ExpiryPeriod, r.Expire = d.ExpiryPeriod
MinTtl: d.NegativeTtl, r.MinTtl = d.NegativeTtl
} return r
} }
type IDomain interface { type IDomain interface {
EmailSOAForamt() string EmailSOAForamt() string
WithDotEnd() string WithDotEnd() string
GenerateSOA() dns.SOARecord GenerateSOA() SOARecord
} }

View File

@ -4,8 +4,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"strings" "strings"
dns "github.com/cloud66-oss/coredns_mysql"
) )
var ErrorZoneNotEndWithDot = errors.New("zone should end with '.'") var ErrorZoneNotEndWithDot = errors.New("zone should end with '.'")
@ -22,10 +20,8 @@ const (
RecordTypeSRV = "SRV" RecordTypeSRV = "SRV"
) )
type RecordContentDefault map[string]any
type recordContentTypes interface { type recordContentTypes interface {
dns.ARecord | dns.AAAARecord | dns.CNAMERecord | dns.CAARecord | dns.NSRecord | dns.MXRecord | dns.SOARecord | dns.SRVRecord | dns.TXTRecord | RecordContentDefault ARecord | AAAARecord | CNAMERecord | CAARecord | NSRecord | MXRecord | SOARecord | SRVRecord | TXTRecord | RecordContentDefault
} }
type Record[T recordContentTypes] struct { type Record[T recordContentTypes] struct {

146
models/record_types.go Normal file
View File

@ -0,0 +1,146 @@
package models
import (
"errors"
"regexp"
"strings"
dns "github.com/cloud66-oss/coredns_mysql"
)
var (
ErrInvalidIPv4 = errors.New("not a valid ipv4 address")
ErrInvalidIPv6 = errors.New("not a valid ipv6 address")
ErrEmptyTXT = errors.New("txt record should not empty")
ErrNoDotSuffix = errors.New("should end with dot")
ErrBadEmailFormat = errors.New("email here should have no '@'")
ErrBadCAATag = errors.New("caa tag should not empty")
ErrBadCAAValue = errors.New("caa value should not empty")
ErrInvalidType = errors.New("invalid type")
)
type ARecord struct {
dns.ARecord
}
func (r ARecord) Validate() error {
ok := regexp.MustCompile("^((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])[.]){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])$").MatchString(r.Ip.String())
if !ok {
return ErrInvalidIPv4
}
return nil
}
type AAAARecord struct {
dns.AAAARecord
}
func (r AAAARecord) Validate() error {
ok := regexp.MustCompile("^(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])[.]){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])[.]){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))$").MatchString(r.Ip.String())
if !ok {
return ErrInvalidIPv6
}
return nil
}
type TXTRecord struct {
dns.TXTRecord
}
func (r TXTRecord) Validate() error {
if r.Text == "" {
return ErrEmptyTXT
}
return nil
}
type CNAMERecord struct {
dns.CNAMERecord
}
func (r CNAMERecord) Validate() error {
if strings.HasSuffix(r.Host, ".") {
return ErrNoDotSuffix
}
return nil
}
type NSRecord struct {
dns.NSRecord
}
func (r NSRecord) Validate() error {
if strings.HasSuffix(r.Host, ".") {
return ErrNoDotSuffix
}
return nil
}
type MXRecord struct {
dns.MXRecord
}
func (r MXRecord) Validate() error {
if strings.HasSuffix(r.Host, ".") {
return ErrNoDotSuffix
}
return nil
}
type SRVRecord struct {
dns.SRVRecord
}
func (r SRVRecord) Validate() error {
if strings.HasPrefix(r.Target, ".") {
return ErrNoDotSuffix
}
return nil
}
type SOARecord struct {
dns.SOARecord
}
func (r SOARecord) Validate() error {
if strings.HasPrefix(r.MBox, ".") {
return ErrNoDotSuffix
}
if strings.HasSuffix(r.Ns, ".") {
return ErrNoDotSuffix
}
if strings.Contains(r.MBox, "@") {
return ErrBadEmailFormat
}
return nil
}
type CAARecord struct {
dns.CAARecord
}
func (r CAARecord) Validate() error {
if r.Tag == "" {
return ErrBadCAATag
}
if r.Value == "" {
return ErrBadCAAValue
}
return nil
}
type RecordContentDefault map[string]any
func (r RecordContentDefault) Validate() error {
return ErrInvalidType
}
type IRecordType interface {
Validate() error
}

View File

@ -9,6 +9,68 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func validateRecord(r models.IRecord) error {
switch r.GetType() {
case models.RecordTypeA:
record := &models.Record[models.ARecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
case models.RecordTypeAAAA:
record := &models.Record[models.AAAARecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
case models.RecordTypeCNAME:
record := &models.Record[models.CNAMERecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
case models.RecordTypeCAA:
record := &models.Record[models.CAARecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
case models.RecordTypeMX:
record := &models.Record[models.MXRecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
case models.RecordTypeNS:
record := &models.Record[models.NSRecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
case models.RecordTypeSOA:
record := &models.Record[models.SOARecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
case models.RecordTypeSRV:
record := &models.Record[models.SRVRecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
case models.RecordTypeTXT:
record := &models.Record[models.TXTRecord]{}
if err := record.FromEntity(r); err != nil {
return err
}
return record.Content.Validate()
default:
return models.ErrInvalidType
}
}
func getRecords(c *gin.Context) { func getRecords(c *gin.Context) {
query := &models.Record[models.RecordContentDefault]{Content: make(models.RecordContentDefault)} query := &models.Record[models.RecordContentDefault]{Content: make(models.RecordContentDefault)}
if err := c.BindQuery(query); err != nil { if err := c.BindQuery(query); err != nil {
@ -52,6 +114,14 @@ func createRecord(c *gin.Context) {
return return
} }
if err := validateRecord(record); err != nil {
c.JSON(http.StatusBadRequest, Response{
Succeed: false,
Message: err.Error(),
})
return
}
irecord, err := controllers.CreateRecord(record) irecord, err := controllers.CreateRecord(record)
if err != nil { if err != nil {
errorHandler(c, err) errorHandler(c, err)
@ -99,6 +169,14 @@ func updateRecord(c *gin.Context) {
return return
} }
if err := validateRecord(record); err != nil {
c.JSON(http.StatusBadRequest, Response{
Succeed: false,
Message: err.Error(),
})
return
}
domain := c.Param("domain") domain := c.Param("domain")
if domain != record.Zone { if domain != record.Zone {
c.JSON(http.StatusBadRequest, Response{ c.JSON(http.StatusBadRequest, Response{