Kaynağa Gözat

Predictor, server, test, cells_1 and cells_2 models check in

Lars 2 yıl önce
işleme
1512ad6da0
8 değiştirilmiş dosya ile 313 ekleme ve 0 silme
  1. 3 0
      .gitignore
  2. 3 0
      README.md
  3. 38 0
      requirements.txt
  4. BIN
      src/cells_1.pth
  5. BIN
      src/cells_2.pth
  6. 109 0
      src/predictor.py
  7. 138 0
      src/server.py
  8. 22 0
      src/test.py

+ 3 - 0
.gitignore

@@ -0,0 +1,3 @@
+.venv/*
+__pycache__
+test/*

+ 3 - 0
README.md

@@ -0,0 +1,3 @@
+
+## Cell Classifier
+

+ 38 - 0
requirements.txt

@@ -0,0 +1,38 @@
+bottle==0.12.25
+certifi==2023.7.22
+charset-normalizer==3.3.2
+filelock==3.13.1
+fsspec==2023.10.0
+gunicorn==21.2.0
+idna==3.4
+imageio==2.31.6
+Jinja2==3.1.2
+lazy_loader==0.3
+MarkupSafe==2.1.3
+mpmath==1.3.0
+networkx==3.2.1
+numpy==1.26.1
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==8.9.2.26
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.18.1
+nvidia-nvjitlink-cu12==12.3.52
+nvidia-nvtx-cu12==12.1.105
+packaging==23.2
+Pillow==10.0.1
+requests==2.31.0
+scikit-image==0.22.0
+scipy==1.11.3
+sympy==1.12
+tifffile==2023.9.26
+torch==2.1.0
+torchvision==0.16.0
+triton==2.1.0
+typing_extensions==4.8.0
+urllib3==2.0.7

BIN
src/cells_1.pth


BIN
src/cells_2.pth


+ 109 - 0
src/predictor.py

@@ -0,0 +1,109 @@
+
+
+import numpy as np
+
+import torch
+import torchvision.transforms as transforms
+import torch.nn.init
+
+from PIL import Image
+from skimage import transform, util
+from typing import Tuple
+
+keep_prob = 0.9 
+n_classes = 8  # multi hot encoded
+MHE = (2,4)  # dvs 2 output values med  4 classes i hver
+
+# Convolutional network model
+class CNN(torch.nn.Module):
+    def __init__(self):
+        super(CNN, self).__init__()
+
+        self.layer1 = torch.nn.Sequential(
+            torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # input image size 44,  1,32,3,1,1
+            torch.nn.ReLU(),        
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),   
+            torch.nn.Dropout(p=1 - keep_prob)) 
+
+        self.layer2 = torch.nn.Sequential(
+            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), # 32,64,3,1,1
+            torch.nn.ReLU(),
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Dropout(p=1 - keep_prob))
+
+        self.layer3 = torch.nn.Sequential(
+            torch.nn.Conv2d(64, 100, kernel_size=3, stride=1, padding=1), # 64,100,3,1,1
+            torch.nn.ReLU(),
+            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
+            torch.nn.Dropout(p=1 - keep_prob))
+
+        self.fc1 = torch.nn.Linear(6 * 6 * 100, 400, bias=True)  # 6 * 6 * 100, 400, fully connected layer 1
+        torch.nn.init.xavier_uniform_(self.fc1.weight)  # _ ???
+        
+        self.layer4 = torch.nn.Sequential(   
+            self.fc1,
+            torch.nn.ReLU(),
+            torch.nn.Dropout(p=1 - keep_prob))
+        
+        self.fc2 = torch.nn.Linear(400, n_classes, bias=True)   # 400 - > 8 classes, , fully connected layer 2
+        torch.nn.init.xavier_uniform_(self.fc2.weight) # initialize weigts parameters
+
+    def forward(self, x):
+        out = self.layer1(x)
+        out = self.layer2(out)
+        out = self.layer3(out)
+        out = out.view(out.size(0), -1)   # Flatten them for fully connected layers
+        out = self.layer4(out)  # self.fc1(out)
+        out = self.fc2(out)
+        return torch.nn.Sigmoid()(out)  # out sigmoid range 0..1  -------------------------------
+
+# ----
+
+class Prediction():
+    def __init__(self, model_name:str, growth:int, cells:int) -> str:
+        self.model = model_name
+        self.growth = growth
+        self.cells = cells
+
+    def __str__(self):
+        growth_level = {1:"early", 2: "mature", 3: "overgrown"}.get(self.growth, "?")
+        cells = {0: "0", 1: "1", 2: "2", 3: "3 or more"}.get(self.cells, "?")
+        return f"Growth level: {growth_level}. Cells: {cells}. Model: {self.model}"
+
+class Predictor():
+    def __init__(self, pth_filename:str):
+        self.cnn = CNN()  
+        self.modelName = pth_filename.split('.')[0]
+
+        # load trained model from file
+        self.cnn.load_state_dict(torch.load(pth_filename))
+
+        # set model in evaluation mode
+        self.cnn.eval()
+
+    def predict(self, input_img:Image) -> Prediction:
+        im_gray = np.array(input_img)
+
+        # resizing to 44 x 44 pixels. Model is trained on 44 x 44 pixel images
+        resize_im = (44, 44)
+        im_gray = util.img_as_ubyte(transform.resize(im_gray, resize_im, order = 1, anti_aliasing = False))
+
+        # convert uint8 to some PIL image format, values range 0..1, float32 (float64 vil give error), kan sikkert også gøres på andre måder
+        image = Image.fromarray(im_gray)  # np.shape(image) = (44,44)
+
+        im_tensor = transforms.ToTensor()(image).unsqueeze(0)  # unsqueeze(0) giver [1,44,44] -> [1,1,44,44]
+
+        # print(np.shape(im_tensor)) # = [1,1,44,44] det er det format som modellen tager som input.  [n_batch, channel, imx, imy]
+        # Hvis batch, n af flere billeder på en gang skal det pakkes i formatet [n,1,44,44]
+
+        # run the model prediction on one image
+        prediction = self.cnn(im_tensor)
+
+        # output is something like : [2.7335e-13, 1.0000e+00, 1.6826e-11, 1.3905e-11, 9.9984e-01, 2.6207e-05, 4.9733e-08, 1.2173e-10]
+        # hvilket indeholder de 8 neuroner. Skal decodes til [A, B], ud fra [A1, A2, A3, A4, B1, B2, B3, B4]
+        # bruge argmax og reshape til at få final [A,B] 
+
+        output_classes = list(torch.argmax(torch.reshape(prediction[0].detach(), MHE), dim = -1).numpy())
+        # .detach() bruges for at undgå "grad_fn=<SigmoidBackward0>" dvs undgå at den aktivt tracker gradienten for hver operation...
+
+        return Prediction(self.modelName, int(output_classes[0]), int(output_classes[1]))

+ 138 - 0
src/server.py

@@ -0,0 +1,138 @@
+
+import gunicorn.app.base
+from predictor import Predictor, Prediction
+
+HOST_PORT = "localhost:8001"
+MODEL = "cells_2.pth"
+
+class CustomUnicornApp(gunicorn.app.base.BaseApplication):
+    """ 
+    This gunicorn app class provides create and exit callbacks for workers,     
+    and runs gunicorn with a single worker and multiple gthreads
+    """
+    def __init__(self, create_app_callback, exit_app_callback, host_port):
+        self._configBind = host_port
+        self._createAppCallback = create_app_callback
+        self._exitAppCallback = exit_app_callback
+        super().__init__()
+
+    @staticmethod
+    def exitWorker(arbiter, worker):
+        # worker.app provides us with a reference to "self", and we can call the 
+        # exit callback with the object created by the createAppCallback:
+        self = worker.app
+        self._exitAppCallback(self._createdApp)
+
+    def load_config(self):
+        self.cfg.set("bind", self._configBind)
+        self.cfg.set("worker_class", "gthread")
+        self.cfg.set("workers", 1)
+        self.cfg.set("threads", 4)
+        self.cfg.set("worker_exit", CustomUnicornApp.exitWorker)
+        # Try to uncomment and make 10 requests, to test correct restart of worker:
+        # self.cfg.set("max_requests", 10) 
+
+    def load(self):
+        # This function is invoked when a worker is booted
+        self._createdApp = self._createAppCallback()
+        return self._createdApp
+    
+
+# --- index.html contents
+
+html = """<!DOCTYPE html>
+<html lang="en">
+<head>
+<meta charset="UTF-8">
+<title>Cell Growth Classifier</title>
+<script>
+function uploadImage() {
+    var formData = new FormData();
+    var imageFile = document.getElementById("imageInput").files[0];
+    formData.append("image", imageFile);
+
+    fetch('/upload', {
+        method: 'POST',
+        body: formData
+    })
+    .then(response => response.json())
+    .then(data => {
+        document.getElementById("results").textContent = `${data.results}`;
+        var image = document.createElement("img");
+        image.src = URL.createObjectURL(imageFile);
+        image.onload = function() {
+            URL.revokeObjectURL(image.src) // Free up memory
+        }
+        document.getElementById("imageDisplay").innerHTML = '';
+        document.getElementById("imageDisplay").appendChild(image);
+    })
+    .catch(error => {
+        console.error('Error:', error);
+        document.getElementById("results").textContent = "An error occurred while uploading the image.";
+    });
+}
+</script>
+</head>
+<body>
+
+<h2>Cell growth classifier demo</h2>
+<input type="file" id="imageInput" accept="image/*"> <br><br>
+<button onclick="uploadImage()">Upload</button>
+
+<h3>Image:</h3>
+<div id="imageDisplay"></div>
+
+<h3>Results:</h3>
+<div id="results"></div>
+
+</body>
+</html>
+"""
+
+# ---- 
+
+from PIL import Image
+from bottle import Bottle, request, response
+import threading
+import io
+
+def startServer(index_html:str, predictor:Predictor):
+    def create():
+        app = Bottle()
+        lock = threading.Lock()
+    
+        @app.route("/")
+        def getIndex():
+            # Serve static content, no lock protection necessary:
+            return index_html
+
+        @app.route('/upload', method='POST')
+        def upload_image():
+            upload = request.files.get('image')
+            if upload is None:
+                return "No image uploaded"
+            
+            # Read the image file in bytes and open it with Pillow
+            image_bytes = io.BytesIO(upload.file.read())
+            try:
+                with Image.open(image_bytes) as img:
+                    pred:Prediction = app.predictor.predict(img)
+                    return {"results": str(pred)}
+            except IOError:
+                response.status = 400
+                return "Invalid image file"
+  
+        # Store predictor in the bottle app object
+        app.predictor = predictor
+        return app
+
+    def exit(app):
+        # Get the service through the app object and save state
+        pass
+
+    CustomUnicornApp(create, exit, HOST_PORT).run()
+
+if __name__ == "__main__":
+    p = Predictor(MODEL)
+    startServer(html, p)
+

+ 22 - 0
src/test.py

@@ -0,0 +1,22 @@
+#!/usr/bin/env python
+
+import sys
+
+def main(model_filename, img_filename):
+    from predictor import Predictor
+    from PIL import Image
+    p = Predictor(model_filename)
+    result = p.predict(Image.open(img_filename))
+    print(str(result))
+
+def usage():
+    print("""Usage:
+
+test.py modelfile imagefile""")
+
+if __name__ == "__main__":
+    if len(sys.argv) != 3:
+        usage()
+    else:
+        main(sys.argv[1], sys.argv[2])
+