package stdcopy

import (
	"bytes"
	"errors"
	"io"
	"io/ioutil"
	"strings"
	"testing"
)

func TestNewStdWriter(t *testing.T) {
	writer := NewStdWriter(ioutil.Discard, Stdout)
	if writer == nil {
		t.Fatalf("NewStdWriter with an invalid StdType should not return nil.")
	}
}

func TestWriteWithUninitializedStdWriter(t *testing.T) {
	writer := stdWriter{
		Writer: nil,
		prefix: byte(Stdout),
	}
	n, err := writer.Write([]byte("Something here"))
	if n != 0 || err == nil {
		t.Fatalf("Should fail when given an uncomplete or uninitialized StdWriter")
	}
}

func TestWriteWithNilBytes(t *testing.T) {
	writer := NewStdWriter(ioutil.Discard, Stdout)
	n, err := writer.Write(nil)
	if err != nil {
		t.Fatalf("Shouldn't have fail when given no data")
	}
	if n > 0 {
		t.Fatalf("Write should have written 0 byte, but has written %d", n)
	}
}

func TestWrite(t *testing.T) {
	writer := NewStdWriter(ioutil.Discard, Stdout)
	data := []byte("Test StdWrite.Write")
	n, err := writer.Write(data)
	if err != nil {
		t.Fatalf("Error while writing with StdWrite")
	}
	if n != len(data) {
		t.Fatalf("Write should have written %d byte but wrote %d.", len(data), n)
	}
}

type errWriter struct {
	n   int
	err error
}

func (f *errWriter) Write(buf []byte) (int, error) {
	return f.n, f.err
}

func TestWriteWithWriterError(t *testing.T) {
	expectedError := errors.New("expected")
	expectedReturnedBytes := 10
	writer := NewStdWriter(&errWriter{
		n:   stdWriterPrefixLen + expectedReturnedBytes,
		err: expectedError}, Stdout)
	data := []byte("This won't get written, sigh")
	n, err := writer.Write(data)
	if err != expectedError {
		t.Fatalf("Didn't get expected error.")
	}
	if n != expectedReturnedBytes {
		t.Fatalf("Didn't get expected written bytes %d, got %d.",
			expectedReturnedBytes, n)
	}
}

func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
	writer := NewStdWriter(&errWriter{n: -1}, Stdout)
	data := []byte("This won't get written, sigh")
	actual, _ := writer.Write(data)
	if actual != 0 {
		t.Fatalf("Expected returned written bytes equal to 0, got %d", actual)
	}
}

func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (buffer *bytes.Buffer, err error) {
	buffer = new(bytes.Buffer)
	dstOut := NewStdWriter(buffer, Stdout)
	_, err = dstOut.Write(stdOutBytes)
	if err != nil {
		return
	}
	dstErr := NewStdWriter(buffer, Stderr)
	_, err = dstErr.Write(stdErrBytes)
	return
}

func TestStdCopyWriteAndRead(t *testing.T) {
	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
	if err != nil {
		t.Fatal(err)
	}
	written, err := StdCopy(ioutil.Discard, ioutil.Discard, buffer)
	if err != nil {
		t.Fatal(err)
	}
	expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes)
	if written != int64(expectedTotalWritten) {
		t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written)
	}
}

type customReader struct {
	n            int
	err          error
	totalCalls   int
	correctCalls int
	src          *bytes.Buffer
}

func (f *customReader) Read(buf []byte) (int, error) {
	f.totalCalls++
	if f.totalCalls <= f.correctCalls {
		return f.src.Read(buf)
	}
	return f.n, f.err
}

func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
	expectedError := errors.New("error")
	reader := &customReader{
		err: expectedError}
	written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
	if written != 0 {
		t.Fatalf("Expected 0 bytes read, got %d", written)
	}
	if err != expectedError {
		t.Fatalf("Didn't get expected error")
	}
}

func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
	expectedError := errors.New("error")
	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
	if err != nil {
		t.Fatal(err)
	}
	reader := &customReader{
		correctCalls: 1,
		n:            stdWriterPrefixLen + 1,
		err:          expectedError,
		src:          buffer}
	written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
	if written != 0 {
		t.Fatalf("Expected 0 bytes read, got %d", written)
	}
	if err != expectedError {
		t.Fatalf("Didn't get expected error")
	}
}

func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
	if err != nil {
		t.Fatal(err)
	}
	reader := &customReader{
		correctCalls: 1,
		n:            stdWriterPrefixLen + 1,
		err:          io.EOF,
		src:          buffer}
	written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
	if written != startingBufLen {
		t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
	}
	if err != nil {
		t.Fatal("Didn't get nil error")
	}
}

func TestStdCopyWithInvalidInputHeader(t *testing.T) {
	dstOut := NewStdWriter(ioutil.Discard, Stdout)
	dstErr := NewStdWriter(ioutil.Discard, Stderr)
	src := strings.NewReader("Invalid input")
	_, err := StdCopy(dstOut, dstErr, src)
	if err == nil {
		t.Fatal("StdCopy with invalid input header should fail.")
	}
}

func TestStdCopyWithCorruptedPrefix(t *testing.T) {
	data := []byte{0x01, 0x02, 0x03}
	src := bytes.NewReader(data)
	written, err := StdCopy(nil, nil, src)
	if err != nil {
		t.Fatalf("StdCopy should not return an error with corrupted prefix.")
	}
	if written != 0 {
		t.Fatalf("StdCopy should have written 0, but has written %d", written)
	}
}

func TestStdCopyReturnsWriteErrors(t *testing.T) {
	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
	if err != nil {
		t.Fatal(err)
	}
	expectedError := errors.New("expected")

	dstOut := &errWriter{err: expectedError}

	written, err := StdCopy(dstOut, ioutil.Discard, buffer)
	if written != 0 {
		t.Fatalf("StdCopy should have written 0, but has written %d", written)
	}
	if err != expectedError {
		t.Fatalf("Didn't get expected error, got %v", err)
	}
}

func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
	if err != nil {
		t.Fatal(err)
	}
	dstOut := &errWriter{n: startingBufLen - 10}

	written, err := StdCopy(dstOut, ioutil.Discard, buffer)
	if written != 0 {
		t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
	}
	if err != io.ErrShortWrite {
		t.Fatalf("Didn't get expected io.ErrShortWrite error")
	}
}

// TestStdCopyReturnsErrorFromSystem tests that StdCopy correctly returns an
// error, when that error is muxed into the Systemerr stream.
func TestStdCopyReturnsErrorFromSystem(t *testing.T) {
	// write in the basic messages, just so there's some fluff in there
	stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
	stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
	buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
	if err != nil {
		t.Fatal(err)
	}
	// add in an error message on the Systemerr stream
	systemErrBytes := []byte(strings.Repeat("S", startingBufLen))
	systemWriter := NewStdWriter(buffer, Systemerr)
	_, err = systemWriter.Write(systemErrBytes)
	if err != nil {
		t.Fatal(err)
	}

	// now copy and demux. we should expect an error containing the string we
	// wrote out
	_, err = StdCopy(ioutil.Discard, ioutil.Discard, buffer)
	if err == nil {
		t.Fatal("expected error, got none")
	}
	if !strings.Contains(err.Error(), string(systemErrBytes)) {
		t.Fatal("expected error to contain message")
	}
}

func BenchmarkWrite(b *testing.B) {
	w := NewStdWriter(ioutil.Discard, Stdout)
	data := []byte("Test line for testing stdwriter performance\n")
	data = bytes.Repeat(data, 100)
	b.SetBytes(int64(len(data)))
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		if _, err := w.Write(data); err != nil {
			b.Fatal(err)
		}
	}
}