diff --git a/models/auth/two_factor_test.go b/models/auth/two_factor_test.go index 36e0404ae2..3787f6ca95 100644 --- a/models/auth/two_factor_test.go +++ b/models/auth/two_factor_test.go @@ -7,6 +7,7 @@ import ( "forgejo.org/models/unittest" + "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -32,3 +33,30 @@ func TestHasTwoFactorByUID(t *testing.T) { assert.True(t, ok) }) } + +func TestNewTwoFactor(t *testing.T) { + require.NoError(t, unittest.PrepareTestDatabase()) + + otpKey, err := totp.Generate(totp.GenerateOpts{ + SecretSize: 40, + Issuer: "forgejo-test", + AccountName: "user2", + }) + require.NoError(t, err) + + t.Run("Transaction failed", func(t *testing.T) { + reset := unittest.SetFaultInjector(2) + require.ErrorIs(t, NewTwoFactor(t.Context(), &TwoFactor{UID: 44}, otpKey.Secret()), unittest.ErrFaultInjected) + reset() + + unittest.AssertExistsIf(t, false, &TwoFactor{UID: 44}) + }) + + t.Run("Normal", func(t *testing.T) { + reset := unittest.SetFaultInjector(4) + require.NoError(t, NewTwoFactor(t.Context(), &TwoFactor{UID: 44}, otpKey.Secret())) + reset() + + unittest.AssertExistsIf(t, true, &TwoFactor{UID: 44}) + }) +} diff --git a/models/unittest/fault_injector.go b/models/unittest/fault_injector.go new file mode 100644 index 0000000000..20a40cce45 --- /dev/null +++ b/models/unittest/fault_injector.go @@ -0,0 +1,50 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later +package unittest + +import ( + "context" + "errors" + + "xorm.io/xorm/contexts" +) + +var ( + faultInjectorCount int64 + faultInjectorNumQueries int64 = -1 + ErrFaultInjected = errors.New("nobody expects a fault injection") +) + +type faultInjectorHook struct{} + +var _ contexts.Hook = &faultInjectorHook{} + +func (faultInjectorHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) { + if faultInjectorNumQueries == -1 { + return c.Ctx, nil + } + + // Always allow ROLLBACK, we always want to allow for transactions to get cancelled. + if faultInjectorCount == faultInjectorNumQueries && c.SQL != "ROLLBACK" { + return c.Ctx, ErrFaultInjected + } + + faultInjectorCount++ + + return c.Ctx, nil +} + +func (faultInjectorHook) AfterProcess(*contexts.ContextHook) error { + return nil +} + +// Allow `numQueries` before all database queries will fail until the +// returning function is executed. +func SetFaultInjector(numQueries int64) func() { + faultInjectorNumQueries = numQueries + + return func() { + faultInjectorNumQueries = -1 + faultInjectorCount = 0 + } +} diff --git a/models/unittest/fault_injector_test.go b/models/unittest/fault_injector_test.go new file mode 100644 index 0000000000..7b1e601d51 --- /dev/null +++ b/models/unittest/fault_injector_test.go @@ -0,0 +1,86 @@ +// Copyright 2025 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: GPL-3.0-or-later +package unittest + +import ( + "testing" + + "github.com/stretchr/testify/require" + "xorm.io/xorm/contexts" +) + +func TestFaultInjector(t *testing.T) { + faultInjector := faultInjectorHook{} + c := &contexts.ContextHook{ + Ctx: t.Context(), + SQL: "Hello, 世界", // We don't check for valid SQL anyway. + } + + t.Run("Should not block", func(t *testing.T) { + // Currently no fault injection is set, so this should go through. + for range 100 { + _, err := faultInjector.BeforeProcess(c) + require.NoError(t, err) + } + }) + + t.Run("Reset", func(t *testing.T) { + // Okay only allow one query to go through. + reset := SetFaultInjector(1) + + // Do the only query. + _, err := faultInjector.BeforeProcess(c) + require.NoError(t, err) + + // Now we reset, we don't check the blocking behavior yet. We first + // must know that we can safely reset. + reset() + + // This should go through. + for range 100 { + _, err := faultInjector.BeforeProcess(c) + require.NoError(t, err) + } + }) + + t.Run("Blocking", func(t *testing.T) { + // Okay only allow one query to go through. + reset := SetFaultInjector(1) + + // Do the only query. + _, err := faultInjector.BeforeProcess(c) + require.NoError(t, err) + + // Any query now will return a error. + for range 100 { + _, err := faultInjector.BeforeProcess(c) + require.ErrorIs(t, err, ErrFaultInjected) + } + + // Ah but there's a exemption for `ROLLBACK`. + _, err = faultInjector.BeforeProcess(&contexts.ContextHook{Ctx: t.Context(), SQL: "ROLLBACK"}) + require.NoError(t, err) + + reset() + }) + + t.Run("Number of queries", func(t *testing.T) { + // For funsies lets test a bunch of max numbers of queries. + for i := range int64(1024) { + // Allow i queries + reset := SetFaultInjector(i) + + // Make i queries. + for range i { + _, err := faultInjector.BeforeProcess(c) + require.NoError(t, err) + } + + // After i'th query it returns a error. + _, err := faultInjector.BeforeProcess(c) + require.ErrorIs(t, err, ErrFaultInjected) + + reset() + } + }) +} diff --git a/models/unittest/testdb.go b/models/unittest/testdb.go index 29ec82c55f..795b1f8719 100644 --- a/models/unittest/testdb.go +++ b/models/unittest/testdb.go @@ -233,6 +233,7 @@ func CreateTestEngine(opts FixturesOptions) error { return err } x.SetMapper(names.GonicMapper{}) + x.AddHook(faultInjectorHook{}) db.SetDefaultEngine(context.Background(), x) if err = db.SyncAllTables(); err != nil {