Files
hpke/hpke_test.go

417 lines
12 KiB
Go

// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hpke
import (
"bytes"
"crypto/ecdh"
"crypto/sha3"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"os"
"testing"
)
func Example() {
// In this example, we use MLKEM768-X25519 as the KEM, HKDF-SHA256 as the
// KDF, and AES-256-GCM as the AEAD to encrypt a single message from a
// sender to a recipient using the one-shot API.
kem, kdf, aead := MLKEM768X25519(), HKDFSHA256(), AES256GCM()
// Recipient side
var (
recipientPrivateKey PrivateKey
publicKeyBytes []byte
)
{
k, err := kem.GenerateKey()
if err != nil {
panic(err)
}
recipientPrivateKey = k
publicKeyBytes = k.PublicKey().Bytes()
}
// Sender side
var ciphertext []byte
{
publicKey, err := kem.NewPublicKey(publicKeyBytes)
if err != nil {
panic(err)
}
message := []byte("|-()-|")
ct, err := Seal(publicKey, kdf, aead, []byte("example"), message)
if err != nil {
panic(err)
}
ciphertext = ct
}
// Recipient side
{
plaintext, err := Open(recipientPrivateKey, kdf, aead, []byte("example"), ciphertext)
if err != nil {
panic(err)
}
fmt.Printf("Decrypted message: %s\n", plaintext)
}
// Output:
// Decrypted message: |-()-|
}
func mustDecodeHex(t *testing.T, in string) []byte {
t.Helper()
b, err := hex.DecodeString(in)
if err != nil {
t.Fatal(err)
}
return b
}
func TestVectors(t *testing.T) {
t.Run("rfc9180", func(t *testing.T) {
testVectors(t, "rfc9180")
})
t.Run("hpke-pq", func(t *testing.T) {
testVectors(t, "hpke-pq")
})
}
func testVectors(t *testing.T, name string) {
vectorsJSON, err := os.ReadFile("testdata/" + name + ".json")
if err != nil {
t.Fatal(err)
}
var vectors []struct {
Mode uint16 `json:"mode"`
KEM uint16 `json:"kem_id"`
KDF uint16 `json:"kdf_id"`
AEAD uint16 `json:"aead_id"`
Info string `json:"info"`
IkmE string `json:"ikmE"`
IkmR string `json:"ikmR"`
SkRm string `json:"skRm"`
PkRm string `json:"pkRm"`
Enc string `json:"enc"`
Encryptions []struct {
Aad string `json:"aad"`
Ct string `json:"ct"`
Nonce string `json:"nonce"`
Pt string `json:"pt"`
} `json:"encryptions"`
Exports []struct {
Context string `json:"exporter_context"`
L int `json:"L"`
Value string `json:"exported_value"`
} `json:"exports"`
// Instead of checking in a very large rfc9180.json, we computed
// alternative accumulated values.
AccEncryptions string `json:"encryptions_accumulated"`
AccExports string `json:"exports_accumulated"`
}
if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
t.Fatal(err)
}
for _, vector := range vectors {
name := fmt.Sprintf("mode %04x kem %04x kdf %04x aead %04x",
vector.Mode, vector.KEM, vector.KDF, vector.AEAD)
t.Run(name, func(t *testing.T) {
if vector.Mode != 0 {
t.Skip("only mode 0 (base) is supported")
}
if vector.KEM == 0x0021 {
t.Skip("KEM 0x0021 (DHKEM(X448)) not supported")
}
if vector.KEM == 0x0040 {
t.Skip("KEM 0x0040 (ML-KEM-512) not supported")
}
if vector.KDF == 0x0012 || vector.KDF == 0x0013 {
t.Skipf("TurboSHAKE KDF not supported")
}
kdf, err := NewKDF(vector.KDF)
if err != nil {
t.Fatal(err)
}
if kdf.ID() != vector.KDF {
t.Errorf("unexpected KDF ID: got %04x, want %04x", kdf.ID(), vector.KDF)
}
aead, err := NewAEAD(vector.AEAD)
if err != nil {
t.Fatal(err)
}
if aead.ID() != vector.AEAD {
t.Errorf("unexpected AEAD ID: got %04x, want %04x", aead.ID(), vector.AEAD)
}
kem, err := NewKEM(vector.KEM)
if err != nil {
t.Fatal(err)
}
if kem.ID() != vector.KEM {
t.Errorf("unexpected KEM ID: got %04x, want %04x", kem.ID(), vector.KEM)
}
pubKeyBytes := mustDecodeHex(t, vector.PkRm)
kemSender, err := kem.NewPublicKey(pubKeyBytes)
if err != nil {
t.Fatal(err)
}
if kemSender.KEM() != kem {
t.Errorf("unexpected KEM from sender: got %04x, want %04x", kemSender.KEM().ID(), kem.ID())
}
if !bytes.Equal(kemSender.Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM bytes: got %x, want %x", kemSender.Bytes(), pubKeyBytes)
}
ikmE := mustDecodeHex(t, vector.IkmE)
info := mustDecodeHex(t, vector.Info)
encap, sender, err := NewSenderWithTestingRandomness(kemSender, ikmE, kdf, aead, info)
if err != nil {
t.Fatal(err)
}
if len(encap) != kem.encSize() {
t.Errorf("unexpected encapsulated key size: got %d, want %d", len(encap), kem.encSize())
}
expectedEncap := mustDecodeHex(t, vector.Enc)
if !bytes.Equal(encap, expectedEncap) {
t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
}
privKeyBytes := mustDecodeHex(t, vector.SkRm)
kemRecipient, err := kem.NewPrivateKey(privKeyBytes)
if err != nil {
t.Fatal(err)
}
if kemRecipient.KEM() != kem {
t.Errorf("unexpected KEM from recipient: got %04x, want %04x", kemRecipient.KEM().ID(), kem.ID())
}
kemRecipientBytes, err := kemRecipient.Bytes()
if err != nil {
t.Fatal(err)
}
// X25519 serialized keys must be clamped, so the bytes might not match.
if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
t.Errorf("unexpected KEM bytes: got %x, want %x", kemRecipientBytes, privKeyBytes)
}
if vector.KEM == DHKEM(ecdh.X25519()).ID() {
kem2, err := kem.NewPrivateKey(kemRecipientBytes)
if err != nil {
t.Fatal(err)
}
kemRecipientBytes2, err := kem2.Bytes()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(kemRecipientBytes2, kemRecipientBytes) {
t.Errorf("X25519 re-serialized key differs: got %x, want %x", kemRecipientBytes2, kemRecipientBytes)
}
if !bytes.Equal(kem2.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.PublicKey().Bytes(), pubKeyBytes)
}
}
if !bytes.Equal(kemRecipient.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.PublicKey().Bytes(), pubKeyBytes)
}
ikm := mustDecodeHex(t, vector.IkmR)
derivRecipient, err := kem.DeriveKeyPair(ikm)
if err != nil {
t.Fatal(err)
}
derivRecipientBytes, err := derivRecipient.Bytes()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(derivRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
t.Errorf("unexpected KEM bytes from seed: got %x, want %x", derivRecipientBytes, privKeyBytes)
}
if !bytes.Equal(derivRecipient.PublicKey().Bytes(), pubKeyBytes) {
t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", derivRecipient.PublicKey().Bytes(), pubKeyBytes)
}
recipient, err := NewRecipient(encap, kemRecipient, kdf, aead, info)
if err != nil {
t.Fatal(err)
}
if aead != ExportOnly() && len(vector.AccEncryptions) != 0 {
source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
for range 1000 {
aad, plaintext := drawRandomInput(t, source), drawRandomInput(t, source)
ciphertext, err := sender.Seal(aad, plaintext)
if err != nil {
t.Fatal(err)
}
sink.Write(ciphertext)
got, err := recipient.Open(aad, ciphertext)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, plaintext) {
t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
}
}
encryptions := make([]byte, 16)
sink.Read(encryptions)
expectedEncryptions := mustDecodeHex(t, vector.AccEncryptions)
if !bytes.Equal(encryptions, expectedEncryptions) {
t.Errorf("unexpected accumulated encryptions, got: %x, want %x", encryptions, expectedEncryptions)
}
} else if aead != ExportOnly() {
for _, enc := range vector.Encryptions {
aad := mustDecodeHex(t, enc.Aad)
plaintext := mustDecodeHex(t, enc.Pt)
expectedCiphertext := mustDecodeHex(t, enc.Ct)
ciphertext, err := sender.Seal(aad, plaintext)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(ciphertext, expectedCiphertext) {
t.Errorf("unexpected ciphertext, got: %x, want %x", ciphertext, expectedCiphertext)
}
got, err := recipient.Open(aad, ciphertext)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, plaintext) {
t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
}
}
} else {
if _, err := sender.Seal(nil, nil); err == nil {
t.Error("expected error from Seal with export-only AEAD")
}
if _, err := recipient.Open(nil, nil); err == nil {
t.Error("expected error from Open with export-only AEAD")
}
}
if len(vector.AccExports) != 0 {
source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
for l := range 1000 {
context := string(drawRandomInput(t, source))
value, err := sender.Export(context, l)
if err != nil {
t.Fatal(err)
}
sink.Write(value)
got, err := recipient.Export(context, l)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, value) {
t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
}
}
exports := make([]byte, 16)
sink.Read(exports)
expectedExports := mustDecodeHex(t, vector.AccExports)
if !bytes.Equal(exports, expectedExports) {
t.Errorf("unexpected accumulated exports, got: %x, want %x", exports, expectedExports)
}
} else {
for _, exp := range vector.Exports {
context := string(mustDecodeHex(t, exp.Context))
expectedValue := mustDecodeHex(t, exp.Value)
value, err := sender.Export(context, exp.L)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(value, expectedValue) {
t.Errorf("unexpected exported value, got: %x, want %x", value, expectedValue)
}
got, err := recipient.Export(context, exp.L)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, value) {
t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
}
}
}
})
}
}
func drawRandomInput(t *testing.T, r io.Reader) []byte {
t.Helper()
l := make([]byte, 1)
if _, err := r.Read(l); err != nil {
t.Fatal(err)
}
n := int(l[0])
b := make([]byte, n)
if _, err := r.Read(b); err != nil {
t.Fatal(err)
}
return b
}
func TestSingletons(t *testing.T) {
if HKDFSHA256() != HKDFSHA256() {
t.Error("HKDFSHA256() != HKDFSHA256()")
}
if HKDFSHA384() != HKDFSHA384() {
t.Error("HKDFSHA384() != HKDFSHA384()")
}
if HKDFSHA512() != HKDFSHA512() {
t.Error("HKDFSHA512() != HKDFSHA512()")
}
if AES128GCM() != AES128GCM() {
t.Error("AES128GCM() != AES128GCM()")
}
if AES256GCM() != AES256GCM() {
t.Error("AES256GCM() != AES256GCM()")
}
if ChaCha20Poly1305() != ChaCha20Poly1305() {
t.Error("ChaCha20Poly1305() != ChaCha20Poly1305()")
}
if ExportOnly() != ExportOnly() {
t.Error("ExportOnly() != ExportOnly()")
}
if DHKEM(ecdh.P256()) != DHKEM(ecdh.P256()) {
t.Error("DHKEM(P-256) != DHKEM(P-256)")
}
if DHKEM(ecdh.P384()) != DHKEM(ecdh.P384()) {
t.Error("DHKEM(P-384) != DHKEM(P-384)")
}
if DHKEM(ecdh.P521()) != DHKEM(ecdh.P521()) {
t.Error("DHKEM(P-521) != DHKEM(P-521)")
}
if DHKEM(ecdh.X25519()) != DHKEM(ecdh.X25519()) {
t.Error("DHKEM(X25519) != DHKEM(X25519)")
}
if MLKEM768() != MLKEM768() {
t.Error("MLKEM768() != MLKEM768()")
}
if MLKEM1024() != MLKEM1024() {
t.Error("MLKEM1024() != MLKEM1024()")
}
if MLKEM768X25519() != MLKEM768X25519() {
t.Error("MLKEM768X25519() != MLKEM768X25519()")
}
if MLKEM768P256() != MLKEM768P256() {
t.Error("MLKEM768P256() != MLKEM768P256()")
}
if MLKEM1024P384() != MLKEM1024P384() {
t.Error("MLKEM1024P384() != MLKEM1024P384()")
}
}