nonce_test.go

100 lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
package application

import (
	"context"
	"encoding/base64"
	"net/http"
	"net/http/httptest"
	"testing"
)

func TestNonceMiddleware_GeneratesNonce(t *testing.T) {
	var capturedNonce string

	inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		capturedNonce = NonceFromContext(r.Context())
	})

	mw := NonceMiddleware()
	handler := mw(inner)

	rec := httptest.NewRecorder()
	req := httptest.NewRequest("GET", "/", nil)
	handler.ServeHTTP(rec, req)

	if capturedNonce == "" {
		t.Error("expected nonce to be set in context, got empty string")
	}
}

func TestNonceFromContext_ReturnsNonce(t *testing.T) {
	ctx := context.WithValue(context.Background(), nonceKey{}, "test-nonce-value")

	nonce := NonceFromContext(ctx)
	if nonce != "test-nonce-value" {
		t.Errorf("expected 'test-nonce-value', got %q", nonce)
	}
}

func TestNonceFromContext_ReturnsEmptyForMissing(t *testing.T) {
	ctx := context.Background()

	nonce := NonceFromContext(ctx)
	if nonce != "" {
		t.Errorf("expected empty string for missing nonce, got %q", nonce)
	}
}

func TestNonce_IsBase64Encoded_CorrectLength(t *testing.T) {
	var capturedNonce string

	inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		capturedNonce = NonceFromContext(r.Context())
	})

	mw := NonceMiddleware()
	handler := mw(inner)

	rec := httptest.NewRecorder()
	req := httptest.NewRequest("GET", "/", nil)
	handler.ServeHTTP(rec, req)

	// Nonce should be valid base64
	decoded, err := base64.StdEncoding.DecodeString(capturedNonce)
	if err != nil {
		t.Fatalf("nonce is not valid base64: %v", err)
	}

	// Should be 16 bytes decoded (128 bits)
	if len(decoded) != 16 {
		t.Errorf("expected 16 decoded bytes, got %d", len(decoded))
	}

	// base64 of 16 bytes = 24 characters (with padding)
	if len(capturedNonce) != 24 {
		t.Errorf("expected 24 character base64 string, got %d characters: %q", len(capturedNonce), capturedNonce)
	}
}

func TestNonce_DifferentRequestsGetDifferentNonces(t *testing.T) {
	nonces := make(map[string]bool)

	inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		nonce := NonceFromContext(r.Context())
		nonces[nonce] = true
	})

	mw := NonceMiddleware()
	handler := mw(inner)

	// Make 10 requests and verify all nonces are unique
	for i := 0; i < 10; i++ {
		rec := httptest.NewRecorder()
		req := httptest.NewRequest("GET", "/", nil)
		handler.ServeHTTP(rec, req)
	}

	if len(nonces) != 10 {
		t.Errorf("expected 10 unique nonces, got %d", len(nonces))
	}
}