diff --git a/src/common/dpp.c b/src/common/dpp.c index bfc0fde27..5bc021c82 100644 --- a/src/common/dpp.c +++ b/src/common/dpp.c @@ -61,6 +61,69 @@ static const struct dpp_curve_params dpp_curves[] = { }; +static int dpp_hash_vector(const struct dpp_curve_params *curve, + size_t num_elem, const u8 *addr[], const size_t *len, + u8 *mac) +{ + if (curve->hash_len == 32) + return sha256_vector(num_elem, addr, len, mac); + if (curve->hash_len == 48) + return sha384_vector(num_elem, addr, len, mac); + if (curve->hash_len == 64) + return sha512_vector(num_elem, addr, len, mac); + return -1; +} + + +static int dpp_hkdf_expand(size_t hash_len, const u8 *secret, size_t secret_len, + const char *label, u8 *out, size_t outlen) +{ + if (hash_len == 32) + return hmac_sha256_kdf(secret, secret_len, NULL, + (const u8 *) label, os_strlen(label), + out, outlen); + if (hash_len == 48) + return hmac_sha384_kdf(secret, secret_len, NULL, + (const u8 *) label, os_strlen(label), + out, outlen); + if (hash_len == 64) + return hmac_sha512_kdf(secret, secret_len, NULL, + (const u8 *) label, os_strlen(label), + out, outlen); + return -1; +} + + +static int dpp_hmac_vector(size_t hash_len, const u8 *key, size_t key_len, + size_t num_elem, const u8 *addr[], + const size_t *len, u8 *mac) +{ + if (hash_len == 32) + return hmac_sha256_vector(key, key_len, num_elem, addr, len, + mac); + if (hash_len == 48) + return hmac_sha384_vector(key, key_len, num_elem, addr, len, + mac); + if (hash_len == 64) + return hmac_sha512_vector(key, key_len, num_elem, addr, len, + mac); + return -1; +} + + +static int dpp_hmac(size_t hash_len, const u8 *key, size_t key_len, + const u8 *data, size_t data_len, u8 *mac) +{ + if (hash_len == 32) + return hmac_sha256(key, key_len, data, data_len, mac); + if (hash_len == 48) + return hmac_sha384(key, key_len, data, data_len, mac); + if (hash_len == 64) + return hmac_sha512(key, key_len, data, data_len, mac); + return -1; +} + + static struct wpabuf * dpp_get_pubkey_point(EVP_PKEY *pkey, int prefix) { int len, res; @@ -903,40 +966,19 @@ static int dpp_derive_k1(const u8 *Mx, size_t Mx_len, u8 *k1, { u8 salt[DPP_MAX_HASH_LEN], prk[DPP_MAX_HASH_LEN]; const char *info = "first intermediate key"; - int res = -1; + int res; /* k1 = HKDF(<>, "first intermediate key", M.x) */ /* HKDF-Extract(<>, M.x) */ os_memset(salt, 0, hash_len); - if (hash_len == 32) { - if (hmac_sha256(salt, SHA256_MAC_LEN, Mx, Mx_len, prk) < 0) - return -1; - } else if (hash_len == 48) { - if (hmac_sha384(salt, SHA384_MAC_LEN, Mx, Mx_len, prk) < 0) - return -1; - } else if (hash_len == 64) { - if (hmac_sha512(salt, SHA512_MAC_LEN, Mx, Mx_len, prk) < 0) - return -1; - } else { + if (dpp_hmac(hash_len, salt, hash_len, Mx, Mx_len, prk) < 0) return -1; - } wpa_hexdump_key(MSG_DEBUG, "DPP: PRK = HKDF-Extract(<>, IKM=M.x)", prk, hash_len); /* HKDF-Expand(PRK, info, L) */ - if (hash_len == 32) - res = hmac_sha256_kdf(prk, SHA256_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - k1, SHA256_MAC_LEN); - else if (hash_len == 48) - res = hmac_sha384_kdf(prk, SHA384_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - k1, SHA384_MAC_LEN); - else if (hash_len == 64) - res = hmac_sha512_kdf(prk, SHA512_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - k1, SHA512_MAC_LEN); + res = dpp_hkdf_expand(hash_len, prk, hash_len, info, k1, hash_len); os_memset(prk, 0, hash_len); if (res < 0) return -1; @@ -958,32 +1000,14 @@ static int dpp_derive_k2(const u8 *Nx, size_t Nx_len, u8 *k2, /* HKDF-Extract(<>, N.x) */ os_memset(salt, 0, hash_len); - if (hash_len == 32) - res = hmac_sha256(salt, SHA256_MAC_LEN, Nx, Nx_len, prk); - else if (hash_len == 48) - res = hmac_sha384(salt, SHA384_MAC_LEN, Nx, Nx_len, prk); - else if (hash_len == 64) - res = hmac_sha512(salt, SHA512_MAC_LEN, Nx, Nx_len, prk); - else - res = -1; + res = dpp_hmac(hash_len, salt, hash_len, Nx, Nx_len, prk); if (res < 0) return -1; wpa_hexdump_key(MSG_DEBUG, "DPP: PRK = HKDF-Extract(<>, IKM=N.x)", prk, hash_len); /* HKDF-Expand(PRK, info, L) */ - if (hash_len == 32) - res = hmac_sha256_kdf(prk, SHA256_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - k2, SHA256_MAC_LEN); - else if (hash_len == 48) - res = hmac_sha384_kdf(prk, SHA384_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - k2, SHA384_MAC_LEN); - else if (hash_len == 64) - res = hmac_sha512_kdf(prk, SHA512_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - k2, SHA512_MAC_LEN); + res = dpp_hkdf_expand(hash_len, prk, hash_len, info, k2, hash_len); os_memset(prk, 0, hash_len); if (res < 0) return -1; @@ -1023,35 +1047,15 @@ static int dpp_derive_ke(struct dpp_authentication *auth, u8 *ke, len[num_elem] = auth->secret_len; num_elem++; } - if (hash_len == 32) - res = hmac_sha256_vector(nonces, 2 * nonce_len, - num_elem, addr, len, prk); - else if (hash_len == 48) - res = hmac_sha384_vector(nonces, 2 * nonce_len, - num_elem, addr, len, prk); - else if (hash_len == 64) - res = hmac_sha512_vector(nonces, 2 * nonce_len, - num_elem, addr, len, prk); - else - res = -1; + res = dpp_hmac_vector(hash_len, nonces, 2 * nonce_len, + num_elem, addr, len, prk); if (res < 0) return -1; wpa_hexdump_key(MSG_DEBUG, "DPP: PRK = HKDF-Extract(<>, IKM)", prk, hash_len); /* HKDF-Expand(PRK, info, L) */ - if (hash_len == 32) - res = hmac_sha256_kdf(prk, SHA256_MAC_LEN, NULL, - (const u8 *) info_ke, os_strlen(info_ke), - ke, SHA256_MAC_LEN); - else if (hash_len == 48) - res = hmac_sha384_kdf(prk, SHA384_MAC_LEN, NULL, - (const u8 *) info_ke, os_strlen(info_ke), - ke, SHA384_MAC_LEN); - else if (hash_len == 64) - res = hmac_sha512_kdf(prk, SHA512_MAC_LEN, NULL, - (const u8 *) info_ke, os_strlen(info_ke), - ke, SHA512_MAC_LEN); + res = dpp_hkdf_expand(hash_len, prk, hash_len, info_ke, ke, hash_len); os_memset(prk, 0, hash_len); if (res < 0) return -1; @@ -1341,14 +1345,7 @@ static int dpp_gen_r_auth(struct dpp_authentication *auth, u8 *r_auth) wpa_printf(MSG_DEBUG, "DPP: R-auth hash components"); for (i = 0; i < num_elem; i++) wpa_hexdump(MSG_DEBUG, "DPP: hash component", addr[i], len[i]); - if (auth->curve->hash_len == 32) - res = sha256_vector(num_elem, addr, len, r_auth); - else if (auth->curve->hash_len == 48) - res = sha384_vector(num_elem, addr, len, r_auth); - else if (auth->curve->hash_len == 64) - res = sha512_vector(num_elem, addr, len, r_auth); - else - res = -1; + res = dpp_hash_vector(auth->curve, num_elem, addr, len, r_auth); if (res == 0) wpa_hexdump(MSG_DEBUG, "DPP: R-auth", r_auth, auth->curve->hash_len); @@ -1431,14 +1428,7 @@ static int dpp_gen_i_auth(struct dpp_authentication *auth, u8 *i_auth) wpa_printf(MSG_DEBUG, "DPP: I-auth hash components"); for (i = 0; i < num_elem; i++) wpa_hexdump(MSG_DEBUG, "DPP: hash component", addr[i], len[i]); - if (auth->curve->hash_len == 32) - res = sha256_vector(num_elem, addr, len, i_auth); - else if (auth->curve->hash_len == 48) - res = sha384_vector(num_elem, addr, len, i_auth); - else if (auth->curve->hash_len == 64) - res = sha512_vector(num_elem, addr, len, i_auth); - else - res = -1; + res = dpp_hash_vector(auth->curve, num_elem, addr, len, i_auth); if (res == 0) wpa_hexdump(MSG_DEBUG, "DPP: I-auth", i_auth, auth->curve->hash_len); @@ -4338,40 +4328,19 @@ static int dpp_derive_pmk(const u8 *Nx, size_t Nx_len, u8 *pmk, { u8 salt[DPP_MAX_HASH_LEN], prk[DPP_MAX_HASH_LEN]; const char *info = "DPP PMK"; - int res = -1; + int res; /* PMK = HKDF(<>, "DPP PMK", N.x) */ /* HKDF-Extract(<>, N.x) */ os_memset(salt, 0, hash_len); - if (hash_len == 32) { - if (hmac_sha256(salt, SHA256_MAC_LEN, Nx, Nx_len, prk) < 0) - return -1; - } else if (hash_len == 48) { - if (hmac_sha384(salt, SHA384_MAC_LEN, Nx, Nx_len, prk) < 0) - return -1; - } else if (hash_len == 64) { - if (hmac_sha512(salt, SHA512_MAC_LEN, Nx, Nx_len, prk) < 0) - return -1; - } else { + if (dpp_hmac(hash_len, salt, hash_len, Nx, Nx_len, prk) < 0) return -1; - } wpa_hexdump_key(MSG_DEBUG, "DPP: PRK = HKDF-Extract(<>, IKM=N.x)", prk, hash_len); /* HKDF-Expand(PRK, info, L) */ - if (hash_len == 32) - res = hmac_sha256_kdf(prk, SHA256_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - pmk, SHA256_MAC_LEN); - else if (hash_len == 48) - res = hmac_sha384_kdf(prk, SHA384_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - pmk, SHA384_MAC_LEN); - else if (hash_len == 64) - res = hmac_sha512_kdf(prk, SHA512_MAC_LEN, NULL, - (const u8 *) info, os_strlen(info), - pmk, SHA512_MAC_LEN); + res = dpp_hkdf_expand(hash_len, prk, hash_len, info, pmk, hash_len); os_memset(prk, 0, hash_len); if (res < 0) return -1; @@ -4410,14 +4379,7 @@ static int dpp_derive_pmkid(const struct dpp_curve_params *curve, (unsigned int) curve->hash_len * 8); wpa_hexdump(MSG_DEBUG, "DPP: PMKID hash payload 1", addr[0], len[0]); wpa_hexdump(MSG_DEBUG, "DPP: PMKID hash payload 2", addr[1], len[1]); - if (curve->hash_len == 32) - res = sha256_vector(2, addr, len, hash); - else if (curve->hash_len == 48) - res = sha384_vector(2, addr, len, hash); - else if (curve->hash_len == 64) - res = sha512_vector(2, addr, len, hash); - else - res = -1; + res = dpp_hash_vector(curve, 2, addr, len, hash); if (res < 0) goto fail; wpa_hexdump(MSG_DEBUG, "DPP: PMKID hash output",