summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/mempolicy.h4
-rw-r--r--kernel/exit.c7
-rw-r--r--mm/mempolicy.c17
3 files changed, 22 insertions, 6 deletions
diff --git a/include/linux/mempolicy.h b/include/linux/mempolicy.h
index 4429d255c8ab..5e5b2969d931 100644
--- a/include/linux/mempolicy.h
+++ b/include/linux/mempolicy.h
@@ -195,6 +195,7 @@ static inline bool vma_migratable(struct vm_area_struct *vma)
}
extern int mpol_misplaced(struct page *, struct vm_area_struct *, unsigned long);
+extern void mpol_put_task_policy(struct task_struct *);
#else
@@ -297,5 +298,8 @@ static inline int mpol_misplaced(struct page *page, struct vm_area_struct *vma,
return -1; /* no node preference */
}
+static inline void mpol_put_task_policy(struct task_struct *task)
+{
+}
#endif /* CONFIG_NUMA */
#endif
diff --git a/kernel/exit.c b/kernel/exit.c
index 2f974ae042a6..091a78be3b09 100644
--- a/kernel/exit.c
+++ b/kernel/exit.c
@@ -848,12 +848,7 @@ void do_exit(long code)
TASKS_RCU(preempt_enable());
exit_notify(tsk, group_dead);
proc_exit_connector(tsk);
-#ifdef CONFIG_NUMA
- task_lock(tsk);
- mpol_put(tsk->mempolicy);
- tsk->mempolicy = NULL;
- task_unlock(tsk);
-#endif
+ mpol_put_task_policy(tsk);
#ifdef CONFIG_FUTEX
if (unlikely(current->pi_state_cache))
kfree(current->pi_state_cache);
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index d8c4e38fb5f4..2da72a5b6ecc 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -2336,6 +2336,23 @@ out:
return ret;
}
+/*
+ * Drop the (possibly final) reference to task->mempolicy. It needs to be
+ * dropped after task->mempolicy is set to NULL so that any allocation done as
+ * part of its kmem_cache_free(), such as by KASAN, doesn't reference a freed
+ * policy.
+ */
+void mpol_put_task_policy(struct task_struct *task)
+{
+ struct mempolicy *pol;
+
+ task_lock(task);
+ pol = task->mempolicy;
+ task->mempolicy = NULL;
+ task_unlock(task);
+ mpol_put(pol);
+}
+
static void sp_delete(struct shared_policy *sp, struct sp_node *n)
{
pr_debug("deleting %lx-l%lx\n", n->start, n->end);