package be.jeffcheasey88.peeratcode.framework; import java.lang.reflect.Method; import java.net.ServerSocket; import java.net.Socket; import java.security.MessageDigest; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.net.ssl.SSLServerSocket; import javax.net.ssl.SSLServerSocketFactory; import org.jose4j.jwk.RsaJsonWebKey; import org.jose4j.jwk.RsaJwkGenerator; import org.jose4j.jws.AlgorithmIdentifiers; import org.jose4j.jws.JsonWebSignature; import org.jose4j.jwt.JwtClaims; import org.jose4j.lang.JoseException; public class Router{ public static void main(String[] args){} private Map> responses; private Map patterns; private Response noFileFound; private RsaJsonWebKey rsaJsonWebKey; private String token_issuer; private int token_expiration; public Router(String token_issuer, int token_expiration) throws Exception{ this.token_issuer = token_issuer; this.token_expiration = token_expiration; this.responses = new HashMap<>(); for(RequestType type : RequestType.values()) this.responses.put(type, new HashMap<>()); this.patterns = new HashMap<>(); this.rsaJsonWebKey = RsaJwkGenerator.generateJwk(2048); } public void listen(int port, boolean ssl) throws Exception{ if (ssl) { // Not needed with the use of a proxy SSLServerSocket server = null; try { SSLServerSocketFactory ssf = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); server = (SSLServerSocket) ssf.createServerSocket(port); while (!server.isClosed()) { Socket socket = server.accept(); Client client = new Client(socket, this); client.start(); } } catch (Exception e) { e.printStackTrace(); } finally { if (server != null) { server.close(); } } } else { try (ServerSocket server = new ServerSocket(port)) { while (!server.isClosed()) { Socket socket = server.accept(); Client client = new Client(socket, this); client.start(); } } catch (Exception e) { e.printStackTrace(); } } } public void register(Response response){ try{ Method method = response.getClass().getDeclaredMethod("exec", Response.class.getDeclaredMethods()[0].getParameterTypes()); Route route = method.getAnnotation(Route.class); this.responses.get(route.type()).put(response, route); this.patterns.put(response, Pattern.compile(route.path())); }catch(Exception e){ throw new IllegalArgumentException(e); } } public void setDefault(Response response){ this.noFileFound = response; } public void exec(RequestType type, String path, User user, HttpReader reader, HttpWriter writer) throws Exception{ if(type == null) return; for(Entry routes : this.responses.get(type).entrySet()){ Matcher matcher = this.patterns.get(routes.getKey()).matcher(path); if(matcher.matches()){ if(user == null && routes.getValue().needLogin()){ writer.response(401, "Access-Control-Allow-Origin: *"); return; } if(routes.getValue().websocket()){ switchToWebSocket(reader, writer); reader = new WebSocketReader(reader); writer = new WebSocketWriter(writer); } routes.getKey().exec(matcher, user, reader, writer); return; } } if(noFileFound != null) noFileFound.exec(null, user, reader, writer); } public RsaJsonWebKey getWebKey(){ return this.rsaJsonWebKey; } public String getTokenIssuer(){ return this.token_issuer; } public void configureSSL(String keyStore, String keyStorePassword){ System.setProperty("javax.net.ssl.keyStore", keyStore); System.setProperty("javax.net.ssl.keyStorePassword", keyStorePassword); } public String createAuthUser(int id) throws JoseException{ JwtClaims claims = new JwtClaims(); claims.setIssuer(token_issuer); // who creates the token and signs it claims.setExpirationTimeMinutesInTheFuture(token_expiration); claims.setGeneratedJwtId(); // a unique identifier for the token claims.setIssuedAtToNow(); // when the token was issued/created (now) claims.setNotBeforeMinutesInThePast(2); // time before which the token is not yet valid (2 minutes ago) claims.setClaim("id", id); JsonWebSignature jws = new JsonWebSignature(); jws.setPayload(claims.toJson()); jws.setKey(rsaJsonWebKey.getPrivateKey()); jws.setKeyIdHeaderValue(rsaJsonWebKey.getKeyId()); jws.setAlgorithmHeaderValue(AlgorithmIdentifiers.RSA_USING_SHA256); return jws.getCompactSerialization(); } private void switchToWebSocket(HttpReader reader, HttpWriter writer) throws Exception{ String key = reader.getHeader("Sec-WebSocket-Key"); if (key == null) throw new IllegalArgumentException(); writer.write("HTTP/1.1 101 Switching Protocols\n"); writer.write("Connection: Upgrade\n"); writer.write("Upgrade: websocket\n"); writer.write("Sec-WebSocket-Accept: " + printBase64Binary(MessageDigest.getInstance("SHA-1") .digest((key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").getBytes("UTF-8"))) + "\n"); writer.write("\n"); writer.flush(); } // From javax.xml.bind.DatatypeConverter private String printBase64Binary(byte[] array){ char[] arrayOfChar = new char[(array.length + 2) / 3 * 4]; int i = _printBase64Binary(array, 0, array.length, arrayOfChar, 0); assert i == arrayOfChar.length; return new String(arrayOfChar); } private int _printBase64Binary(byte[] paramArrayOfbyte, int paramInt1, int paramInt2, char[] paramArrayOfchar, int paramInt3){ int i = paramInt2; int j; for (j = paramInt1; i >= 3; j += 3){ paramArrayOfchar[paramInt3++] = encode(paramArrayOfbyte[j] >> 2); paramArrayOfchar[paramInt3++] = encode( (paramArrayOfbyte[j] & 0x3) << 4 | paramArrayOfbyte[j + 1] >> 4 & 0xF); paramArrayOfchar[paramInt3++] = encode( (paramArrayOfbyte[j + 1] & 0xF) << 2 | paramArrayOfbyte[j + 2] >> 6 & 0x3); paramArrayOfchar[paramInt3++] = encode(paramArrayOfbyte[j + 2] & 0x3F); i -= 3; } if (i == 1){ paramArrayOfchar[paramInt3++] = encode(paramArrayOfbyte[j] >> 2); paramArrayOfchar[paramInt3++] = encode((paramArrayOfbyte[j] & 0x3) << 4); paramArrayOfchar[paramInt3++] = '='; paramArrayOfchar[paramInt3++] = '='; } if (i == 2){ paramArrayOfchar[paramInt3++] = encode(paramArrayOfbyte[j] >> 2); paramArrayOfchar[paramInt3++] = encode( (paramArrayOfbyte[j] & 0x3) << 4 | paramArrayOfbyte[j + 1] >> 4 & 0xF); paramArrayOfchar[paramInt3++] = encode((paramArrayOfbyte[j + 1] & 0xF) << 2); paramArrayOfchar[paramInt3++] = '='; } return paramInt3; } private char encode(int paramInt){ return encodeMap[paramInt & 0x3F]; } private static final char[] encodeMap = initEncodeMap(); private static char[] initEncodeMap(){ char[] arrayOfChar = new char[64]; byte b; for (b = 0; b < 26; b++) arrayOfChar[b] = (char) (65 + b); for (b = 26; b < 52; b++) arrayOfChar[b] = (char) (97 + b - 26); for (b = 52; b < 62; b++) arrayOfChar[b] = (char) (48 + b - 52); arrayOfChar[62] = '+'; arrayOfChar[63] = '/'; return arrayOfChar; } }