diff --git a/src/main/java/stirling/software/SPDF/config/UserBasedRateLimitingFilter.java b/src/main/java/stirling/software/SPDF/config/UserBasedRateLimitingFilter.java index 441cd080..2695b72e 100644 --- a/src/main/java/stirling/software/SPDF/config/UserBasedRateLimitingFilter.java +++ b/src/main/java/stirling/software/SPDF/config/UserBasedRateLimitingFilter.java @@ -16,6 +16,7 @@ import org.springframework.web.filter.OncePerRequestFilter; import io.github.bucket4j.Bandwidth; import io.github.bucket4j.Bucket; import io.github.bucket4j.Bucket4j; +import io.github.bucket4j.ConsumptionProbe; import io.github.bucket4j.Refill; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -26,43 +27,52 @@ public class UserBasedRateLimitingFilter extends OncePerRequestFilter { private final Map buckets = new ConcurrentHashMap<>(); + @Autowired + private UserDetailsService userDetailsService; @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - String method = request.getMethod(); - if (!"POST".equalsIgnoreCase(method)) { - filterChain.doFilter(request, response); - return; - } - - String identifier; - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - - if (authentication != null && authentication.isAuthenticated()) { - UserDetails userDetails = (UserDetails) authentication.getPrincipal(); - identifier = userDetails.getUsername(); - } else { - identifier = request.getRemoteAddr(); // Use IP as identifier if not authenticated - } - - Bucket userBucket = buckets.computeIfAbsent(identifier, k -> createUserBucket()); - - if (userBucket.tryConsume(1)) { - filterChain.doFilter(request, response); - } else { - response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value()); - response.getWriter().write("Rate limit exceeded."); - return; - } + String method = request.getMethod(); + + if (!"POST".equalsIgnoreCase(method)) { + // If the request is not a POST, just pass it through without rate limiting + filterChain.doFilter(request, response); + return; } -//https://www.baeldung.com/spring-bucket4j - private Bucket createUserBucket() { - Refill refill = Refill.of(3, Duration.ofDays(1)); - Bandwidth limit = Bandwidth.classic(3, refill).withInitialTokens(3); - return Bucket4j.builder().addLimit(limit).build(); + + String identifier; + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + + if (authentication != null && authentication.isAuthenticated()) { + UserDetails userDetails = (UserDetails) authentication.getPrincipal(); + identifier = userDetails.getUsername(); + } else { + identifier = request.getRemoteAddr(); // Use IP as identifier if not authenticated + } + + Bucket userBucket = buckets.computeIfAbsent(identifier, k -> createUserBucket()); + ConsumptionProbe probe = userBucket.tryConsumeAndReturnRemaining(1); + + if (probe.isConsumed()) { + response.setHeader("X-Rate-Limit-Remaining", Long.toString(probe.getRemainingTokens())); + filterChain.doFilter(request, response); + } else { + long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000; + response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value()); + response.setHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill)); + response.getWriter().write("Rate limit exceeded for POST requests."); + return; } } +private Bucket createUserBucket() { + Bandwidth limit = Bandwidth.classic(1000, Refill.intervally(1000, Duration.ofDays(1))); + return Bucket.builder().addLimit(limit).build(); +} + +} + + diff --git a/src/main/java/stirling/software/SPDF/config/security/InitialSetup.java b/src/main/java/stirling/software/SPDF/config/security/InitialSetup.java index 64923dc7..97ae9373 100644 --- a/src/main/java/stirling/software/SPDF/config/security/InitialSetup.java +++ b/src/main/java/stirling/software/SPDF/config/security/InitialSetup.java @@ -18,9 +18,9 @@ public class InitialSetup { String initialUsername = System.getenv("INITIAL_USERNAME"); String initialPassword = System.getenv("INITIAL_PASSWORD"); if(initialUsername != null && initialPassword != null) { - userService.saveUser(initialUsername, initialPassword, Role.ADMIN); + userService.saveUser(initialUsername, initialPassword, Role.ADMIN.getRoleId()); } else { - userService.saveUser("admin", "password", Role.ADMIN); + userService.saveUser("admin", "password", Role.ADMIN.getRoleId()); } } } diff --git a/src/main/java/stirling/software/SPDF/config/security/UserService.java b/src/main/java/stirling/software/SPDF/config/security/UserService.java index 4965f52f..15f7d48c 100644 --- a/src/main/java/stirling/software/SPDF/config/security/UserService.java +++ b/src/main/java/stirling/software/SPDF/config/security/UserService.java @@ -1,10 +1,20 @@ package stirling.software.SPDF.config.security; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import java.util.Map; import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.Collection; import java.util.HashMap; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.stereotype.Service; @@ -21,6 +31,68 @@ public class UserService { @Autowired private PasswordEncoder passwordEncoder; + public Authentication getAuthentication(String apiKey) { + User user = getUserByApiKey(apiKey); + if (user == null) { + throw new UsernameNotFoundException("API key is not valid"); + } + + // Convert the user into an Authentication object + return new UsernamePasswordAuthenticationToken( + user, // principal (typically the user) + null, // credentials (we don't expose the password or API key here) + getAuthorities(user) // user's authorities (roles/permissions) + ); + } + + private Collection getAuthorities(User user) { + // Convert each Authority object into a SimpleGrantedAuthority object. + return user.getAuthorities().stream() + .map((Authority authority) -> new SimpleGrantedAuthority(authority.getAuthority())) + .collect(Collectors.toList()); + + + } + + private String generateApiKey() { + String apiKey; + do { + apiKey = UUID.randomUUID().toString(); + } while (userRepository.findByApiKey(apiKey) != null); // Ensure uniqueness + return apiKey; + } + + public User addApiKeyToUser(String username) { + User user = userRepository.findByUsername(username) + .orElseThrow(() -> new UsernameNotFoundException("User not found")); + + user.setApiKey(generateApiKey()); + return userRepository.save(user); + } + + public User refreshApiKeyForUser(String username) { + return addApiKeyToUser(username); // reuse the add API key method for refreshing + } + + public String getApiKeyForUser(String username) { + User user = userRepository.findByUsername(username) + .orElseThrow(() -> new UsernameNotFoundException("User not found")); + return user.getApiKey(); + } + + public boolean isValidApiKey(String apiKey) { + return userRepository.findByApiKey(apiKey) != null; + } + + public User getUserByApiKey(String apiKey) { + return userRepository.findByApiKey(apiKey); + } + + public boolean validateApiKeyForUser(String username, String apiKey) { + Optional userOpt = userRepository.findByUsername(username); + return userOpt.isPresent() && userOpt.get().getApiKey().equals(apiKey); + } + public void saveUser(String username, String password) { User user = new User(); user.setUsername(username); diff --git a/src/main/java/stirling/software/SPDF/model/Role.java b/src/main/java/stirling/software/SPDF/model/Role.java index 2d28176f..fc0d1e9b 100644 --- a/src/main/java/stirling/software/SPDF/model/Role.java +++ b/src/main/java/stirling/software/SPDF/model/Role.java @@ -1,10 +1,42 @@ package stirling.software.SPDF.model; -public final class Role { +public enum Role { - public static final String ADMIN = "ROLE_ADMIN"; - public static final String USER = "ROLE_USER"; - public static final String LIMITED_API_USER = "ROLE_LIMITED_API_USER"; - public static final String WEB_ONLY_USER = "ROLE_WEB_ONLY_USER"; + // Unlimited access + ADMIN("ROLE_ADMIN", Integer.MAX_VALUE, Integer.MAX_VALUE), + + // Unlimited access + USER("ROLE_USER", Integer.MAX_VALUE, Integer.MAX_VALUE), + + // 40 API calls Per Day, 40 web calls + LIMITED_API_USER("ROLE_LIMITED_API_USER", 40, 40), + + // 20 API calls Per Day, 20 web calls + EXTRA_LIMITED_API_USER("ROLE_EXTRA_LIMITED_API_USER", 20, 20), + + // 0 API calls per day and 20 web calls + WEB_ONLY_USER("ROLE_WEB_ONLY_USER", 0, 20); + + private final String roleId; + private final int apiCallsPerDay; + private final int webCallsPerDay; + + Role(String roleId, int apiCallsPerDay, int webCallsPerDay) { + this.roleId = roleId; + this.apiCallsPerDay = apiCallsPerDay; + this.webCallsPerDay = webCallsPerDay; + } + + public String getRoleId() { + return roleId; + } + + public int getApiCallsPerDay() { + return apiCallsPerDay; + } + + public int getWebCallsPerDay() { + return webCallsPerDay; + } diff --git a/src/main/java/stirling/software/SPDF/model/User.java b/src/main/java/stirling/software/SPDF/model/User.java index 6b085d0c..40d71da1 100644 --- a/src/main/java/stirling/software/SPDF/model/User.java +++ b/src/main/java/stirling/software/SPDF/model/User.java @@ -28,6 +28,9 @@ public class User { @Column(name = "password") private String password; + @Column(name = "apiKey") + private String apiKey; + @Column(name = "enabled") private boolean enabled; @@ -42,6 +45,14 @@ public class User { + public String getApiKey() { + return apiKey; + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + public Map getSettings() { return settings; } diff --git a/src/main/java/stirling/software/SPDF/repository/UserRepository.java b/src/main/java/stirling/software/SPDF/repository/UserRepository.java index 064df341..744953d7 100644 --- a/src/main/java/stirling/software/SPDF/repository/UserRepository.java +++ b/src/main/java/stirling/software/SPDF/repository/UserRepository.java @@ -8,5 +8,6 @@ import stirling.software.SPDF.model.User; public interface UserRepository extends JpaRepository { Optional findByUsername(String username); + User findByApiKey(String apiKey); }