feat: implement key generation, encryption, and decryption functions with tests
This commit is contained in:
90
hsp/crypt.go
Normal file
90
hsp/crypt.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package hsp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyPair struct {
|
||||||
|
Private [32]byte
|
||||||
|
Public [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKeyPair(publicKey, privateKey [32]byte) *KeyPair {
|
||||||
|
return &KeyPair{
|
||||||
|
Public: publicKey,
|
||||||
|
Private: privateKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateKeyPair() (pair *KeyPair, err error) {
|
||||||
|
privateKey := make([]byte, 32)
|
||||||
|
publicKey := make([]byte, 32)
|
||||||
|
_, err = rand.Read(privateKey[:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey[0] &= 248
|
||||||
|
privateKey[31] &= 127
|
||||||
|
privateKey[31] |= 64
|
||||||
|
|
||||||
|
curve25519.ScalarBaseMult((*[32]byte)(publicKey), (*[32]byte)(privateKey))
|
||||||
|
return NewKeyPair([32]byte(publicKey), [32]byte(privateKey)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeriveSharedKey(privateKey, peerPublicKey [32]byte) (sharedKey [32]byte, err error) {
|
||||||
|
generated, err := curve25519.X25519(privateKey[:], peerPublicKey[:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedKey = [32]byte(generated)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func Encrypt(key []byte, data []byte) (encrypted []byte, nonce []byte, err error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce = make([]byte, aesGCM.NonceSize())
|
||||||
|
_, err = io.ReadFull(rand.Reader, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted = aesGCM.Seal(nil, nonce, data, nil)
|
||||||
|
|
||||||
|
return encrypted, nonce, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Decrypt(key []byte, nonce []byte, encrypted []byte) (data []byte, err error) {
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
aesGCM, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = aesGCM.Open(nil, nonce, encrypted, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
93
hsp/crypt_test.go
Normal file
93
hsp/crypt_test.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package hsp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateAndDerive(t *testing.T) {
|
||||||
|
clientKeys, err := GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("ERR: Failed to generate client keys:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverKeys, err := GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("ERR: Failed to generate server keys:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientShared, err := DeriveSharedKey(clientKeys.Private, serverKeys.Public)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("ERR: Failed to generate client shared key:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverShared, err := DeriveSharedKey(serverKeys.Private, clientKeys.Public)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("ERR: Failed to generate server shared key:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientHash := sha256.Sum256(clientShared[:])
|
||||||
|
serverHash := sha256.Sum256(serverShared[:])
|
||||||
|
|
||||||
|
t.Logf("Client shared key: %x\n", clientHash)
|
||||||
|
t.Logf("Server shared key: %x\n", serverHash)
|
||||||
|
|
||||||
|
// Check they match
|
||||||
|
if clientHash == serverHash {
|
||||||
|
t.Log("🎉 Secure shared key established! 🎉")
|
||||||
|
} else {
|
||||||
|
t.Log("❌ Something went wrong ❌")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecrypt(t *testing.T) {
|
||||||
|
clientKeys, err := GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("ERR: Failed to generate client keys:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverKeys, err := GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("ERR: Failed to generate server keys:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientShared, err := DeriveSharedKey(clientKeys.Private, serverKeys.Public)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("ERR: Failed to generate client shared key:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serverShared, err := DeriveSharedKey(serverKeys.Private, clientKeys.Public)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("ERR: Failed to generate server shared key:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientHash := sha256.Sum256(clientShared[:])
|
||||||
|
serverHash := sha256.Sum256(serverShared[:])
|
||||||
|
|
||||||
|
msg := []byte("Hello, World!")
|
||||||
|
|
||||||
|
data, nonce, err := Encrypt(clientHash[:], msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("ERR: Failed to encrypt data:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := Decrypt(serverHash[:], nonce, data)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("ERR: Failed to decrypt data:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(msg, decrypted) {
|
||||||
|
t.Error("Plain data doesn't match decrypted one")
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user