Added port validation

pull/12233/head
Pablo Caderno 2021-08-11 17:43:38 +10:00
parent 8173cab7ef
commit 1f83aa3ac2
2 changed files with 63 additions and 0 deletions

View File

@ -1118,6 +1118,14 @@ func validateFlags(cmd *cobra.Command, drvName string) {
viper.Set(imageRepository, validateImageRepository(viper.GetString(imageRepository)))
}
if cmd.Flags().Changed(ports) {
err := validatePorts(viper.GetStringSlice(ports))
if err != nil {
exit.Message(reason.Usage, "{{.err}}", out.V{"err": err})
}
}
if cmd.Flags().Changed(containerRuntime) {
runtime := strings.ToLower(viper.GetString(containerRuntime))
@ -1314,6 +1322,25 @@ func validateListenAddress(listenAddr string) {
}
}
// This function validates that the --ports are not below 1024 for the host and not outside range
func validatePorts(ports []string) error {
for _, portDuplet := range ports {
for i, port := range strings.Split(portDuplet, ":") {
p, err := strconv.Atoi(port)
if err != nil {
return errors.Errorf("Sorry, one of the ports provided with --ports flag is not valid %s", ports)
}
if p > 65535 || p < 1 {
return errors.Errorf("Sorry, one of the ports provided with --ports flag is outside range %s", ports)
}
if p < 1024 && i == 0 {
return errors.Errorf("Sorry, you cannot use privileged ports on the host (below 1024) %s", ports)
}
}
}
return nil
}
// This function validates that the --insecure-registry follows one of the following formats:
// "<ip>[:<port>]" "<hostname>[:<port>]" "<network>/<netmask>"
func validateInsecureRegistry() {

View File

@ -363,3 +363,39 @@ func TestValidateImageRepository(t *testing.T) {
}
}
func TestValidatePorts(t *testing.T) {
var tests = []struct {
ports []string
errorMsg string
}{
{
ports: []string{"test:80"},
errorMsg: "Sorry, one of the ports provided with --ports flag is not valid [test:80]",
},
{
ports: []string{"0:80"},
errorMsg: "Sorry, one of the ports provided with --ports flag is outside range [0:80]",
},
{
ports: []string{"80:80"},
errorMsg: "Sorry, you cannot use privileged ports on the host (below 1024) [80:80]",
},
{
ports: []string{"8080:80", "6443:443"},
errorMsg: "",
},
}
for _, test := range tests {
t.Run(strings.Join(test.ports, ","), func(t *testing.T) {
gotError := ""
got := validatePorts(test.ports)
if got != nil {
gotError = got.Error()
}
if gotError != test.errorMsg {
t.Errorf("validatePorts(ports=%v): got %v, expected %v", test.ports, got, test.errorMsg)
}
})
}
}