mirror of
				https://codeberg.org/forgejo/forgejo.git
				synced 2025-11-04 08:21:11 +00:00 
			
		
		
		
	- gopkg.in/yaml.v3 is archived and unmaintained - go.yaml.in/yaml/v3 is a compatible fork under the umbrella of https://yaml.org/ ### Tests There is no need for more tests than already provided: it is like an upgrade to a minor version, only from a fork. I browsed the changes and there are some bug fixes. They all seem reasonably minimal. It is not one of those forks that went crazy with breaking changes 😁 And there is a non zero chance that [a bug that matters to Forgejo Actions](https://github.com/yaml/go-yaml/issues/76) is fixed there. It is rare and can wait but it did happen on Codeberg. Reviewed-on: https://codeberg.org/forgejo/forgejo/pulls/8956 Reviewed-by: oliverpool <oliverpool@noreply.codeberg.org> Co-authored-by: Earl Warren <contact@earl-warren.org> Co-committed-by: Earl Warren <contact@earl-warren.org>
		
			
				
	
	
		
			211 lines
		
	
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			211 lines
		
	
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2025 The Forgejo Authors. All rights reserved.
 | 
						|
// SPDX-License-Identifier: GPL-3.0-or-later
 | 
						|
 | 
						|
package unittest
 | 
						|
 | 
						|
import (
 | 
						|
	"database/sql"
 | 
						|
	"encoding/hex"
 | 
						|
	"encoding/json" //nolint:depguard
 | 
						|
	"fmt"
 | 
						|
	"os"
 | 
						|
	"path/filepath"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"forgejo.org/modules/container"
 | 
						|
 | 
						|
	"go.yaml.in/yaml/v3"
 | 
						|
)
 | 
						|
 | 
						|
type insertSQL struct {
 | 
						|
	statement string
 | 
						|
	values    []any
 | 
						|
}
 | 
						|
 | 
						|
type fixtureFile struct {
 | 
						|
	name       string
 | 
						|
	insertSQLs []insertSQL
 | 
						|
}
 | 
						|
 | 
						|
type loader struct {
 | 
						|
	db      *sql.DB
 | 
						|
	dialect string
 | 
						|
 | 
						|
	fixtureFiles []*fixtureFile
 | 
						|
}
 | 
						|
 | 
						|
func newFixtureLoader(db *sql.DB, dialect string, fixturePaths []string, allTableNames container.Set[string]) (*loader, error) {
 | 
						|
	l := &loader{
 | 
						|
		db:           db,
 | 
						|
		dialect:      dialect,
 | 
						|
		fixtureFiles: []*fixtureFile{},
 | 
						|
	}
 | 
						|
 | 
						|
	tablesWithoutFixture := allTableNames
 | 
						|
 | 
						|
	// Load fixtures
 | 
						|
	for _, fixturePath := range fixturePaths {
 | 
						|
		stat, err := os.Stat(fixturePath)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		// If fixture path is a directory, then read read the files of the directory
 | 
						|
		// and use those as fixture files.
 | 
						|
		if stat.IsDir() {
 | 
						|
			files, err := os.ReadDir(fixturePath)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
			for _, file := range files {
 | 
						|
				if !file.IsDir() {
 | 
						|
					fixtureFile, err := l.buildFixtureFile(filepath.Join(fixturePath, file.Name()))
 | 
						|
					if err != nil {
 | 
						|
						return nil, err
 | 
						|
					}
 | 
						|
					l.fixtureFiles = append(l.fixtureFiles, fixtureFile)
 | 
						|
					tablesWithoutFixture.Remove(fixtureFile.name)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			fixtureFile, err := l.buildFixtureFile(fixturePath)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
			l.fixtureFiles = append(l.fixtureFiles, fixtureFile)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Even though these tables have no fixtures, they can still be used and ensure
 | 
						|
	// they are cleaned.
 | 
						|
	for table := range tablesWithoutFixture.Seq() {
 | 
						|
		l.fixtureFiles = append(l.fixtureFiles, &fixtureFile{
 | 
						|
			name: table,
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	return l, nil
 | 
						|
}
 | 
						|
 | 
						|
// quoteKeyword returns the quoted string of keyword.
 | 
						|
func (l *loader) quoteKeyword(keyword string) string {
 | 
						|
	switch l.dialect {
 | 
						|
	case "sqlite3":
 | 
						|
		return `"` + keyword + `"`
 | 
						|
	case "mysql":
 | 
						|
		return "`" + keyword + "`"
 | 
						|
	case "postgres":
 | 
						|
		parts := strings.Split(keyword, ".")
 | 
						|
		for i, p := range parts {
 | 
						|
			parts[i] = `"` + p + `"`
 | 
						|
		}
 | 
						|
		return strings.Join(parts, ".")
 | 
						|
	default:
 | 
						|
		return "invalid"
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// placeholder returns the placeholder string.
 | 
						|
func (l *loader) placeholder(index int) string {
 | 
						|
	if l.dialect == "postgres" {
 | 
						|
		return fmt.Sprintf("$%d", index)
 | 
						|
	}
 | 
						|
	return "?"
 | 
						|
}
 | 
						|
 | 
						|
func (l *loader) buildFixtureFile(fixturePath string) (*fixtureFile, error) {
 | 
						|
	f, err := os.Open(fixturePath)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	defer f.Close()
 | 
						|
 | 
						|
	var records []map[string]any
 | 
						|
	if err := yaml.NewDecoder(f).Decode(&records); err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	fixture := &fixtureFile{
 | 
						|
		name:       filepath.Base(strings.TrimSuffix(f.Name(), filepath.Ext(f.Name()))),
 | 
						|
		insertSQLs: []insertSQL{},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, record := range records {
 | 
						|
		columns := []string{}
 | 
						|
		sqlValues := []string{}
 | 
						|
		values := []any{}
 | 
						|
		i := 1
 | 
						|
 | 
						|
		for key, value := range record {
 | 
						|
			columns = append(columns, l.quoteKeyword(key))
 | 
						|
 | 
						|
			switch v := value.(type) {
 | 
						|
			case string:
 | 
						|
				// Try to decode hex.
 | 
						|
				if strings.HasPrefix(v, "0x") {
 | 
						|
					value, err = hex.DecodeString(strings.TrimPrefix(v, "0x"))
 | 
						|
					if err != nil {
 | 
						|
						return nil, err
 | 
						|
					}
 | 
						|
				}
 | 
						|
			case []any:
 | 
						|
				// Decode array.
 | 
						|
				var bytes []byte
 | 
						|
				bytes, err = json.Marshal(v)
 | 
						|
				if err != nil {
 | 
						|
					return nil, err
 | 
						|
				}
 | 
						|
				value = string(bytes)
 | 
						|
			}
 | 
						|
 | 
						|
			values = append(values, value)
 | 
						|
 | 
						|
			sqlValues = append(sqlValues, l.placeholder(i))
 | 
						|
			i++
 | 
						|
		}
 | 
						|
 | 
						|
		// Construct the insert SQL.
 | 
						|
		fixture.insertSQLs = append(fixture.insertSQLs, insertSQL{
 | 
						|
			statement: fmt.Sprintf(
 | 
						|
				"INSERT INTO %s (%s) VALUES (%s)",
 | 
						|
				l.quoteKeyword(fixture.name),
 | 
						|
				strings.Join(columns, ", "),
 | 
						|
				strings.Join(sqlValues, ", "),
 | 
						|
			),
 | 
						|
			values: values,
 | 
						|
		})
 | 
						|
	}
 | 
						|
 | 
						|
	return fixture, nil
 | 
						|
}
 | 
						|
 | 
						|
func (l *loader) Load() error {
 | 
						|
	// Start transaction.
 | 
						|
	tx, err := l.db.Begin()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	defer func() {
 | 
						|
		_ = tx.Rollback()
 | 
						|
	}()
 | 
						|
 | 
						|
	// Clean the table and re-insert the fixtures.
 | 
						|
	tableDeleted := make(container.Set[string])
 | 
						|
	for _, fixture := range l.fixtureFiles {
 | 
						|
		if !tableDeleted.Contains(fixture.name) {
 | 
						|
			if _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", l.quoteKeyword(fixture.name))); err != nil {
 | 
						|
				return fmt.Errorf("cannot delete table %s: %w", fixture.name, err)
 | 
						|
			}
 | 
						|
			tableDeleted.Add(fixture.name)
 | 
						|
		}
 | 
						|
 | 
						|
		for _, insertSQL := range fixture.insertSQLs {
 | 
						|
			if _, err := tx.Exec(insertSQL.statement, insertSQL.values...); err != nil {
 | 
						|
				return fmt.Errorf("cannot insert %q with values %q: %w", insertSQL.statement, insertSQL.values, err)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return tx.Commit()
 | 
						|
}
 |