package redis

import (
	"context"
	"fmt"
	"io"
	"net"
	"testing"
	"time"

	"github.com/redis/go-redis/v9/internal/proto"
)

var ctx = context.TODO()

type ClientStub struct {
	Cmdable
	resp []byte
}

var initHello = []byte("%1\r\n+proto\r\n:3\r\n")

func NewClientStub(resp []byte) *ClientStub {
	stub := &ClientStub{
		resp: resp,
	}

	stub.Cmdable = NewClient(&Options{
		PoolSize: 128,
		Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
			return stub.stubConn(initHello), nil
		},
	})
	return stub
}

func NewClusterClientStub(resp []byte) *ClientStub {
	stub := &ClientStub{
		resp: resp,
	}

	client := NewClusterClient(&ClusterOptions{
		PoolSize: 128,
		Addrs:    []string{":6379"},
		Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
			return stub.stubConn(initHello), nil
		},
		ClusterSlots: func(_ context.Context) ([]ClusterSlot, error) {
			return []ClusterSlot{
				{
					Start: 0,
					End:   16383,
					Nodes: []ClusterNode{{Addr: "127.0.0.1:6379"}},
				},
			}, nil
		},
	})

	// init command.
	tmpClient := NewClient(&Options{Addr: ":6379"})
	cmdsInfo, err := tmpClient.Command(ctx).Result()
	_ = tmpClient.Close()
	client.cmdsInfoCache = newCmdsInfoCache(func(_ context.Context) (map[string]*CommandInfo, error) {
		return cmdsInfo, err
	})

	stub.Cmdable = client
	return stub
}

func (c *ClientStub) stubConn(init []byte) *ConnStub {
	return &ConnStub{
		init: init,
		resp: c.resp,
	}
}

type ConnStub struct {
	init []byte
	resp []byte
	pos  int
}

func (c *ConnStub) Read(b []byte) (n int, err error) {
	// Return conn.init()
	if len(c.init) > 0 {
		n = copy(b, c.init)
		c.init = c.init[n:]
		return n, nil
	}

	if len(c.resp) == 0 {
		return 0, io.EOF
	}

	if c.pos >= len(c.resp) {
		c.pos = 0
	}
	n = copy(b, c.resp[c.pos:])
	c.pos += n
	return n, nil
}

func (c *ConnStub) Write(b []byte) (n int, err error)  { return len(b), nil }
func (c *ConnStub) Close() error                       { return nil }
func (c *ConnStub) LocalAddr() net.Addr                { return nil }
func (c *ConnStub) RemoteAddr() net.Addr               { return nil }
func (c *ConnStub) SetDeadline(_ time.Time) error      { return nil }
func (c *ConnStub) SetReadDeadline(_ time.Time) error  { return nil }
func (c *ConnStub) SetWriteDeadline(_ time.Time) error { return nil }

type ClientStubFunc func([]byte) *ClientStub

func BenchmarkDecode(b *testing.B) {
	type Benchmark struct {
		name string
		stub ClientStubFunc
	}

	benchmarks := []Benchmark{
		{"server", NewClientStub},
		{"cluster", NewClusterClientStub},
	}

	for _, bench := range benchmarks {
		b.Run(fmt.Sprintf("RespError-%s", bench.name), func(b *testing.B) {
			respError(b, bench.stub)
		})
		b.Run(fmt.Sprintf("RespStatus-%s", bench.name), func(b *testing.B) {
			respStatus(b, bench.stub)
		})
		b.Run(fmt.Sprintf("RespInt-%s", bench.name), func(b *testing.B) {
			respInt(b, bench.stub)
		})
		b.Run(fmt.Sprintf("RespString-%s", bench.name), func(b *testing.B) {
			respString(b, bench.stub)
		})
		b.Run(fmt.Sprintf("RespArray-%s", bench.name), func(b *testing.B) {
			respArray(b, bench.stub)
		})
		b.Run(fmt.Sprintf("RespPipeline-%s", bench.name), func(b *testing.B) {
			respPipeline(b, bench.stub)
		})
		b.Run(fmt.Sprintf("RespTxPipeline-%s", bench.name), func(b *testing.B) {
			respTxPipeline(b, bench.stub)
		})

		// goroutine
		b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=5", bench.name), func(b *testing.B) {
			dynamicGoroutine(b, bench.stub, 5)
		})
		b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=20", bench.name), func(b *testing.B) {
			dynamicGoroutine(b, bench.stub, 20)
		})
		b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=50", bench.name), func(b *testing.B) {
			dynamicGoroutine(b, bench.stub, 50)
		})
		b.Run(fmt.Sprintf("DynamicGoroutine-%s-pool=100", bench.name), func(b *testing.B) {
			dynamicGoroutine(b, bench.stub, 100)
		})

		b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=5", bench.name), func(b *testing.B) {
			staticGoroutine(b, bench.stub, 5)
		})
		b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=20", bench.name), func(b *testing.B) {
			staticGoroutine(b, bench.stub, 20)
		})
		b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=50", bench.name), func(b *testing.B) {
			staticGoroutine(b, bench.stub, 50)
		})
		b.Run(fmt.Sprintf("StaticGoroutine-%s-pool=100", bench.name), func(b *testing.B) {
			staticGoroutine(b, bench.stub, 100)
		})
	}
}

func respError(b *testing.B, stub ClientStubFunc) {
	rdb := stub([]byte("-ERR test error\r\n"))
	respErr := proto.RedisError("ERR test error")

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		if err := rdb.Get(ctx, "key").Err(); err != respErr {
			b.Fatalf("response error, got %q, want %q", err, respErr)
		}
	}
}

func respStatus(b *testing.B, stub ClientStubFunc) {
	rdb := stub([]byte("+OK\r\n"))
	var val string

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		if val = rdb.Set(ctx, "key", "value", 0).Val(); val != "OK" {
			b.Fatalf("response error, got %q, want OK", val)
		}
	}
}

func respInt(b *testing.B, stub ClientStubFunc) {
	rdb := stub([]byte(":10\r\n"))
	var val int64

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		if val = rdb.Incr(ctx, "key").Val(); val != 10 {
			b.Fatalf("response error, got %q, want 10", val)
		}
	}
}

func respString(b *testing.B, stub ClientStubFunc) {
	rdb := stub([]byte("$5\r\nhello\r\n"))
	var val string

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		if val = rdb.Get(ctx, "key").Val(); val != "hello" {
			b.Fatalf("response error, got %q, want hello", val)
		}
	}
}

func respArray(b *testing.B, stub ClientStubFunc) {
	rdb := stub([]byte("*3\r\n$5\r\nhello\r\n:10\r\n+OK\r\n"))
	var val []interface{}

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		if val = rdb.MGet(ctx, "key").Val(); len(val) != 3 {
			b.Fatalf("response error, got len(%d), want len(3)", len(val))
		}
	}
}

func respPipeline(b *testing.B, stub ClientStubFunc) {
	rdb := stub([]byte("+OK\r\n$5\r\nhello\r\n:1\r\n"))
	var pipe Pipeliner

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		pipe = rdb.Pipeline()
		set := pipe.Set(ctx, "key", "value", 0)
		get := pipe.Get(ctx, "key")
		del := pipe.Del(ctx, "key")
		_, err := pipe.Exec(ctx)
		if err != nil {
			b.Fatalf("response error, got %q, want nil", err)
		}
		if set.Val() != "OK" || get.Val() != "hello" || del.Val() != 1 {
			b.Fatal("response error")
		}
	}
}

func respTxPipeline(b *testing.B, stub ClientStubFunc) {
	rdb := stub([]byte("+OK\r\n+QUEUED\r\n+QUEUED\r\n+QUEUED\r\n*3\r\n+OK\r\n$5\r\nhello\r\n:1\r\n"))

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		var set *StatusCmd
		var get *StringCmd
		var del *IntCmd
		_, err := rdb.TxPipelined(ctx, func(pipe Pipeliner) error {
			set = pipe.Set(ctx, "key", "value", 0)
			get = pipe.Get(ctx, "key")
			del = pipe.Del(ctx, "key")
			return nil
		})
		if err != nil {
			b.Fatalf("response error, got %q, want nil", err)
		}
		if set.Val() != "OK" || get.Val() != "hello" || del.Val() != 1 {
			b.Fatal("response error")
		}
	}
}

func dynamicGoroutine(b *testing.B, stub ClientStubFunc, concurrency int) {
	rdb := stub([]byte("$5\r\nhello\r\n"))
	c := make(chan struct{}, concurrency)

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		c <- struct{}{}
		go func() {
			if val := rdb.Get(ctx, "key").Val(); val != "hello" {
				panic(fmt.Sprintf("response error, got %q, want hello", val))
			}
			<-c
		}()
	}
	// Here no longer wait for all goroutines to complete, it will not affect the test results.
	close(c)
}

func staticGoroutine(b *testing.B, stub ClientStubFunc, concurrency int) {
	rdb := stub([]byte("$5\r\nhello\r\n"))
	c := make(chan struct{}, concurrency)

	b.ResetTimer()

	for i := 0; i < concurrency; i++ {
		go func() {
			for {
				_, ok := <-c
				if !ok {
					return
				}
				if val := rdb.Get(ctx, "key").Val(); val != "hello" {
					panic(fmt.Sprintf("response error, got %q, want hello", val))
				}
			}
		}()
	}
	for i := 0; i < b.N; i++ {
		c <- struct{}{}
	}
	close(c)
}

Related articles

redis tracing

package redisotel import ( "context" "fmt" "net" "runtime" "strings" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" "go.opentelemetry.io/otel/trace" "github.com/redis/

redis conn_check_dummy

//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos // +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos package pool import "net" func connCheck(conn net.Conn) error { retu

redis rediscmd_test

package rediscmd import ( "testing" . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" ) func TestGinkgo(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "redisext") } var _ = Describe("AppendArg", func() { DescribeTable("...", func(

redis pool_single

package pool import "context" type SingleConnPool struct { pool Pooler cn *Conn stickyErr error } var _ Pooler = (*SingleConnPool)(nil) func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { return &SingleConnPool{ pool: po

redis iterator

package redis import ( "context" ) // ScanIterator is used to incrementally iterate over a collection of elements. type ScanIterator struct { cmd *ScanCmd pos int } // Err returns the last iterator error, if any. func (it *ScanIterator) Err() error

redis release

#!/bin/bash set -e help() { cat <<- EOF Usage: TAG=tag $0 Updates version in go.mod files and pushes a new brash to GitHub. VARIABLES: TAG git tag, for example, v1.0.0 EOF exit 0 } if [ -z "$TAG" ] then printf "TAG is required\n

redis once

/* Copyright 2014 The Camlistore Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENS

redis reader_test

package proto_test import ( "bytes" "io" "testing" "github.com/redis/go-redis/v9/internal/proto" ) func BenchmarkReader_ParseReply_Status(b *testing.B) { benchmarkParseReply(b, "+OK\r\n", false) } func BenchmarkReader_ParseReply_Int(b *testing.B

redis cluster

package redis import ( "context" "crypto/tls" "fmt" "math" "net" "net/url" "runtime" "sort" "strings" "sync" "sync/atomic" "time" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hashtag" "github.com/redis/g

redis race_test

package redis_test import ( "bytes" "fmt" "net" "strconv" "sync/atomic" "testing" "time" . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" "github.com/redis/go-redis/v9" ) var _ = Describe("races", func() { var client *redis.Client v