// Copyright (C) 2014 The Protocol Authors.

package protocol

import (
	"bytes"
	"encoding/base32"
	"encoding/binary"
	"errors"
	"fmt"
	"regexp"
	"strings"

	"github.com/calmh/luhn"
	"github.com/syncthing/syncthing/lib/sha256"
)

const DeviceIDLength = 32

type DeviceID [DeviceIDLength]byte
type ShortID uint64

var (
	LocalDeviceID = DeviceID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
	EmptyDeviceID = DeviceID{ /* all zeroes */ }
)

// NewDeviceID generates a new device ID from the raw bytes of a certificate
func NewDeviceID(rawCert []byte) DeviceID {
	var n DeviceID
	hf := sha256.New()
	hf.Write(rawCert)
	hf.Sum(n[:0])
	return n
}

func DeviceIDFromString(s string) (DeviceID, error) {
	var n DeviceID
	err := n.UnmarshalText([]byte(s))
	return n, err
}

func DeviceIDFromBytes(bs []byte) DeviceID {
	var n DeviceID
	if len(bs) != len(n) {
		panic("incorrect length of byte slice representing device ID")
	}
	copy(n[:], bs)
	return n
}

// String returns the canonical string representation of the device ID
func (n DeviceID) String() string {
	if n == EmptyDeviceID {
		return ""
	}
	id := base32.StdEncoding.EncodeToString(n[:])
	id = strings.Trim(id, "=")
	id, err := luhnify(id)
	if err != nil {
		// Should never happen
		panic(err)
	}
	id = chunkify(id)
	return id
}

func (n DeviceID) GoString() string {
	return n.String()
}

func (n DeviceID) Compare(other DeviceID) int {
	return bytes.Compare(n[:], other[:])
}

func (n DeviceID) Equals(other DeviceID) bool {
	return bytes.Equal(n[:], other[:])
}

// Short returns an integer representing bits 0-63 of the device ID.
func (n DeviceID) Short() ShortID {
	return ShortID(binary.BigEndian.Uint64(n[:]))
}

func (n *DeviceID) MarshalText() ([]byte, error) {
	return []byte(n.String()), nil
}

func (s ShortID) String() string {
	if s == 0 {
		return ""
	}
	var bs [8]byte
	binary.BigEndian.PutUint64(bs[:], uint64(s))
	return base32.StdEncoding.EncodeToString(bs[:])[:7]
}

func (n *DeviceID) UnmarshalText(bs []byte) error {
	id := string(bs)
	id = strings.Trim(id, "=")
	id = strings.ToUpper(id)
	id = untypeoify(id)
	id = unchunkify(id)

	var err error
	switch len(id) {
	case 0:
		*n = EmptyDeviceID
		return nil
	case 56:
		// New style, with check digits
		id, err = unluhnify(id)
		if err != nil {
			return err
		}
		fallthrough
	case 52:
		// Old style, no check digits
		dec, err := base32.StdEncoding.DecodeString(id + "====")
		if err != nil {
			return err
		}
		copy(n[:], dec)
		return nil
	default:
		return fmt.Errorf("%q: device ID invalid: incorrect length", bs)
	}
}

func (n *DeviceID) ProtoSize() int {
	// Used by protobuf marshaller.
	return DeviceIDLength
}

func (n *DeviceID) MarshalTo(bs []byte) (int, error) {
	// Used by protobuf marshaller.
	if len(bs) < DeviceIDLength {
		return 0, errors.New("destination too short")
	}
	copy(bs, (*n)[:])
	return DeviceIDLength, nil
}

func (n *DeviceID) Unmarshal(bs []byte) error {
	// Used by protobuf marshaller.
	if len(bs) < DeviceIDLength {
		return fmt.Errorf("%q: not enough data", bs)
	}
	copy((*n)[:], bs)
	return nil
}

func luhnify(s string) (string, error) {
	if len(s) != 52 {
		panic("unsupported string length")
	}

	res := make([]string, 0, 4)
	for i := 0; i < 4; i++ {
		p := s[i*13 : (i+1)*13]
		l, err := luhn.Base32.Generate(p)
		if err != nil {
			return "", err
		}
		res = append(res, fmt.Sprintf("%s%c", p, l))
	}
	return res[0] + res[1] + res[2] + res[3], nil
}

func unluhnify(s string) (string, error) {
	if len(s) != 56 {
		return "", fmt.Errorf("%q: unsupported string length %d", s, len(s))
	}

	res := make([]string, 0, 4)
	for i := 0; i < 4; i++ {
		p := s[i*14 : (i+1)*14-1]
		l, err := luhn.Base32.Generate(p)
		if err != nil {
			return "", err
		}
		if g := fmt.Sprintf("%s%c", p, l); g != s[i*14:(i+1)*14] {
			return "", fmt.Errorf("%q: check digit incorrect", s)
		}
		res = append(res, p)
	}
	return res[0] + res[1] + res[2] + res[3], nil
}

func chunkify(s string) string {
	s = regexp.MustCompile("(.{7})").ReplaceAllString(s, "$1-")
	s = strings.Trim(s, "-")
	return s
}

func unchunkify(s string) string {
	s = strings.Replace(s, "-", "", -1)
	s = strings.Replace(s, " ", "", -1)
	return s
}

func untypeoify(s string) string {
	s = strings.Replace(s, "0", "O", -1)
	s = strings.Replace(s, "1", "I", -1)
	s = strings.Replace(s, "8", "B", -1)
	return s
}

// DeviceIDs is a sortable slice of DeviceID
type DeviceIDs []DeviceID

func (l DeviceIDs) Len() int {
	return len(l)
}

func (l DeviceIDs) Less(a, b int) bool {
	return l[a].Compare(l[b]) == -1
}

func (l DeviceIDs) Swap(a, b int) {
	l[a], l[b] = l[b], l[a]
}
