diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3a277e0e3b40eeeefd42fdde005b1ca63fdc4704 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +scripts/__pycache__/ +data/augmented_data/* +data/real_data/*/*.tif +data/real_test_data/*/*.tif +data/synth_data/*/*.tif +results/*/*.png +results/*/*/*.png \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..3b0e8e197cd2ce62dd8ae4054f0a699a5848fecc --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Valentin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 1f2983c65f8d9bab8ea0266c4f0124bbc1da4c4b..b4f1028c27f4d3a8efef58c0e8076342cefc8eae 100644 --- a/README.md +++ b/README.md @@ -1,93 +1,54 @@ # segment_sem_images +conda env create -f environment.yml +## This is a repository with the code for the segmentation of Slurry Coatings Using Machine Learning Techniques -## Getting started +In the present study, the thickness and presence of the coating layers Fe2Al5 and FeAl, the pores in the Fe2Al5 layer, and their concentration in %, the pore line parallel to the surface and its distance to surface as well as the concentration of Cr-precipitates are determined. -To make it easy for you to get started with GitLab, here's a list of recommended next steps. - -Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)! +--- +## Installation +--- -## Add your files +Clone the repository on your disk. Then, use the Conda package and environment management system for Python. Using the *.yaml files, you can create a Python environment on your system and install the required modules. There are two versions: a regular CPU (environment_cpu.yaml) and a GPU (environment_gpu.yaml). If you have a CUDA-compatible GPU, use the GPU yaml file to install the GPU support. It will provide faster segmentation of image data. -- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files -- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command: +To create a CPU environment, use: ``` -cd existing_repo -git remote add origin https://code.it4i.cz/set0013/segment_sem_images.git -git branch -M main -git push -uf origin main +conda env create -f environment_cpu.yaml ``` -## Integrate with your tools - -- [ ] [Set up project integrations](https://code.it4i.cz/set0013/segment_sem_images/-/settings/integrations) - -## Collaborate with your team - -- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/) -- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html) -- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically) -- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/) -- [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html) - -## Test and Deploy - -Use the built-in continuous integration in GitLab. - -- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html) -- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/) -- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html) -- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/) -- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html) - -*** - -# Editing this README - -When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template. - -## Suggestions for a good README +To create a GPU environment, use: -Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information. - -## Name -Choose a self-explaining name for your project. - -## Description -Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors. - -## Badges -On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge. - -## Visuals -Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method. - -## Installation -Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection. - -## Usage -Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README. +``` +conda env create -f environment_gpu.yaml +``` -## Support -Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc. +--- +## Segment the images +--- -## Roadmap -If you have ideas for releases in the future, it is a good idea to list them in the README. +To segment the images, use the evaluation.py python script file. -## Contributing -State if you are open to contributions and what your requirements are for accepting them. +Update the following variables in the evaluation.py python script: +*images_path*, *masks_path*, and *weights_file*. The path should be provided as an absolute path. Individual arguments should be provided as strings, i.e. in quotes. All the results are stored in the newly created sub-folder inside the *results* folder. +If masks are not available, then put empty string in masks_path. -For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self. +To segment all 6 labels together, use weight file which has *"_6"* in it's name. +To segment only chromium precipitates, use weight file which has *"_crpr"* in it's name. -You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser. +To use model trained only on synthetic data, use weight file which has *"_synth_class"* in it's name. +To use model trainedd on mix of synthetic and real data, use weight file which has *"_synth_real_mix_class"* in it's name. -## Authors and acknowledgment -Show your appreciation to those who have contributed to the project. +Example: +``` +python evaluation.py +OR +& C:/Users/xyz/AppData/Local/anaconda3/envs/tensorflow_gpu_env/python.exe c:/segment_sem_images_temp/evaluation.py +``` +--- ## License -For open source projects, say how it is licensed. +--- -## Project status -If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers. +The model is licensed under the [MIT License](LICENSE.txt). \ No newline at end of file diff --git a/data/real_data/.gitkeep b/data/real_data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data/real_test_data/.gitkeep b/data/real_test_data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data/synth_data/.gitkeep b/data/synth_data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/environment_cpu.yml b/environment_cpu.yml new file mode 100644 index 0000000000000000000000000000000000000000..ba8702c1957921336f6ba64e710ddc6dee574010 --- /dev/null +++ b/environment_cpu.yml @@ -0,0 +1,13 @@ +name: tensorflow_cpu_env +channels: + - defaults +dependencies: + - python=3.8 + - pip==23.3.1 + - tensorflow==2.10.0 + - pillow==10.0.1 + - matplotlib==3.7.2 + - scikit-learn==1.3.0 + - pip: + - colormath==3.0.0 + - opencv-python==4.9.0.80 \ No newline at end of file diff --git a/environment_gpu.yml b/environment_gpu.yml new file mode 100644 index 0000000000000000000000000000000000000000..d9b2a38d0321dfa82d523ae947eb8ffc9194e89e --- /dev/null +++ b/environment_gpu.yml @@ -0,0 +1,16 @@ +name: tensorflow_gpu_env +channels: + - defaults + - nvidia +dependencies: + - python=3.8 + - pip==23.3.1 + - cudatoolkit=11.8.0 + - cudnn + - tensorflow[and-cuda] + - pillow==10.0.1 + - matplotlib==3.7.2 + - scikit-learn==1.3.0 + - pip: + - colormath==3.0.0 + - opencv-python==4.9.0.80 \ No newline at end of file diff --git a/evaluation.py b/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4fd8baf542a3e9a385a739476703c2c0ddb824 --- /dev/null +++ b/evaluation.py @@ -0,0 +1,28 @@ +import subprocess +import time +import os + +# Start measuring script execution time +start_time = time.time() + +# -------------------------------------------------------------------------- +images_path = "C:/Users/set0013/Downloads/Fraunhofer/git/segment_sem_images_temp/data/real_test_data/real_images_test" +masks_path = "C:/Users/set0013/Downloads/Fraunhofer/git/segment_sem_images_temp/data/real_test_data/real_masks_test" +# masks_path = "" +weights_file = "model_weights_exp_synth_class_unet_6.h5" + +# -------------------------------------------------------------------------- +experiment_name = weights_file.replace("model_weights_", "").replace(".h5", "") + +# Use subprocess to run the test.py script +script_path = "scripts/test.py" +subprocess.run(["python", script_path, "--images_path", images_path, "--masks_path", masks_path, "--experiment_name", experiment_name]) + +# Calculate and print the total time taken +end_time = time.time() +execution_time = end_time - start_time +minutes = int(execution_time / 60) +seconds = int(execution_time % 60) +print(f"Total time taken: {minutes} minutes and {seconds} seconds") + +# -------------------------------------------------------------------------- diff --git a/results/images/.gitkeep b/results/images/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/results/predictions/.gitkeep b/results/predictions/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/FreeMono.ttf b/scripts/FreeMono.ttf new file mode 100644 index 0000000000000000000000000000000000000000..7485f9e4c84d5a372c81e11df2cd9f5e2eb2064a Binary files /dev/null and b/scripts/FreeMono.ttf differ diff --git a/scripts/convert_color_mask_to_number_mask.py b/scripts/convert_color_mask_to_number_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..694f9beff07e24e7bb9da565cc2116d31c939e1e --- /dev/null +++ b/scripts/convert_color_mask_to_number_mask.py @@ -0,0 +1,124 @@ +import os +import cv2 +import numpy as np +from sklearn.cluster import KMeans +from colormath.color_diff import delta_e_cie2000 +from colormath.color_objects import LabColor, sRGBColor +from colormath import color_conversions + +def color_distance(color1, color2): + lab_color1 = color_conversions.convert_color(sRGBColor(*color1), LabColor) + lab_color2 = color_conversions.convert_color(sRGBColor(*color2), LabColor) + return delta_e_cie2000(lab_color1, lab_color2) + +def closest_color(requested_color, color_dict): + min_distance = float('inf') + closest_idx = None + + for idx, rgb_color in color_dict.items(): + distance = color_distance(requested_color, rgb_color) + + if distance < min_distance: + min_distance = distance + closest_idx = idx + + return closest_idx + +def quantize_colors(image_path, num_colors, color_dict): + # Read the image + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) # Convert RGBA to RGB + + # Reshape the image to a 2D array of pixels + pixels = image.reshape((-1, 3)) + + # Use KMeans clustering to quantize the colors with random initial centers + kmeans = KMeans(n_clusters=num_colors, init='k-means++') + kmeans.fit(pixels) + # print(np.unique(kmeans.labels_)) + + # Get the centroids (representative colors) of each cluster + colors = kmeans.cluster_centers_.astype(int) + + # Assign range 0-7 to each cluster centroid + assigned_numbers = np.zeros(len(colors), dtype=np.uint8) + assigned_hex_colors = {} + for i, centroid in enumerate(colors): + assigned_numbers[i] = closest_color(centroid, color_dict) + closest_hex_color = color_dict.get(assigned_numbers[i]) + assigned_hex_colors[assigned_numbers[i]] = closest_hex_color + + # Map the assigned colors to the assigned numbers (0-7) + labels_2d = kmeans.labels_.reshape(image.shape[0], image.shape[1]) + assigned_numbers_image = np.zeros_like(labels_2d, dtype=np.uint8) + for i in range(image.shape[0]): + for j in range(image.shape[1]): + assigned_numbers_image[i, j] = assigned_numbers[labels_2d[i, j]] + + # Create a new color image using the assigned_numbers_image + assigned_colors_image = np.zeros_like(image) + for i in range(image.shape[0]): + for j in range(image.shape[1]): + assigned_colors_image[i, j] = color_dict[assigned_numbers_image[i, j]] + + return assigned_numbers_image, assigned_colors_image + + +if __name__ == "__main__": + from_blender = True # True, False + from_fraunhofer = False # True, False + num_colors = 6 # 9 # can edit this based on colors in the original mask + + if from_blender == True: + + input_path = "/mnt/proj3/open-28-64/set0013/sem_image_segmentation/data/base_data/masks_synth/" ############################ + # output_path_assigned_colors = "/mnt/proj3/open-28-64/set0013/sem_image_segmentation/data/base_data/masks_synth_quantized/" + output_path_assigned_numbers = "/mnt/proj3/open-28-64/set0013/sem_image_segmentation/data/base_data/masks_synth_gray/" + + color_dict = { + 0: [255, 192, 203], # pink # Bg1 + 1: [255, 0, 255], # purple # Bg2 + 2: [255, 255, 255], # white # Fe2Al5 + 3: [0, 255, 0], # green # Pores + 4: [0, 0, 255], # blue # Crpr + 5: [255, 255, 0], # yellow # FeAl + 6: [128, 128, 128], # gray # Fe2Al + 7: [255, 0, 0], # red # AIN + 8: [0, 0, 0] # black # Strip + # 9: [0, 255, 255] # cyan # New + } + + if from_fraunhofer == True: + input_path = "/mnt/proj3/open-28-64/set0013/sem_image_segmentation/data/base_data/masks_real/" + # output_path_assigned_colors = "/mnt/proj3/open-28-64/set0013/sem_image_segmentation/data/base_data/masks_real_quantized/" + output_path_assigned_numbers = "/mnt/proj3/open-28-64/set0013/sem_image_segmentation/data/base_data/masks_real_gray/" + + color_dict = { + 0: [255, 121, 193], # pink # Bg1 + 1: [255, 0, 0], # red # Bg2 + 2: [79, 255, 130], # green # Fe2Al5 + 3: [198, 118, 255], # purple # Pores + 4: [255, 255, 10], # yellow # Crpr + 5: [0, 0, 0], # # FeAl + 6: [84, 226, 255], # blue # Fe2Al + 7: [0, 0, 0], # # AIN + 8: [0, 0, 0] # # Strip + } + + # Create output directories if they don't exist + # os.makedirs(output_path_assigned_colors, exist_ok=True) + os.makedirs(output_path_assigned_numbers, exist_ok=True) + + input_files = [filename for filename in os.listdir(input_path)] + for filename in input_files: + print(filename) + input_image_path = os.path.join(input_path, filename) + assigned_numbers_image, assigned_colors_image = quantize_colors(input_image_path, num_colors, color_dict) + + # output_assigned_colors_path = os.path.join(output_path_assigned_colors, filename) + output_assigned_numbers_path = os.path.join(output_path_assigned_numbers, filename) + # cv2.imwrite(output_assigned_colors_path, cv2.cvtColor(assigned_colors_image, cv2.COLOR_RGB2BGR)) + cv2.imwrite(output_assigned_numbers_path, cv2.cvtColor(assigned_numbers_image, cv2.COLOR_RGB2BGR)) + + print("All images processed and saved.") + \ No newline at end of file diff --git a/scripts/datasets.py b/scripts/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2a231f4aa8535b890dd885648487b0564b617f --- /dev/null +++ b/scripts/datasets.py @@ -0,0 +1,100 @@ +# -------------------------------------------------- Import Modules +import os +import numpy as np +from PIL import Image +import random +import matplotlib.pyplot as plt + +# -------------------------------------------------- Functions + +def load_data(images_path, masks_path, give_pixel_size = True, preprocess = True): + images = [] + masks = [] + pixel_sizes = [] + + for filename in os.listdir(images_path): + # load images + image_path = os.path.join(images_path, filename) + img = np.array(Image.open(image_path).convert("L")) + + replacement_img_data = img[-61, :] + img[-60:] = replacement_img_data + + # load masks + mask_path = os.path.join(masks_path, filename) + msk = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) # / num_classes 6 # dtype=np.uint8 + + img = img / 255.0 + img = np.expand_dims(img, axis=-1) + images.append(img) + + msk = np.expand_dims(msk, axis=-1) + masks.append(msk) + + if give_pixel_size: + with tifffile.TiffFile(image_path) as tif: + tif_tags = {} + for tag in tif.pages[0].tags.values(): + name, value = tag.name, tag.value + tif_tags[name] = value + # image = tif.pages[0].asarray() + pixel_size = tif_tags['CZ_SEM']['ap_image_pixel_size'][1] + + pixel_sizes.append(pixel_size) + + return images, masks, pixel_sizes + + +def split_dataset(images, masks, test_size=0.1, val_size=0.1, random_state=42): + # Combine images and masks into pairs + data = list(zip(images, masks)) + + random.seed(random_state) + random.shuffle(data) + + total_samples = len(data) + test_samples = int(total_samples * test_size) + val_samples = int(total_samples * val_size) + train_samples = total_samples - test_samples - val_samples + + images_train = np.array([item[0] for item in data[:train_samples]]) + images_val = np.array([item[0] for item in data[train_samples:train_samples + val_samples]]) + images_test = np.array([item[0] for item in data[train_samples + val_samples:]]) + + masks_train = np.array([item[1] for item in data[:train_samples]]) + masks_val = np.array([item[1] for item in data[train_samples:train_samples + val_samples]]) + masks_test = np.array([item[1] for item in data[train_samples + val_samples:]]) + + return images_train, images_val, images_test, masks_train, masks_val, masks_test + + +def plot_train_mask(images_train, masks_train): + # Choose the number of images you want to visualize + num_images_to_visualize = 5 + + # Choose the start index for visualization + start_index = 0 + + for i in range(start_index, start_index + num_images_to_visualize): + # Get the i-th image and mask + image = images_train[i] + mask = masks_train[i] + + # Create a figure with two subplots: image and mask + fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + + # Plot the image + axes[0].imshow(image) + axes[0].set_title('Image') + axes[0].axis('off') + + # Plot the mask + axes[1].imshow(mask, cmap='gray') # Assuming mask is grayscale + axes[1].set_title('Mask') + axes[1].axis('off') + + # Adjust layout to prevent overlap + plt.tight_layout() + + # Save the figure + plt.savefig(f'image_mask_comparison_{i}.png') diff --git a/scripts/model.py b/scripts/model.py new file mode 100644 index 0000000000000000000000000000000000000000..34c20ab209570edc03ba40177f17ad1c128e00e3 --- /dev/null +++ b/scripts/model.py @@ -0,0 +1,105 @@ +# -------------------------------------------------- Import Modules +import tensorflow as tf +from tensorflow.keras.models import * +from tensorflow.keras.layers import * +from tensorflow.keras.optimizers import * +from tensorflow.keras import backend as K +from keras.layers import UpSampling2D, concatenate, Conv1D, Conv2D, MaxPooling1D, MaxPooling2D, Dense, Flatten, BatchNormalization, Dropout + +tf.random.set_seed(1234) + +# ---------------------------------------------------------------------------- Utility Functions +def dice_coef(y_true, y_pred, smooth=1): + intersection = K.sum(y_true * y_pred, axis=[1,2,3]) + union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) + return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0) + + +def dice_coef_loss(y_true, y_pred): + return -dice_coef(y_true, y_pred) + + +def get_crop_shape(target, refer): + """ + To get the width and height which is the difference in shape of conv4, up_conv5 + """ + # width, the 3rd dimension + print(target.shape) + print(refer.shape) + cw = (target.shape[2] - refer.shape[2]) + assert (cw >= 0) + if cw % 2 != 0: + cw1, cw2 = int(cw/2), int(cw/2) + 1 + else: + cw1, cw2 = int(cw/2), int(cw/2) + # height, the 2nd dimension + ch = (target.shape[1] - refer.shape[1]) + assert (ch >= 0) + if ch % 2 != 0: + ch1, ch2 = int(ch/2), int(ch/2) + 1 + else: + ch1, ch2 = int(ch/2), int(ch/2) + + return (ch1, ch2), (cw1, cw2) + + +# ---------------------------------------------------------------------------- Model architecture 1 +def unet(input_size, n_class): + inputs = tf.keras.Input(shape=input_size) + conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) + conv1 = BatchNormalization()(conv1) + conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) + conv1 = BatchNormalization()(conv1) + pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) + conv2 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) + conv2 = BatchNormalization()(conv2) + conv2 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) + conv2 = BatchNormalization()(conv2) + pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) + conv3 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) + conv3 = BatchNormalization()(conv3) + conv3 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) + conv3 = BatchNormalization()(conv3) + pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) + + conv4 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) + conv4 = BatchNormalization()(conv4) + conv4 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) + conv4 = BatchNormalization()(conv4) + drop4 = Dropout(0.5)(conv4) + + + conv5 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(drop4) + conv5 = BatchNormalization()(conv5) + conv5 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) + conv5 = BatchNormalization()(conv5) + drop5 = Dropout(0.5)(conv5) + + + up6 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) + # up6 = tf.image.resize(up6, conv3.shape) + merge6 = concatenate([conv3,up6], axis = 3) + conv6 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) + conv6 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) + + + up7 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) + # up7 = tf.image.resize(up7, conv2.shape) + merge7 = concatenate([conv2,up7], axis = 3) + conv7 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) + conv7 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) + + + up8 = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) + # up8 = tf.image.resize(up8, conv1.shape) + merge8 = concatenate([conv1,up8], axis = 3) + conv8 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) + conv8 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) + conv9 = Conv2D(n_class, 1, activation = 'softmax')(conv8) + + model = tf.keras.Model(inputs = inputs, outputs = conv9) + + return model + + +# ---------------------------------------------------------------------------- Model architecture 2 diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 0000000000000000000000000000000000000000..17d483435cee32828b8c238007a152b3efa8303c --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,462 @@ +# -------------------------------------------------- Import Modules +import os +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import time +import cv2 +import argparse + +import tensorflow as tf + +AUTOTUNE = tf.data.experimental.AUTOTUNE +print(f"Tensorflow ver. {tf.__version__}") + +from datasets import load_data, split_dataset +from model import * +from utils import * + +start_time = time.time() + +# -------------------------------------------------------------------- +parser = argparse.ArgumentParser() +parser.add_argument('--images_path', type=str, help='') +parser.add_argument('--masks_path', type=str, help='') +parser.add_argument('--experiment_name', type=str, help='') +args = parser.parse_args() + +# ---------------------------------------------------------------------------- Configs +current_file_directory = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.dirname(current_file_directory) + +experiment_name = args.experiment_name +images_folder_name = args.images_path +masks_folder_name = args.masks_path +file_weights = f'weights/model_weights_{experiment_name}.h5' +file_best_weights = f'weights/best_model_weights_{experiment_name}.h5' + +data_dir = os.path.join(root_dir, "data") +weights_file = os.path.join(root_dir, file_weights) +best_weights_file = os.path.join(root_dir, file_best_weights) + +num_classes = int(experiment_name.split('_')[-1].split('.')[0]) +batch_size = 1 +num_of_epochs = 100 +initial_learning_rate = 0.001 # 0.00001 +lr_decay = 0.001 +image_height = 768 +image_width = 1024 +num_channels = 1 +img_size = (image_height, image_width, num_channels) +w_decay = 1e-5 +dropout_rate = 0.2 + +var_preprocess = False +var_give_pixel_size = False + +pred_dir = os.path.join(root_dir, "results/predictions", experiment_name) +recreate_folder(pred_dir) + +are_masks_available = True +if masks_folder_name=="": + are_masks_available = False + masks_folder_name = images_folder_name +# ---------------------------------------------------------------------------- Load Data +images, masks, pixel_sizes = load_data(os.path.join(data_dir, images_folder_name), os.path.join(data_dir, masks_folder_name), var_give_pixel_size, var_preprocess) +var_data_random_state = 42 +var_data_test_size = 1 # 0, 0.15 +var_data_val_size = 0 # 0.15 +images_train, images_val, images_test, masks_train, masks_val, masks_test = split_dataset(images, masks, test_size=var_data_test_size, val_size=var_data_val_size, random_state=var_data_random_state) + +# # Duplicate the grayscale channel to create pseudo-RGB images +# images_train, images_val, images_test = grayscale_to_rgb(images_train, images_val, images_test) + +# ---------------------------------------------------------------------------- Model and Training +model = unet(img_size, num_classes) +# model = DeeplabV3Plus_ResNet50(image_height, image_width, 3, num_classes) + +# Load the trained model weights +model.load_weights(weights_file) + +# ---------------------------------------------------------------------------- Testing +num_test_samples = len(images_test) +predictions = np.zeros((num_test_samples, *img_size[:-1], num_classes)) # Initialize the predictions array +for i in range(0, num_test_samples, batch_size): + batch_images = images_test[i:i + batch_size] + batch_preds = model.predict(batch_images) + predictions[i:i + batch_size] = batch_preds + +predictions = np.array(tf.math.argmax(predictions, axis=-1)[..., None]) + +# ----------------------------------------------------------- Colors and Index +color_dict = { + 0: [255, 121, 193], # pink # Bg1 + 1: [255, 0, 0], # red # Bg2 + 2: [79, 255, 130], # green # Fe2Al5 + 3: [198, 118, 255], # purple # Pores + 4: [255, 255, 10], # yellow # Crpr + # 5: [0, 0, 0], # # FeAl + 5: [84, 226, 255], # blue # Fe2Al + # 6: [84, 226, 255], # blue # Fe2Al + # 7: [0, 0, 0], # # AIN + # 8: [0, 0, 0] # # Strip +} + +Bg1_idx = 0 +Bg2_idx = 1 +Fe2Al5_idx = 2 +Pores_idx = 3 +Crpr_idx = 4 +Fe2Al_idx = 5 + +# ----------------------------------------------------------- Saving predictions for visual comparison +if not are_masks_available: + for image_idx, (pred_mask, original_image) in enumerate(zip(predictions, images_test)): + prediction_mask_colored = np.zeros((*pred_mask.shape[:-1], 3), dtype=np.uint8) + for class_idx, color_hex in color_dict.items(): + color_rgb = color_hex + class_mask = (pred_mask[..., 0] == class_idx) + prediction_mask_colored[class_mask] = np.array(color_rgb, dtype=np.uint8) + + original_image = original_image.squeeze() + original_image = (original_image * 255).astype(np.uint8) + original_image = Image.fromarray(original_image) + prediction_mask_colored = Image.fromarray(prediction_mask_colored.astype('uint8')) + # Saving images + original_image.save(f"{pred_dir}/exp1_{image_idx}_original_image.png") + prediction_mask_colored.save(f"{pred_dir}/exp1_{image_idx}_prediction_image.png") + +# ----------------------------------------------------------- Calculating scores +if are_masks_available: + # Initialize lists to store accuracy and dice scores for each class and each image + image_class_accuracy_array = [[0.0] * num_classes for _ in range(len(predictions))] + image_class_dice_scr_array = [[0.0] * num_classes for _ in range(len(predictions))] + + for i, (pred_masks, original_mask) in enumerate(zip(predictions, masks_test)): + # class_idx_dilate = 1 + # kernel = np.ones((3,3), np.uint8) + # class_mask = (pred_masks == class_idx_dilate) + # dilated_class_mask = cv2.dilate(class_mask.astype(np.uint8), kernel, iterations=1) + # dilated_class_mask = np.expand_dims(dilated_class_mask, axis=-1) + # pred_masks = np.where(dilated_class_mask, class_idx_dilate, pred_masks) + + # mask_to_replace = (original_mask == 1) + # original_mask[mask_to_replace] = 3 + + # Flatten the prediction and original masks to 1D arrays + pred_flat = pred_masks.reshape(-1) + original_flat = original_mask.reshape(-1) + + # Print accuracy and dice scores for each class + for class_idx in range(num_classes): + class_original = (original_flat == class_idx) + class_pred = (pred_flat == class_idx) + image_class_accuracy_array[i][class_idx] = calculate_accuracy(class_original, class_pred) + image_class_dice_scr_array[i][class_idx] = calculate_dice_score(class_original, class_pred) + print(f"Image {i} class {class_idx} accuracy: {image_class_accuracy_array[i][class_idx]:.4f}") + print(f"Image {i} class {class_idx} dice score: {image_class_dice_scr_array[i][class_idx]:.4f}") + + # print(class_idx) + # TP, TN, FP, FN = confusion_matrix(class_original, class_pred) + # print(TP, TN, FP, FN) + # print((TP + TN) / (TP + TN + FP + FN)) + # print((2 * TP) / (2 * TP + FP + FN)) + + average_class_accuracy_scores = [0.0] * num_classes + average_class_dice_scores = [0.0] * num_classes + std_deviation_accuracy = [] + std_deviation_dice = [] + num_images = len(predictions) + for class_index in range(num_classes): + class_acc_total = sum(row[class_index] for row in image_class_accuracy_array) + average_class_accuracy_scores[class_index] = class_acc_total / num_images + + class_dice_total = sum(row[class_index] for row in image_class_dice_scr_array) + average_class_dice_scores[class_index] = class_dice_total / num_images + + # Calculate standard deviation for accuracy + accuracy_values = [row[class_index] for row in image_class_accuracy_array] + std_dev_accuracy = np.std(accuracy_values) + std_deviation_accuracy.append(std_dev_accuracy) + + # Calculate standard deviation for dice score + dice_values = [row[class_index] for row in image_class_dice_scr_array] + std_dev_dice = np.std(dice_values) + std_deviation_dice.append(std_dev_dice) + + for class_index, average in enumerate(average_class_accuracy_scores): + print(f"Class {class_index} average accuracy score: {average:.4f}") + for class_index, average in enumerate(average_class_dice_scores): + print(f"Class {class_index} average dice score: {average:.4f}") + for class_index, stdd in enumerate(std_deviation_accuracy): + print(f"Class {class_index} std_deviation_accuracy: {stdd:.4f}") + for class_index, stdd in enumerate(std_deviation_dice): + print(f"Class {class_index} std_deviation_dice: {stdd:.4f}") + +# ----------------------------------------------------------- Measurements Thickness +def remove_outliers(data, threshold=1.5, q1_value=45, q2_value=55): + # Calculate the first quartile (Q1) and third quartile (Q3) + Q1 = np.percentile(data, q1_value) + Q3 = np.percentile(data, q2_value) + + # Calculate the interquartile range (IQR) + IQR = Q3 - Q1 + + # Calculate lower and upper bounds for outliers + lower_bound = Q1 - (threshold * IQR) + upper_bound = Q3 + (threshold * IQR) + + # Create a mask to filter out outliers + outlier_mask = (data < lower_bound) | (data > upper_bound) + + # Apply the mask to the data to remove outliers + filtered_data = data[~outlier_mask] + + return filtered_data + +average_row_list = [] +mean_first_pixel_list = [] +mean_last_pixel_list = [] +dist1_um_list = [] +dist2_um_list = [] +pores_concentration_list = [] +crpr_concentration_list = [] +marked_img_list = [] +result_img_list = [] +prediction_mask_colored_img_list = [] +original_mask_colored_img_list = [] + +if are_masks_available: + # Loop through each image's predicted mask + for image_idx, (pred_mask, original_image, original_mask) in enumerate(zip(predictions, images_test, masks_test)): + # pixel_size = pixel_sizes[image_idx] + pixel_size = 223.3 + pixels_per_um = 1000/pixel_size + + print(f"Image {image_idx}:") + + # class_idx_dilate = 1 + # # kernel = np.ones((1,1), np.uint8) + # kernel = np.ones((2, 2), np.uint8) + # # kernel = np.ones((3,3), np.uint8) + # class_mask = (pred_mask == class_idx_dilate) + # dilated_class_mask = cv2.dilate(class_mask.astype(np.uint8), kernel, iterations=1) + # dilated_class_mask = np.expand_dims(dilated_class_mask, axis=-1) + # pred_mask = np.where(dilated_class_mask, class_idx_dilate, pred_mask) + + # save to see change + prediction_mask_colored = np.zeros((*pred_mask.shape[:-1], 3), dtype=np.uint8) + original_mask_colored = np.zeros((*original_mask.shape[:-1], 3), dtype=np.uint8) + + # Assign colors to each class based on the color_dict + for class_idx, color_hex in color_dict.items(): + color_rgb = color_hex + + class_mask = (pred_mask[..., 0] == class_idx) + prediction_mask_colored[class_mask] = np.array(color_rgb, dtype=np.uint8) + + # Assign the same color to the original mask + original_mask_class = (original_mask[..., 0] == class_idx) + original_mask_colored[original_mask_class] = np.array(color_rgb, dtype=np.uint8) + + class_mask = (pred_mask == Fe2Al5_idx).astype(np.uint8) + # Create an RGB image with the same dimensions as the mask + marked_image = np.zeros((class_mask.shape[0], class_mask.shape[1], 3), dtype=np.uint8) + + # Create an empty canvas for filling + filled_image = np.zeros_like(pred_mask, dtype=np.uint8) + + # Find the first and last pixel vertically + first_pixel = np.argmax(class_mask, axis=0) + last_pixel = class_mask.shape[0] - np.argmax(class_mask[::-1], axis=0) - 1 + + # Update min and max values + mean_first_pixel = int(np.mean(remove_outliers(first_pixel))) + mean_last_pixel = int(np.mean(remove_outliers(last_pixel))) + + ################################ + class_mask = (pred_mask == Pores_idx).astype(np.uint8) + + # Define a kernel for morphological operations + kernel = np.ones((3, 3), np.uint8) + + # Apply binary erosion to remove small noise + class_mask = cv2.erode(class_mask, kernel, iterations=1) + + # Apply binary dilation to expand the main area + class_mask = cv2.dilate(class_mask, kernel, iterations=1) + + # Find the first and last pixel vertically for each column + first_pixel = np.argmax(class_mask, axis=0) + last_pixel = class_mask.shape[0] - np.argmax(class_mask[::-1], axis=0) - 1 + + temp_marked_image = np.zeros((class_mask.shape[0], class_mask.shape[1], 3), dtype=np.uint8) + boundary_color = [255, 0, 0] + # Mark the boundary edge of class 3 in the marked image + for col_idx in range(class_mask.shape[1]): + if first_pixel[col_idx] > mean_first_pixel: + temp_marked_image[first_pixel[col_idx], col_idx] = boundary_color + if last_pixel[col_idx] < mean_last_pixel: + temp_marked_image[last_pixel[col_idx], col_idx] = boundary_color + + # Calculate the average row index where the line should pass through + red_channel = temp_marked_image[:, :, 0] + red_mask = red_channel > 250 + nonzero_rows = np.nonzero(red_mask)[0] + if len(nonzero_rows) > 0: + average_row = int(np.mean(nonzero_rows)) + else: + average_row = 0 + + # Add a horizontal line through the average of class 3 + line_color = [0, 0, 255] + marked_image[average_row, :, :] = line_color + marked_image[mean_first_pixel, :, :] = line_color + marked_image[mean_last_pixel, :, :] = line_color + + org_image = original_image.copy() + # Convert the original image to a standard grayscale format + org_image = (org_image * 255).astype(np.uint8) + org_image = np.squeeze(org_image) + org_image = np.stack([org_image] * 3, axis=-1) + + # Apply the mask to the main image to show the marked regions + result = cv2.addWeighted(org_image, 1, marked_image, 0.5, 0) + + # replace pixels of main_image with marked image wherever marked image is non zero + mask = np.any(marked_image != 0, axis=2) + result = np.copy(org_image) + result[mask] = marked_image[mask] + + distance_pore_line_to_Fe2Al5_top_surface = average_row - mean_first_pixel + distance_pore_line_to_Fe2Al5_bottom_surface = mean_last_pixel - average_row + + dist1_um = np.round(float(distance_pore_line_to_Fe2Al5_top_surface/pixels_per_um),2) + dist2_um = np.round(float(distance_pore_line_to_Fe2Al5_bottom_surface/pixels_per_um),2) + + # Concentration in % + buffer = 50 + temp_pred_mask = pred_mask.copy() + temp_pred_mask[0:(mean_first_pixel-buffer), :, :] = 0 + temp_pred_mask[(mean_last_pixel+buffer):temp_pred_mask.shape[0], :, :] = 0 + + total_area = np.size(pred_mask[mean_first_pixel:mean_last_pixel, :, :]) + + pred_flat = pred_mask.reshape(-1) # Flatten to 1D arrays + + pores_concentration = np.round((np.count_nonzero(pred_flat == Pores_idx) / total_area) * 100, 2) + crpr_concentration = np.round((np.count_nonzero(pred_flat == Crpr_idx) / total_area) * 100, 2) + + # ---------------------- + average_row_list.append(average_row) + mean_first_pixel_list.append(mean_first_pixel) + mean_last_pixel_list.append(mean_last_pixel) + + dist1_um_list.append(dist1_um) + dist2_um_list.append(dist2_um) + + pores_concentration_list.append(pores_concentration) + crpr_concentration_list.append(crpr_concentration) + + marked_img_list.append(marked_image) + result_img_list.append(result) + prediction_mask_colored_img_list.append(prediction_mask_colored) + original_mask_colored_img_list.append(original_mask_colored) + + # ---------------------- + + + # ---------------------- drawing and saving results + for image_idx in range(0, len(predictions)): + print(image_idx) + + # ---------------------- + average_row = average_row_list[image_idx] + mean_first_pixel = mean_first_pixel_list[image_idx] + mean_last_pixel = mean_last_pixel_list[image_idx] + distance_pore_line_to_Fe2Al5_top_surface = average_row - mean_first_pixel + distance_pore_line_to_Fe2Al5_bottom_surface = mean_last_pixel - average_row + + marked_image = marked_img_list[image_idx] + result = result_img_list[image_idx] + prediction_mask_colored = prediction_mask_colored_img_list[image_idx] + original_mask_colored = original_mask_colored_img_list[image_idx] + + # ---------------------- + dist1_um_mean = np.round(np.mean(dist1_um_list), 2) + dist1_um_std_dev = np.round(np.std(dist1_um_list), 2) + + dist2_um_mean = np.round(np.mean(dist2_um_list), 2) + dist2_um_std_dev = np.round(np.std(dist2_um_list), 2) + + pores_concentration_mean = np.round(np.mean(pores_concentration_list), 2) + pores_concentration_std_dev = np.round(np.std(pores_concentration_list), 2) + + crpr_concentration_mean = np.round(np.mean(crpr_concentration_list), 2) + crpr_concentration_std_dev = np.round(np.std(crpr_concentration_list), 2) + + # ---------------------- + # drawing on result + font_color = (255, 255, 0) + font_size = 30 + # font = ImageFont.load_default() + font = ImageFont.truetype("scripts/FreeMono.ttf", font_size) + arrow_color = (0, 0, 0) + arrow_thickness = 3 + textpos_x = int(result.shape[1]/2) + + arrow_start = (textpos_x-font_size, mean_first_pixel) + arrow_end = (textpos_x-font_size, average_row) + result = cv2.arrowedLine(result, arrow_start, arrow_end, font_color, 2) + result = cv2.arrowedLine(result, arrow_end, arrow_start, font_color, 2) + + arrow_start = (textpos_x-font_size, average_row) + arrow_end = (textpos_x-font_size, mean_last_pixel) + result = cv2.arrowedLine(result, arrow_start, arrow_end, font_color, 2) + result = cv2.arrowedLine(result, arrow_end, arrow_start, font_color, 2) + + # ---------------------- + marked_image = Image.fromarray(marked_image, mode='RGB') + result = Image.fromarray(result, mode='RGB') + prediction_mask_colored_img = Image.fromarray(prediction_mask_colored, mode='RGB') + original_mask_colored_img = Image.fromarray(original_mask_colored, mode='RGB') + + # ---------------------- + plus_minus = "±" + dist1_str = str(dist1_um_list[image_idx]) + "um " + plus_minus + " " + str(dist1_um_std_dev) + "um" + dist2_str = str(dist2_um_list[image_idx]) + "um " + plus_minus + " " + str(dist2_um_std_dev) + "um" + + textpos_y = int(mean_first_pixel + distance_pore_line_to_Fe2Al5_top_surface/2 - font_size/2) + draw = ImageDraw.Draw(result) + draw.text((textpos_x, textpos_y), str(dist1_str), fill=font_color, font=font) + + textpos_y = int(mean_last_pixel - distance_pore_line_to_Fe2Al5_top_surface/2 - font_size/2) + draw = ImageDraw.Draw(result) + draw.text((textpos_x, textpos_y), str(dist2_str), fill=font_color, font=font) + + # ---------------------- + pores_str = "Pores in the Fe2Al5 layer: " + str(pores_concentration_list[image_idx]) + "% " + plus_minus + " " + str(pores_concentration_std_dev) + "%" + crpr_str = "Crpr in the Fe2Al5 layer: " + str(crpr_concentration_list[image_idx]) + "% " + plus_minus + " " + str(crpr_concentration_std_dev) + "%" + + font_size = 10 + font = ImageFont.truetype("scripts/FreeMono.ttf", font_size) + textpos_x = 20 + + textpos_y = int(mean_first_pixel + font_size*2) + draw = ImageDraw.Draw(result) + for i in range(0, 5): + draw.text((textpos_x, textpos_y), str(pores_str), fill=font_color, font=font) + + textpos_y = int(textpos_y + font_size*2) + draw = ImageDraw.Draw(result) + for i in range(0, 5): + draw.text((textpos_x, textpos_y), str(crpr_str), fill=font_color, font=font) + + # Saving images + result.save(f"{pred_dir}/exp1_{image_idx}_original_marked_image.png") + prediction_mask_colored_img.save(f"{pred_dir}/exp1_{image_idx}_prediction_replace.png") + original_mask_colored_img.save(f"{pred_dir}/exp1_{image_idx}_original_mask.png") + # ---------------------- + +# ------------------------------------------------- End +print("--- %s seconds ---" % (time.time() - start_time)) +# ------------------------------------------------------------------------------------------ End diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e6541816335d4dbe6dff0bc95d696d94c8b99f3d --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,183 @@ + +# -------------------------------------------------- Import Modules +import os +import numpy as np +import time +import argparse + +import tensorflow as tf +from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint +from tensorflow.keras.preprocessing.image import ImageDataGenerator + +from sklearn.metrics import accuracy_score, f1_score, jaccard_score + +from datasets import * +from model import * +from utils import * + +AUTOTUNE = tf.data.experimental.AUTOTUNE +print(f"Tensorflow ver. {tf.__version__}") +# from tensorflow.python.client import device_lib +# print(device_lib.list_local_devices()) + +start_time = time.time() +# -------------------------------------------------------------------- +parser = argparse.ArgumentParser() +parser.add_argument('--synth_images_path', type=str, help='') +parser.add_argument('--synth_masks_path', type=str, help='') +parser.add_argument('--real_images_path', type=str, help='') +parser.add_argument('--real_masks_path', type=str, help='') +parser.add_argument('--experiment_name', type=str, help='') +parser.add_argument('--load_synth_data', type=bool, help='') +parser.add_argument('--load_real_data', type=bool, help='') +args = parser.parse_args() + +# ---------------------------------------------------------------------------- Configs +current_file_directory = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.dirname(current_file_directory) + +data_dir = os.path.join(root_dir, "data") + +experiment_name = args.experiment_name +images_folder_dir = args.synth_images_path +masks_folder_dir = args.synth_masks_path +real_images_dir = args.real_images_path +real_masks_dir = args.real_masks_path + +weights_file = os.path.join(root_dir, f'weights/model_weights_{experiment_name}.h5') +best_weights_file = os.path.join(root_dir, f'weights/best_model_weights_{experiment_name}.h5') + +aug_images_dir = os.path.join(data_dir, f'augmented_data/augmented_data_{experiment_name}/images/') +aug_masks_dir = os.path.join(data_dir, f'augmented_data/augmented_data_{experiment_name}/masks/') + +file_plot_data_aug = f'results/images/data_augmentation_{experiment_name}.png' +file_plot_training = f'results/images/training_plot_{experiment_name}.png' + +# ------------------------- +# adjust preprocess in datasets_tf.py +# True, False +load_synth_data = args.load_synth_data +load_real_data = args.load_real_data + +var_synt_preprocess = False +var_real_preprocess = False +var_give_pixel_size = False + +perform_data_aug = True +perform_training = True + +var_data_random_state = 42 +var_synt_data_test_size = 0.0 # 0, 0.15 +var_synt_data_val_size = 0.15 # 0.15 +var_real_data_test_size = 0.0 # +var_real_data_val_size = 0.15 # + +# ------------------------- +num_classes = int(experiment_name.split('_')[-1].split('.')[0]) +num_channels = 1 # 1, 3 +image_height = 768 # 715, 768 +image_width = 1024 +img_size = (image_height, image_width, num_channels) + +num_of_epochs = 100 +batch_size = 2 +initial_learning_rate = 0.001 # 0.00001 +lr_decay = 0.001 +w_decay = 1e-5 +dropout_rate=0.2 + +# ---------------------------------------------------------------------------- Load Data +if load_synth_data: + images, masks, _ = load_data(images_folder_dir, masks_folder_dir, var_give_pixel_size, var_synt_preprocess) + images_train, images_val, images_test, masks_train, masks_val, masks_test = split_dataset(images, masks, test_size=var_synt_data_test_size, val_size=var_synt_data_val_size, random_state=var_data_random_state) + +if load_real_data: + real_images, real_masks, _ = load_data(real_images_dir, real_masks_dir, var_give_pixel_size, var_real_preprocess) + real_images_train, real_images_val, real_images_test, real_masks_train, real_masks_val, real_masks_test = split_dataset(real_images, real_masks, test_size=var_real_data_test_size, val_size=var_real_data_val_size, random_state=var_data_random_state) + + if load_synth_data: + images_train = np.concatenate((images_train, np.array(real_images_train)), axis=0) + images_val = np.concatenate((images_val, np.array(real_images_val)), axis=0) + images_test = np.concatenate((images_test, np.array(real_images_test)), axis=0) + masks_train = np.concatenate((masks_train, np.array(real_masks_train)), axis=0) + masks_val = np.concatenate((masks_val, np.array(real_masks_val)), axis=0) + masks_test = np.concatenate((masks_test, np.array(real_masks_test)), axis=0) + else: + images_train = real_images_train + images_val = real_images_val + images_test = real_images_test + masks_train = real_masks_train + masks_val = real_masks_val + masks_test = real_masks_test + +# ---------------------------------------------------------------------------- Data Augmentation and Data Generator +if perform_data_aug: + data_gen_args = dict( + horizontal_flip=True, + fill_mode='reflect' + ) + + image_datagen = ImageDataGenerator(preprocessing_function=adjust_brightness_contrast, **data_gen_args) + mask_datagen = ImageDataGenerator(**data_gen_args) + + seed = 42 + image_data_iterator = image_datagen.flow(images_train, seed=seed, batch_size=batch_size) + mask_data_iterator = mask_datagen.flow(masks_train, seed=seed, batch_size=batch_size) + + plot_data_augmentation(image_data_iterator, mask_data_iterator, file_plot_data_aug, num_samples=5) + + num_augmented_samples = int(len(images_train)/batch_size) # len(images_train) # batch_size * num_augmented_samples + save_data_augmentation(aug_images_dir, aug_masks_dir, num_augmented_samples, batch_size, image_data_iterator, mask_data_iterator) + + aug_images, aug_masks, _ = load_data(aug_images_dir, aug_masks_dir, False, False) + images_train = np.concatenate((images_train, np.array(aug_images)), axis=0) + masks_train = np.concatenate((masks_train, np.array(aug_masks)), axis=0) + +# ---------------------------------------------------------------------------- Final Dataset +print(np.min(images_train[0])) +print(np.max(images_train[0])) +print(np.unique(masks_train[0])) +print(images_train.shape) +# images_train, images_val, images_test = grayscale_to_rgb(images_train, images_val, images_test) + +# ---------------------------------------------------------------------------- Model and Training +if perform_training: + # strategy = tf.distribute.MirroredStrategy() + # with strategy.scope(): + model = unet(img_size, num_classes) # model = DeeplabV3Plus_ResNet50(image_height, image_width, 3, num_classes) + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy() # loss_fn = Custom_CE_Loss() + var_metrics = "accuracy" # var_metrics=[MeanIoU(num_classes)] # tf.keras.metrics.F1Score + optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate) + model.compile(optimizer=optimizer, loss=loss_fn, metrics=var_metrics) + + reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1, min_lr=1e-7) + early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True) + checkpoint = ModelCheckpoint(best_weights_file, monitor='val_loss', save_best_only=True) + + # train_dataset = tf.data.Dataset.from_tensor_slices((images_train, masks_train)).shuffle(buffer_size=10000).batch(batch_size) + # val_dataset = tf.data.Dataset.from_tensor_slices((images_val, masks_val)).batch(batch_size) + + # # train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) + # # val_dist_dataset = strategy.experimental_distribute_dataset(val_dataset) + + # # history = model.fit(train_dist_dataset, epochs=num_of_epochs, validation_data=val_dist_dataset, callbacks=[reduce_lr, early_stopping, checkpoint]) + + # history = model.fit(train_dataset, epochs=num_of_epochs, validation_data=val_dataset, callbacks=[reduce_lr, early_stopping, checkpoint]) + + history = model.fit( + images_train, + masks_train, + epochs=num_of_epochs, + batch_size=batch_size, + validation_data=(images_val, masks_val), + callbacks=[reduce_lr, early_stopping, checkpoint] + ) + + model.save(filepath=weights_file) + print("Training completed.") + + plot_training_history(history, file_plot_training) + +# ------------------------------------------------------------------------------------------ +print("--- %s seconds ---" % (time.time() - start_time)) +# ------------------------------------------------------------------------------------------ End diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0042a072c3ca4a3c16132130f6ee6c431d6c60c0 --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,443 @@ +import os +import numpy as np +import matplotlib.pyplot as plt +import cv2 +import shutil + +import tensorflow as tf +from tensorflow.keras import backend as K +import keras.utils as image + +from sklearn.metrics import accuracy_score, f1_score, jaccard_score + +from datasets import * +from model import * +from utils import * + +# ----------------------------------------------------------------------------------------------------------------------- +""" +Loading Data +""" + +def recreate_folder(folder_path): + # Check if the folder exists + if os.path.exists(folder_path): + # If it exists, delete the folder and its contents + try: + shutil.rmtree(folder_path) + except OSError as e: + print(f"Error: {e}") + return False + + # Create the folder anew + try: + os.makedirs(folder_path) + return True + except OSError as e: + print(f"Error: {e}") + return False + +# ----------------------------------------------------------------------------------------------------------------------- +""" +Preprocessing functions +""" + +def extract_color(input_path, output_path, label, lower_color = np.array([20, 100, 100]), upper_color= np.array([40, 255, 255])): + image = cv2.imread(input_path) + hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + color_mask = cv2.inRange(hsv_image, lower_color, upper_color) + + # Apply the mask to the original image + color_extracted = cv2.bitwise_and(image, image, mask=color_mask) + + color_mask = color_extracted > 0 + color_extracted[color_mask] = label + + cv2.imwrite(output_path, color_extracted) + + +def blackwhite_mask(input_path, output_path, threshold_value, old_label, new_label): + """ + blackwhite_mask(input_path, output_path, 180, 255, 1) + blackwhite_mask(input_path, output_path, 1, 255, 1) + """ + os.makedirs(output_path, exist_ok=True) + + input_files = [filename for filename in os.listdir(input_path)] + for filename in input_files: + # print(filename) + input_image_path = os.path.join(input_path, filename) + image_mask = cv2.imread(input_image_path) + # print(np.unique(image_mask)) + + _, binary_image = cv2.threshold(image_mask, threshold_value, 255, cv2.THRESH_BINARY) + mask_to_replace = (binary_image == old_label) + binary_image[mask_to_replace] = new_label + output_image_path = os.path.join(output_path, filename) + cv2.imwrite(output_image_path, binary_image) + print(np.unique(binary_image)) + print("All images processed and saved.") + + +def add_mask_to_image(input_path, mask_path, output_path, value): + """ + add_mask_to_image(input_path, mask_path, output_path, 0) + """ + os.makedirs(output_path, exist_ok=True) + + input_files = [filename for filename in os.listdir(input_path)] + for filename in input_files: + # print(filename) + input_image_path = os.path.join(input_path, filename) + input_mask_path = os.path.join(mask_path, filename) + + image = cv2.imread(input_image_path) + mask = cv2.imread(input_mask_path) + + result_image = image.copy() + result_image[mask == 1] = value + + output_image_path = os.path.join(output_path, filename) + cv2.imwrite(output_image_path, result_image) + print(np.unique(result_image)) + print("All images processed and saved.") + + +def replace_mask(input_path, output_path, old_label, new_label): + """ + replace_mask(input_path, output_path, 6, 5) + """ + os.makedirs(output_path, exist_ok=True) + + input_files = [filename for filename in os.listdir(input_path)] + for filename in input_files: + input_image_path = os.path.join(input_path, filename) + image_mask = cv2.imread(input_image_path) + mask_to_replace = (image_mask == old_label) + image_mask[mask_to_replace] = new_label + output_image_path = os.path.join(output_path, filename) + cv2.imwrite(output_image_path, image_mask) + print(np.unique(image_mask)) + print("All images processed and saved.") + + +def grayscale_to_rgb(images_train, images_val, images_test): + # Duplicate the grayscale channel to create pseudo-RGB images + images_train = np.repeat(images_train, 3, axis=-1) + images_val = np.repeat(images_val, 3, axis=-1) + images_test = np.repeat(images_test, 3, axis=-1) + return images_train, images_val, images_test + + +def adjust_brightness_contrast(image, brightness_range=(0.5, 1.5), contrast_range=(0.5, 1.5)): # (0.8, 1.2) + # Adjust brightness + brightness_factor = np.random.uniform(brightness_range[0], brightness_range[1]) + image = image * brightness_factor + + # Adjust contrast + contrast_factor = np.random.uniform(contrast_range[0], contrast_range[1]) + image = (image - np.mean(image)) * contrast_factor + np.mean(image) + + # Clip values to ensure they are within the range [0, 1] + image = np.clip(image, 0, 1) + return image + + +# ----------------------------------------------------------------------------------------------------------------------- +""" +Data Augmentation +""" + +def plot_data_augmentation(image_data_iterator, mask_data_iterator, file_plot_data_aug, num_samples): + plt.figure(figsize=(15, num_samples)) + for i in range(num_samples): + augmented_image = next(image_data_iterator)[0] + augmented_mask = next(mask_data_iterator)[0] + plt.subplot(2, num_samples, 0 * num_samples + i + 1) + plt.imshow(augmented_image, cmap='gray') + # plt.imshow(augmented_image*255) + plt.title('Augmented Image') + plt.axis('off') + plt.subplot(2, num_samples, 1 * num_samples + i + 1) + # plt.imshow(augmented_mask, cmap='gray') + plt.imshow(augmented_mask, cmap='viridis') + plt.title('Augmented Mask') + plt.axis('off') + plt.tight_layout() + plt.savefig(file_plot_data_aug) + print(f'Saved augmented plot') + + +def save_data_augmentation(aug_images_dir, aug_masks_dir, num_augmented_samples, batch_size, image_data_iterator, mask_data_iterator): + os.makedirs(aug_images_dir, exist_ok=True) + os.makedirs(aug_masks_dir, exist_ok=True) + + for i in range(num_augmented_samples): + augmented_images = image_data_iterator.next() + augmented_masks = mask_data_iterator.next() + # print(augmented_images.shape) + if augmented_images.shape[0]==batch_size: + for j in range(batch_size): + image_filename = os.path.join(aug_images_dir, f'augmented_{i * batch_size + j}.png') + image.save_img(image_filename, augmented_images[j]) + mask_filename = os.path.join(aug_masks_dir, f'augmented_{i * batch_size + j}.png') + image.save_img(mask_filename, augmented_masks[j], scale=False) + + print(f'Saved {num_augmented_samples * batch_size} augmented image-mask pairs') + +# ----------------------------------------------------------------------------------------------------------------------- +""" +Loss functions +""" + +def focal_loss(gamma=2.0, alpha=0.25): + def focal_loss_fixed(y_true, y_pred): + epsilon = K.epsilon() + y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon) + p_t = tf.where(K.equal(y_true, 1), y_pred, 1 - y_pred) + alpha_factor = K.ones_like(y_true) * alpha + alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1 - alpha_factor) + focal_loss = -K.pow(1 - p_t, gamma) * K.log(p_t) + return alpha_t * focal_loss + + return focal_loss_fixed + +def dice_coef(y_true, y_pred, smooth=100): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + dice = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) + return dice + +def dice_coef_loss(y_true, y_pred, smooth=100): + return 1 - dice_coef(y_true, y_pred, smooth) + +def dice_coef_multilabel(y_true, y_pred, M, smooth=100): + dice = 0 + for index in range(M): + dice += dice_coef(y_true[:,:,:,index], y_pred[:,:,:,index], smooth) + return dice + +def weighted_dice_loss(class_weights, smooth=1.0): + def weighted_dice_loss_fixed(y_true, y_pred): + intersection = K.sum(y_true * y_pred, axis=[1, 2, 3]) + union = K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3]) + dice = (2. * intersection + smooth) / (union + smooth) + + # Apply class weights + weights = K.sum(y_true, axis=[1, 2, 3]) + K.sum(y_pred, axis=[1, 2, 3]) + weights = class_weights / (weights + K.epsilon()) + weighted_dice = K.sum(weights * dice) + + dice_loss = 1.0 - weighted_dice + return dice_loss + + return weighted_dice_loss_fixed + +class MeanIoU(tf.keras.metrics.MeanIoU): + def update_state(self, y_true, y_pred, sample_weight=None): + # sparse code + y_true = tf.argmax(y_true, axis=-1) + y_pred = tf.argmax(y_pred, axis=-1) + return super(MeanIoU, self).update_state(y_true, y_pred, sample_weight) + + +# @keras_export("keras.losses.DiceLoss") +# class DiceLoss(LossFunctionWrapper): +# def __init__(self, smooth=1.0, reduction=losses_utils.ReductionV2.AUTO, name="dice_loss"): +# super().__init__(dice_loss, name=name, reduction=reduction, smooth=smooth) + +# def dice_loss(y_true, y_pred, smooth=1.0): +# y_true = tf.convert_to_tensor(y_true) +# y_pred = tf.convert_to_tensor(y_pred) + +# intersection = K.sum(y_true * y_pred, axis=[1,2,3]) +# union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) + +# dice = (2. * intersection + smooth) / (union + smooth) +# dice_loss = 1.0 - dice + +# return dice_loss +# ----------------------------------------------------------------------------------------------------------------------- +""" +Training plot +""" + +def plot_training_history(history, file_plot_training): + plt.figure(figsize=(12, 12)) + + # Plotting the training loss + plt.subplot(2, 1, 1) + plt.plot(history.history['loss'], label='Training Loss') + plt.plot(history.history['val_loss'], label='Validation Loss') + plt.title('Training and Validation Loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend(['train', 'val'], loc='upper left') + + # Plotting the training accuracy + plt.subplot(2, 1, 2) + plt.plot(history.history['accuracy'], label='Training Accuracy') + plt.plot(history.history['val_accuracy'], label='Validation Accuracy') + plt.title('Training and Validation Accuracy') + plt.xlabel('Epoch') + plt.ylabel('Accuracy') + plt.legend(['train', 'val'], loc='upper left') + + plt.tight_layout() + plt.savefig(file_plot_training) + + + +# ----------------------------------------------------------------------------------------------------------------------- +""" +Testing metrics +""" +def confusion_matrix(actual_labels, predicted_labels): + TP = ((actual_labels == 1) & (predicted_labels == 1)).sum() + TN = ((actual_labels == 0) & (predicted_labels == 0)).sum() + FP = ((actual_labels == 0) & (predicted_labels == 1)).sum() + FN = ((actual_labels == 1) & (predicted_labels == 0)).sum() + + # accuracy = (TP + TN) / (TP + TN+ FP + FN) + # dice_coefficient = (2 * TP) / (2 * TP + FP + FN) + return TP, TN, FP, FN + + +def calculate_accuracy(actual_labels, predicted_labels): + accuracy_sklearn = accuracy_score(actual_labels, predicted_labels) + return accuracy_sklearn + + +def calculate_dice_score(class_original, class_pred): + # Ensure that the input arrays have the same shape + if class_original.shape != class_pred.shape: + raise ValueError("Input arrays must have the same shape") + + # Calculate the intersection between class_original and class_pred + intersection = np.sum((class_original == 1) & (class_pred == 1)) + + # Calculate the total number of 1s in both arrays + total_original_ones = np.sum(class_original == 1) + total_pred_ones = np.sum(class_pred == 1) + + # Calculate the Dice score + dice_score = (2.0 * intersection) / (total_original_ones + total_pred_ones) + + return dice_score + + +# ----------------------------------------------------------------------------------------------------------------------- +""" +Hyperparameter tuning search +""" + +if False: + from keras_tuner.tuners import RandomSearch, Hyperband + from keras_tuner import HyperParameters + + def build_model(hp): + hp_filters = hp.Int('filters', min_value=32, max_value=128, step=32) + hp_kernel_size = hp.Int('kernel_size', min_value=3, max_value=5, step=2) + hp_dropout = hp.Float('dropout', min_value=0.2, max_value=0.5, step=0.1) + + model = unet_model(input_shape=(768, 1024, 1), num_classes=8, filters=hp_filters, + kernel_size=hp_kernel_size, dropout=hp_dropout) + return model + + hyperparameters = HyperParameters() + tuner = Hyperband( + build_model, + objective='val_loss', # Metric to optimize + max_trials=10, # Number of trials to run + directory='my_dir', # Directory to store results + project_name='my_project', # Name of the project + overwrite=True + ) + + # Perform Hyperparameter Tuning + tuner.search(image_data_iterator, mask_data_iterator, + steps_per_epoch=steps_per_epoch, + epochs=num_of_epochs, + validation_data=(images_val, masks_val), + validation_steps=validation_steps, + callbacks=[reduce_lr, early_stopping]) + + # Retrieve the Best Hyperparameters + best_hps = tuner.get_best_hyperparameters(num_trials=1)[0] + + print(f""" + The hyperparameter search is complete. The optimal number of units in the first densely-connected + layer is {best_hps.get('units')} and the optimal learning rate for the optimizer + is {best_hps.get('learning_rate')}. + """) + + final_model = tuner.hypermodel.build(best_hps) + history = final_model.fit(image_data_iterator, mask_data_iterator, + steps_per_epoch=steps_per_epoch, + epochs=num_of_epochs, + validation_data=(images_val, masks_val), + validation_steps=validation_steps, + callbacks=[reduce_lr, early_stopping]) + + +# ----------------------------------------------------------------------------------------------------------------------- +""" +Extra experimental code +""" + +if True: + # kernel_size_2 = 3 # (5, 5), odd number + # # kernel_size = (3, 3) # (5, 5), odd number + # # sigma = 0 # 2, standard deviation + # # img = cv2.GaussianBlur(img, kernel_size, sigma) + # img = cv2.medianBlur(img, kernel_size_2) + + # replacement_img_data = img[-120:-60] + # img[-60:] = replacement_img_data + + # crop_height = 54 # 53 + # height, width = img.shape + # img = img[0:(height - crop_height), 0:width] + # msk = msk[0:(height - crop_height), 0:width] + + # img = img.resize((768, 1024)) + # msk = np.array(msk.resize((768, 1024), resample=Image.NEAREST)) + + # gauss_noise=np.zeros((640,480),dtype=np.uint8) + # cv2.randn(gauss_noise,128,20) + # gauss_noise=(gauss_noise*0.5).astype(np.uint8) + + # uni_noise=np.zeros(img.shape,dtype=np.float32) + # cv2.randu( # # kernel_size_2 = 1 # (5, 5), odd number + # kernel_size = (3, 3) # (5, 5), odd number + # sigma = 0.0 # 2, standard deviation + # img = cv2.GaussianBlur(img, kernel_size, sigma) + # # img = cv2.medianBlur(img, kernel_size_2) + # name = "imagess/blurred_image_median/blurred_image_" + str(filename) + ".png" + # cv2.imwrite(name, img)uni_noise,0,255) + # uni_noise=(uni_noise*0.3).astype(np.float32) + + # imp_noise=np.zeros((640,480),dtype=np.uint8) + # cv2.randu(imp_noise,0,255) + # imp_noise=cv2.threshold(imp_noise,245,255,cv2.THRESH_BINARY)[1] + + # img=cv2.add(img,gauss_noise) + # img=cv2.add(img,uni_noise) + # img=cv2.add(img,imp_noise) + + # name = "images7/noise_image_" + str(filename) + ".png" + # cv2.imwrite(name, img) + + # replacement_msk_data = msk[-120:-60] + # msk[-60:] = replacement_msk_data + + # kernel = np.ones((3, 3), np.uint8) + # class_mask = (msk == 4) + # dilated_class_mask = cv2.dilate(class_mask.astype(np.uint8), kernel, iterations=1) + # msk = np.where(dilated_class_mask, 4, msk) + + pass + +# ----------------------------------------------------------------------------------------------------------------------- End diff --git a/training.py b/training.py new file mode 100644 index 0000000000000000000000000000000000000000..e6567cf083ebc4a44c279e63c5277027370c3b67 --- /dev/null +++ b/training.py @@ -0,0 +1,43 @@ +import subprocess +import time +import os + +# Start measuring script execution time +start_time = time.time() + +# -------------------------------------------------------------------------- +current_file_directory = os.path.dirname(os.path.abspath(__file__)) +data_dir = os.path.join(current_file_directory, "data") + +# weights_file = "model_weights_exp_synth_class_unet_6.h5" +# weights_file = "model_weights_exp_synth_class_unet_crpr_2.h5" +# weights_file = "model_weights_exp_synth_class_unet_pores_2.h5" +# weights_file = "model_weights_exp_synth_real_mix_class_unet_6.h5" +weights_file = "model_weights_exp_synth_real_mix_class_unet_crpr_2.h5" +# weights_file = "model_weights_exp_synth_real_mix_class_unet_pores_2.h5" + +synth_images_path = os.path.join(data_dir, "synth_data/synth_images") +# synth_masks_path = os.path.join(data_dir, "synth_data/synth_masks") +synth_masks_path = os.path.join(data_dir, "synth_data/synth_masks_crpr") +# synth_masks_path = os.path.join(data_dir, "synth_data/synth_masks_pores") +real_images_path = os.path.join(data_dir, "real_data/real_images") +# real_masks_path = os.path.join(data_dir, "real_data/real_masks") +real_masks_path = os.path.join(data_dir, "real_data/real_masks_crpr") +# real_masks_path = os.path.join(data_dir, "real_data/real_masks_pores") +load_synth_data = True +load_real_data = True +# -------------------------------------------------------------------------- +experiment_name = weights_file.replace("model_weights_", "").replace(".h5", "") + +# Use subprocess to run the test.py script +script_path = "scripts/train.py" +subprocess.run(["python", script_path, "--synth_images_path", synth_images_path, "--synth_masks_path", synth_masks_path, "--real_images_path", real_images_path, "--real_masks_path", real_masks_path, "--experiment_name", experiment_name, "--load_synth_data", str(load_synth_data), "--load_real_data", str(load_real_data)]) + +# Calculate and print the total time taken +end_time = time.time() +execution_time = end_time - start_time +minutes = int(execution_time / 60) +seconds = int(execution_time % 60) +print(f"Total time taken: {minutes} minutes and {seconds} seconds") + +# -------------------------------------------------------------------------- diff --git a/weights/model_weights_exp_synth_class_unet_6.h5 b/weights/model_weights_exp_synth_class_unet_6.h5 new file mode 100644 index 0000000000000000000000000000000000000000..032bce69159b8d3ea708bf72135b724ebc32518d Binary files /dev/null and b/weights/model_weights_exp_synth_class_unet_6.h5 differ diff --git a/weights/model_weights_exp_synth_class_unet_crpr_2.h5 b/weights/model_weights_exp_synth_class_unet_crpr_2.h5 new file mode 100644 index 0000000000000000000000000000000000000000..99f25887ad20e74ca5b7a55f96c77bf80e53a018 Binary files /dev/null and b/weights/model_weights_exp_synth_class_unet_crpr_2.h5 differ diff --git a/weights/model_weights_exp_synth_class_unet_pores_2.h5 b/weights/model_weights_exp_synth_class_unet_pores_2.h5 new file mode 100644 index 0000000000000000000000000000000000000000..47c62501373c0a0c3b5207fa48045441caf6f667 Binary files /dev/null and b/weights/model_weights_exp_synth_class_unet_pores_2.h5 differ diff --git a/weights/model_weights_exp_synth_real_mix_class_unet_6.h5 b/weights/model_weights_exp_synth_real_mix_class_unet_6.h5 new file mode 100644 index 0000000000000000000000000000000000000000..847858bde2c59bc6f19e8bb5253f4db3a9e17a2e Binary files /dev/null and b/weights/model_weights_exp_synth_real_mix_class_unet_6.h5 differ diff --git a/weights/model_weights_exp_synth_real_mix_class_unet_crpr_2.h5 b/weights/model_weights_exp_synth_real_mix_class_unet_crpr_2.h5 new file mode 100644 index 0000000000000000000000000000000000000000..e6ccb8c469f6989fa0c6c78ffb518d853290cef7 Binary files /dev/null and b/weights/model_weights_exp_synth_real_mix_class_unet_crpr_2.h5 differ diff --git a/weights/model_weights_exp_synth_real_mix_class_unet_pores_2.h5 b/weights/model_weights_exp_synth_real_mix_class_unet_pores_2.h5 new file mode 100644 index 0000000000000000000000000000000000000000..60503c521fc6a059daad3fdec1658adfe286f588 Binary files /dev/null and b/weights/model_weights_exp_synth_real_mix_class_unet_pores_2.h5 differ