diff --git a/src/ap/ap_config.c b/src/ap/ap_config.c index d869ebde5..caf75c4d5 100644 --- a/src/ap/ap_config.c +++ b/src/ap/ap_config.c @@ -1,6 +1,6 @@ /* * hostapd / Configuration helper functions - * Copyright (c) 2003-2012, Jouni Malinen + * Copyright (c) 2003-2013, Jouni Malinen * * This software may be distributed under the terms of the BSD license. * See README for more details. @@ -625,14 +625,31 @@ const char * hostapd_get_vlan_id_ifname(struct hostapd_vlan *vlan, int vlan_id) const u8 * hostapd_get_psk(const struct hostapd_bss_config *conf, - const u8 *addr, const u8 *prev_psk) + const u8 *addr, const u8 *p2p_dev_addr, + const u8 *prev_psk) { struct hostapd_wpa_psk *psk; int next_ok = prev_psk == NULL; + if (p2p_dev_addr) { + wpa_printf(MSG_DEBUG, "Searching a PSK for " MACSTR + " p2p_dev_addr=" MACSTR " prev_psk=%p", + MAC2STR(addr), MAC2STR(p2p_dev_addr), prev_psk); + if (!is_zero_ether_addr(p2p_dev_addr)) + addr = NULL; /* Use P2P Device Address for matching */ + } else { + wpa_printf(MSG_DEBUG, "Searching a PSK for " MACSTR + " prev_psk=%p", + MAC2STR(addr), prev_psk); + } + for (psk = conf->ssid.wpa_psk; psk != NULL; psk = psk->next) { if (next_ok && - (psk->group || os_memcmp(psk->addr, addr, ETH_ALEN) == 0)) + (psk->group || + (addr && os_memcmp(psk->addr, addr, ETH_ALEN) == 0) || + (!addr && p2p_dev_addr && + os_memcmp(psk->p2p_dev_addr, p2p_dev_addr, ETH_ALEN) == + 0))) return psk->psk; if (psk->psk == prev_psk) diff --git a/src/ap/ap_config.h b/src/ap/ap_config.h index 1fd41afd6..c5531faa7 100644 --- a/src/ap/ap_config.h +++ b/src/ap/ap_config.h @@ -552,7 +552,8 @@ int hostapd_rate_found(int *list, int rate); int hostapd_wep_key_cmp(struct hostapd_wep_keys *a, struct hostapd_wep_keys *b); const u8 * hostapd_get_psk(const struct hostapd_bss_config *conf, - const u8 *addr, const u8 *prev_psk); + const u8 *addr, const u8 *p2p_dev_addr, + const u8 *prev_psk); int hostapd_setup_wpa_psk(struct hostapd_bss_config *conf); int hostapd_vlan_id_valid(struct hostapd_vlan *vlan, int vlan_id); const char * hostapd_get_vlan_id_ifname(struct hostapd_vlan *vlan, diff --git a/src/ap/wpa_auth.c b/src/ap/wpa_auth.c index d4d2ee779..0286c5b8c 100644 --- a/src/ap/wpa_auth.c +++ b/src/ap/wpa_auth.c @@ -82,11 +82,14 @@ static inline int wpa_auth_get_eapol(struct wpa_authenticator *wpa_auth, static inline const u8 * wpa_auth_get_psk(struct wpa_authenticator *wpa_auth, - const u8 *addr, const u8 *prev_psk) + const u8 *addr, + const u8 *p2p_dev_addr, + const u8 *prev_psk) { if (wpa_auth->cb.get_psk == NULL) return NULL; - return wpa_auth->cb.get_psk(wpa_auth->cb.ctx, addr, prev_psk); + return wpa_auth->cb.get_psk(wpa_auth->cb.ctx, addr, p2p_dev_addr, + prev_psk); } @@ -1681,7 +1684,7 @@ SM_STATE(WPA_PTK, INITPSK) { const u8 *psk; SM_ENTRY_MA(WPA_PTK, INITPSK, wpa_ptk); - psk = wpa_auth_get_psk(sm->wpa_auth, sm->addr, NULL); + psk = wpa_auth_get_psk(sm->wpa_auth, sm->addr, sm->p2p_dev_addr, NULL); if (psk) { os_memcpy(sm->PMK, psk, PMK_LEN); #ifdef CONFIG_IEEE80211R @@ -1774,7 +1777,8 @@ SM_STATE(WPA_PTK, PTKCALCNEGOTIATING) * the packet */ for (;;) { if (wpa_key_mgmt_wpa_psk(sm->wpa_key_mgmt)) { - pmk = wpa_auth_get_psk(sm->wpa_auth, sm->addr, pmk); + pmk = wpa_auth_get_psk(sm->wpa_auth, sm->addr, + sm->p2p_dev_addr, pmk); if (pmk == NULL) break; } else @@ -2161,7 +2165,8 @@ SM_STEP(WPA_PTK) } break; case WPA_PTK_INITPSK: - if (wpa_auth_get_psk(sm->wpa_auth, sm->addr, NULL)) + if (wpa_auth_get_psk(sm->wpa_auth, sm->addr, sm->p2p_dev_addr, + NULL)) SM_ENTER(WPA_PTK, PTKSTART); else { wpa_auth_logger(sm->wpa_auth, sm->addr, LOGGER_INFO, diff --git a/src/ap/wpa_auth.h b/src/ap/wpa_auth.h index 358fbf250..47503d00c 100644 --- a/src/ap/wpa_auth.h +++ b/src/ap/wpa_auth.h @@ -184,7 +184,8 @@ struct wpa_auth_callbacks { void (*set_eapol)(void *ctx, const u8 *addr, wpa_eapol_variable var, int value); int (*get_eapol)(void *ctx, const u8 *addr, wpa_eapol_variable var); - const u8 * (*get_psk)(void *ctx, const u8 *addr, const u8 *prev_psk); + const u8 * (*get_psk)(void *ctx, const u8 *addr, const u8 *p2p_dev_addr, + const u8 *prev_psk); int (*get_msk)(void *ctx, const u8 *addr, u8 *msk, size_t *len); int (*set_key)(void *ctx, int vlan_id, enum wpa_alg alg, const u8 *addr, int idx, u8 *key, size_t key_len); diff --git a/src/ap/wpa_auth_glue.c b/src/ap/wpa_auth_glue.c index ea5e74fc6..d977b42fe 100644 --- a/src/ap/wpa_auth_glue.c +++ b/src/ap/wpa_auth_glue.c @@ -186,6 +186,7 @@ static int hostapd_wpa_auth_get_eapol(void *ctx, const u8 *addr, static const u8 * hostapd_wpa_auth_get_psk(void *ctx, const u8 *addr, + const u8 *p2p_dev_addr, const u8 *prev_psk) { struct hostapd_data *hapd = ctx; @@ -200,7 +201,7 @@ static const u8 * hostapd_wpa_auth_get_psk(void *ctx, const u8 *addr, } #endif /* CONFIG_SAE */ - psk = hostapd_get_psk(hapd->conf, addr, prev_psk); + psk = hostapd_get_psk(hapd->conf, addr, p2p_dev_addr, prev_psk); /* * This is about to iterate over all psks, prev_psk gives the last * returned psk which should not be returned again. diff --git a/wpa_supplicant/ibss_rsn.c b/wpa_supplicant/ibss_rsn.c index e6a29684f..47ef35ec6 100644 --- a/wpa_supplicant/ibss_rsn.c +++ b/wpa_supplicant/ibss_rsn.c @@ -257,7 +257,8 @@ static void auth_logger(void *ctx, const u8 *addr, logger_level level, } -static const u8 * auth_get_psk(void *ctx, const u8 *addr, const u8 *prev_psk) +static const u8 * auth_get_psk(void *ctx, const u8 *addr, + const u8 *p2p_dev_addr, const u8 *prev_psk) { struct ibss_rsn *ibss_rsn = ctx; wpa_printf(MSG_DEBUG, "AUTH: %s (addr=" MACSTR " prev_psk=%p)",