浏览代码

Add SecurityContext to delegating TaskScheduler

Wrap DelegatingSecurityContextTaskScheduler's Runnable tasks in
DelegatingSecurityContextRunnables, allowing to specify a
SecurityContext to use for tasks execution.

- Renamed private variable taskScheduler to delegate
- Removed unused local variable in unit test
- Add SecurityContext tests for delegating TaskScheduler

Closes gh-9514
Giacomo Baso 4 年之前
父节点
当前提交
80743a267c

+ 39 - 13
core/src/main/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskScheduler.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -19,8 +19,12 @@ package org.springframework.security.scheduling;
 import java.util.Date;
 import java.util.Date;
 import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.ScheduledFuture;
 
 
+import org.springframework.core.task.TaskExecutor;
 import org.springframework.scheduling.TaskScheduler;
 import org.springframework.scheduling.TaskScheduler;
 import org.springframework.scheduling.Trigger;
 import org.springframework.scheduling.Trigger;
+import org.springframework.security.concurrent.DelegatingSecurityContextRunnable;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 
 
 /**
 /**
@@ -32,45 +36,67 @@ import org.springframework.util.Assert;
  */
  */
 public class DelegatingSecurityContextTaskScheduler implements TaskScheduler {
 public class DelegatingSecurityContextTaskScheduler implements TaskScheduler {
 
 
-	private final TaskScheduler taskScheduler;
+	private final TaskScheduler delegate;
+
+	private final SecurityContext securityContext;
 
 
 	/**
 	/**
-	 * Creates a new {@link DelegatingSecurityContextTaskScheduler}
-	 * @param taskScheduler the {@link TaskScheduler}
+	 * Creates a new {@link DelegatingSecurityContextTaskScheduler} that uses the
+	 * specified {@link SecurityContext}.
+	 * @param delegateTaskScheduler the {@link TaskScheduler} to delegate to. Cannot be
+	 * null.
+	 * @param securityContext the {@link SecurityContext} to use for each
+	 * {@link DelegatingSecurityContextRunnable} or null to default to the current
+	 * {@link SecurityContext}
 	 */
 	 */
-	public DelegatingSecurityContextTaskScheduler(TaskScheduler taskScheduler) {
-		Assert.notNull(taskScheduler, "Task scheduler must not be null");
-		this.taskScheduler = taskScheduler;
+	public DelegatingSecurityContextTaskScheduler(TaskScheduler delegateTaskScheduler,
+			SecurityContext securityContext) {
+		Assert.notNull(delegateTaskScheduler, "delegateTaskScheduler cannot be null");
+		this.delegate = delegateTaskScheduler;
+		this.securityContext = securityContext;
+	}
+
+	/**
+	 * Creates a new {@link DelegatingSecurityContextTaskScheduler} that uses the current
+	 * {@link SecurityContext} from the {@link SecurityContextHolder}.
+	 * @param delegate the {@link TaskExecutor} to delegate to. Cannot be null.
+	 */
+	public DelegatingSecurityContextTaskScheduler(TaskScheduler delegate) {
+		this(delegate, null);
 	}
 	}
 
 
 	@Override
 	@Override
 	public ScheduledFuture<?> schedule(Runnable task, Trigger trigger) {
 	public ScheduledFuture<?> schedule(Runnable task, Trigger trigger) {
-		return this.taskScheduler.schedule(task, trigger);
+		return this.delegate.schedule(wrap(task), trigger);
 	}
 	}
 
 
 	@Override
 	@Override
 	public ScheduledFuture<?> schedule(Runnable task, Date startTime) {
 	public ScheduledFuture<?> schedule(Runnable task, Date startTime) {
-		return this.taskScheduler.schedule(task, startTime);
+		return this.delegate.schedule(wrap(task), startTime);
 	}
 	}
 
 
 	@Override
 	@Override
 	public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, Date startTime, long period) {
 	public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, Date startTime, long period) {
-		return this.taskScheduler.scheduleAtFixedRate(task, startTime, period);
+		return this.delegate.scheduleAtFixedRate(wrap(task), startTime, period);
 	}
 	}
 
 
 	@Override
 	@Override
 	public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, long period) {
 	public ScheduledFuture<?> scheduleAtFixedRate(Runnable task, long period) {
-		return this.taskScheduler.scheduleAtFixedRate(task, period);
+		return this.delegate.scheduleAtFixedRate(wrap(task), period);
 	}
 	}
 
 
 	@Override
 	@Override
 	public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, Date startTime, long delay) {
 	public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, Date startTime, long delay) {
-		return this.taskScheduler.scheduleWithFixedDelay(task, startTime, delay);
+		return this.delegate.scheduleWithFixedDelay(wrap(task), startTime, delay);
 	}
 	}
 
 
 	@Override
 	@Override
 	public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, long delay) {
 	public ScheduledFuture<?> scheduleWithFixedDelay(Runnable task, long delay) {
-		return this.taskScheduler.scheduleWithFixedDelay(task, delay);
+		return this.delegate.scheduleWithFixedDelay(wrap(task), delay);
+	}
+
+	private Runnable wrap(Runnable delegate) {
+		return DelegatingSecurityContextRunnable.create(delegate, this.securityContext);
 	}
 	}
 
 
 }
 }

+ 46 - 4
core/src/test/java/org/springframework/security/scheduling/DelegatingSecurityContextTaskSchedulerTests.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -16,9 +16,9 @@
 
 
 package org.springframework.security.scheduling;
 package org.springframework.security.scheduling;
 
 
-import java.time.Duration;
 import java.time.Instant;
 import java.time.Instant;
 import java.util.Date;
 import java.util.Date;
+import java.util.concurrent.ScheduledFuture;
 
 
 import org.junit.After;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Before;
@@ -28,11 +28,16 @@ import org.mockito.MockitoAnnotations;
 
 
 import org.springframework.scheduling.TaskScheduler;
 import org.springframework.scheduling.TaskScheduler;
 import org.springframework.scheduling.Trigger;
 import org.springframework.scheduling.Trigger;
+import org.springframework.scheduling.concurrent.ConcurrentTaskScheduler;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
 
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.isA;
 import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.BDDMockito.willAnswer;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
 
 
 /**
 /**
@@ -47,22 +52,30 @@ public class DelegatingSecurityContextTaskSchedulerTests {
 	@Mock
 	@Mock
 	private TaskScheduler scheduler;
 	private TaskScheduler scheduler;
 
 
+	@Mock
+	private SecurityContext securityContext;
+
 	@Mock
 	@Mock
 	private Runnable runnable;
 	private Runnable runnable;
 
 
 	@Mock
 	@Mock
 	private Trigger trigger;
 	private Trigger trigger;
 
 
+	private SecurityContext originalSecurityContext;
+
 	private DelegatingSecurityContextTaskScheduler delegatingSecurityContextTaskScheduler;
 	private DelegatingSecurityContextTaskScheduler delegatingSecurityContextTaskScheduler;
 
 
 	@Before
 	@Before
 	public void setup() {
 	public void setup() {
 		MockitoAnnotations.initMocks(this);
 		MockitoAnnotations.initMocks(this);
-		this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(this.scheduler);
+		this.originalSecurityContext = SecurityContextHolder.createEmptyContext();
+		this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(this.scheduler,
+				this.securityContext);
 	}
 	}
 
 
 	@After
 	@After
 	public void cleanup() {
 	public void cleanup() {
+		SecurityContextHolder.clearContext();
 		this.delegatingSecurityContextTaskScheduler = null;
 		this.delegatingSecurityContextTaskScheduler = null;
 	}
 	}
 
 
@@ -71,6 +84,36 @@ public class DelegatingSecurityContextTaskSchedulerTests {
 		assertThatIllegalArgumentException().isThrownBy(() -> new DelegatingSecurityContextTaskScheduler(null));
 		assertThatIllegalArgumentException().isThrownBy(() -> new DelegatingSecurityContextTaskScheduler(null));
 	}
 	}
 
 
+	@Test
+	public void testSchedulerCurrentSecurityContext() throws Exception {
+		willAnswer((invocation) -> {
+			assertThat(SecurityContextHolder.getContext()).isEqualTo(this.originalSecurityContext);
+			return null;
+		}).given(this.runnable).run();
+		TaskScheduler delegateTaskScheduler = new ConcurrentTaskScheduler();
+		this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(delegateTaskScheduler);
+		assertWrapped(this.runnable);
+	}
+
+	@Test
+	public void testSchedulerExplicitSecurityContext() throws Exception {
+		willAnswer((invocation) -> {
+			assertThat(SecurityContextHolder.getContext()).isEqualTo(this.securityContext);
+			return null;
+		}).given(this.runnable).run();
+		TaskScheduler delegateTaskScheduler = new ConcurrentTaskScheduler();
+		this.delegatingSecurityContextTaskScheduler = new DelegatingSecurityContextTaskScheduler(delegateTaskScheduler,
+				this.securityContext);
+		assertWrapped(this.runnable);
+	}
+
+	private void assertWrapped(Runnable runnable) throws Exception {
+		ScheduledFuture<?> schedule = this.delegatingSecurityContextTaskScheduler.schedule(runnable, new Date());
+		schedule.get();
+		verify(this.runnable).run();
+		assertThat(SecurityContextHolder.getContext()).isEqualTo(this.originalSecurityContext);
+	}
+
 	@Test
 	@Test
 	public void testSchedulerWithRunnableAndTrigger() {
 	public void testSchedulerWithRunnableAndTrigger() {
 		this.delegatingSecurityContextTaskScheduler.schedule(this.runnable, this.trigger);
 		this.delegatingSecurityContextTaskScheduler.schedule(this.runnable, this.trigger);
@@ -87,7 +130,6 @@ public class DelegatingSecurityContextTaskSchedulerTests {
 	@Test
 	@Test
 	public void testScheduleAtFixedRateWithRunnableAndDate() {
 	public void testScheduleAtFixedRateWithRunnableAndDate() {
 		Date date = new Date(1544751374L);
 		Date date = new Date(1544751374L);
-		Duration duration = Duration.ofSeconds(4L);
 		this.delegatingSecurityContextTaskScheduler.scheduleAtFixedRate(this.runnable, date, 1000L);
 		this.delegatingSecurityContextTaskScheduler.scheduleAtFixedRate(this.runnable, date, 1000L);
 		verify(this.scheduler).scheduleAtFixedRate(isA(Runnable.class), isA(Date.class), eq(1000L));
 		verify(this.scheduler).scheduleAtFixedRate(isA(Runnable.class), isA(Date.class), eq(1000L));
 	}
 	}