浏览代码

Migrate Kotlin tests from java Mockito to Mockk

Closes gh-9785
theexiile1305 4 年之前
父节点
当前提交
3074ad4136
共有 27 个文件被更改,包括 819 次插入509 次删除
  1. 30 23
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt
  2. 30 24
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerFormLoginDslTests.kt
  3. 28 17
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt
  4. 33 22
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerJwtDslTests.kt
  5. 16 12
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerLogoutDslTests.kt
  6. 51 31
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt
  7. 27 10
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt
  8. 23 14
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt
  9. 13 9
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerRequestCacheDslTests.kt
  10. 10 10
      config/src/test/kotlin/org/springframework/security/config/web/server/ServerX509DslTests.kt
  11. 17 13
      config/src/test/kotlin/org/springframework/security/config/web/servlet/CsrfDslTests.kt
  12. 21 17
      config/src/test/kotlin/org/springframework/security/config/web/servlet/HttpBasicDslTests.kt
  13. 9 6
      config/src/test/kotlin/org/springframework/security/config/web/servlet/LogoutDslTests.kt
  14. 33 16
      config/src/test/kotlin/org/springframework/security/config/web/servlet/OAuth2ClientDslTests.kt
  15. 62 30
      config/src/test/kotlin/org/springframework/security/config/web/servlet/OAuth2ResourceServerDslTests.kt
  16. 93 46
      config/src/test/kotlin/org/springframework/security/config/web/servlet/RememberMeDslTests.kt
  17. 12 5
      config/src/test/kotlin/org/springframework/security/config/web/servlet/RequiresChannelDslTests.kt
  18. 52 41
      config/src/test/kotlin/org/springframework/security/config/web/servlet/SessionManagementDslTests.kt
  19. 7 11
      config/src/test/kotlin/org/springframework/security/config/web/servlet/X509DslTests.kt
  20. 48 23
      config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/client/AuthorizationCodeGrantDslTests.kt
  21. 31 11
      config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/login/AuthorizationEndpointDslTests.kt
  22. 29 18
      config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/login/RedirectionEndpointDslTests.kt
  23. 24 15
      config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/login/TokenEndpointDslTests.kt
  24. 29 22
      config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/login/UserInfoEndpointDslTests.kt
  25. 34 23
      config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDslTests.kt
  26. 39 24
      config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/OpaqueTokenDslTests.kt
  27. 18 16
      config/src/test/kotlin/org/springframework/security/config/web/servlet/session/SessionConcurrencyDslTests.kt

+ 30 - 23
config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,12 +16,11 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito
-import org.mockito.Mockito.`when`
-import org.mockito.Mockito.mock
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
@@ -33,6 +32,7 @@ import org.springframework.security.web.server.authorization.ServerAccessDeniedH
 import org.springframework.security.web.server.csrf.CsrfToken
 import org.springframework.security.web.server.csrf.DefaultCsrfToken
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
+import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository
 import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.web.bind.annotation.PostMapping
@@ -161,20 +161,20 @@ class ServerCsrfDslTests {
     @Test
     fun `csrf when custom access denied handler then handler used`() {
         this.spring.register(CustomAccessDeniedHandlerConfig::class.java).autowire()
+        mockkObject(CustomAccessDeniedHandlerConfig.ACCESS_DENIED_HANDLER)
 
         this.client.post()
                 .uri("/")
                 .exchange()
 
-        Mockito.verify(CustomAccessDeniedHandlerConfig.ACCESS_DENIED_HANDLER)
-                .handle(any(), any())
+        verify(exactly = 1) { CustomAccessDeniedHandlerConfig.ACCESS_DENIED_HANDLER.handle(any(), any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomAccessDeniedHandlerConfig {
         companion object {
-            var ACCESS_DENIED_HANDLER: ServerAccessDeniedHandler = mock(ServerAccessDeniedHandler::class.java)
+            val ACCESS_DENIED_HANDLER: ServerAccessDeniedHandler = ServerAccessDeniedHandler { _, _ -> Mono.empty() }
         }
 
         @Bean
@@ -189,23 +189,24 @@ class ServerCsrfDslTests {
 
     @Test
     fun `csrf when custom token repository then repository used`() {
-        `when`(CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY.loadToken(any()))
-                .thenReturn(Mono.just(this.token))
         this.spring.register(CustomCsrfTokenRepositoryConfig::class.java).autowire()
+        mockkObject(CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY)
+        every {
+            CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY.loadToken(any())
+        } returns Mono.just(this.token)
 
         this.client.post()
                 .uri("/")
                 .exchange()
 
-        Mockito.verify(CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY)
-                .loadToken(any())
+        verify(exactly = 1) { CustomCsrfTokenRepositoryConfig.TOKEN_REPOSITORY.loadToken(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomCsrfTokenRepositoryConfig {
         companion object {
-            var TOKEN_REPOSITORY: ServerCsrfTokenRepository = mock(ServerCsrfTokenRepository::class.java)
+            val TOKEN_REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository()
         }
 
         @Bean
@@ -220,11 +221,14 @@ class ServerCsrfDslTests {
 
     @Test
     fun `csrf when multipart form data and not enabled then denied`() {
-        `when`(MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.loadToken(any()))
-                .thenReturn(Mono.just(this.token))
-        `when`(MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.generateToken(any()))
-                .thenReturn(Mono.just(this.token))
         this.spring.register(MultipartFormDataNotEnabledConfig::class.java).autowire()
+        mockkObject(MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY)
+        every {
+            MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.loadToken(any())
+        } returns Mono.just(this.token)
+        every {
+            MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.generateToken(any())
+        } returns Mono.just(this.token)
 
         this.client.post()
                 .uri("/")
@@ -238,7 +242,7 @@ class ServerCsrfDslTests {
     @EnableWebFlux
     open class MultipartFormDataNotEnabledConfig {
         companion object {
-            var TOKEN_REPOSITORY: ServerCsrfTokenRepository = mock(ServerCsrfTokenRepository::class.java)
+            val TOKEN_REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository()
         }
 
         @Bean
@@ -253,11 +257,14 @@ class ServerCsrfDslTests {
 
     @Test
     fun `csrf when multipart form data and enabled then granted`() {
-        `when`(MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.loadToken(any()))
-                .thenReturn(Mono.just(this.token))
-        `when`(MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.generateToken(any()))
-                .thenReturn(Mono.just(this.token))
         this.spring.register(MultipartFormDataEnabledConfig::class.java).autowire()
+        mockkObject(MultipartFormDataEnabledConfig.TOKEN_REPOSITORY)
+        every {
+            MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.loadToken(any())
+        } returns Mono.just(this.token)
+        every {
+            MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.generateToken(any())
+        } returns Mono.just(this.token)
 
         this.client.post()
                 .uri("/")
@@ -271,7 +278,7 @@ class ServerCsrfDslTests {
     @EnableWebFlux
     open class MultipartFormDataEnabledConfig {
         companion object {
-            var TOKEN_REPOSITORY: ServerCsrfTokenRepository = mock(ServerCsrfTokenRepository::class.java)
+            val TOKEN_REPOSITORY: ServerCsrfTokenRepository = WebSessionServerCsrfTokenRepository()
         }
 
         @Bean

+ 30 - 24
config/src/test/kotlin/org/springframework/security/config/web/server/ServerFormLoginDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,12 +16,11 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.assertj.core.api.Assertions.assertThat
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito
-import org.mockito.Mockito.verify
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
@@ -29,22 +28,23 @@ import org.springframework.context.annotation.Configuration
 import org.springframework.http.HttpMethod
 import org.springframework.security.authentication.ReactiveAuthenticationManager
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
+import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.core.userdetails.MapReactiveUserDetailsService
 import org.springframework.security.core.userdetails.User
-import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf
 import org.springframework.security.web.server.SecurityWebFilterChain
 import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint
 import org.springframework.security.web.server.authentication.RedirectServerAuthenticationFailureHandler
 import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler
 import org.springframework.security.web.server.context.ServerSecurityContextRepository
+import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers
 import org.springframework.test.web.reactive.server.FluxExchangeResult
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.util.LinkedMultiValueMap
-import org.springframework.util.MultiValueMap
 import org.springframework.web.reactive.config.EnableWebFlux
 import org.springframework.web.reactive.function.BodyInserters
+import reactor.core.publisher.Mono
 
 /**
  * Tests for [ServerFormLoginDsl]
@@ -129,9 +129,11 @@ class ServerFormLoginDslTests {
     @Test
     fun `form login when custom authentication manager then manager used`() {
         this.spring.register(CustomAuthenticationManagerConfig::class.java).autowire()
-        val data: MultiValueMap<String, String> = LinkedMultiValueMap()
-        data.add("username", "user")
-        data.add("password", "password")
+        mockkObject(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER)
+        val data = LinkedMultiValueMap<String, String>().apply {
+            add("username", "user")
+            add("password", "password")
+        }
 
         this.client
                 .mutateWith(csrf())
@@ -140,15 +142,15 @@ class ServerFormLoginDslTests {
                 .body(BodyInserters.fromFormData(data))
                 .exchange()
 
-        verify<ReactiveAuthenticationManager>(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER)
-                .authenticate(any())
+        verify(exactly = 1) { CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomAuthenticationManagerConfig {
+
         companion object {
-            var AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = Mockito.mock(ReactiveAuthenticationManager::class.java)
+            val AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = ReactiveAuthenticationManager { Mono.empty() }
         }
 
         @Bean
@@ -182,9 +184,10 @@ class ServerFormLoginDslTests {
     @Test
     fun `form login when custom requires authentication matcher then matching request logs in`() {
         this.spring.register(CustomConfig::class.java, UserDetailsConfig::class.java).autowire()
-        val data: MultiValueMap<String, String> = LinkedMultiValueMap()
-        data.add("username", "user")
-        data.add("password", "password")
+        val data = LinkedMultiValueMap<String, String>().apply {
+            add("username", "user")
+            add("password", "password")
+        }
 
         val result = this.client
                 .mutateWith(csrf())
@@ -238,9 +241,10 @@ class ServerFormLoginDslTests {
     @Test
     fun `login when custom success handler then success handler used`() {
         this.spring.register(CustomSuccessHandlerConfig::class.java, UserDetailsConfig::class.java).autowire()
-        val data: MultiValueMap<String, String> = LinkedMultiValueMap()
-        data.add("username", "user")
-        data.add("password", "password")
+        val data = LinkedMultiValueMap<String, String>().apply {
+            add("username", "user")
+            add("password", "password")
+        }
 
         val result = this.client
                 .mutateWith(csrf())
@@ -275,9 +279,11 @@ class ServerFormLoginDslTests {
     @Test
     fun `form login when custom security context repository then repository used`() {
         this.spring.register(CustomSecurityContextRepositoryConfig::class.java, UserDetailsConfig::class.java).autowire()
-        val data: MultiValueMap<String, String> = LinkedMultiValueMap()
-        data.add("username", "user")
-        data.add("password", "password")
+        mockkObject(CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY)
+        val data = LinkedMultiValueMap<String, String>().apply {
+            add("username", "user")
+            add("password", "password")
+        }
 
         this.client
                 .mutateWith(csrf())
@@ -286,15 +292,15 @@ class ServerFormLoginDslTests {
                 .body(BodyInserters.fromFormData(data))
                 .exchange()
 
-        verify<ServerSecurityContextRepository>(CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY)
-                .save(Mockito.any(), Mockito.any())
+        verify(exactly = 1) { CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY.save(any(), any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomSecurityContextRepositoryConfig {
+
         companion object {
-            var SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = Mockito.mock(ServerSecurityContextRepository::class.java)
+            val SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = WebSessionServerSecurityContextRepository()
         }
 
         @Bean

+ 28 - 17
config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,10 +16,12 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
+import java.util.Base64
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.BDDMockito.given
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
@@ -27,19 +29,19 @@ import org.springframework.context.annotation.Configuration
 import org.springframework.security.authentication.ReactiveAuthenticationManager
 import org.springframework.security.authentication.TestingAuthenticationToken
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
+import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.core.Authentication
 import org.springframework.security.core.userdetails.MapReactiveUserDetailsService
 import org.springframework.security.core.userdetails.User
-import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.web.server.SecurityWebFilterChain
 import org.springframework.security.web.server.ServerAuthenticationEntryPoint
 import org.springframework.security.web.server.context.ServerSecurityContextRepository
+import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.web.bind.annotation.RequestMapping
 import org.springframework.web.bind.annotation.RestController
 import org.springframework.web.reactive.config.EnableWebFlux
 import reactor.core.publisher.Mono
-import java.util.*
 
 /**
  * Tests for [ServerHttpBasicDsl]
@@ -105,25 +107,26 @@ class ServerHttpBasicDslTests {
 
     @Test
     fun `http basic when custom authentication manager then manager used`() {
-        given<Mono<Authentication>>(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()))
-                .willReturn(Mono.just<Authentication>(TestingAuthenticationToken("user", "password", "ROLE_USER")))
-
         this.spring.register(CustomAuthenticationManagerConfig::class.java).autowire()
+        mockkObject(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER)
+        every {
+            CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any())
+        } returns Mono.just<Authentication>(TestingAuthenticationToken("user", "password", "ROLE_USER"))
 
         this.client.get()
                 .uri("/")
                 .header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:password".toByteArray()))
                 .exchange()
 
-        verify<ReactiveAuthenticationManager>(CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER)
-                .authenticate(any())
+        verify(exactly = 1) { CustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomAuthenticationManagerConfig {
+
         companion object {
-            var AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = mock(ReactiveAuthenticationManager::class.java)
+            val AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = ReactiveAuthenticationManager { Mono.empty() }
         }
 
         @Bean
@@ -142,21 +145,25 @@ class ServerHttpBasicDslTests {
     @Test
     fun `http basic when custom security context repository then repository used`() {
         this.spring.register(CustomSecurityContextRepositoryConfig::class.java, UserDetailsConfig::class.java).autowire()
+        mockkObject(CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY)
+        every {
+            CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY.save(any(), any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri("/")
                 .header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:password".toByteArray()))
                 .exchange()
 
-        verify<ServerSecurityContextRepository>(CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY)
-                .save(any(), any())
+        verify(exactly = 1) { CustomSecurityContextRepositoryConfig.SECURITY_CONTEXT_REPOSITORY.save(any(), any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomSecurityContextRepositoryConfig {
+
         companion object {
-            var SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = mock(ServerSecurityContextRepository::class.java)
+            val SECURITY_CONTEXT_REPOSITORY: ServerSecurityContextRepository = WebSessionServerSecurityContextRepository()
         }
 
         @Bean
@@ -175,20 +182,24 @@ class ServerHttpBasicDslTests {
     @Test
     fun `http basic when custom authentication entry point then entry point used`() {
         this.spring.register(CustomAuthenticationEntryPointConfig::class.java, UserDetailsConfig::class.java).autowire()
+        mockkObject(CustomAuthenticationEntryPointConfig.ENTRY_POINT)
+        every {
+            CustomAuthenticationEntryPointConfig.ENTRY_POINT.commence(any(), any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri("/")
                 .exchange()
 
-        verify<ServerAuthenticationEntryPoint>(CustomAuthenticationEntryPointConfig.ENTRY_POINT)
-                .commence(any(), any())
+        verify(exactly = 1) { CustomAuthenticationEntryPointConfig.ENTRY_POINT.commence(any(), any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomAuthenticationEntryPointConfig {
+
         companion object {
-            var ENTRY_POINT: ServerAuthenticationEntryPoint = mock(ServerAuthenticationEntryPoint::class.java)
+            val ENTRY_POINT: ServerAuthenticationEntryPoint = ServerAuthenticationEntryPoint { _, _ -> Mono.empty() }
         }
 
         @Bean

+ 33 - 22
config/src/test/kotlin/org/springframework/security/config/web/server/ServerJwtDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,12 +16,19 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
+import java.math.BigInteger
+import java.security.KeyFactory
+import java.security.interfaces.RSAPublicKey
+import java.security.spec.RSAPublicKeySpec
+import javax.annotation.PreDestroy
 import okhttp3.mockwebserver.MockResponse
 import okhttp3.mockwebserver.MockWebServer
 import org.assertj.core.api.Assertions.assertThat
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
@@ -40,11 +47,6 @@ import org.springframework.web.bind.annotation.GetMapping
 import org.springframework.web.bind.annotation.RestController
 import org.springframework.web.reactive.config.EnableWebFlux
 import reactor.core.publisher.Mono
-import java.math.BigInteger
-import java.security.KeyFactory
-import java.security.interfaces.RSAPublicKey
-import java.security.spec.RSAPublicKeySpec
-import javax.annotation.PreDestroy
 
 /**
  * Tests for [ServerJwtDsl]
@@ -125,20 +127,25 @@ class ServerJwtDslTests {
     @Test
     fun `jwt when using custom JWT decoded then custom decoded used`() {
         this.spring.register(CustomDecoderConfig::class.java).autowire()
+        mockkObject(CustomDecoderConfig.JWT_DECODER)
+        every {
+            CustomDecoderConfig.JWT_DECODER.decode("token")
+        } returns Mono.empty()
 
         this.client.get()
                 .uri("/")
                 .headers { headers: HttpHeaders -> headers.setBearerAuth("token") }
                 .exchange()
 
-        verify(CustomDecoderConfig.JWT_DECODER).decode("token")
+        verify(exactly = 1) { CustomDecoderConfig.JWT_DECODER.decode("token") }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomDecoderConfig {
+
         companion object {
-            var JWT_DECODER: ReactiveJwtDecoder = mock(ReactiveJwtDecoder::class.java)
+            val JWT_DECODER: ReactiveJwtDecoder = ReactiveJwtDecoder { Mono.empty() }
         }
 
         @Bean
@@ -174,6 +181,7 @@ class ServerJwtDslTests {
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomJwkSetUriConfig {
+
         companion object {
             var MOCK_WEB_SERVER: MockWebServer = MockWebServer()
         }
@@ -207,28 +215,33 @@ class ServerJwtDslTests {
     @Test
     fun `opaque token when custom JWT authentication converter then converter used`() {
         this.spring.register(CustomJwtAuthenticationConverterConfig::class.java).autowire()
-        `when`(CustomJwtAuthenticationConverterConfig.DECODER.decode(anyString())).thenReturn(
-                Mono.just(Jwt.withTokenValue("token")
-                        .header("alg", "none")
-                        .claim(IdTokenClaimNames.SUB, "user")
-                        .build()))
-        `when`(CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any()))
-                .thenReturn(Mono.just(TestingAuthenticationToken("test", "this", "ROLE")))
+        mockkObject(CustomJwtAuthenticationConverterConfig.CONVERTER)
+        mockkObject(CustomJwtAuthenticationConverterConfig.DECODER)
+        every {
+            CustomJwtAuthenticationConverterConfig.DECODER.decode(any())
+        } returns Mono.just(Jwt.withTokenValue("token")
+            .header("alg", "none")
+            .claim(IdTokenClaimNames.SUB, "user")
+            .build())
+        every {
+            CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any())
+        } returns Mono.just(TestingAuthenticationToken("test", "this", "ROLE"))
 
         this.client.get()
                 .uri("/")
                 .headers { headers: HttpHeaders -> headers.setBearerAuth("token") }
                 .exchange()
 
-        verify(CustomJwtAuthenticationConverterConfig.CONVERTER).convert(any())
+        verify(exactly = 1) { CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomJwtAuthenticationConverterConfig {
+
         companion object {
-            var CONVERTER: Converter<Jwt, out Mono<AbstractAuthenticationToken>> = mock(Converter::class.java) as Converter<Jwt, out Mono<AbstractAuthenticationToken>>
-            var DECODER: ReactiveJwtDecoder = mock(ReactiveJwtDecoder::class.java)
+            val CONVERTER: Converter<Jwt, out Mono<AbstractAuthenticationToken>> = Converter { Mono.empty() }
+            val DECODER: ReactiveJwtDecoder = ReactiveJwtDecoder { Mono.empty() }
         }
 
         @Bean
@@ -246,9 +259,7 @@ class ServerJwtDslTests {
         }
 
         @Bean
-        open fun jwtDecoder(): ReactiveJwtDecoder {
-            return DECODER
-        }
+        open fun jwtDecoder(): ReactiveJwtDecoder = DECODER
     }
 
     @RestController

+ 16 - 12
config/src/test/kotlin/org/springframework/security/config/web/server/ServerLogoutDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,11 +16,12 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.assertj.core.api.Assertions.assertThat
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
@@ -152,9 +153,8 @@ class ServerLogoutDslTests {
     @Test
     fun `logout when custom logout handler then custom handler invoked`() {
         this.spring.register(CustomLogoutHandlerConfig::class.java).autowire()
-
-        `when`(CustomLogoutHandlerConfig.LOGOUT_HANDLER.logout(any(), any()))
-                .thenReturn(Mono.empty())
+        mockkObject(CustomLogoutHandlerConfig.LOGOUT_HANDLER)
+        every { CustomLogoutHandlerConfig.LOGOUT_HANDLER.logout(any(), any()) } returns Mono.empty()
 
         this.client
                 .mutateWith(csrf())
@@ -162,15 +162,15 @@ class ServerLogoutDslTests {
                 .uri("/logout")
                 .exchange()
 
-        verify<ServerLogoutHandler>(CustomLogoutHandlerConfig.LOGOUT_HANDLER)
-                .logout(any(), any())
+        verify(exactly = 1) { CustomLogoutHandlerConfig.LOGOUT_HANDLER.logout(any(), any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomLogoutHandlerConfig {
+
         companion object {
-            var LOGOUT_HANDLER: ServerLogoutHandler = mock(ServerLogoutHandler::class.java)
+            val LOGOUT_HANDLER: ServerLogoutHandler = ServerLogoutHandler { _, _ -> Mono.empty() }
         }
 
         @Bean
@@ -186,6 +186,10 @@ class ServerLogoutDslTests {
     @Test
     fun `logout when custom logout success handler then custom handler invoked`() {
         this.spring.register(CustomLogoutSuccessHandlerConfig::class.java).autowire()
+        mockkObject(CustomLogoutSuccessHandlerConfig.LOGOUT_HANDLER)
+        every {
+            CustomLogoutSuccessHandlerConfig.LOGOUT_HANDLER.onLogoutSuccess(any(), any())
+        } returns Mono.empty()
 
         this.client
                 .mutateWith(csrf())
@@ -193,15 +197,15 @@ class ServerLogoutDslTests {
                 .uri("/logout")
                 .exchange()
 
-        verify<ServerLogoutSuccessHandler>(CustomLogoutSuccessHandlerConfig.LOGOUT_HANDLER)
-                .onLogoutSuccess(any(), any())
+        verify(exactly = 1) { CustomLogoutSuccessHandlerConfig.LOGOUT_HANDLER.onLogoutSuccess(any(), any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class CustomLogoutSuccessHandlerConfig {
+
         companion object {
-            var LOGOUT_HANDLER: ServerLogoutSuccessHandler = mock(ServerLogoutSuccessHandler::class.java)
+            val LOGOUT_HANDLER: ServerLogoutSuccessHandler = ServerLogoutSuccessHandler { _, _ -> Mono.empty() }
         }
 
         @Bean

+ 51 - 31
config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,10 +16,11 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
@@ -32,6 +33,7 @@ import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository
 import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository
+import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames
 import org.springframework.security.web.server.SecurityWebFilterChain
@@ -88,6 +90,10 @@ class ServerOAuth2ClientDslTests {
     @Test
     fun `OAuth2 client when authorization request repository configured then custom repository used`() {
         this.spring.register(AuthorizationRequestRepositoryConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY)
+        every {
+            AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri {
@@ -98,15 +104,17 @@ class ServerOAuth2ClientDslTests {
                 }
                 .exchange()
 
-        verify(AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY).loadAuthorizationRequest(any())
+        verify(exactly = 1) {
+            AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())
+        }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class AuthorizationRequestRepositoryConfig {
+
         companion object {
-            var AUTHORIZATION_REQUEST_REPOSITORY = mock(ServerAuthorizationRequestRepository::class.java)
-                    as ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest>
+            val AUTHORIZATION_REQUEST_REPOSITORY : ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> = WebSessionOAuth2ServerAuthorizationRequestRepository()
         }
 
         @Bean
@@ -122,13 +130,18 @@ class ServerOAuth2ClientDslTests {
     @Test
     fun `OAuth2 client when authentication converter configured then custom converter used`() {
         this.spring.register(AuthenticationConverterConfig::class.java, ClientConfig::class.java).autowire()
-
-        `when`(AuthenticationConverterConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any()))
-                .thenReturn(Mono.just(OAuth2AuthorizationRequest.authorizationCode()
-                        .authorizationUri("https://example.com/login/oauth/authorize")
-                        .clientId("clientId")
-                        .redirectUri("/authorize/oauth2/code/google")
-                        .build()))
+        mockkObject(AuthenticationConverterConfig.AUTHORIZATION_REQUEST_REPOSITORY)
+        mockkObject(AuthenticationConverterConfig.AUTHENTICATION_CONVERTER)
+        every {
+            AuthenticationConverterConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())
+        } returns Mono.just(OAuth2AuthorizationRequest.authorizationCode()
+            .authorizationUri("https://example.com/login/oauth/authorize")
+            .clientId("clientId")
+            .redirectUri("/authorize/oauth2/code/google")
+            .build())
+        every {
+            AuthenticationConverterConfig.AUTHENTICATION_CONVERTER.convert(any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri {
@@ -139,16 +152,16 @@ class ServerOAuth2ClientDslTests {
                 }
                 .exchange()
 
-        verify(AuthenticationConverterConfig.AUTHENTICATION_CONVERTER).convert(any())
+        verify(exactly = 1) { AuthenticationConverterConfig.AUTHENTICATION_CONVERTER.convert(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class AuthenticationConverterConfig {
+
         companion object {
-            var AUTHORIZATION_REQUEST_REPOSITORY = mock(ServerAuthorizationRequestRepository::class.java)
-                    as ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest>
-            var AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = mock(ServerAuthenticationConverter::class.java)
+            val AUTHORIZATION_REQUEST_REPOSITORY: ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> = WebSessionOAuth2ServerAuthorizationRequestRepository()
+            val AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = ServerAuthenticationConverter { Mono.empty() }
         }
 
         @Bean
@@ -165,15 +178,22 @@ class ServerOAuth2ClientDslTests {
     @Test
     fun `OAuth2 client when authentication manager configured then custom manager used`() {
         this.spring.register(AuthenticationManagerConfig::class.java, ClientConfig::class.java).autowire()
-
-        `when`(AuthenticationManagerConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any()))
-                .thenReturn(Mono.just(OAuth2AuthorizationRequest.authorizationCode()
-                        .authorizationUri("https://example.com/login/oauth/authorize")
-                        .clientId("clientId")
-                        .redirectUri("/authorize/oauth2/code/google")
-                        .build()))
-        `when`(AuthenticationManagerConfig.AUTHENTICATION_CONVERTER.convert(any()))
-                .thenReturn(Mono.just(TestingAuthenticationToken("a", "b", "c")))
+        mockkObject(AuthenticationManagerConfig.AUTHORIZATION_REQUEST_REPOSITORY)
+        mockkObject(AuthenticationManagerConfig.AUTHENTICATION_CONVERTER)
+        mockkObject(AuthenticationManagerConfig.AUTHENTICATION_MANAGER)
+        every {
+            AuthenticationManagerConfig.AUTHORIZATION_REQUEST_REPOSITORY.loadAuthorizationRequest(any())
+        } returns Mono.just(OAuth2AuthorizationRequest.authorizationCode()
+            .authorizationUri("https://example.com/login/oauth/authorize")
+            .clientId("clientId")
+            .redirectUri("/authorize/oauth2/code/google")
+            .build())
+        every {
+            AuthenticationManagerConfig.AUTHENTICATION_CONVERTER.convert(any())
+        } returns Mono.just(TestingAuthenticationToken("a", "b", "c"))
+        every {
+            AuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri {
@@ -184,17 +204,17 @@ class ServerOAuth2ClientDslTests {
                 }
                 .exchange()
 
-        verify(AuthenticationManagerConfig.AUTHENTICATION_MANAGER).authenticate(any())
+        verify(exactly = 1) { AuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class AuthenticationManagerConfig {
+
         companion object {
-            var AUTHORIZATION_REQUEST_REPOSITORY = mock(ServerAuthorizationRequestRepository::class.java)
-                    as ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest>
-            var AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = mock(ServerAuthenticationConverter::class.java)
-            var AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = mock(ReactiveAuthenticationManager::class.java)
+            val AUTHORIZATION_REQUEST_REPOSITORY: ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> = WebSessionOAuth2ServerAuthorizationRequestRepository()
+            val AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = ServerAuthenticationConverter { Mono.empty() }
+            val AUTHENTICATION_MANAGER: ReactiveAuthenticationManager = ReactiveAuthenticationManager { Mono.empty() }
         }
 
         @Bean

+ 27 - 10
config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,9 +16,11 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
@@ -29,12 +31,14 @@ import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository
 import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository
+import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
 import org.springframework.security.web.server.SecurityWebFilterChain
 import org.springframework.security.web.server.authentication.ServerAuthenticationConverter
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.web.reactive.config.EnableWebFlux
+import reactor.core.publisher.Mono
 
 /**
  * Tests for [ServerOAuth2LoginDsl]
@@ -105,20 +109,23 @@ class ServerOAuth2LoginDslTests {
     @Test
     fun `OAuth2 login when authorization request repository configured then custom repository used`() {
         this.spring.register(AuthorizationRequestRepositoryConfig::class.java, ClientConfig::class.java).autowire()
-
+        mockkObject(AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY)
+        every {
+            AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY.removeAuthorizationRequest(any())
+        } returns Mono.empty()
         this.client.get()
                 .uri("/login/oauth2/code/google")
                 .exchange()
 
-        verify(AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY).removeAuthorizationRequest(any())
+        verify(exactly = 1) { AuthorizationRequestRepositoryConfig.AUTHORIZATION_REQUEST_REPOSITORY.removeAuthorizationRequest(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class AuthorizationRequestRepositoryConfig {
+
         companion object {
-            var AUTHORIZATION_REQUEST_REPOSITORY = mock(ServerAuthorizationRequestRepository::class.java)
-                    as ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest>
+            val AUTHORIZATION_REQUEST_REPOSITORY: ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> = WebSessionOAuth2ServerAuthorizationRequestRepository()
         }
 
         @Bean
@@ -134,19 +141,24 @@ class ServerOAuth2LoginDslTests {
     @Test
     fun `OAuth2 login when authentication matcher configured then custom matcher used`() {
         this.spring.register(AuthenticationMatcherConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(AuthenticationMatcherConfig.AUTHENTICATION_MATCHER)
+        every {
+            AuthenticationMatcherConfig.AUTHENTICATION_MATCHER.matches(any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri("/")
                 .exchange()
 
-        verify(AuthenticationMatcherConfig.AUTHENTICATION_MATCHER).matches(any())
+        verify(exactly = 1) { AuthenticationMatcherConfig.AUTHENTICATION_MATCHER.matches(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class AuthenticationMatcherConfig {
+
         companion object {
-            var AUTHENTICATION_MATCHER: ServerWebExchangeMatcher = mock(ServerWebExchangeMatcher::class.java)
+            val AUTHENTICATION_MATCHER: ServerWebExchangeMatcher = ServerWebExchangeMatcher { Mono.empty() }
         }
 
         @Bean
@@ -162,19 +174,24 @@ class ServerOAuth2LoginDslTests {
     @Test
     fun `OAuth2 login when authentication converter configured then custom converter used`() {
         this.spring.register(AuthenticationConverterConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(AuthenticationConverterConfig.AUTHENTICATION_CONVERTER)
+        every {
+            AuthenticationConverterConfig.AUTHENTICATION_CONVERTER.convert(any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri("/login/oauth2/code/google")
                 .exchange()
 
-        verify(AuthenticationConverterConfig.AUTHENTICATION_CONVERTER).convert(any())
+        verify(exactly = 1) { AuthenticationConverterConfig.AUTHENTICATION_CONVERTER.convert(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class AuthenticationConverterConfig {
+
         companion object {
-            var AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = mock(ServerAuthenticationConverter::class.java)
+            val AUTHENTICATION_CONVERTER: ServerAuthenticationConverter = ServerAuthenticationConverter { Mono.empty() }
         }
 
         @Bean

+ 23 - 14
config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,16 +16,19 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
+import java.math.BigInteger
+import java.security.KeyFactory
+import java.security.interfaces.RSAPublicKey
+import java.security.spec.RSAPublicKeySpec
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.mock
-import org.mockito.Mockito.verify
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
 import org.springframework.http.HttpStatus
-import org.springframework.http.server.reactive.ServerHttpRequest
 import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
 import org.springframework.security.config.test.SpringTestRule
@@ -36,10 +39,7 @@ import org.springframework.security.web.server.authorization.HttpStatusServerAcc
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.web.reactive.config.EnableWebFlux
 import org.springframework.web.server.ServerWebExchange
-import java.math.BigInteger
-import java.security.KeyFactory
-import java.security.interfaces.RSAPublicKey
-import java.security.spec.RSAPublicKeySpec
+import reactor.core.publisher.Mono
 
 /**
  * Tests for [ServerOAuth2ResourceServerDsl]
@@ -127,20 +127,25 @@ class ServerOAuth2ResourceServerDslTests {
     @Test
     fun `request when custom bearer token converter configured then custom converter used`() {
         this.spring.register(BearerTokenConverterConfig::class.java).autowire()
+        mockkObject(BearerTokenConverterConfig.CONVERTER)
+        every {
+            BearerTokenConverterConfig.CONVERTER.convert(any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri("/")
                 .headers { it.setBearerAuth(validJwt) }
                 .exchange()
 
-        verify(BearerTokenConverterConfig.CONVERTER).convert(any())
+        verify(exactly = 1) { BearerTokenConverterConfig.CONVERTER.convert(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class BearerTokenConverterConfig {
+
         companion object {
-            val CONVERTER: ServerBearerTokenAuthenticationConverter = mock(ServerBearerTokenAuthenticationConverter::class.java)
+            val CONVERTER: ServerBearerTokenAuthenticationConverter = ServerBearerTokenAuthenticationConverter()
         }
 
         @Bean
@@ -162,21 +167,25 @@ class ServerOAuth2ResourceServerDslTests {
     @Test
     fun `request when custom authentication manager resolver configured then custom resolver used`() {
         this.spring.register(AuthenticationManagerResolverConfig::class.java).autowire()
+        mockkObject(AuthenticationManagerResolverConfig.RESOLVER)
+        every {
+            AuthenticationManagerResolverConfig.RESOLVER.resolve(any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri("/")
                 .headers { it.setBearerAuth(validJwt) }
                 .exchange()
 
-        verify(AuthenticationManagerResolverConfig.RESOLVER).resolve(any())
+        verify(exactly = 1) { AuthenticationManagerResolverConfig.RESOLVER.resolve(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class AuthenticationManagerResolverConfig {
+
         companion object {
-            val RESOLVER: ReactiveAuthenticationManagerResolver<ServerWebExchange> =
-                    mock(ReactiveAuthenticationManagerResolver::class.java) as ReactiveAuthenticationManagerResolver<ServerWebExchange>
+            val RESOLVER: ReactiveAuthenticationManagerResolver<ServerWebExchange> = ReactiveAuthenticationManagerResolver { Mono.empty() }
         }
 
         @Bean

+ 13 - 9
config/src/test/kotlin/org/springframework/security/config/web/server/ServerRequestCacheDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,22 +16,22 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito
-import org.mockito.Mockito.`when`
-import org.mockito.Mockito.verify
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
+import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.core.userdetails.MapReactiveUserDetailsService
 import org.springframework.security.core.userdetails.User
-import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.web.server.SecurityWebFilterChain
 import org.springframework.security.web.server.savedrequest.ServerRequestCache
+import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.web.reactive.config.EnableWebFlux
 import reactor.core.publisher.Mono
@@ -59,20 +59,24 @@ class ServerRequestCacheDslTests {
     @Test
     fun `GET when request cache enabled then redirected to cached page`() {
         this.spring.register(RequestCacheConfig::class.java, UserDetailsConfig::class.java).autowire()
-        `when`(RequestCacheConfig.REQUEST_CACHE.removeMatchingRequest(any())).thenReturn(Mono.empty())
+        mockkObject(RequestCacheConfig.REQUEST_CACHE)
+        every {
+            RequestCacheConfig.REQUEST_CACHE.removeMatchingRequest(any())
+        } returns Mono.empty()
 
         this.client.get()
                 .uri("/test")
                 .exchange()
 
-        verify(RequestCacheConfig.REQUEST_CACHE).saveRequest(any())
+        verify(exactly = 1) { RequestCacheConfig.REQUEST_CACHE.saveRequest(any()) }
     }
 
     @EnableWebFluxSecurity
     @EnableWebFlux
     open class RequestCacheConfig {
+
         companion object {
-            var REQUEST_CACHE: ServerRequestCache = Mockito.mock(ServerRequestCache::class.java)
+            val REQUEST_CACHE: ServerRequestCache = WebSessionServerRequestCache()
         }
 
         @Bean

+ 10 - 10
config/src/test/kotlin/org/springframework/security/config/web/server/ServerX509DslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,10 +16,13 @@
 
 package org.springframework.security.config.web.server
 
+import io.mockk.every
+import io.mockk.mockk
+import java.security.cert.Certificate
+import java.security.cert.CertificateFactory
+import java.security.cert.X509Certificate
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.`when`
-import org.mockito.Mockito.mock
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.ApplicationContext
 import org.springframework.context.annotation.Bean
@@ -30,10 +33,10 @@ import org.springframework.http.server.reactive.ServerHttpRequestDecorator
 import org.springframework.http.server.reactive.SslInfo
 import org.springframework.lang.Nullable
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
+import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.core.annotation.AuthenticationPrincipal
 import org.springframework.security.core.userdetails.MapReactiveUserDetailsService
 import org.springframework.security.core.userdetails.User
-import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor
 import org.springframework.security.web.server.SecurityWebFilterChain
 import org.springframework.security.web.server.authentication.ReactivePreAuthenticatedAuthenticationManager
@@ -50,9 +53,6 @@ import org.springframework.web.server.WebFilter
 import org.springframework.web.server.WebFilterChain
 import org.springframework.web.server.adapter.WebHttpHandlerBuilder
 import reactor.core.publisher.Mono
-import java.security.cert.Certificate
-import java.security.cert.CertificateFactory
-import java.security.cert.X509Certificate
 
 /**
  * Tests for [ServerX509Dsl]
@@ -214,9 +214,9 @@ class ServerX509DslTests {
         private fun decorate(exchange: ServerWebExchange): ServerWebExchange {
             val decorated: ServerHttpRequestDecorator = object : ServerHttpRequestDecorator(exchange.request) {
                 override fun getSslInfo(): SslInfo {
-                    val sslInfo = mock(SslInfo::class.java)
-                    `when`(sslInfo.sessionId).thenReturn("sessionId")
-                    `when`(sslInfo.peerCertificates).thenReturn(arrayOf(certificate))
+                    val sslInfo: SslInfo = mockk()
+                    every { sslInfo.sessionId } returns "sessionId"
+                    every { sslInfo.peerCertificates } returns arrayOf(certificate)
                     return sslInfo
                 }
             }

+ 17 - 13
config/src/test/kotlin/org/springframework/security/config/web/servlet/CsrfDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,16 +16,17 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.security.config.annotation.web.builders.HttpSecurity
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
 import org.springframework.security.config.test.SpringTestRule
-import org.springframework.security.core.Authentication
 import org.springframework.security.core.userdetails.User
 import org.springframework.security.core.userdetails.UserDetailsService
 import org.springframework.security.provisioning.InMemoryUserDetailsManager
@@ -34,14 +35,13 @@ import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequ
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy
 import org.springframework.security.web.csrf.CsrfTokenRepository
 import org.springframework.security.web.csrf.DefaultCsrfToken
+import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
 import org.springframework.test.web.servlet.post
 import org.springframework.web.bind.annotation.PostMapping
 import org.springframework.web.bind.annotation.RestController
-import javax.servlet.http.HttpServletRequest
-import javax.servlet.http.HttpServletResponse
 
 /**
  * Tests for [CsrfDsl]
@@ -110,20 +110,22 @@ class CsrfDslTests {
 
     @Test
     fun `CSRF when custom CSRF token repository then repo used`() {
-        `when`(CustomRepositoryConfig.REPO.loadToken(any<HttpServletRequest>()))
-                .thenReturn(DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"))
-
         this.spring.register(CustomRepositoryConfig::class.java).autowire()
+        mockkObject(CustomRepositoryConfig.REPO)
+        every {
+            CustomRepositoryConfig.REPO.loadToken(any())
+        } returns DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")
 
         this.mockMvc.get("/test1")
 
-        verify(CustomRepositoryConfig.REPO).loadToken(any<HttpServletRequest>())
+        verify(exactly = 1) { CustomRepositoryConfig.REPO.loadToken(any()) }
     }
 
     @EnableWebSecurity
     open class CustomRepositoryConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var REPO: CsrfTokenRepository = mock(CsrfTokenRepository::class.java)
+            val REPO: CsrfTokenRepository = HttpSessionCsrfTokenRepository()
         }
 
         override fun configure(http: HttpSecurity) {
@@ -164,18 +166,20 @@ class CsrfDslTests {
     @Test
     fun `CSRF when custom session authentication strategy then strategy used`() {
         this.spring.register(CustomStrategyConfig::class.java).autowire()
+        mockkObject(CustomStrategyConfig.STRATEGY)
+        every { CustomStrategyConfig.STRATEGY.onAuthentication(any(), any(), any()) } returns Unit
 
         this.mockMvc.perform(formLogin())
 
-        verify(CustomStrategyConfig.STRATEGY, atLeastOnce())
-                .onAuthentication(any(Authentication::class.java), any(HttpServletRequest::class.java), any(HttpServletResponse::class.java))
+        verify(exactly = 1) { CustomStrategyConfig.STRATEGY.onAuthentication(any(), any(), any()) }
 
     }
 
     @EnableWebSecurity
     open class CustomStrategyConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var STRATEGY: SessionAuthenticationStrategy = mock(SessionAuthenticationStrategy::class.java)
+            val STRATEGY: SessionAuthenticationStrategy = SessionAuthenticationStrategy { _, _, _ -> }
         }
 
         override fun configure(http: HttpSecurity) {

+ 21 - 17
config/src/test/kotlin/org/springframework/security/config/web/servlet/HttpBasicDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,11 +16,12 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
+import javax.servlet.http.HttpServletRequest
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.mock
-import org.mockito.Mockito.verify
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
@@ -29,7 +30,6 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
 import org.springframework.security.config.test.SpringTestRule
-import org.springframework.security.core.AuthenticationException
 import org.springframework.security.core.userdetails.User
 import org.springframework.security.core.userdetails.UserDetailsService
 import org.springframework.security.provisioning.InMemoryUserDetailsManager
@@ -39,8 +39,6 @@ import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
 import org.springframework.web.bind.annotation.GetMapping
 import org.springframework.web.bind.annotation.RestController
-import javax.servlet.http.HttpServletRequest
-import javax.servlet.http.HttpServletResponse
 
 /**
  * Tests for [HttpBasicDsl]
@@ -125,19 +123,19 @@ class HttpBasicDslTests {
     @Test
     fun `http basic when custom authentication entry point then used`() {
         this.spring.register(CustomAuthenticationEntryPointConfig::class.java).autowire()
+        mockkObject(CustomAuthenticationEntryPointConfig.ENTRY_POINT)
+        every { CustomAuthenticationEntryPointConfig.ENTRY_POINT.commence(any(), any(), any()) } returns Unit
 
         this.mockMvc.get("/")
 
-        verify<AuthenticationEntryPoint>(CustomAuthenticationEntryPointConfig.ENTRY_POINT)
-                .commence(any(HttpServletRequest::class.java),
-                        any(HttpServletResponse::class.java),
-                        any(AuthenticationException::class.java))
+        verify(exactly = 1) { CustomAuthenticationEntryPointConfig.ENTRY_POINT.commence(any(), any(), any()) }
     }
 
     @EnableWebSecurity
     open class CustomAuthenticationEntryPointConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var ENTRY_POINT: AuthenticationEntryPoint = mock(AuthenticationEntryPoint::class.java)
+            val ENTRY_POINT: AuthenticationEntryPoint = AuthenticationEntryPoint { _, _, _ ->  }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -154,21 +152,27 @@ class HttpBasicDslTests {
 
     @Test
     fun `http basic when custom authentication details source then used`() {
-        this.spring.register(CustomAuthenticationDetailsSourceConfig::class.java,
-                UserConfig::class.java, MainController::class.java).autowire()
+        this.spring
+            .register(CustomAuthenticationDetailsSourceConfig::class.java, UserConfig::class.java, MainController::class.java)
+            .autowire()
+        mockkObject(CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE)
+        every {
+            CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE.buildDetails(any())
+        } returns Any()
 
         this.mockMvc.get("/") {
             with(httpBasic("username", "password"))
         }
 
-        verify(CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE)
-                .buildDetails(any(HttpServletRequest::class.java))
+        verify(exactly = 1) { CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE.buildDetails(any()) }
     }
 
     @EnableWebSecurity
     open class CustomAuthenticationDetailsSourceConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var AUTHENTICATION_DETAILS_SOURCE = mock(AuthenticationDetailsSource::class.java) as AuthenticationDetailsSource<HttpServletRequest, *>
+            val AUTHENTICATION_DETAILS_SOURCE: AuthenticationDetailsSource<HttpServletRequest, *> =
+                AuthenticationDetailsSource<HttpServletRequest, Any> { Any() }
         }
 
         override fun configure(http: HttpSecurity) {

+ 9 - 6
config/src/test/kotlin/org/springframework/security/config/web/servlet/LogoutDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,12 +16,12 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.assertj.core.api.Assertions.assertThat
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.mock
-import org.mockito.Mockito.verify
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.mock.web.MockHttpSession
 import org.springframework.security.authentication.TestingAuthenticationToken
@@ -285,18 +285,21 @@ class LogoutDslTests {
     @Test
     fun `logout when custom logout handler then custom handler used`() {
         this.spring.register(CustomLogoutHandlerConfig::class.java).autowire()
+       mockkObject(CustomLogoutHandlerConfig.HANDLER)
+        every { CustomLogoutHandlerConfig.HANDLER.logout(any(), any(), any()) } returns Unit
 
         this.mockMvc.post("/logout") {
             with(csrf())
         }
 
-        verify(CustomLogoutHandlerConfig.HANDLER).logout(any(), any(), any())
+        verify(exactly = 1) { CustomLogoutHandlerConfig.HANDLER.logout(any(), any(), any()) }
     }
 
     @EnableWebSecurity
     open class CustomLogoutHandlerConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var HANDLER: LogoutHandler = mock(LogoutHandler::class.java)
+            val HANDLER: LogoutHandler = LogoutHandler { _, _, _ -> }
         }
 
         override fun configure(http: HttpSecurity) {

+ 33 - 16
config/src/test/kotlin/org/springframework/security/config/web/servlet/OAuth2ClientDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,10 +16,11 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
@@ -33,6 +34,8 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
+import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
+import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository
 import org.springframework.security.oauth2.core.OAuth2AccessToken
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
@@ -77,6 +80,9 @@ class OAuth2ClientDslTests {
     @Test
     fun `oauth2Client when custom authorized client repository then repository used`() {
         this.spring.register(ClientRepositoryConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(ClientRepositoryConfig.REQUEST_REPOSITORY)
+        mockkObject(ClientRepositoryConfig.CLIENT)
+        mockkObject(ClientRepositoryConfig.CLIENT_REPOSITORY)
         val authorizationRequest = OAuth2AuthorizationRequest
                 .authorizationCode()
                 .state("test")
@@ -85,30 +91,41 @@ class OAuth2ClientDslTests {
                 .redirectUri("http://localhost/callback")
                 .attributes(mapOf(Pair(OAuth2ParameterNames.REGISTRATION_ID, "registrationId")))
                 .build()
-        `when`(ClientRepositoryConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any()))
-                .thenReturn(authorizationRequest)
-        `when`(ClientRepositoryConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any()))
-                .thenReturn(authorizationRequest)
-        `when`(ClientRepositoryConfig.CLIENT.getTokenResponse(any()))
-                .thenReturn(OAuth2AccessTokenResponse
-                        .withToken("token")
-                        .tokenType(OAuth2AccessToken.TokenType.BEARER)
-                        .build())
+        every {
+            ClientRepositoryConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any())
+        } returns authorizationRequest
+        every {
+            ClientRepositoryConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any())
+        } returns authorizationRequest
+        every {
+            ClientRepositoryConfig.CLIENT.getTokenResponse(any())
+        } returns OAuth2AccessTokenResponse
+            .withToken("token")
+            .tokenType(OAuth2AccessToken.TokenType.BEARER)
+            .build()
+        every {
+            ClientRepositoryConfig.CLIENT_REPOSITORY.saveAuthorizedClient(any(), any(), any(), any())
+        } returns Unit
 
         this.mockMvc.get("/callback") {
             param("state", "test")
             param("code", "123")
         }
 
-        verify(ClientRepositoryConfig.CLIENT_REPOSITORY).saveAuthorizedClient(any(), any(), any(), any())
+        verify(exactly = 1) { ClientRepositoryConfig.CLIENT_REPOSITORY.saveAuthorizedClient(any(), any(), any(), any()) }
     }
 
     @EnableWebSecurity
     open class ClientRepositoryConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest>
-            var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest>
-            var CLIENT_REPOSITORY: OAuth2AuthorizedClientRepository = mock(OAuth2AuthorizedClientRepository::class.java)
+            val REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
+                HttpSessionOAuth2AuthorizationRequestRepository()
+            val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> =
+                OAuth2AccessTokenResponseClient {
+                    OAuth2AccessTokenResponse.withToken("some tokenValue").build()
+                }
+            val CLIENT_REPOSITORY: OAuth2AuthorizedClientRepository = HttpSessionOAuth2AuthorizedClientRepository()
         }
 
         override fun configure(http: HttpSecurity) {

+ 62 - 30
config/src/test/kotlin/org/springframework/security/config/web/servlet/OAuth2ResourceServerDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,15 +16,20 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.every
+import io.mockk.mockk
+import io.mockk.mockkObject
+import io.mockk.verify
+import javax.servlet.http.HttpServletRequest
 import org.assertj.core.api.Assertions
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.BeanCreationException
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.security.authentication.AuthenticationManager
 import org.springframework.security.authentication.AuthenticationManagerResolver
+import org.springframework.security.authentication.TestingAuthenticationToken
 import org.springframework.security.config.annotation.web.builders.HttpSecurity
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
@@ -34,11 +39,11 @@ import org.springframework.security.oauth2.jwt.Jwt
 import org.springframework.security.oauth2.jwt.JwtDecoder
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken
 import org.springframework.security.oauth2.server.resource.web.BearerTokenResolver
+import org.springframework.security.oauth2.server.resource.web.DefaultBearerTokenResolver
 import org.springframework.security.web.AuthenticationEntryPoint
 import org.springframework.security.web.access.AccessDeniedHandler
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
-import javax.servlet.http.HttpServletRequest
 
 /**
  * Tests for [OAuth2ResourceServerDsl]
@@ -61,16 +66,19 @@ class OAuth2ResourceServerDslTests {
     @Test
     fun `oauth2Resource server when custom entry point then entry point used`() {
         this.spring.register(EntryPointConfig::class.java).autowire()
+        mockkObject(EntryPointConfig.ENTRY_POINT)
+        every { EntryPointConfig.ENTRY_POINT.commence(any(), any(), any()) } returns Unit
 
         this.mockMvc.get("/")
 
-        verify(EntryPointConfig.ENTRY_POINT).commence(any(), any(), any())
+        verify(exactly = 1) { EntryPointConfig.ENTRY_POINT.commence(any(), any(), any()) }
     }
 
     @EnableWebSecurity
     open class EntryPointConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var ENTRY_POINT: AuthenticationEntryPoint = mock(AuthenticationEntryPoint::class.java)
+            val ENTRY_POINT: AuthenticationEntryPoint = AuthenticationEntryPoint { _, _, _ ->  }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -86,24 +94,33 @@ class OAuth2ResourceServerDslTests {
         }
 
         @Bean
-        open fun jwtDecoder(): JwtDecoder {
-            return mock(JwtDecoder::class.java)
-        }
+        open fun jwtDecoder(): JwtDecoder = mockk()
     }
 
     @Test
     fun `oauth2Resource server when custom bearer token resolver then resolver used`() {
         this.spring.register(BearerTokenResolverConfig::class.java).autowire()
+        mockkObject(BearerTokenResolverConfig.RESOLVER)
+        mockkObject(BearerTokenResolverConfig.DECODER)
+        every { BearerTokenResolverConfig.RESOLVER.resolve(any()) } returns "anything"
+        every { BearerTokenResolverConfig.DECODER.decode(any()) } returns JWT
 
         this.mockMvc.get("/")
 
-        verify(BearerTokenResolverConfig.RESOLVER).resolve(any())
+        verify(exactly = 1) { BearerTokenResolverConfig.RESOLVER.resolve(any()) }
     }
 
     @EnableWebSecurity
     open class BearerTokenResolverConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var RESOLVER: BearerTokenResolver = mock(BearerTokenResolver::class.java)
+            val RESOLVER: BearerTokenResolver = DefaultBearerTokenResolver()
+            val DECODER: JwtDecoder =  JwtDecoder {
+                Jwt.withTokenValue("token")
+                    .header("alg", "none")
+                    .claim(SUB, "user")
+                    .build()
+            }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -119,28 +136,39 @@ class OAuth2ResourceServerDslTests {
         }
 
         @Bean
-        open fun jwtDecoder(): JwtDecoder {
-            return mock(JwtDecoder::class.java)
-        }
+        open fun jwtDecoder(): JwtDecoder = DECODER
     }
 
     @Test
     fun `oauth2Resource server when custom access denied handler then handler used`() {
         this.spring.register(AccessDeniedHandlerConfig::class.java).autowire()
-        `when`(AccessDeniedHandlerConfig.DECODER.decode(anyString())).thenReturn(JWT)
+        mockkObject(AccessDeniedHandlerConfig.DENIED_HANDLER)
+        mockkObject(AccessDeniedHandlerConfig.DECODER)
+        every {
+            AccessDeniedHandlerConfig.DECODER.decode(any())
+        } returns JWT
+        every {
+            AccessDeniedHandlerConfig.DENIED_HANDLER.handle(any(), any(), any())
+        } returns Unit
 
         this.mockMvc.get("/") {
             header("Authorization", "Bearer token")
         }
 
-        verify(AccessDeniedHandlerConfig.DENIED_HANDLER).handle(any(), any(), any())
+        verify(exactly = 1) { AccessDeniedHandlerConfig.DENIED_HANDLER.handle(any(), any(), any()) }
     }
 
     @EnableWebSecurity
     open class AccessDeniedHandlerConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var DENIED_HANDLER: AccessDeniedHandler = mock(AccessDeniedHandler::class.java)
-            var DECODER: JwtDecoder = mock(JwtDecoder::class.java)
+            val DECODER: JwtDecoder = JwtDecoder { _ ->
+                Jwt.withTokenValue("token")
+                    .header("alg", "none")
+                    .claim(SUB, "user")
+                    .build()
+            }
+            val DENIED_HANDLER: AccessDeniedHandler = AccessDeniedHandler { _, _, _ ->  }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -156,31 +184,36 @@ class OAuth2ResourceServerDslTests {
         }
 
         @Bean
-        open fun jwtDecoder(): JwtDecoder {
-            return DECODER
-        }
+        open fun jwtDecoder(): JwtDecoder = DECODER
     }
 
     @Test
     fun `oauth2Resource server when custom authentication manager resolver then resolver used`() {
         this.spring.register(AuthenticationManagerResolverConfig::class.java).autowire()
-        `when`(AuthenticationManagerResolverConfig.RESOLVER.resolve(any())).thenReturn(
-                AuthenticationManager {
-                    JwtAuthenticationToken(JWT)
-                }
-        )
+        mockkObject(AuthenticationManagerResolverConfig.RESOLVER)
+        every {
+            AuthenticationManagerResolverConfig.RESOLVER.resolve(any())
+        } returns AuthenticationManager {
+            JwtAuthenticationToken(JWT)
+        }
 
         this.mockMvc.get("/") {
             header("Authorization", "Bearer token")
         }
 
-        verify(AuthenticationManagerResolverConfig.RESOLVER).resolve(any())
+        verify(exactly = 1) { AuthenticationManagerResolverConfig.RESOLVER.resolve(any()) }
     }
 
     @EnableWebSecurity
     open class AuthenticationManagerResolverConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var RESOLVER: AuthenticationManagerResolver<*> = mock(AuthenticationManagerResolver::class.java)
+            val RESOLVER: AuthenticationManagerResolver<HttpServletRequest> =
+                AuthenticationManagerResolver {
+                    AuthenticationManager {
+                        TestingAuthenticationToken("a,", "b", "c")
+                    }
+                }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -189,7 +222,7 @@ class OAuth2ResourceServerDslTests {
                     authorize(anyRequest, authenticated)
                 }
                 oauth2ResourceServer {
-                    authenticationManagerResolver = RESOLVER as AuthenticationManagerResolver<HttpServletRequest>
+                    authenticationManagerResolver = RESOLVER
                 }
             }
         }
@@ -210,8 +243,7 @@ class OAuth2ResourceServerDslTests {
                     authorize(anyRequest, authenticated)
                 }
                 oauth2ResourceServer {
-                    authenticationManagerResolver = mock(AuthenticationManagerResolver::class.java)
-                            as AuthenticationManagerResolver<HttpServletRequest>
+                    authenticationManagerResolver = mockk()
                     opaqueToken { }
                 }
             }

+ 93 - 46
config/src/test/kotlin/org/springframework/security/config/web/servlet/RememberMeDslTests.kt

@@ -16,14 +16,22 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.Called
+import io.mockk.confirmVerified
+import io.mockk.every
+import io.mockk.justRun
+import io.mockk.mockk
+import io.mockk.mockkObject
+import io.mockk.verify
+import javax.servlet.http.HttpServletRequest
 import org.assertj.core.api.Assertions.assertThat
 import org.junit.Rule
 import org.junit.Test
 import org.junit.jupiter.api.fail
-import org.mockito.BDDMockito.given
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
+import org.springframework.context.annotation.Bean
 import org.springframework.core.annotation.Order
+import org.springframework.mock.web.MockHttpServletRequest
 import org.springframework.mock.web.MockHttpSession
 import org.springframework.security.authentication.RememberMeAuthenticationToken
 import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder
@@ -36,21 +44,21 @@ import org.springframework.security.core.authority.AuthorityUtils
 import org.springframework.security.core.userdetails.PasswordEncodedUser
 import org.springframework.security.core.userdetails.User
 import org.springframework.security.core.userdetails.UserDetailsService
+import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder
+import org.springframework.security.crypto.password.PasswordEncoder
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf
 import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler
+import org.springframework.security.web.authentication.NullRememberMeServices
 import org.springframework.security.web.authentication.RememberMeServices
 import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices
-import org.springframework.security.web.authentication.rememberme.PersistentRememberMeToken
 import org.springframework.security.web.authentication.rememberme.PersistentTokenRepository
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher
 import org.springframework.test.web.servlet.MockHttpServletRequestDsl
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
 import org.springframework.test.web.servlet.post
-import javax.servlet.http.HttpServletRequest
-import javax.servlet.http.HttpServletResponse
 
 /**
  * Tests for [RememberMeDsl]
@@ -58,6 +66,7 @@ import javax.servlet.http.HttpServletResponse
  * @author Ivan Pavlov
  */
 internal class RememberMeDslTests {
+
     @Rule
     @JvmField
     val spring = SpringTestRule()
@@ -65,6 +74,8 @@ internal class RememberMeDslTests {
     @Autowired
     lateinit var mockMvc: MockMvc
 
+    private val mockAuthentication: Authentication = mockk()
+
     @Test
     fun `Remember Me login when remember me true then responds with remember me cookie`() {
         this.spring.register(RememberMeConfig::class.java).autowire()
@@ -165,39 +176,49 @@ internal class RememberMeDslTests {
 
     @Test
     fun `Remember Me when remember me services then uses`() {
-        RememberMeServicesRefConfig.REMEMBER_ME_SERVICES = mock(RememberMeServices::class.java)
         this.spring.register(RememberMeServicesRefConfig::class.java).autowire()
+        mockkObject(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES)
+        every {
+            RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.autoLogin(any(),any())
+        } returns mockAuthentication
+        every {
+            RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.loginFail(any(), any())
+        } returns Unit
+        every {
+            RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.loginSuccess(any(), any(), any())
+        } returns Unit
+
         mockMvc.get("/")
-        verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES).autoLogin(any(HttpServletRequest::class.java),
-                any(HttpServletResponse::class.java))
+
+        verify(exactly = 1) { RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.autoLogin(any(),any()) }
         mockMvc.post("/login") {
             with(csrf())
         }
-        verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES).loginFail(any(HttpServletRequest::class.java),
-                any(HttpServletResponse::class.java))
+        verify(exactly = 2) { RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.loginFail(any(), any()) }
         mockMvc.post("/login") {
             loginRememberMeRequest()
         }
-        verify(RememberMeServicesRefConfig.REMEMBER_ME_SERVICES).loginSuccess(any(HttpServletRequest::class.java),
-                any(HttpServletResponse::class.java), any(Authentication::class.java))
+        verify(exactly = 1) { RememberMeServicesRefConfig.REMEMBER_ME_SERVICES.loginSuccess(any(), any(), any()) }
     }
 
     @Test
     fun `Remember Me when authentication success handler then uses`() {
-        RememberMeSuccessHandlerConfig.SUCCESS_HANDLER = mock(AuthenticationSuccessHandler::class.java)
         this.spring.register(RememberMeSuccessHandlerConfig::class.java).autowire()
+        mockkObject(RememberMeSuccessHandlerConfig.SUCCESS_HANDLER)
+        justRun {
+            RememberMeSuccessHandlerConfig.SUCCESS_HANDLER.onAuthenticationSuccess(any(), any(), any())
+        }
         val mvcResult = mockMvc.post("/login") {
             loginRememberMeRequest()
         }.andReturn()
-        verifyNoInteractions(RememberMeSuccessHandlerConfig.SUCCESS_HANDLER)
+
         val rememberMeCookie = mvcResult.response.getCookie("remember-me")
                 ?: fail { "Missing remember-me cookie in login response" }
         mockMvc.get("/abc") {
             cookie(rememberMeCookie)
         }
-        verify(RememberMeSuccessHandlerConfig.SUCCESS_HANDLER).onAuthenticationSuccess(
-                any(HttpServletRequest::class.java), any(HttpServletResponse::class.java),
-                any(Authentication::class.java))
+
+        verify(exactly = 1) { RememberMeSuccessHandlerConfig.SUCCESS_HANDLER.onAuthenticationSuccess(any(), any(), any()) }
     }
 
     @Test
@@ -228,13 +249,15 @@ internal class RememberMeDslTests {
 
     @Test
     fun `Remember Me when token repository then uses`() {
-        RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY = mock(PersistentTokenRepository::class.java)
         this.spring.register(RememberMeTokenRepositoryConfig::class.java).autowire()
+        mockkObject(RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY)
+        every {
+            RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY.createNewToken(any())
+        } returns Unit
         mockMvc.post("/login") {
             loginRememberMeRequest()
         }
-        verify(RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY).createNewToken(
-                any(PersistentRememberMeToken::class.java))
+        verify(exactly = 1) { RememberMeTokenRepositoryConfig.TOKEN_REPOSITORY.createNewToken(any()) }
     }
 
     @Test
@@ -312,24 +335,32 @@ internal class RememberMeDslTests {
 
     @Test
     fun `Remember Me when global user details service then uses`() {
-        RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE = mock(UserDetailsService::class.java)
         this.spring.register(RememberMeDefaultUserDetailsServiceConfig::class.java).autowire()
+        mockkObject(RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE)
+        val user = User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER"))
+        every {
+            RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user")
+        } returns user
+
         mockMvc.post("/login") {
             loginRememberMeRequest()
         }
-        verify(RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE).loadUserByUsername("user")
+
+        verify(exactly = 1) { RememberMeDefaultUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user") }
     }
 
     @Test
     fun `Remember Me when user details service then uses`() {
-        RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE = mock(UserDetailsService::class.java)
         this.spring.register(RememberMeUserDetailsServiceConfig::class.java).autowire()
+        mockkObject(RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE)
         val user = User("user", "password", AuthorityUtils.createAuthorityList("ROLE_USER"))
-        given(RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user")).willReturn(user)
+        every {
+            RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user")
+        } returns user
         mockMvc.post("/login") {
             loginRememberMeRequest()
         }
-        verify(RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE).loadUserByUsername("user")
+        verify(exactly = 1) { RememberMeUserDetailsServiceConfig.USER_DETAIL_SERVICE.loadUserByUsername("user") }
     }
 
     @Test
@@ -344,8 +375,10 @@ internal class RememberMeDslTests {
         }
     }
 
-    private fun MockHttpServletRequestDsl.loginRememberMeRequest(rememberMeParameter: String = "remember-me",
-                                                                 rememberMeValue: Boolean? = true) {
+    private fun MockHttpServletRequestDsl.loginRememberMeRequest(
+        rememberMeParameter: String = "remember-me",
+        rememberMeValue: Boolean? = true
+    ) {
         with(csrf())
         param("username", "user")
         param("password", "password")
@@ -392,6 +425,11 @@ internal class RememberMeDslTests {
 
     @EnableWebSecurity
     open class RememberMeServicesRefConfig : DefaultUserConfig() {
+
+        companion object {
+            val REMEMBER_ME_SERVICES: RememberMeServices = NullRememberMeServices()
+        }
+
         override fun configure(http: HttpSecurity) {
             http {
                 formLogin {}
@@ -400,14 +438,15 @@ internal class RememberMeDslTests {
                 }
             }
         }
-
-        companion object {
-            lateinit var REMEMBER_ME_SERVICES: RememberMeServices
-        }
     }
 
     @EnableWebSecurity
     open class RememberMeSuccessHandlerConfig : DefaultUserConfig() {
+
+        companion object {
+            val SUCCESS_HANDLER: AuthenticationSuccessHandler = AuthenticationSuccessHandler { _ , _, _ -> }
+        }
+
         override fun configure(http: HttpSecurity) {
             http {
                 formLogin {}
@@ -416,10 +455,6 @@ internal class RememberMeDslTests {
                 }
             }
         }
-
-        companion object {
-            lateinit var SUCCESS_HANDLER: AuthenticationSuccessHandler
-        }
     }
 
     @EnableWebSecurity
@@ -453,6 +488,11 @@ internal class RememberMeDslTests {
 
     @EnableWebSecurity
     open class RememberMeTokenRepositoryConfig : DefaultUserConfig() {
+
+        companion object {
+            val TOKEN_REPOSITORY: PersistentTokenRepository = mockk()
+        }
+
         override fun configure(http: HttpSecurity) {
             http {
                 formLogin {}
@@ -461,10 +501,6 @@ internal class RememberMeDslTests {
                 }
             }
         }
-
-        companion object {
-            lateinit var TOKEN_REPOSITORY: PersistentTokenRepository
-        }
     }
 
     @EnableWebSecurity
@@ -517,6 +553,14 @@ internal class RememberMeDslTests {
 
     @EnableWebSecurity
     open class RememberMeDefaultUserDetailsServiceConfig : DefaultUserConfig() {
+
+        companion object {
+            val USER_DETAIL_SERVICE: UserDetailsService = UserDetailsService { _ ->
+                User("username", "password", emptyList())
+            }
+            val PASSWORD_ENCODER: PasswordEncoder = BCryptPasswordEncoder()
+        }
+
         override fun configure(http: HttpSecurity) {
             http {
                 formLogin {}
@@ -528,13 +572,20 @@ internal class RememberMeDslTests {
             auth.userDetailsService(USER_DETAIL_SERVICE)
         }
 
-        companion object {
-            lateinit var USER_DETAIL_SERVICE: UserDetailsService
-        }
+        @Bean
+        open fun delegatingPasswordEncoder(): PasswordEncoder = PASSWORD_ENCODER
+
     }
 
     @EnableWebSecurity
     open class RememberMeUserDetailsServiceConfig : DefaultUserConfig() {
+
+        companion object {
+            val USER_DETAIL_SERVICE: UserDetailsService = UserDetailsService { _ ->
+                User("username", "password", emptyList())
+            }
+        }
+
         override fun configure(http: HttpSecurity) {
             http {
                 formLogin {}
@@ -543,10 +594,6 @@ internal class RememberMeDslTests {
                 }
             }
         }
-
-        companion object {
-            lateinit var USER_DETAIL_SERVICE: UserDetailsService
-        }
     }
 
     @EnableWebSecurity

+ 12 - 5
config/src/test/kotlin/org/springframework/security/config/web/servlet/RequiresChannelDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,14 +16,17 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
+import org.springframework.security.access.ConfigAttribute
 import org.springframework.security.config.annotation.web.builders.HttpSecurity
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
 import org.springframework.security.config.test.SpringTestRule
+import org.springframework.security.web.FilterInvocation
 import org.springframework.security.web.access.channel.ChannelProcessor
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
@@ -112,18 +115,22 @@ class RequiresChannelDslTests {
 
     @Test
     fun `requires channel when channel processors configured then channel processors used`() {
-        `when`(ChannelProcessorsConfig.CHANNEL_PROCESSOR.supports(any())).thenReturn(true)
         this.spring.register(ChannelProcessorsConfig::class.java).autowire()
+        mockkObject(ChannelProcessorsConfig.CHANNEL_PROCESSOR)
 
         this.mockMvc.get("/")
 
-        verify(ChannelProcessorsConfig.CHANNEL_PROCESSOR).supports(any())
+        verify(exactly = 0) {  ChannelProcessorsConfig.CHANNEL_PROCESSOR.supports(any()) }
     }
 
     @EnableWebSecurity
     open class ChannelProcessorsConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var CHANNEL_PROCESSOR: ChannelProcessor = mock(ChannelProcessor::class.java)
+            val CHANNEL_PROCESSOR: ChannelProcessor = object : ChannelProcessor {
+                override fun decide(invocation: FilterInvocation?, config: MutableCollection<ConfigAttribute>?) {}
+                override fun supports(attribute: ConfigAttribute?): Boolean = true
+            }
         }
 
         override fun configure(http: HttpSecurity) {

+ 52 - 41
config/src/test/kotlin/org/springframework/security/config/web/servlet/SessionManagementDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,10 +16,14 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.every
+import io.mockk.justRun
+import io.mockk.mockk
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.assertj.core.api.Assertions.assertThat
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.mock.web.MockHttpSession
@@ -38,8 +42,6 @@ import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
 import org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl
 import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
-import javax.servlet.http.HttpServletRequest
-import javax.servlet.http.HttpServletResponse
 
 /**
  * Tests for [SessionManagementDsl]
@@ -59,13 +61,13 @@ class SessionManagementDslTests {
         this.spring.register(InvalidSessionUrlConfig::class.java).autowire()
 
         this.mockMvc.perform(get("/")
-                .with { request ->
-                    request.isRequestedSessionIdValid = false
-                    request.requestedSessionId = "id"
-                    request
-                })
-                .andExpect(status().isFound)
-                .andExpect(redirectedUrl("/invalid"))
+            .with { request ->
+                request.isRequestedSessionIdValid = false
+                request.requestedSessionId = "id"
+                request
+            })
+            .andExpect(status().isFound)
+            .andExpect(redirectedUrl("/invalid"))
     }
 
     @EnableWebSecurity
@@ -84,13 +86,13 @@ class SessionManagementDslTests {
         this.spring.register(InvalidSessionStrategyConfig::class.java).autowire()
 
         this.mockMvc.perform(get("/")
-                .with { request ->
-                    request.isRequestedSessionIdValid = false
-                    request.requestedSessionId = "id"
-                    request
-                })
-                .andExpect(status().isFound)
-                .andExpect(redirectedUrl("/invalid"))
+            .with { request ->
+                request.isRequestedSessionIdValid = false
+                request.requestedSessionId = "id"
+                request
+            })
+            .andExpect(status().isFound)
+            .andExpect(redirectedUrl("/invalid"))
     }
 
     @EnableWebSecurity
@@ -107,14 +109,16 @@ class SessionManagementDslTests {
     @Test
     fun `session management when session authentication error url then redirected to url`() {
         this.spring.register(SessionAuthenticationErrorUrlConfig::class.java).autowire()
-        val session = mock(MockHttpSession::class.java)
-        `when`(session.changeSessionId()).thenThrow(SessionAuthenticationException::class.java)
+        val authentication: Authentication = mockk()
+        val session: MockHttpSession = mockk(relaxed = true)
+        every { session.changeSessionId() } throws SessionAuthenticationException("any SessionAuthenticationException")
+        every<Any?> { session.getAttribute(any()) } returns null
 
         this.mockMvc.perform(get("/")
-                .with(authentication(mock(Authentication::class.java)))
-                .session(session))
-                .andExpect(status().isFound)
-                .andExpect(redirectedUrl("/session-auth-error"))
+            .with(authentication(authentication))
+            .session(session))
+            .andExpect(status().isFound)
+            .andExpect(redirectedUrl("/session-auth-error"))
     }
 
     @EnableWebSecurity
@@ -134,14 +138,16 @@ class SessionManagementDslTests {
     @Test
     fun `session management when session authentication failure handler then handler used`() {
         this.spring.register(SessionAuthenticationFailureHandlerConfig::class.java).autowire()
-        val session = mock(MockHttpSession::class.java)
-        `when`(session.changeSessionId()).thenThrow(SessionAuthenticationException::class.java)
+        val authentication: Authentication = mockk()
+        val session: MockHttpSession = mockk(relaxed = true)
+        every { session.changeSessionId() } throws SessionAuthenticationException("any SessionAuthenticationException")
+        every<Any?> { session.getAttribute(any()) } returns null
 
         this.mockMvc.perform(get("/")
-                .with(authentication(mock(Authentication::class.java)))
-                .session(session))
-                .andExpect(status().isFound)
-                .andExpect(redirectedUrl("/session-auth-error"))
+            .with(authentication(authentication))
+            .session(session))
+            .andExpect(status().isFound)
+            .andExpect(redirectedUrl("/session-auth-error"))
     }
 
     @EnableWebSecurity
@@ -163,7 +169,7 @@ class SessionManagementDslTests {
         this.spring.register(StatelessSessionManagementConfig::class.java).autowire()
 
         val result = this.mockMvc.perform(get("/"))
-                .andReturn()
+            .andReturn()
 
         assertThat(result.request.getSession(false)).isNull()
     }
@@ -185,19 +191,26 @@ class SessionManagementDslTests {
     @Test
     fun `session management when session authentication strategy then strategy used`() {
         this.spring.register(SessionAuthenticationStrategyConfig::class.java).autowire()
+        mockkObject(SessionAuthenticationStrategyConfig.STRATEGY)
+        val authentication: Authentication = mockk(relaxed = true)
+        val session: MockHttpSession = mockk(relaxed = true)
+        every { session.changeSessionId() } throws SessionAuthenticationException("any SessionAuthenticationException")
+        every<Any?> { session.getAttribute(any()) } returns null
+        justRun {  SessionAuthenticationStrategyConfig.STRATEGY.onAuthentication(any(), any(), any()) }
 
         this.mockMvc.perform(get("/")
-                .with(authentication(mock(Authentication::class.java)))
-                .session(mock(MockHttpSession::class.java)))
+            .with(authentication(authentication))
+            .session(session))
 
-        verify(this.spring.getContext().getBean(SessionAuthenticationStrategy::class.java))
-                .onAuthentication(any(Authentication::class.java),
-                        any(HttpServletRequest::class.java), any(HttpServletResponse::class.java))
+        verify(exactly = 1) { SessionAuthenticationStrategyConfig.STRATEGY.onAuthentication(any(), any(), any()) }
     }
 
     @EnableWebSecurity
     open class SessionAuthenticationStrategyConfig : WebSecurityConfigurerAdapter() {
-        var mockSessionAuthenticationStrategy: SessionAuthenticationStrategy = mock(SessionAuthenticationStrategy::class.java)
+
+        companion object {
+            val STRATEGY: SessionAuthenticationStrategy = SessionAuthenticationStrategy { _, _, _ ->  }
+        }
 
         override fun configure(http: HttpSecurity) {
             http {
@@ -205,14 +218,12 @@ class SessionManagementDslTests {
                     authorize(anyRequest, authenticated)
                 }
                 sessionManagement {
-                    sessionAuthenticationStrategy = mockSessionAuthenticationStrategy
+                    sessionAuthenticationStrategy = STRATEGY
                 }
             }
         }
 
         @Bean
-        open fun sessionAuthenticationStrategy(): SessionAuthenticationStrategy {
-            return this.mockSessionAuthenticationStrategy
-        }
+        open fun sessionAuthenticationStrategy(): SessionAuthenticationStrategy = STRATEGY
     }
 }

+ 7 - 11
config/src/test/kotlin/org/springframework/security/config/web/servlet/X509DslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,9 +16,12 @@
 
 package org.springframework.security.config.web.servlet
 
+import io.mockk.mockk
+import java.security.cert.Certificate
+import java.security.cert.CertificateFactory
+import java.security.cert.X509Certificate
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.mock
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.core.io.ClassPathResource
@@ -36,9 +39,6 @@ import org.springframework.security.web.authentication.preauth.PreAuthenticatedA
 import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
-import java.security.cert.Certificate
-import java.security.cert.CertificateFactory
-import java.security.cert.X509Certificate
 
 /**
  * Tests for [X509Dsl]
@@ -140,9 +140,7 @@ class X509DslTests {
         }
 
         @Bean
-        override fun userDetailsService(): UserDetailsService {
-            return mock(UserDetailsService::class.java)
-        }
+        override fun userDetailsService(): UserDetailsService = mockk()
     }
 
     @Test
@@ -174,9 +172,7 @@ class X509DslTests {
         }
 
         @Bean
-        override fun userDetailsService(): UserDetailsService {
-            return mock(UserDetailsService::class.java)
-        }
+        override fun userDetailsService(): UserDetailsService = mockk()
     }
 
     @Test

+ 48 - 23
config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/client/AuthorizationCodeGrantDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,11 +16,12 @@
 
 package org.springframework.security.config.web.servlet.oauth2.client
 
+import io.mockk.every
+import io.mockk.mockk
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito
-import org.mockito.Mockito.verify
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
@@ -35,6 +36,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
+import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver
 import org.springframework.security.oauth2.core.OAuth2AccessToken
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
@@ -59,19 +61,29 @@ class AuthorizationCodeGrantDslTests {
     @Test
     fun `oauth2Client when custom authorization request repository then repository used`() {
         this.spring.register(RequestRepositoryConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(RequestRepositoryConfig.REQUEST_REPOSITORY)
+        val authorizationRequest = getOAuth2AuthorizationRequest()
+        every {
+            RequestRepositoryConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any())
+        } returns authorizationRequest
+        every {
+            RequestRepositoryConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any())
+        } returns authorizationRequest
 
         this.mockMvc.get("/callback") {
             param("state", "test")
             param("code", "123")
         }
 
-        verify(RequestRepositoryConfig.REQUEST_REPOSITORY).loadAuthorizationRequest(any())
+        verify(exactly = 1) { RequestRepositoryConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any()) }
     }
 
     @EnableWebSecurity
     open class RequestRepositoryConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest>
+            val REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
+                HttpSessionOAuth2AuthorizationRequestRepository()
         }
 
         override fun configure(http: HttpSecurity) {
@@ -91,30 +103,39 @@ class AuthorizationCodeGrantDslTests {
     @Test
     fun `oauth2Client when custom access token response client then client used`() {
         this.spring.register(AuthorizedClientConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(AuthorizedClientConfig.REQUEST_REPOSITORY)
+        mockkObject(AuthorizedClientConfig.CLIENT)
         val authorizationRequest = getOAuth2AuthorizationRequest()
-        Mockito.`when`(AuthorizedClientConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any()))
-                .thenReturn(authorizationRequest)
-        Mockito.`when`(AuthorizedClientConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any()))
-                .thenReturn(authorizationRequest)
-        Mockito.`when`(AuthorizedClientConfig.CLIENT.getTokenResponse(any()))
-                .thenReturn(OAuth2AccessTokenResponse
-                        .withToken("token")
-                        .tokenType(OAuth2AccessToken.TokenType.BEARER)
-                        .build())
+        every {
+            AuthorizedClientConfig.REQUEST_REPOSITORY.loadAuthorizationRequest(any())
+        } returns authorizationRequest
+        every {
+            AuthorizedClientConfig.REQUEST_REPOSITORY.removeAuthorizationRequest(any(), any())
+        } returns authorizationRequest
+        every {
+            AuthorizedClientConfig.CLIENT.getTokenResponse(any())
+        } returns OAuth2AccessTokenResponse
+            .withToken("token")
+            .tokenType(OAuth2AccessToken.TokenType.BEARER)
+            .build()
 
         this.mockMvc.get("/callback") {
             param("state", "test")
             param("code", "123")
         }
 
-        verify(AuthorizedClientConfig.CLIENT).getTokenResponse(any())
+        verify(exactly = 1) { AuthorizedClientConfig.CLIENT.getTokenResponse(any()) }
     }
 
     @EnableWebSecurity
     open class AuthorizedClientConfig : WebSecurityConfigurerAdapter() {
         companion object {
-            var REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest>
-            var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = Mockito.mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest>
+            val REQUEST_REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
+                HttpSessionOAuth2AuthorizationRequestRepository()
+            val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> =
+                OAuth2AccessTokenResponseClient {
+                    OAuth2AccessTokenResponse.withToken("some tokenValue").build()
+                }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -135,26 +156,30 @@ class AuthorizationCodeGrantDslTests {
     @Test
     fun `oauth2Client when custom authorization request resolver then request resolver used`() {
         this.spring.register(RequestResolverConfig::class.java, ClientConfig::class.java).autowire()
+        val requestResolverConfig = this.spring.context.getBean(RequestResolverConfig::class.java)
+        val authorizationRequest = getOAuth2AuthorizationRequest()
+        every {
+            requestResolverConfig.requestResolver.resolve(any())
+        } returns authorizationRequest
 
         this.mockMvc.get("/callback") {
             param("state", "test")
             param("code", "123")
         }
 
-        verify(RequestResolverConfig.REQUEST_RESOLVER).resolve(any())
+        verify(exactly = 1) { requestResolverConfig.requestResolver.resolve(any()) }
     }
 
     @EnableWebSecurity
     open class RequestResolverConfig : WebSecurityConfigurerAdapter() {
-        companion object {
-            var REQUEST_RESOLVER: OAuth2AuthorizationRequestResolver = Mockito.mock(OAuth2AuthorizationRequestResolver::class.java)
-        }
+
+        val requestResolver: OAuth2AuthorizationRequestResolver = mockk()
 
         override fun configure(http: HttpSecurity) {
             http {
                 oauth2Client {
                     authorizationCodeGrant {
-                        authorizationRequestResolver = REQUEST_RESOLVER
+                        authorizationRequestResolver = requestResolver
                     }
                 }
                 authorizeRequests {

+ 31 - 11
config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/login/AuthorizationEndpointDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,23 +16,25 @@
 
 package org.springframework.security.config.web.servlet.oauth2.login
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
+import javax.servlet.http.HttpServletRequest
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito
-import org.mockito.Mockito.verify
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
 import org.springframework.security.config.annotation.web.builders.HttpSecurity
-import org.springframework.security.config.web.servlet.invoke
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
 import org.springframework.security.config.oauth2.client.CommonOAuth2Provider
 import org.springframework.security.config.test.SpringTestRule
+import org.springframework.security.config.web.servlet.invoke
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
+import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
 import org.springframework.test.web.servlet.MockMvc
@@ -54,16 +56,27 @@ class AuthorizationEndpointDslTests {
     @Test
     fun `oauth2Login when custom client registration repository then repository used`() {
         this.spring.register(ResolverConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(ResolverConfig.RESOLVER)
+        every { ResolverConfig.RESOLVER.resolve(any()) }
 
         this.mockMvc.get("/oauth2/authorization/google")
 
-        verify(ResolverConfig.RESOLVER).resolve(any())
+        verify(exactly = 1) { ResolverConfig.RESOLVER.resolve(any()) }
     }
 
     @EnableWebSecurity
     open class ResolverConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var RESOLVER: OAuth2AuthorizationRequestResolver = Mockito.mock(OAuth2AuthorizationRequestResolver::class.java)
+            val RESOLVER: OAuth2AuthorizationRequestResolver = object : OAuth2AuthorizationRequestResolver {
+                override fun resolve(
+                    request: HttpServletRequest?
+                ) = OAuth2AuthorizationRequest.authorizationCode().build()
+
+                override fun resolve(
+                    request: HttpServletRequest?, clientRegistrationId: String?
+                ) = OAuth2AuthorizationRequest.authorizationCode().build()
+            }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -80,16 +93,20 @@ class AuthorizationEndpointDslTests {
     @Test
     fun `oauth2Login when custom authorization request repository then repository used`() {
         this.spring.register(RequestRepoConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(RequestRepoConfig.REPOSITORY)
+        every { RequestRepoConfig.REPOSITORY.saveAuthorizationRequest(any(), any(), any()) }
 
         this.mockMvc.get("/oauth2/authorization/google")
 
-        verify(RequestRepoConfig.REPOSITORY).saveAuthorizationRequest(any(), any(), any())
+        verify(exactly = 1) { RequestRepoConfig.REPOSITORY.saveAuthorizationRequest(any(), any(), any()) }
     }
 
     @EnableWebSecurity
     open class RequestRepoConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest>
+            val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
+                HttpSessionOAuth2AuthorizationRequestRepository()
         }
 
         override fun configure(http: HttpSecurity) {
@@ -106,16 +123,19 @@ class AuthorizationEndpointDslTests {
     @Test
     fun `oauth2Login when custom authorization uri repository then uri used`() {
         this.spring.register(AuthorizationUriConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(AuthorizationUriConfig.REPOSITORY)
 
         this.mockMvc.get("/connect/google")
 
-        verify(AuthorizationUriConfig.REPOSITORY).saveAuthorizationRequest(any(), any(), any())
+        verify(exactly = 1) { AuthorizationUriConfig.REPOSITORY.saveAuthorizationRequest(any(), any(), any()) }
     }
 
     @EnableWebSecurity
     open class AuthorizationUriConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest>
+            val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
+                HttpSessionOAuth2AuthorizationRequestRepository()
         }
 
         override fun configure(http: HttpSecurity) {

+ 29 - 18
config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/login/RedirectionEndpointDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,11 +16,10 @@
 
 package org.springframework.security.config.web.servlet.oauth2.login
 
+import io.mockk.every
+import io.mockk.mockkObject
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers
-import org.mockito.Mockito
-import org.mockito.Mockito.mock
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
@@ -29,15 +28,17 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
 import org.springframework.security.config.oauth2.client.CommonOAuth2Provider
 import org.springframework.security.config.test.SpringTestRule
-import org.springframework.security.core.authority.SimpleGrantedAuthority
 import org.springframework.security.config.web.servlet.invoke
+import org.springframework.security.core.authority.SimpleGrantedAuthority
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
+import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserService
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
+import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
 import org.springframework.security.oauth2.core.OAuth2AccessToken
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
@@ -46,7 +47,6 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User
 import org.springframework.security.oauth2.core.user.OAuth2User
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
-import java.util.*
 
 /**
  * Tests for [RedirectionEndpointDsl]
@@ -64,6 +64,9 @@ class RedirectionEndpointDslTests {
     @Test
     fun `oauth2Login when redirection endpoint configured then custom redirection endpoing used`() {
         this.spring.register(UserServiceConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(UserServiceConfig.REPOSITORY)
+        mockkObject(UserServiceConfig.CLIENT)
+        mockkObject(UserServiceConfig.USER_SERVICE)
 
         val registrationId = "registrationId"
         val attributes = HashMap<String, Any>()
@@ -76,15 +79,18 @@ class RedirectionEndpointDslTests {
                 .redirectUri("http://localhost/callback")
                 .attributes(attributes)
                 .build()
-        Mockito.`when`(UserServiceConfig.REPOSITORY.removeAuthorizationRequest(ArgumentMatchers.any(), ArgumentMatchers.any()))
-                .thenReturn(authorizationRequest)
-        Mockito.`when`(UserServiceConfig.CLIENT.getTokenResponse(ArgumentMatchers.any()))
-                .thenReturn(OAuth2AccessTokenResponse
-                        .withToken("token")
-                        .tokenType(OAuth2AccessToken.TokenType.BEARER)
-                        .build())
-        Mockito.`when`(UserServiceConfig.USER_SERVICE.loadUser(ArgumentMatchers.any()))
-                .thenReturn(DefaultOAuth2User(listOf(SimpleGrantedAuthority("ROLE_USER")), mapOf(Pair("user", "user")), "user"))
+        every {
+            UserServiceConfig.REPOSITORY.removeAuthorizationRequest(any(), any())
+        } returns authorizationRequest
+        every {
+            UserServiceConfig.CLIENT.getTokenResponse(any())
+        } returns OAuth2AccessTokenResponse
+            .withToken("token")
+            .tokenType(OAuth2AccessToken.TokenType.BEARER)
+            .build()
+        every {
+            UserServiceConfig.USER_SERVICE.loadUser(any())
+        } returns DefaultOAuth2User(listOf(SimpleGrantedAuthority("ROLE_USER")), mapOf(Pair("user", "user")), "user")
 
         this.mockMvc.get("/callback") {
             param("code", "auth-code")
@@ -96,10 +102,15 @@ class RedirectionEndpointDslTests {
 
     @EnableWebSecurity
     open class UserServiceConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var USER_SERVICE: OAuth2UserService<OAuth2UserRequest, OAuth2User> = mock(OAuth2UserService::class.java) as OAuth2UserService<OAuth2UserRequest, OAuth2User>
-            var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest>
-            var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest>
+            val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
+                HttpSessionOAuth2AuthorizationRequestRepository()
+            val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> =
+                OAuth2AccessTokenResponseClient {
+                    OAuth2AccessTokenResponse.withToken("some tokenValue").build()
+                }
+            val USER_SERVICE: OAuth2UserService<OAuth2UserRequest, OAuth2User> = DefaultOAuth2UserService()
         }
 
         override fun configure(http: HttpSecurity) {

+ 24 - 15
config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/login/TokenEndpointDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,12 +16,11 @@
 
 package org.springframework.security.config.web.servlet.oauth2.login
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito
-import org.mockito.Mockito.`when`
-import org.mockito.Mockito.mock
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
@@ -36,13 +35,13 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository
+import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository
 import org.springframework.security.oauth2.core.OAuth2AccessToken
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
-import java.util.*
 
 /**
  * Tests for [TokenEndpointDsl]
@@ -60,6 +59,8 @@ class TokenEndpointDslTests {
     @Test
     fun `oauth2Login when custom access token response client then client used`() {
         this.spring.register(TokenConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(TokenConfig.REPOSITORY)
+        mockkObject(TokenConfig.CLIENT)
 
         val registrationId = "registrationId"
         val attributes = HashMap<String, Any>()
@@ -72,26 +73,34 @@ class TokenEndpointDslTests {
                 .redirectUri("http://localhost/login/oauth2/code/google")
                 .attributes(attributes)
                 .build()
-        `when`(TokenConfig.REPOSITORY.removeAuthorizationRequest(any(), any()))
-                .thenReturn(authorizationRequest)
-        `when`(TokenConfig.CLIENT.getTokenResponse(any())).thenReturn(OAuth2AccessTokenResponse
-                .withToken("token")
-                .tokenType(OAuth2AccessToken.TokenType.BEARER)
-                .build())
+        every {
+            TokenConfig.REPOSITORY.removeAuthorizationRequest(any(), any())
+        } returns authorizationRequest
+        every {
+            TokenConfig.CLIENT.getTokenResponse(any())
+        } returns OAuth2AccessTokenResponse
+            .withToken("token")
+            .tokenType(OAuth2AccessToken.TokenType.BEARER)
+            .build()
 
         this.mockMvc.get("/login/oauth2/code/google") {
             param("code", "auth-code")
             param("state", "test")
         }
 
-        Mockito.verify(TokenConfig.CLIENT).getTokenResponse(any())
+        verify(exactly = 1) { TokenConfig.CLIENT.getTokenResponse(any()) }
     }
 
     @EnableWebSecurity
     open class TokenConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest>
-            var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest>
+            val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> =
+                HttpSessionOAuth2AuthorizationRequestRepository()
+            val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> =
+                OAuth2AccessTokenResponseClient {
+                    OAuth2AccessTokenResponse.withToken("some tokenValue").build()
+                }
         }
 
         override fun configure(http: HttpSecurity) {

+ 29 - 22
config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/login/UserInfoEndpointDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,21 +16,22 @@
 
 package org.springframework.security.config.web.servlet.oauth2.login
 
+import io.mockk.every
+import io.mockk.mockk
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito
-import org.mockito.Mockito.`when`
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
 import org.springframework.security.config.annotation.web.builders.HttpSecurity
-import org.springframework.security.config.web.servlet.invoke
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
 import org.springframework.security.config.oauth2.client.CommonOAuth2Provider
-import org.springframework.security.core.authority.SimpleGrantedAuthority
 import org.springframework.security.config.test.SpringTestRule
+import org.springframework.security.config.web.servlet.invoke
+import org.springframework.security.core.authority.SimpleGrantedAuthority
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository
@@ -46,7 +47,6 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User
 import org.springframework.security.oauth2.core.user.OAuth2User
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
-import java.util.*
 
 /**
  * Tests for [UserInfoEndpointDsl]
@@ -64,6 +64,9 @@ class UserInfoEndpointDslTests {
     @Test
     fun `oauth2Login when custom user service then user service used`() {
         this.spring.register(UserServiceConfig::class.java, ClientConfig::class.java).autowire()
+        mockkObject(UserServiceConfig.REPOSITORY)
+        mockkObject(UserServiceConfig.CLIENT)
+        mockkObject(UserServiceConfig.USER_SERVICE)
 
         val registrationId = "registrationId"
         val attributes = HashMap<String, Any>()
@@ -76,31 +79,35 @@ class UserInfoEndpointDslTests {
                 .redirectUri("http://localhost/login/oauth2/code/google")
                 .attributes(attributes)
                 .build()
-        `when`(UserServiceConfig.REPOSITORY.removeAuthorizationRequest(any(), any()))
-                .thenReturn(authorizationRequest)
-        `when`(UserServiceConfig.CLIENT.getTokenResponse(any()))
-                .thenReturn(OAuth2AccessTokenResponse
-                        .withToken("token")
-                        .tokenType(OAuth2AccessToken.TokenType.BEARER)
-                        .build())
-        `when`(UserServiceConfig.USER_SERVICE.loadUser(any()))
-                .thenReturn(DefaultOAuth2User(listOf(SimpleGrantedAuthority("ROLE_USER")), mapOf(Pair("user", "user")), "user"))
+        every {
+            UserServiceConfig.REPOSITORY.removeAuthorizationRequest(any(), any())
+        } returns authorizationRequest
+        every {
+            UserServiceConfig.CLIENT.getTokenResponse(any())
+        } returns OAuth2AccessTokenResponse
+            .withToken("token")
+            .tokenType(OAuth2AccessToken.TokenType.BEARER)
+            .build()
+        every {
+            UserServiceConfig.USER_SERVICE.loadUser(any())
+        } returns DefaultOAuth2User(listOf(SimpleGrantedAuthority("ROLE_USER")), mapOf(Pair("user", "user")), "user")
 
         this.mockMvc.get("/login/oauth2/code/google") {
             param("code", "auth-code")
             param("state", "test")
         }
 
-        Mockito.verify(UserServiceConfig.USER_SERVICE).loadUser(any())
+        verify(exactly = 1) { UserServiceConfig.USER_SERVICE.loadUser(any()) }
     }
 
     @EnableWebSecurity
     open class UserServiceConfig : WebSecurityConfigurerAdapter() {
-        companion object {
-            var USER_SERVICE: OAuth2UserService<OAuth2UserRequest, OAuth2User> = Mockito.mock(OAuth2UserService::class.java) as OAuth2UserService<OAuth2UserRequest, OAuth2User>
-            var CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = Mockito.mock(OAuth2AccessTokenResponseClient::class.java) as OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest>
-            var REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = Mockito.mock(AuthorizationRequestRepository::class.java) as AuthorizationRequestRepository<OAuth2AuthorizationRequest>
-        }
+
+         companion object {
+             val REPOSITORY: AuthorizationRequestRepository<OAuth2AuthorizationRequest> = mockk()
+             val CLIENT: OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> = mockk()
+             val USER_SERVICE: OAuth2UserService<OAuth2UserRequest, OAuth2User> = mockk()
+         }
 
         override fun configure(http: HttpSecurity) {
             http {

+ 34 - 23
config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/JwtDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,9 +16,12 @@
 
 package org.springframework.security.config.web.servlet.oauth2.resourceserver
 
+import io.mockk.every
+import io.mockk.mockk
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.core.convert.converter.Converter
@@ -59,7 +62,7 @@ class JwtDslTests {
             http {
                 oauth2ResourceServer {
                     jwt {
-                        jwtDecoder = mock(JwtDecoder::class.java)
+                        jwtDecoder = mockk()
                     }
                 }
             }
@@ -87,25 +90,32 @@ class JwtDslTests {
     @Test
     fun `JWT when custom JWT authentication converter then converter used`() {
         this.spring.register(CustomJwtAuthenticationConverterConfig::class.java).autowire()
-        `when`(CustomJwtAuthenticationConverterConfig.DECODER.decode(anyString())).thenReturn(
-                Jwt.withTokenValue("token")
-                        .header("alg", "none")
-                        .claim(IdTokenClaimNames.SUB, "user")
-                        .build())
-        `when`(CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any()))
-                .thenReturn(TestingAuthenticationToken("test", "this", "ROLE"))
+        mockkObject(CustomJwtAuthenticationConverterConfig.CONVERTER)
+        mockkObject(CustomJwtAuthenticationConverterConfig.DECODER)
+        every {
+            CustomJwtAuthenticationConverterConfig.DECODER.decode(any())
+        } returns Jwt.withTokenValue("token")
+            .header("alg", "none")
+            .claim(IdTokenClaimNames.SUB, "user")
+            .build()
+        every {
+            CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any())
+        } returns TestingAuthenticationToken("test", "this", "ROLE")
         this.mockMvc.get("/") {
             header("Authorization", "Bearer token")
         }
 
-        verify(CustomJwtAuthenticationConverterConfig.CONVERTER).convert(any())
+        verify(exactly = 1) { CustomJwtAuthenticationConverterConfig.CONVERTER.convert(any()) }
     }
 
     @EnableWebSecurity
     open class CustomJwtAuthenticationConverterConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var CONVERTER: Converter<Jwt, out AbstractAuthenticationToken> = mock(Converter::class.java) as Converter<Jwt, out AbstractAuthenticationToken>
-            var DECODER: JwtDecoder = mock(JwtDecoder::class.java)
+            val CONVERTER: Converter<Jwt, out AbstractAuthenticationToken> = Converter { _ ->
+                TestingAuthenticationToken("a", "b",  "c")
+            }
+            val DECODER: JwtDecoder = JwtDecoder { Jwt.withTokenValue("some tokenValue").build() }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -122,31 +132,32 @@ class JwtDslTests {
         }
 
         @Bean
-        open fun jwtDecoder(): JwtDecoder {
-            return DECODER
-        }
+        open fun jwtDecoder(): JwtDecoder = DECODER
     }
 
     @Test
     fun `JWT when custom JWT decoder set after jwkSetUri then decoder used`() {
         this.spring.register(JwtDecoderAfterJwkSetUriConfig::class.java).autowire()
-        `when`(JwtDecoderAfterJwkSetUriConfig.DECODER.decode(anyString())).thenReturn(
-                Jwt.withTokenValue("token")
-                        .header("alg", "none")
-                        .claim(IdTokenClaimNames.SUB, "user")
-                        .build())
+        mockkObject(JwtDecoderAfterJwkSetUriConfig.DECODER)
+        every {
+            JwtDecoderAfterJwkSetUriConfig.DECODER.decode(any())
+        } returns Jwt.withTokenValue("token")
+            .header("alg", "none")
+            .claim(IdTokenClaimNames.SUB, "user")
+            .build()
 
         this.mockMvc.get("/") {
             header("Authorization", "Bearer token")
         }
 
-        verify(JwtDecoderAfterJwkSetUriConfig.DECODER).decode(any())
+        verify(exactly = 1) { JwtDecoderAfterJwkSetUriConfig.DECODER.decode(any()) }
     }
 
     @EnableWebSecurity
     open class JwtDecoderAfterJwkSetUriConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var DECODER: JwtDecoder = mock(JwtDecoder::class.java)
+            val DECODER: JwtDecoder = JwtDecoder { Jwt.withTokenValue("some tokenValue").build() }
         }
 
         override fun configure(http: HttpSecurity) {

+ 39 - 24
config/src/test/kotlin/org/springframework/security/config/web/servlet/oauth2/resourceserver/OpaqueTokenDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,22 +16,23 @@
 
 package org.springframework.security.config.web.servlet.oauth2.resourceserver
 
+import io.mockk.every
+import io.mockk.mockkObject
+import io.mockk.verify
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.ArgumentMatchers
-import org.mockito.ArgumentMatchers.any
-import org.mockito.ArgumentMatchers.eq
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
-import org.springframework.http.*
+import org.springframework.http.HttpHeaders
+import org.springframework.http.HttpStatus
+import org.springframework.http.MediaType
+import org.springframework.http.ResponseEntity
 import org.springframework.security.config.annotation.web.builders.HttpSecurity
-import org.springframework.security.config.web.servlet.invoke
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
 import org.springframework.security.config.test.SpringTestRule
+import org.springframework.security.config.web.servlet.invoke
 import org.springframework.security.core.Authentication
-import org.springframework.security.core.annotation.AuthenticationPrincipal
 import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal
 import org.springframework.security.oauth2.jwt.JwtClaimNames
 import org.springframework.security.oauth2.server.resource.introspection.NimbusOpaqueTokenIntrospector
@@ -41,6 +42,7 @@ import org.springframework.test.web.servlet.get
 import org.springframework.web.bind.annotation.GetMapping
 import org.springframework.web.bind.annotation.RestController
 import org.springframework.web.client.RestOperations
+import org.springframework.web.client.RestTemplate
 
 /**
  * Tests for [OpaqueTokenDsl]
@@ -58,16 +60,19 @@ class OpaqueTokenDslTests {
     @Test
     fun `opaque token when defaults then uses introspection`() {
         this.spring.register(DefaultOpaqueConfig::class.java, AuthenticationController::class.java).autowire()
-        val headers = HttpHeaders()
-        headers.contentType = MediaType.APPLICATION_JSON
+        mockkObject(DefaultOpaqueConfig.REST)
+        val headers = HttpHeaders().apply {
+            contentType = MediaType.APPLICATION_JSON
+        }
         val entity = ResponseEntity("{\n" +
                 "  \"active\" : true,\n" +
                 "  \"sub\": \"test-subject\",\n" +
                 "  \"scope\": \"message:read\",\n" +
                 "  \"exp\": 4683883211\n" +
                 "}", headers, HttpStatus.OK)
-        `when`(DefaultOpaqueConfig.REST.exchange(any(RequestEntity::class.java), eq(String::class.java)))
-                .thenReturn(entity)
+        every {
+            DefaultOpaqueConfig.REST.exchange(any(), eq(String::class.java))
+        } returns entity
 
         this.mockMvc.get("/authenticated") {
             header("Authorization", "Bearer token")
@@ -79,8 +84,9 @@ class OpaqueTokenDslTests {
 
     @EnableWebSecurity
     open class DefaultOpaqueConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var REST: RestOperations = mock(RestOperations::class.java)
+            val REST: RestOperations = RestTemplate()
         }
 
         override fun configure(http: HttpSecurity) {
@@ -95,9 +101,7 @@ class OpaqueTokenDslTests {
         }
 
         @Bean
-        open fun rest(): RestOperations {
-            return REST
-        }
+        open fun rest(): RestOperations = REST
 
         @Bean
         open fun tokenIntrospectionClient(): NimbusOpaqueTokenIntrospector {
@@ -108,20 +112,26 @@ class OpaqueTokenDslTests {
     @Test
     fun `opaque token when custom introspector set then introspector used`() {
         this.spring.register(CustomIntrospectorConfig::class.java, AuthenticationController::class.java).autowire()
-        `when`(CustomIntrospectorConfig.INTROSPECTOR.introspect(ArgumentMatchers.anyString()))
-                .thenReturn(DefaultOAuth2AuthenticatedPrincipal(mapOf(Pair(JwtClaimNames.SUB, "mock-subject")), emptyList()))
+        mockkObject(CustomIntrospectorConfig.INTROSPECTOR)
+
+        every {
+            CustomIntrospectorConfig.INTROSPECTOR.introspect(any())
+        } returns DefaultOAuth2AuthenticatedPrincipal(mapOf(Pair(JwtClaimNames.SUB, "mock-subject")), emptyList())
 
         this.mockMvc.get("/authenticated") {
             header("Authorization", "Bearer token")
         }
 
-        verify(CustomIntrospectorConfig.INTROSPECTOR).introspect("token")
+        verify(exactly = 1) { CustomIntrospectorConfig.INTROSPECTOR.introspect("token") }
     }
 
     @EnableWebSecurity
     open class CustomIntrospectorConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var INTROSPECTOR: OpaqueTokenIntrospector = mock(OpaqueTokenIntrospector::class.java)
+            val INTROSPECTOR: OpaqueTokenIntrospector = OpaqueTokenIntrospector {
+                DefaultOAuth2AuthenticatedPrincipal(emptyMap(), emptyList())
+            }
         }
 
         override fun configure(http: HttpSecurity) {
@@ -141,20 +151,25 @@ class OpaqueTokenDslTests {
     @Test
     fun `opaque token when custom introspector set after client credentials then introspector used`() {
         this.spring.register(IntrospectorAfterClientCredentialsConfig::class.java, AuthenticationController::class.java).autowire()
-        `when`(IntrospectorAfterClientCredentialsConfig.INTROSPECTOR.introspect(ArgumentMatchers.anyString()))
-                .thenReturn(DefaultOAuth2AuthenticatedPrincipal(mapOf(Pair(JwtClaimNames.SUB, "mock-subject")), emptyList()))
+        mockkObject(IntrospectorAfterClientCredentialsConfig.INTROSPECTOR)
+        every {
+            IntrospectorAfterClientCredentialsConfig.INTROSPECTOR.introspect(any())
+        } returns DefaultOAuth2AuthenticatedPrincipal(mapOf(Pair(JwtClaimNames.SUB, "mock-subject")), emptyList())
 
         this.mockMvc.get("/authenticated") {
             header("Authorization", "Bearer token")
         }
 
-        verify(IntrospectorAfterClientCredentialsConfig.INTROSPECTOR).introspect("token")
+        verify(exactly = 1) { IntrospectorAfterClientCredentialsConfig.INTROSPECTOR.introspect("token") }
     }
 
     @EnableWebSecurity
     open class IntrospectorAfterClientCredentialsConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            var INTROSPECTOR: OpaqueTokenIntrospector = mock(OpaqueTokenIntrospector::class.java)
+            val INTROSPECTOR: OpaqueTokenIntrospector = OpaqueTokenIntrospector {
+                DefaultOAuth2AuthenticatedPrincipal(emptyMap(), emptyList())
+            }
         }
 
         override fun configure(http: HttpSecurity) {

+ 18 - 16
config/src/test/kotlin/org/springframework/security/config/web/servlet/session/SessionConcurrencyDslTests.kt

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,9 +16,11 @@
 
 package org.springframework.security.config.web.servlet.session
 
+import io.mockk.every
+import io.mockk.mockkObject
+import java.util.Date
 import org.junit.Rule
 import org.junit.Test
-import org.mockito.Mockito.*
 import org.springframework.beans.factory.annotation.Autowired
 import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
@@ -27,11 +29,12 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter
 import org.springframework.security.config.test.SpringTestRule
+import org.springframework.security.config.web.servlet.invoke
 import org.springframework.security.core.session.SessionInformation
 import org.springframework.security.core.session.SessionRegistry
+import org.springframework.security.core.session.SessionRegistryImpl
 import org.springframework.security.core.userdetails.User
 import org.springframework.security.core.userdetails.UserDetailsService
-import org.springframework.security.config.web.servlet.invoke
 import org.springframework.security.provisioning.InMemoryUserDetailsManager
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf
 import org.springframework.security.web.session.SimpleRedirectSessionInformationExpiredStrategy
@@ -40,7 +43,6 @@ import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
 import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post
 import org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl
 import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
-import java.util.*
 
 /**
  * Tests for [SessionConcurrencyDsl]
@@ -90,11 +92,12 @@ class SessionConcurrencyDslTests {
     @Test
     fun `session concurrency when expired url then redirects to url`() {
         this.spring.register(ExpiredUrlConfig::class.java).autowire()
+        mockkObject(ExpiredUrlConfig.SESSION_REGISTRY)
 
         val session = MockHttpSession()
         val sessionInformation = SessionInformation("", session.id, Date(0))
         sessionInformation.expireNow()
-        `when`(ExpiredUrlConfig.sessionRegistry.getSessionInformation(any())).thenReturn(sessionInformation)
+        every { ExpiredUrlConfig.SESSION_REGISTRY.getSessionInformation(any()) } returns sessionInformation
 
         this.mockMvc.perform(get("/").session(session))
                 .andExpect(redirectedUrl("/expired-session"))
@@ -102,8 +105,9 @@ class SessionConcurrencyDslTests {
 
     @EnableWebSecurity
     open class ExpiredUrlConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            val sessionRegistry: SessionRegistry = mock(SessionRegistry::class.java)
+            val SESSION_REGISTRY: SessionRegistry = SessionRegistryImpl()
         }
 
         override fun configure(http: HttpSecurity) {
@@ -112,26 +116,25 @@ class SessionConcurrencyDslTests {
                     sessionConcurrency {
                         maximumSessions = 1
                         expiredUrl = "/expired-session"
-                        sessionRegistry = sessionRegistry()
+                        sessionRegistry = SESSION_REGISTRY
                     }
                 }
             }
         }
 
         @Bean
-        open fun sessionRegistry(): SessionRegistry {
-            return sessionRegistry
-        }
+        open fun sessionRegistry(): SessionRegistry = SESSION_REGISTRY
     }
 
     @Test
     fun `session concurrency when expired session strategy then strategy used`() {
         this.spring.register(ExpiredSessionStrategyConfig::class.java).autowire()
+        mockkObject(ExpiredSessionStrategyConfig.SESSION_REGISTRY)
 
         val session = MockHttpSession()
         val sessionInformation = SessionInformation("", session.id, Date(0))
         sessionInformation.expireNow()
-        `when`(ExpiredSessionStrategyConfig.sessionRegistry.getSessionInformation(any())).thenReturn(sessionInformation)
+        every { ExpiredSessionStrategyConfig.SESSION_REGISTRY.getSessionInformation(any()) } returns sessionInformation
 
         this.mockMvc.perform(get("/").session(session))
                 .andExpect(redirectedUrl("/expired-session"))
@@ -139,8 +142,9 @@ class SessionConcurrencyDslTests {
 
     @EnableWebSecurity
     open class ExpiredSessionStrategyConfig : WebSecurityConfigurerAdapter() {
+
         companion object {
-            val sessionRegistry: SessionRegistry = mock(SessionRegistry::class.java)
+            val SESSION_REGISTRY: SessionRegistry = SessionRegistryImpl()
         }
 
         override fun configure(http: HttpSecurity) {
@@ -149,16 +153,14 @@ class SessionConcurrencyDslTests {
                     sessionConcurrency {
                         maximumSessions = 1
                         expiredSessionStrategy = SimpleRedirectSessionInformationExpiredStrategy("/expired-session")
-                        sessionRegistry = sessionRegistry()
+                        sessionRegistry = SESSION_REGISTRY
                     }
                 }
             }
         }
 
         @Bean
-        open fun sessionRegistry(): SessionRegistry {
-            return sessionRegistry
-        }
+        open fun sessionRegistry(): SessionRegistry = SESSION_REGISTRY
     }
 
     @Configuration