浏览代码

Server now takes command line parameters

Lars 2 年之前
父节点
当前提交
e3e8625bbb
共有 2 个文件被更改,包括 19 次插入8 次删除
  1. 2 1
      src/predictor.py
  2. 17 7
      src/server.py

+ 2 - 1
src/predictor.py

@@ -76,8 +76,9 @@ class Predictor():
         self.modelName = pth_filename.split('.')[0]
         self.modelName = pth_filename.split('.')[0]
 
 
         # load trained model from file
         # load trained model from file
+        print("Loading model from: " + pth_filename)
         self.cnn.load_state_dict(torch.load(pth_filename))
         self.cnn.load_state_dict(torch.load(pth_filename))
-
+        
         # set model in evaluation mode
         # set model in evaluation mode
         self.cnn.eval()
         self.cnn.eval()
 
 

+ 17 - 7
src/server.py

@@ -1,10 +1,9 @@
+#!/usr/bin/env python
 
 
 import gunicorn.app.base
 import gunicorn.app.base
+import sys
 from predictor import Predictor, Prediction
 from predictor import Predictor, Prediction
 
 
-HOST_PORT = "localhost:8001"
-MODEL = "cells_2.pth"
-
 class CustomUnicornApp(gunicorn.app.base.BaseApplication):
 class CustomUnicornApp(gunicorn.app.base.BaseApplication):
     """ 
     """ 
     This gunicorn app class provides create and exit callbacks for workers,     
     This gunicorn app class provides create and exit callbacks for workers,     
@@ -96,7 +95,7 @@ from bottle import Bottle, request, response
 import threading
 import threading
 import io
 import io
 
 
-def startServer(index_html:str, predictor:Predictor):
+def startServer(index_html:str, predictor:Predictor, host_port:str):
     def create():
     def create():
         app = Bottle()
         app = Bottle()
         lock = threading.Lock()
         lock = threading.Lock()
@@ -130,9 +129,20 @@ def startServer(index_html:str, predictor:Predictor):
         # Get the service through the app object and save state
         # Get the service through the app object and save state
         pass
         pass
 
 
-    CustomUnicornApp(create, exit, HOST_PORT).run()
+    CustomUnicornApp(create, exit, host_port).run()
+
+def usage():
+    print("""Usage:
+    server.py modelfile host_port
+
+Example:      
+    ./server.py cells_2.pth localhost:8001
+""")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    p = Predictor(MODEL)
-    startServer(html, p)
+    if len(sys.argv) != 3:
+        usage()
+    else:
+        p = Predictor(sys.argv[1])
+        startServer(html, p, sys.argv[2])