Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend.ai-client-tester/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ dependencies {
compile project(':backend.ai-client')
compile group: 'commons-cli', name: 'commons-cli', version:'1.3.1'
compile 'commons-io:commons-io:2.6'
compile 'com.google.code.gson:gson:2.8.2'
implementation 'com.squareup.okhttp3:okhttp:3.12.1'
}
jar {
manifest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

import ai.backend.client.ClientConfig;
import ai.backend.client.Kernel;
import ai.backend.client.StreamExecutionHandler;
import ai.backend.client.StreamExecutionlistener;
import ai.backend.client.exceptions.AuthorizationFailureException;
import ai.backend.client.exceptions.ConfigurationException;
import ai.backend.client.exceptions.NetworkFailureException;
import ai.backend.client.values.ExecutionMode;
import ai.backend.client.values.ExecutionResult;
import ai.backend.client.values.RunStatus;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import okhttp3.OkHttpClient;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okio.ByteString;
import org.apache.commons.cli.*;
import org.apache.commons.io.FilenameUtils;

Expand All @@ -20,6 +29,7 @@
import java.util.logging.Logger;

import static java.util.logging.Level.SEVERE;
import com.google.gson.*;

public class Main {

Expand Down Expand Up @@ -112,6 +122,7 @@ public static void main(String[] args) {
kernel = createKernel(cmd);
} catch (ConfigurationException e) {
System.err.println("Bad ClientConfig");
System.err.println(e.getMessage());
return;
}
LOGGER.info(String.format("Kernel is ready : %s", kernel.getId()));
Expand All @@ -120,10 +131,7 @@ public static void main(String[] args) {

String buildCmd = cmd.getOptionValue("b", "*");
String execCmd = cmd.getOptionValue("e", "*");
runCode(kernel, buildCmd, execCmd);

finish(kernel);

runStreamCode(kernel, buildCmd, execCmd);
}

private static void uploadFiles(Kernel kernel, HashMap<String, String> files) {
Expand Down Expand Up @@ -179,6 +187,7 @@ public static void runCode(Kernel kernel, String buildCmd, String execCmd) {
System.out.print(result.getStdout());
System.err.print(result.getStderr());
if (result.isFinished()) {
LOGGER.info(String.format("Finished: ", kernel.getId()));
break;
}
if (result.getStatus() == RunStatus.WAITING_INPUT) {
Expand All @@ -195,6 +204,18 @@ public static void runCode(Kernel kernel, String buildCmd, String execCmd) {
}
}

public static void runStreamCode(Kernel kernel, String buildCmd, String execCmd) {
ExecutionMode mode = ExecutionMode.BATCH;
String runId = Kernel.generateRunId();
String code = "";
JsonObject opts = new JsonObject();
opts.addProperty("build", buildCmd);
opts.addProperty("exec", execCmd);

XXListener listener = new XXListener();
StreamExecutionHandler ws = kernel.streamExecute(mode, runId, code, opts, listener);
}

private static void finish(Kernel kernel) {
kernel.destroy();
}
Expand Down Expand Up @@ -224,4 +245,55 @@ protected static String getUnixRelativePath(String base, String path) throws IOE
rp = FilenameUtils.separatorsToUnix(rp);
return rp;
}

private static final class XXListener extends StreamExecutionlistener {
protected static Gson GSON;
private static final int NORMAL_CLOSURE_STATUS = 1000;
BufferedReader stdin = new BufferedReader(new InputStreamReader(System.in));
ExecutionMode mode = ExecutionMode.BATCH;
String code = "";

@Override
public void onMessage(WebSocket webSocket, String text) {
JsonElement je = new JsonParser().parse(String.format("{\"result\": %s }",text));
ExecutionResult result = new ExecutionResult(je.getAsJsonObject());
System.out.print(result.getStdout());
System.err.print(result.getStderr());
if (result.isFinished()) {
webSocket.close(NORMAL_CLOSURE_STATUS, null);
}
if (result.getStatus() == RunStatus.WAITING_INPUT) {
try {
code = stdin.readLine();
} catch (IOException e) {
code = "<user-input error>";
}
mode = ExecutionMode.INPUT;
this.send_code(webSocket, mode, code, null);

}
}

private void send_code(WebSocket webSocket, ExecutionMode mode, String code, JsonObject opts) {
JsonObject jsonObject = new JsonObject();
jsonObject.addProperty("mode", mode.getValue());
jsonObject.addProperty("code", code);
if (opts != null) {
jsonObject.add("options", opts);
}
String requestBody = GSON.toJson(jsonObject);
webSocket.send(requestBody);
}

@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
webSocket.close(NORMAL_CLOSURE_STATUS, null);
}

@Override
public void onClosed(WebSocket webSocket, int code, String reason) {
super.onClosed(webSocket, code, reason);
LOGGER.info(String.format("Finished"));
}
}
}
2 changes: 1 addition & 1 deletion backend.ai-client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ description = 'Backend.AI Client Library for Java'

dependencies {
compile 'com.google.code.gson:gson:2.8.2'
compile 'com.squareup.okhttp3:okhttp:3.9.1'
implementation 'com.squareup.okhttp3:okhttp:3.12.1'
testCompile 'org.junit.jupiter:junit-jupiter-api:5.0.1'
testCompile 'junit:junit:4.12'
}
57 changes: 33 additions & 24 deletions backend.ai-client/src/main/java/ai/backend/client/APIFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class APIFunction {

private final Auth auth;
private static SimpleDateFormat DATEFORMAT;
private final OkHttpClient restClient = new OkHttpClient();
protected final OkHttpClient restClient = new OkHttpClient();

static {
DATEFORMAT = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss");
Expand Down Expand Up @@ -47,29 +47,7 @@ public ClientConfig getClientConfig() {

protected Response makeRequest(String method, String queryString, RequestBody requestBody, String authBaseString)
throws IOException, BackendClientException {
Date now = new Date();
if (!queryString.startsWith("/")) {
throw new InvalidParametersException("queryString must start with a slash.");
}
queryString = "/" + this.config.getApiVersionMajor() + queryString;
String dateString = String.format("%s%s", APIFunction.DATEFORMAT.format(now), "+00:00");
String sig = this.auth.getCredentialString(
method,
queryString,
now,
String.format("%s/%s",requestBody.contentType().type(), requestBody.contentType().subtype()),
authBaseString);
String auth = String.format("BackendAI signMethod=HMAC-SHA256, credential=%s" ,sig);
Request request = new Request.Builder()
.url(String.format("%s%s", this.config.getEndPoint(), queryString))
.method(method, requestBody)
.addHeader("Content-Type", requestBody.contentType().toString())
.addHeader("Content-Length", String.format("%d", requestBody.contentLength()))
.addHeader("X-BackendAI-Version", this.config.getApiVersion())
.addHeader("Date", dateString)
.addHeader("User-Agent", this.config.getUserAgent())
.addHeader("Authorization", auth)
.build();
Request request = getRequest(method, queryString, requestBody, authBaseString);
Response response = this.restClient.newCall(request).execute();
if (!response.isSuccessful()) {
int code = response.code();
Expand Down Expand Up @@ -98,6 +76,37 @@ protected Response makeRequest(String method, String queryString, RequestBody re
return response;
}

protected Request getRequest(String method, String queryString, RequestBody requestBody, String authBaseString) throws IOException {
Date now = new Date();
if (!queryString.startsWith("/")) {
throw new InvalidParametersException("queryString must start with a slash.");
}
String dateString = String.format("%s%s", APIFunction.DATEFORMAT.format(now), "+00:00");
String sig = this.auth.getCredentialString(
method,
queryString,
now,
String.format("%s/%s", requestBody.contentType().type(), requestBody.contentType().subtype()),
authBaseString);
String auth = String.format("BackendAI signMethod=HMAC-SHA256, credential=%s", sig);
RequestBody bdy;
if (method.equals("GET")) {
bdy = null;
} else {
bdy = requestBody;
}
return new Request.Builder()
.url(String.format("%s%s", this.config.getEndPoint(), queryString))
.method(method, bdy)
.addHeader("Content-Type", requestBody.contentType().toString())
.addHeader("Content-Length", String.format("%d", requestBody.contentLength()))
.addHeader("X-BackendAI-Version", this.config.getApiVersion())
.addHeader("Date", dateString)
.addHeader("User-Agent", this.config.getUserAgent())
.addHeader("Authorization", auth)
.build();
}

protected Response makeRequest(String method, String queryString, String requestBody)
throws IOException, BackendClientException {
RequestBody formBody = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ public ClientConfig build() throws ConfigurationException{
}
try {
String url = String.format("%s/%s", endPoint, apiVersionMajor);
new URL(url);
URL uri = new URL(url);
hostname = uri.getHost();
if (uri.getPort() != -1){
hostname = String.format("%s:%d", uri.getHost(), uri.getPort());
}
} catch (MalformedURLException e) {
throw new ConfigurationException("Malformed endpoint URL");
}
Expand Down
35 changes: 31 additions & 4 deletions backend.ai-client/src/main/java/ai/backend/client/Kernel.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import ai.backend.client.values.ExecutionMode;
import ai.backend.client.values.ExecutionResult;
import com.google.gson.JsonObject;
import okhttp3.MediaType;
import okhttp3.MultipartBody;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.*;
import okio.BufferedSink;
import okio.ByteString;

import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -213,4 +212,32 @@ public static String generateRunId() {
public static String generateSessionToken() {
return UUID.randomUUID().toString().replaceAll("-", "");
}

public StreamExecutionHandler streamExecute(ExecutionMode mode, String runId, String code, JsonObject opts, StreamExecutionlistener listener) throws BackendClientException {
JsonObject jsonObject = new JsonObject();
jsonObject.addProperty("mode", mode.getValue());
jsonObject.addProperty("code", code);
if (opts != null) {
jsonObject.add("options", opts);
}
jsonObject.addProperty("runId", runId);
String requestBody = GSON.toJson(jsonObject);

WebSocket ws;
StreamExecutionHandler handler;
try {
RequestBody x = RequestBody.create(MediaType.parse("application/json"), new byte[0]);

Request request = getRequest("GET", String.format("/stream/kernel/%s/execute", this.sessionToken), x, "");
ws = this.restClient.newWebSocket(request, listener);
listener.setClient(this.restClient);
ws.send(requestBody);
handler = new StreamExecutionHandler(ws);
//this.restClient.dispatcher().executorService().shutdown();
} catch (IOException e) {
throw new BackendClientException("Request/response failed", e);
}
return handler;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package ai.backend.client;

import okhttp3.WebSocket;

public class StreamExecutionHandler{
private WebSocket ws;

public StreamExecutionHandler(WebSocket ws) {
this.ws = ws;
}

void send(String str) {
ws.send(str);
}
void send(String code, String mode, String options) {
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package ai.backend.client;

import okhttp3.OkHttpClient;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;

public class StreamExecutionlistener extends WebSocketListener {
private OkHttpClient client;

public void setClient(OkHttpClient client) {
this.client = client;
}

public OkHttpClient getClient() {
return client;
}

@Override
public void onClosed(WebSocket webSocket, int code, String reason) {
super.onClosed(webSocket, code, reason);
client.dispatcher().executorService().shutdown();
}
}