diff --git a/lars/nepho/inference.py b/lars/nepho/inference.py index 5ed1441..feb918e 100644 --- a/lars/nepho/inference.py +++ b/lars/nepho/inference.py @@ -1,8 +1,13 @@ import asyncio -DEFAULT_CATEGORIES = ["NO PRECIPITATION", "STRATIFORM RAIN", "SNOW", "SCATTERED CONVECTION", - "LINEAR CONVECTION", "SUPERCELLS", "UNKNOWN"] +DEFAULT_CATEGORIES = {"No precipitation": "No echoes greater than 10 dBZ present. A circle of echoes near radar site may be present due to ground clutter.", + "Stratiform rain": "Widespread echoes between 0 and 35 dBZ, not present as a circular pattern around the radar site.", + "Scattered Convection": "Present as isolated to scattered cells with reflectivities between 35-65 dBZ", + "Linear convection": "Cells must be organized into a linear structure with reflectivities between 40-60 dBZ", + "Supercells": "Supercells contain the classic hook echo and bounded weak echo region signatures with reflectivities above 55 dBZ", + "Unknown": "If you cannot confidently classify the radar image into one of the above categories"} -async def label_radar_data(radar_df, model, categories=None): +async def label_radar_data(radar_df, model, categories=None, site="Bankhead National Forest", + verbose=True, vmin=-20, vmax=60): """ Label radar data using a given model. @@ -10,17 +15,43 @@ async def label_radar_data(radar_df, model, categories=None): ---------- radar_df (pd.DataFrame): DataFrame containing radar data to be labeled. model: Model used for labeling the radar data. + site: str: Radar site identifier. Returns ------- pd.DataFrame DataFrame containing the labeled radar data. """ + if categories is None: + categories = DEFAULT_CATEGORIES prompt = "This is an image of weather radar base reflectivity data." \ + f" The radar site is the ARM Facility {site} site." \ " Please classify the weather depicted into one of the following categories: " \ - f"{', '.join(categories) if categories else ', '.join(DEFAULT_CATEGORIES)}." + f"{', '.join(categories) if categories else ', '.join(categories)}." + prompt += "Each category is defined as follows: " + for category, description in categories.items(): + prompt += f"{category}: {description}; " + prompt += f"The reflectivity values range from {vmin} dBZ as indicated by the blue colors to {vmax} dBZ as indicated by the red colors." + radar_df["llm_label"] = "" + for fi in radar_df["file_path"].values: - output = await model.chat(prompt, images=[fi]) - print(output) - radar_df.loc[radar_df["file_path"] == fi, "label"] = output + time = radar_df.loc[radar_df["file_path"] == fi, "time"].values[0] + prompt_with_time = prompt + f"Please provide just the category label for the radar image taken at time {time}." + + + output_model = await model.chat(prompt_with_time, images=[fi]) + # Find the category label in the output + output = output_model.strip() + for category in categories.keys(): + if category.lower() in output.lower(): + output = category + break + if verbose: + print("Category assigned:", output) + print("Model output:", output_model) + if output[-1] == ".": + output = output[:-1] + radar_df.loc[radar_df["file_path"] == fi, "llm_label"] = output.strip() + + return radar_df \ No newline at end of file diff --git a/lars/nepho/models/ollama_model.py b/lars/nepho/models/ollama_model.py index c4328f2..8e1f799 100644 --- a/lars/nepho/models/ollama_model.py +++ b/lars/nepho/models/ollama_model.py @@ -63,14 +63,22 @@ async def chat(self, prompt: str, images: Optional[List[str]] = None) -> str: image_data = self.encode_image(image_path) images_data.append(image_data) - + #if self.model_name == "llama4:scout": + # payload = { + # "model": self.model_name, + # "messages": [ + # {"role": "user", "content": prompt, "images": images_data} + # ], + # "stream": False + #} + #else: payload = { - "model": self.model_name, - "prompt": prompt, - "images": images_data, - "stream": False + "model": self.model_name, + "prompt": prompt, + "images": images_data, + "stream": False } - + # Use generate endpoint for vision models url = self.api_url else: @@ -107,7 +115,7 @@ async def chat(self, prompt: str, images: Optional[List[str]] = None) -> str: def supports_vision(self) -> bool: """Check if this model supports vision capabilities.""" - vision_models = ["llava", "bakllava", "moondream", "minicpm-v", "llava-llama2", "llava-llama3"] + vision_models = ["llava", "bakllava", "moondream", "minicpm-v", "llava-llama2", "llava-llama3", "llama4:scout"] return any(vision_model in self.model_name.lower() for vision_model in vision_models) async def list_available_models(self) -> List[str]: diff --git a/lars/preprocessing/radar_preprocessing.py b/lars/preprocessing/radar_preprocessing.py index 57a1976..72e082a 100644 --- a/lars/preprocessing/radar_preprocessing.py +++ b/lars/preprocessing/radar_preprocessing.py @@ -35,7 +35,7 @@ def preprocess_radar_data(file_path, output_path, """ file_list = glob.glob(file_path + '/*.nc') - out_df = pd.DataFrame(columns=['file_path', 'time', 'label']) + out_df = pd.DataFrame(columns=['file_path', 'time', 'label', 'ref_min', 'ref_max']) if not "vmin" in kwargs: kwargs['vmin'] = -20 if not "vmax" in kwargs: @@ -57,16 +57,21 @@ def preprocess_radar_data(file_path, output_path, if sweep["sweep_mode"] == 'ppi' or sweep["sweep_mode"] == 'sector': fig = plt.figure(figsize=(4, 4)) ax = plt.axes() - sweep["corrected_reflectivity"].plot(x="x", y="y", - add_colorbar=False, - ax=ax, - **kwargs) + sweep["corrected_reflectivity"].where( + sweep["corrected_reflectivity"] > min_ref).plot(x="x", y="y", + ax=ax, + **kwargs) + min_ref = sweep["corrected_reflectivity"].where( + sweep["corrected_reflectivity"] > min_ref).values.min() + max_ref = sweep["corrected_reflectivity"].where( + sweep["corrected_reflectivity"] > min_ref).values.max() + ax.set_xlim(x_bounds) ax.set_ylim(y_bounds) - ax.axis('off') - ax.set_title('') - ax.set_xlabel('') - ax.set_ylabel('') + ax.set_xlabel('X [m]') + ax.set_ylabel('Y [m]') + ax.set_xticks([-100000, -50000, 0, 50000, 100000]) + ax.set_yticks([-100000, -50000, 0, 50000, 100000]) fig.tight_layout() file_name = os.path.join(output_path, os.path.basename(file).replace('.nc', '.png')) @@ -74,9 +79,10 @@ def preprocess_radar_data(file_path, output_path, label = "UNKNOWN" # Placeholder for actual label extraction logic fig.savefig(os.path.join(output_path, os.path.basename(file).replace('.nc', '.png')), - dpi=100) + dpi=150) plt.close(fig) - out_df.loc[len(out_df)] = [file_name, time_str, label] + out_df.loc[len(out_df)] = [file_name, time_str, label, min_ref, max_ref] + else: print(f"Sweep mode is not PPI or sector scan in {file}, skipping.") else: