Commit 88b024f5 authored by Shyam Prasad N's avatar Shyam Prasad N Committed by Steve French
Browse files

cifs: protect all accesses to chan_* with chan_lock



A spin lock called chan_lock was introduced recently.
But not all accesses were protected. Doing that with
this change.

To make sure that a channel is not freed when in use,
we need to introduce a ref count. But today, we don't
ever free channels.

Signed-off-by: default avatarShyam Prasad N <sprasad@microsoft.com>
Signed-off-by: default avatarSteve French <stfrench@microsoft.com>
parent a05885ce
Loading
Loading
Loading
Loading
+3 −1
Original line number Original line Diff line number Diff line
@@ -1831,7 +1831,6 @@ void cifs_put_smb_ses(struct cifs_ses *ses)


	spin_lock(&ses->chan_lock);
	spin_lock(&ses->chan_lock);
	chan_count = ses->chan_count;
	chan_count = ses->chan_count;
	spin_unlock(&ses->chan_lock);


	/* close any extra channels */
	/* close any extra channels */
	if (chan_count > 1) {
	if (chan_count > 1) {
@@ -1848,6 +1847,7 @@ void cifs_put_smb_ses(struct cifs_ses *ses)
			ses->chans[i].server = NULL;
			ses->chans[i].server = NULL;
		}
		}
	}
	}
	spin_unlock(&ses->chan_lock);


	sesInfoFree(ses);
	sesInfoFree(ses);
	cifs_put_tcp_session(server, 0);
	cifs_put_tcp_session(server, 0);
@@ -2123,8 +2123,10 @@ cifs_get_smb_ses(struct TCP_Server_Info *server, struct smb3_fs_context *ctx)
	mutex_unlock(&ses->session_mutex);
	mutex_unlock(&ses->session_mutex);


	/* each channel uses a different signing key */
	/* each channel uses a different signing key */
	spin_lock(&ses->chan_lock);
	memcpy(ses->chans[0].signkey, ses->smb3signingkey,
	memcpy(ses->chans[0].signkey, ses->smb3signingkey,
	       sizeof(ses->smb3signingkey));
	       sizeof(ses->smb3signingkey));
	spin_unlock(&ses->chan_lock);


	if (rc)
	if (rc)
		goto get_ses_fail;
		goto get_ses_fail;
+7 −3
Original line number Original line Diff line number Diff line
@@ -65,6 +65,8 @@ bool is_ses_using_iface(struct cifs_ses *ses, struct cifs_server_iface *iface)
	return false;
	return false;
}
}


/* channel helper functions. assumed that chan_lock is held by caller. */

unsigned int
unsigned int
cifs_ses_get_chan_index(struct cifs_ses *ses,
cifs_ses_get_chan_index(struct cifs_ses *ses,
			struct TCP_Server_Info *server)
			struct TCP_Server_Info *server)
@@ -134,10 +136,10 @@ int cifs_try_adding_channels(struct cifs_sb_info *cifs_sb, struct cifs_ses *ses)
	left = ses->chan_max - ses->chan_count;
	left = ses->chan_max - ses->chan_count;


	if (left <= 0) {
	if (left <= 0) {
		spin_unlock(&ses->chan_lock);
		cifs_dbg(FYI,
		cifs_dbg(FYI,
			 "ses already at max_channels (%zu), nothing to open\n",
			 "ses already at max_channels (%zu), nothing to open\n",
			 ses->chan_max);
			 ses->chan_max);
		spin_unlock(&ses->chan_lock);
		return 0;
		return 0;
	}
	}


@@ -369,12 +371,14 @@ void cifs_ses_mark_for_reconnect(struct cifs_ses *ses)
{
{
	int i;
	int i;


	for (i = 0; i < ses->chan_count; i++) {
	spin_lock(&cifs_tcp_ses_lock);
	spin_lock(&cifs_tcp_ses_lock);
	spin_lock(&ses->chan_lock);
	for (i = 0; i < ses->chan_count; i++) {
		if (ses->chans[i].server->tcpStatus != CifsExiting)
		if (ses->chans[i].server->tcpStatus != CifsExiting)
			ses->chans[i].server->tcpStatus = CifsNeedReconnect;
			ses->chans[i].server->tcpStatus = CifsNeedReconnect;
		spin_unlock(&cifs_tcp_ses_lock);
	}
	}
	spin_unlock(&ses->chan_lock);
	spin_unlock(&cifs_tcp_ses_lock);
}
}


static __u32 cifs_ssetup_hdr(struct cifs_ses *ses,
static __u32 cifs_ssetup_hdr(struct cifs_ses *ses,
+3 −1
Original line number Original line Diff line number Diff line
@@ -244,10 +244,10 @@ smb2_reconnect(__le16 smb2_command, struct cifs_tcon *tcon,
		spin_unlock(&ses->chan_lock);
		spin_unlock(&ses->chan_lock);
		return 0;
		return 0;
	}
	}
	spin_unlock(&ses->chan_lock);
	cifs_dbg(FYI, "sess reconnect mask: 0x%lx, tcon reconnect: %d",
	cifs_dbg(FYI, "sess reconnect mask: 0x%lx, tcon reconnect: %d",
		 tcon->ses->chans_need_reconnect,
		 tcon->ses->chans_need_reconnect,
		 tcon->need_reconnect);
		 tcon->need_reconnect);
	spin_unlock(&ses->chan_lock);


	nls_codepage = load_nls_default();
	nls_codepage = load_nls_default();


@@ -3835,11 +3835,13 @@ void smb2_reconnect_server(struct work_struct *work)
		 * binding session, but tcon is healthy (some other channel
		 * binding session, but tcon is healthy (some other channel
		 * is active)
		 * is active)
		 */
		 */
		spin_lock(&ses->chan_lock);
		if (!tcon_selected && cifs_chan_needs_reconnect(ses, server)) {
		if (!tcon_selected && cifs_chan_needs_reconnect(ses, server)) {
			list_add_tail(&ses->rlist, &tmp_ses_list);
			list_add_tail(&ses->rlist, &tmp_ses_list);
			ses_selected = ses_exist = true;
			ses_selected = ses_exist = true;
			ses->ses_count++;
			ses->ses_count++;
		}
		}
		spin_unlock(&ses->chan_lock);
	}
	}
	/*
	/*
	 * Get the reference to server struct to be sure that the last call of
	 * Get the reference to server struct to be sure that the last call of
+6 −0
Original line number Original line Diff line number Diff line
@@ -100,6 +100,7 @@ int smb2_get_sign_key(__u64 ses_id, struct TCP_Server_Info *server, u8 *key)
	goto out;
	goto out;


found:
found:
	spin_lock(&ses->chan_lock);
	if (cifs_chan_needs_reconnect(ses, server) &&
	if (cifs_chan_needs_reconnect(ses, server) &&
	    !CIFS_ALL_CHANS_NEED_RECONNECT(ses)) {
	    !CIFS_ALL_CHANS_NEED_RECONNECT(ses)) {
		/*
		/*
@@ -108,6 +109,7 @@ int smb2_get_sign_key(__u64 ses_id, struct TCP_Server_Info *server, u8 *key)
		 * session key
		 * session key
		 */
		 */
		memcpy(key, ses->smb3signingkey, SMB3_SIGN_KEY_SIZE);
		memcpy(key, ses->smb3signingkey, SMB3_SIGN_KEY_SIZE);
		spin_unlock(&ses->chan_lock);
		goto out;
		goto out;
	}
	}


@@ -119,9 +121,11 @@ int smb2_get_sign_key(__u64 ses_id, struct TCP_Server_Info *server, u8 *key)
		chan = ses->chans + i;
		chan = ses->chans + i;
		if (chan->server == server) {
		if (chan->server == server) {
			memcpy(key, chan->signkey, SMB3_SIGN_KEY_SIZE);
			memcpy(key, chan->signkey, SMB3_SIGN_KEY_SIZE);
			spin_unlock(&ses->chan_lock);
			goto out;
			goto out;
		}
		}
	}
	}
	spin_unlock(&ses->chan_lock);


	cifs_dbg(VFS,
	cifs_dbg(VFS,
		 "%s: Could not find channel signing key for session 0x%llx\n",
		 "%s: Could not find channel signing key for session 0x%llx\n",
@@ -430,8 +434,10 @@ generate_smb3signingkey(struct cifs_ses *ses,
			return rc;
			return rc;


		/* safe to access primary channel, since it will never go away */
		/* safe to access primary channel, since it will never go away */
		spin_lock(&ses->chan_lock);
		memcpy(ses->chans[0].signkey, ses->smb3signingkey,
		memcpy(ses->chans[0].signkey, ses->smb3signingkey,
		       SMB3_SIGN_KEY_SIZE);
		       SMB3_SIGN_KEY_SIZE);
		spin_unlock(&ses->chan_lock);


		rc = generate_key(ses, ptriplet->encryption.label,
		rc = generate_key(ses, ptriplet->encryption.label,
				  ptriplet->encryption.context,
				  ptriplet->encryption.context,
+3 −0
Original line number Original line Diff line number Diff line
@@ -1049,7 +1049,10 @@ struct TCP_Server_Info *cifs_pick_channel(struct cifs_ses *ses)


	/* round robin */
	/* round robin */
	index = (uint)atomic_inc_return(&ses->chan_seq);
	index = (uint)atomic_inc_return(&ses->chan_seq);

	spin_lock(&ses->chan_lock);
	index %= ses->chan_count;
	index %= ses->chan_count;
	spin_unlock(&ses->chan_lock);


	return ses->chans[index].server;
	return ses->chans[index].server;
}
}