Home Up PDF Prof. Dr. Ingo Claßen
package linearizability

import (
	"context"
	"errors"
	"fmt"
	"os"
	"strconv"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/anishathalye/porcupine"
	"github.com/redis/go-redis/v9"
)

// ---------------------------------------------------------------------------
// Configuration (overridable via environment variables)
// ---------------------------------------------------------------------------

func env(key, def string) string {
	if v := os.Getenv(key); v != "" {
		return v
	}
	return def
}

func envInt(key string, def int) int {
	if v := os.Getenv(key); v != "" {
		if n, err := strconv.Atoi(v); err == nil {
			return n
		}
	}
	return def
}

// encode packs the register's integer value into a large payload so that
// replication of each write is expensive. decode recovers the integer.
func encode(v int) string {
	return strconv.Itoa(v) + "|" + strings.Repeat("x", valueSize)
}

func decode(s string) (int, error) {
	if i := strings.IndexByte(s, '|'); i >= 0 {
		s = s[:i]
	}
	return strconv.Atoi(s)
}

var (
	primaryAddr  = env("PRIMARY_ADDR", "localhost:6379")
	replicaAddrs = strings.Split(env("REPLICA_ADDRS", "localhost:6381,localhost:6382,localhost:6383"), ",")

	// READ_TARGET = "replica" (default) reads from the replicas -> expected to
	// reveal stale reads and FAIL the linearizability check.
	// READ_TARGET = "primary" reads only from the primary -> expected to PASS.
	readTarget = env("READ_TARGET", "replica")

	// The single key (a register) we hammer with reads and writes.
	regKey = env("REG_KEY", "porcupine:register")

	// Size in bytes of the payload appended to each written value. Large values
	// are the trick that makes the staleness observable: the primary acks a
	// SET from memory instantly, but the replicas must stream the whole payload
	// over the link. A tight burst of large writes makes the replicas fall
	// seconds behind, so readers observe counters older than writes that have
	// already RETURNED on the primary -> a real-time linearizability violation.
	valueSize = envInt("VALUE_SIZE", 512*1024) // 512 KiB per write

	// Writes are issued back-to-back (no gap) to outrun replication.
	numWrites = envInt("NUM_WRITES", 1500)
)

// ---------------------------------------------------------------------------
// Porcupine model: a single integer register.
//   write(v): always OK, sets state := v
//   read():   OK iff the observed value equals the current linearized state
// ---------------------------------------------------------------------------

type regInput struct {
	isRead bool
	value  int // value for a write; ignored for a read
}

var registerModel = porcupine.Model{
	// Initial value of the register; we seed the key to 0 before the run.
	Init: func() interface{} { return 0 },
	Step: func(state, input, output interface{}) (bool, interface{}) {
		in := input.(regInput)
		st := state.(int)
		if in.isRead {
			out := output.(int)
			return out == st, st // read must observe the current state
		}
		return true, in.value // write always allowed, advances state
	},
	DescribeOperation: func(input, output interface{}) string {
		in := input.(regInput)
		if in.isRead {
			return fmt.Sprintf("read() -> %d", output.(int))
		}
		return fmt.Sprintf("write(%d)", in.value)
	},
}

// ---------------------------------------------------------------------------
// Test
// ---------------------------------------------------------------------------

func TestRedisLinearizability(t *testing.T) {
	ctx := context.Background()

	primary := redis.NewClient(&redis.Options{Addr: primaryAddr})
	defer primary.Close()

	if err := primary.Ping(ctx).Err(); err != nil {
		t.Fatalf("cannot reach primary at %s: %v (is docker compose up?)", primaryAddr, err)
	}

	// Build the pool of read clients.
	var readClients []*redis.Client
	if readTarget == "primary" {
		readClients = []*redis.Client{primary}
		t.Logf("READ_TARGET=primary -> reading from primary only (expected: LINEARIZABLE)")
	} else {
		for _, a := range replicaAddrs {
			rc := redis.NewClient(&redis.Options{Addr: strings.TrimSpace(a)})
			if err := rc.Ping(ctx).Err(); err != nil {
				t.Fatalf("cannot reach replica at %s: %v", a, err)
			}
			readClients = append(readClients, rc)
		}
		t.Logf("READ_TARGET=replica -> reading from %d replicas (expected: STALE READS / NON-LINEARIZABLE)", len(readClients))
	}
	defer func() {
		if readTarget != "primary" {
			for _, rc := range readClients {
				rc.Close()
			}
		}
	}()

	// Seed the register to the model's initial value and wait for replicas to
	// observe it, so the history starts from a known, consistent state.
	if err := primary.Set(ctx, regKey, encode(0), 0).Err(); err != nil {
		t.Fatalf("seed write failed: %v", err)
	}
	waitForValue(t, ctx, readClients, regKey, 0, 15*time.Second)
	t.Logf("payload size per write: %d bytes; writes: %d (burst, no gap)", valueSize, numWrites)

	start := time.Now()
	now := func() int64 { return int64(time.Since(start)) }

	var (
		mu     sync.Mutex
		events []porcupine.Operation
	)
	record := func(clientID int, in regInput, out int, call, ret int64) {
		mu.Lock()
		events = append(events, porcupine.Operation{
			ClientId: clientID,
			Input:    in,
			Output:   out,
			Call:     call,
			Return:   ret,
		})
		mu.Unlock()
	}

	var wg sync.WaitGroup
	done := make(chan struct{})

	// Writer: client 0, writes strictly increasing values to the primary.
	wg.Add(1)
	go func() {
		defer wg.Done()
		defer close(done)
		for v := 1; v <= numWrites; v++ {
			payload := encode(v)
			call := now()
			err := primary.Set(ctx, regKey, payload, 0).Err()
			ret := now()
			if err != nil {
				t.Errorf("write(%d) failed: %v", v, err)
				return
			}
			record(0, regInput{isRead: false, value: v}, v, call, ret)
		}
	}()

	// Readers: one goroutine per read client, looping until the writer is done.
	for i, rc := range readClients {
		wg.Add(1)
		clientID := i + 1
		rc := rc
		go func() {
			defer wg.Done()
			for {
				select {
				case <-done:
					return
				default:
				}
				call := now()
				val, err := getInt(ctx, rc, regKey)
				ret := now()
				if err != nil && !errors.Is(err, redis.Nil) {
					// transient; skip
					continue
				}
				record(clientID, regInput{isRead: true}, val, call, ret)
				time.Sleep(time.Duration(envInt("READ_GAP_MS", 2)) * time.Millisecond)
			}
		}()
	}

	wg.Wait()

	t.Logf("collected %d operations (%d writes, rest reads)", len(events), numWrites)

	// Run the linearizability check with visualization info.
	res, info := porcupine.CheckOperationsVerbose(registerModel, events, 10*time.Second)

	switch res {
	case porcupine.Ok:
		t.Logf("RESULT: history is LINEARIZABLE (OK)")
	case porcupine.Illegal:
		dumpViz(t, info)
		t.Errorf("RESULT: history is NON-LINEARIZABLE (stale reads detected)")
	case porcupine.Unknown:
		t.Errorf("RESULT: check timed out (UNKNOWN) — try fewer operations")
	}
}

// dumpViz writes an interactive Porcupine HTML visualization next to the test.
func dumpViz(t *testing.T, info porcupine.LinearizationInfo) {
	path := "linearizability.html"
	if err := porcupine.VisualizePath(registerModel, info, path); err != nil {
		t.Logf("could not write visualization: %v", err)
		return
	}
	wd, _ := os.Getwd()
	t.Logf("visualization written to %s/%s (open in a browser)", wd, path)
}

// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------

func getInt(ctx context.Context, c *redis.Client, key string) (int, error) {
	s, err := c.Get(ctx, key).Result()
	if err != nil {
		return 0, err
	}
	return decode(s)
}

func waitForValue(t *testing.T, ctx context.Context, clients []*redis.Client, key string, want int, timeout time.Duration) {
	deadline := time.Now().Add(timeout)
	for _, c := range clients {
		for {
			v, err := getInt(ctx, c, key)
			if err == nil && v == want {
				break
			}
			if time.Now().After(deadline) {
				t.Fatalf("replicas did not converge to %s=%d within %s", key, want, timeout)
			}
			time.Sleep(20 * time.Millisecond)
		}
	}
}