mirror of
				https://codeberg.org/forgejo/forgejo.git
				synced 2025-10-20 17:12:25 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			548 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			548 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
 | |
| //
 | |
| // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
 | |
| //
 | |
| // This Source Code Form is subject to the terms of the Mozilla Public
 | |
| // License, v. 2.0. If a copy of the MPL was not distributed with this file,
 | |
| // You can obtain one at http://mozilla.org/MPL/2.0/.
 | |
| 
 | |
| package mysql
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/tls"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"net/url"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	errInvalidDSNUnescaped       = errors.New("invalid DSN: did you forget to escape a param value?")
 | |
| 	errInvalidDSNAddr            = errors.New("invalid DSN: network address not terminated (missing closing brace)")
 | |
| 	errInvalidDSNNoSlash         = errors.New("invalid DSN: missing the slash separating the database name")
 | |
| 	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
 | |
| )
 | |
| 
 | |
| // Config is a configuration parsed from a DSN string
 | |
| type Config struct {
 | |
| 	User             string            // Username
 | |
| 	Passwd           string            // Password (requires User)
 | |
| 	Net              string            // Network type
 | |
| 	Addr             string            // Network address (requires Net)
 | |
| 	DBName           string            // Database name
 | |
| 	Params           map[string]string // Connection parameters
 | |
| 	Collation        string            // Connection collation
 | |
| 	Loc              *time.Location    // Location for time.Time values
 | |
| 	MaxAllowedPacket int               // Max packet size allowed
 | |
| 	TLSConfig        string            // TLS configuration name
 | |
| 	tls              *tls.Config       // TLS configuration
 | |
| 	Timeout          time.Duration     // Dial timeout
 | |
| 	ReadTimeout      time.Duration     // I/O read timeout
 | |
| 	WriteTimeout     time.Duration     // I/O write timeout
 | |
| 
 | |
| 	AllowAllFiles           bool // Allow all files to be used with LOAD DATA LOCAL INFILE
 | |
| 	AllowCleartextPasswords bool // Allows the cleartext client side plugin
 | |
| 	AllowNativePasswords    bool // Allows the native password authentication method
 | |
| 	AllowOldPasswords       bool // Allows the old insecure password method
 | |
| 	ClientFoundRows         bool // Return number of matching rows instead of rows changed
 | |
| 	ColumnsWithAlias        bool // Prepend table alias to column names
 | |
| 	InterpolateParams       bool // Interpolate placeholders into query string
 | |
| 	MultiStatements         bool // Allow multiple statements in one query
 | |
| 	ParseTime               bool // Parse time values to time.Time
 | |
| 	Strict                  bool // Return warnings as errors
 | |
| }
 | |
| 
 | |
| // FormatDSN formats the given Config into a DSN string which can be passed to
 | |
| // the driver.
 | |
| func (cfg *Config) FormatDSN() string {
 | |
| 	var buf bytes.Buffer
 | |
| 
 | |
| 	// [username[:password]@]
 | |
| 	if len(cfg.User) > 0 {
 | |
| 		buf.WriteString(cfg.User)
 | |
| 		if len(cfg.Passwd) > 0 {
 | |
| 			buf.WriteByte(':')
 | |
| 			buf.WriteString(cfg.Passwd)
 | |
| 		}
 | |
| 		buf.WriteByte('@')
 | |
| 	}
 | |
| 
 | |
| 	// [protocol[(address)]]
 | |
| 	if len(cfg.Net) > 0 {
 | |
| 		buf.WriteString(cfg.Net)
 | |
| 		if len(cfg.Addr) > 0 {
 | |
| 			buf.WriteByte('(')
 | |
| 			buf.WriteString(cfg.Addr)
 | |
| 			buf.WriteByte(')')
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// /dbname
 | |
| 	buf.WriteByte('/')
 | |
| 	buf.WriteString(cfg.DBName)
 | |
| 
 | |
| 	// [?param1=value1&...¶mN=valueN]
 | |
| 	hasParam := false
 | |
| 
 | |
| 	if cfg.AllowAllFiles {
 | |
| 		hasParam = true
 | |
| 		buf.WriteString("?allowAllFiles=true")
 | |
| 	}
 | |
| 
 | |
| 	if cfg.AllowCleartextPasswords {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&allowCleartextPasswords=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?allowCleartextPasswords=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if cfg.AllowNativePasswords {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&allowNativePasswords=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?allowNativePasswords=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if cfg.AllowOldPasswords {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&allowOldPasswords=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?allowOldPasswords=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if cfg.ClientFoundRows {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&clientFoundRows=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?clientFoundRows=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&collation=")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?collation=")
 | |
| 		}
 | |
| 		buf.WriteString(col)
 | |
| 	}
 | |
| 
 | |
| 	if cfg.ColumnsWithAlias {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&columnsWithAlias=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?columnsWithAlias=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if cfg.InterpolateParams {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&interpolateParams=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?interpolateParams=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if cfg.Loc != time.UTC && cfg.Loc != nil {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&loc=")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?loc=")
 | |
| 		}
 | |
| 		buf.WriteString(url.QueryEscape(cfg.Loc.String()))
 | |
| 	}
 | |
| 
 | |
| 	if cfg.MultiStatements {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&multiStatements=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?multiStatements=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if cfg.ParseTime {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&parseTime=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?parseTime=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if cfg.ReadTimeout > 0 {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&readTimeout=")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?readTimeout=")
 | |
| 		}
 | |
| 		buf.WriteString(cfg.ReadTimeout.String())
 | |
| 	}
 | |
| 
 | |
| 	if cfg.Strict {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&strict=true")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?strict=true")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if cfg.Timeout > 0 {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&timeout=")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?timeout=")
 | |
| 		}
 | |
| 		buf.WriteString(cfg.Timeout.String())
 | |
| 	}
 | |
| 
 | |
| 	if len(cfg.TLSConfig) > 0 {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&tls=")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?tls=")
 | |
| 		}
 | |
| 		buf.WriteString(url.QueryEscape(cfg.TLSConfig))
 | |
| 	}
 | |
| 
 | |
| 	if cfg.WriteTimeout > 0 {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&writeTimeout=")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?writeTimeout=")
 | |
| 		}
 | |
| 		buf.WriteString(cfg.WriteTimeout.String())
 | |
| 	}
 | |
| 
 | |
| 	if cfg.MaxAllowedPacket > 0 {
 | |
| 		if hasParam {
 | |
| 			buf.WriteString("&maxAllowedPacket=")
 | |
| 		} else {
 | |
| 			hasParam = true
 | |
| 			buf.WriteString("?maxAllowedPacket=")
 | |
| 		}
 | |
| 		buf.WriteString(strconv.Itoa(cfg.MaxAllowedPacket))
 | |
| 
 | |
| 	}
 | |
| 
 | |
| 	// other params
 | |
| 	if cfg.Params != nil {
 | |
| 		for param, value := range cfg.Params {
 | |
| 			if hasParam {
 | |
| 				buf.WriteByte('&')
 | |
| 			} else {
 | |
| 				hasParam = true
 | |
| 				buf.WriteByte('?')
 | |
| 			}
 | |
| 
 | |
| 			buf.WriteString(param)
 | |
| 			buf.WriteByte('=')
 | |
| 			buf.WriteString(url.QueryEscape(value))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return buf.String()
 | |
| }
 | |
| 
 | |
| // ParseDSN parses the DSN string to a Config
 | |
| func ParseDSN(dsn string) (cfg *Config, err error) {
 | |
| 	// New config with some default values
 | |
| 	cfg = &Config{
 | |
| 		Loc:       time.UTC,
 | |
| 		Collation: defaultCollation,
 | |
| 	}
 | |
| 
 | |
| 	// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
 | |
| 	// Find the last '/' (since the password or the net addr might contain a '/')
 | |
| 	foundSlash := false
 | |
| 	for i := len(dsn) - 1; i >= 0; i-- {
 | |
| 		if dsn[i] == '/' {
 | |
| 			foundSlash = true
 | |
| 			var j, k int
 | |
| 
 | |
| 			// left part is empty if i <= 0
 | |
| 			if i > 0 {
 | |
| 				// [username[:password]@][protocol[(address)]]
 | |
| 				// Find the last '@' in dsn[:i]
 | |
| 				for j = i; j >= 0; j-- {
 | |
| 					if dsn[j] == '@' {
 | |
| 						// username[:password]
 | |
| 						// Find the first ':' in dsn[:j]
 | |
| 						for k = 0; k < j; k++ {
 | |
| 							if dsn[k] == ':' {
 | |
| 								cfg.Passwd = dsn[k+1 : j]
 | |
| 								break
 | |
| 							}
 | |
| 						}
 | |
| 						cfg.User = dsn[:k]
 | |
| 
 | |
| 						break
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				// [protocol[(address)]]
 | |
| 				// Find the first '(' in dsn[j+1:i]
 | |
| 				for k = j + 1; k < i; k++ {
 | |
| 					if dsn[k] == '(' {
 | |
| 						// dsn[i-1] must be == ')' if an address is specified
 | |
| 						if dsn[i-1] != ')' {
 | |
| 							if strings.ContainsRune(dsn[k+1:i], ')') {
 | |
| 								return nil, errInvalidDSNUnescaped
 | |
| 							}
 | |
| 							return nil, errInvalidDSNAddr
 | |
| 						}
 | |
| 						cfg.Addr = dsn[k+1 : i-1]
 | |
| 						break
 | |
| 					}
 | |
| 				}
 | |
| 				cfg.Net = dsn[j+1 : k]
 | |
| 			}
 | |
| 
 | |
| 			// dbname[?param1=value1&...¶mN=valueN]
 | |
| 			// Find the first '?' in dsn[i+1:]
 | |
| 			for j = i + 1; j < len(dsn); j++ {
 | |
| 				if dsn[j] == '?' {
 | |
| 					if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
 | |
| 						return
 | |
| 					}
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 			cfg.DBName = dsn[i+1 : j]
 | |
| 
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if !foundSlash && len(dsn) > 0 {
 | |
| 		return nil, errInvalidDSNNoSlash
 | |
| 	}
 | |
| 
 | |
| 	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
 | |
| 		return nil, errInvalidDSNUnsafeCollation
 | |
| 	}
 | |
| 
 | |
| 	// Set default network if empty
 | |
| 	if cfg.Net == "" {
 | |
| 		cfg.Net = "tcp"
 | |
| 	}
 | |
| 
 | |
| 	// Set default address if empty
 | |
| 	if cfg.Addr == "" {
 | |
| 		switch cfg.Net {
 | |
| 		case "tcp":
 | |
| 			cfg.Addr = "127.0.0.1:3306"
 | |
| 		case "unix":
 | |
| 			cfg.Addr = "/tmp/mysql.sock"
 | |
| 		default:
 | |
| 			return nil, errors.New("default addr for network '" + cfg.Net + "' unknown")
 | |
| 		}
 | |
| 
 | |
| 	}
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // parseDSNParams parses the DSN "query string"
 | |
| // Values must be url.QueryEscape'ed
 | |
| func parseDSNParams(cfg *Config, params string) (err error) {
 | |
| 	for _, v := range strings.Split(params, "&") {
 | |
| 		param := strings.SplitN(v, "=", 2)
 | |
| 		if len(param) != 2 {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		// cfg params
 | |
| 		switch value := param[1]; param[0] {
 | |
| 
 | |
| 		// Disable INFILE whitelist / enable all files
 | |
| 		case "allowAllFiles":
 | |
| 			var isBool bool
 | |
| 			cfg.AllowAllFiles, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// Use cleartext authentication mode (MySQL 5.5.10+)
 | |
| 		case "allowCleartextPasswords":
 | |
| 			var isBool bool
 | |
| 			cfg.AllowCleartextPasswords, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// Use native password authentication
 | |
| 		case "allowNativePasswords":
 | |
| 			var isBool bool
 | |
| 			cfg.AllowNativePasswords, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// Use old authentication mode (pre MySQL 4.1)
 | |
| 		case "allowOldPasswords":
 | |
| 			var isBool bool
 | |
| 			cfg.AllowOldPasswords, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// Switch "rowsAffected" mode
 | |
| 		case "clientFoundRows":
 | |
| 			var isBool bool
 | |
| 			cfg.ClientFoundRows, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// Collation
 | |
| 		case "collation":
 | |
| 			cfg.Collation = value
 | |
| 			break
 | |
| 
 | |
| 		case "columnsWithAlias":
 | |
| 			var isBool bool
 | |
| 			cfg.ColumnsWithAlias, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// Compression
 | |
| 		case "compress":
 | |
| 			return errors.New("compression not implemented yet")
 | |
| 
 | |
| 		// Enable client side placeholder substitution
 | |
| 		case "interpolateParams":
 | |
| 			var isBool bool
 | |
| 			cfg.InterpolateParams, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// Time Location
 | |
| 		case "loc":
 | |
| 			if value, err = url.QueryUnescape(value); err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 			cfg.Loc, err = time.LoadLocation(value)
 | |
| 			if err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 		// multiple statements in one query
 | |
| 		case "multiStatements":
 | |
| 			var isBool bool
 | |
| 			cfg.MultiStatements, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// time.Time parsing
 | |
| 		case "parseTime":
 | |
| 			var isBool bool
 | |
| 			cfg.ParseTime, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// I/O read Timeout
 | |
| 		case "readTimeout":
 | |
| 			cfg.ReadTimeout, err = time.ParseDuration(value)
 | |
| 			if err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 		// Strict mode
 | |
| 		case "strict":
 | |
| 			var isBool bool
 | |
| 			cfg.Strict, isBool = readBool(value)
 | |
| 			if !isBool {
 | |
| 				return errors.New("invalid bool value: " + value)
 | |
| 			}
 | |
| 
 | |
| 		// Dial Timeout
 | |
| 		case "timeout":
 | |
| 			cfg.Timeout, err = time.ParseDuration(value)
 | |
| 			if err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 		// TLS-Encryption
 | |
| 		case "tls":
 | |
| 			boolValue, isBool := readBool(value)
 | |
| 			if isBool {
 | |
| 				if boolValue {
 | |
| 					cfg.TLSConfig = "true"
 | |
| 					cfg.tls = &tls.Config{}
 | |
| 				} else {
 | |
| 					cfg.TLSConfig = "false"
 | |
| 				}
 | |
| 			} else if vl := strings.ToLower(value); vl == "skip-verify" {
 | |
| 				cfg.TLSConfig = vl
 | |
| 				cfg.tls = &tls.Config{InsecureSkipVerify: true}
 | |
| 			} else {
 | |
| 				name, err := url.QueryUnescape(value)
 | |
| 				if err != nil {
 | |
| 					return fmt.Errorf("invalid value for TLS config name: %v", err)
 | |
| 				}
 | |
| 
 | |
| 				if tlsConfig, ok := tlsConfigRegister[name]; ok {
 | |
| 					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
 | |
| 						host, _, err := net.SplitHostPort(cfg.Addr)
 | |
| 						if err == nil {
 | |
| 							tlsConfig.ServerName = host
 | |
| 						}
 | |
| 					}
 | |
| 
 | |
| 					cfg.TLSConfig = name
 | |
| 					cfg.tls = tlsConfig
 | |
| 				} else {
 | |
| 					return errors.New("invalid value / unknown config name: " + name)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 		// I/O write Timeout
 | |
| 		case "writeTimeout":
 | |
| 			cfg.WriteTimeout, err = time.ParseDuration(value)
 | |
| 			if err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 		case "maxAllowedPacket":
 | |
| 			cfg.MaxAllowedPacket, err = strconv.Atoi(value)
 | |
| 			if err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 		default:
 | |
| 			// lazy init
 | |
| 			if cfg.Params == nil {
 | |
| 				cfg.Params = make(map[string]string)
 | |
| 			}
 | |
| 
 | |
| 			if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return
 | |
| }
 |