diff --git a/.gitignore b/.gitignore index 6d2b7c50..2c9017c0 100644 --- a/.gitignore +++ b/.gitignore @@ -50,4 +50,8 @@ yalc.lock .idea -package/.eggs \ No newline at end of file +package/.eggs + +psyneulinkviewer-darwin-x64/ +package/psyneulinkviewer/__pycache__ +package/__pycache__ \ No newline at end of file diff --git a/public/appState.js b/public/appState.js index 8452fbdb..91ec39cf 100644 --- a/public/appState.js +++ b/public/appState.js @@ -1,5 +1,6 @@ const appStates = require('../src/nodeConstants').appStates; const stateTransitions = require('../src/nodeConstants').stateTransitions; +const environments = require('../src/nodeConstants').environments; const appStateFactory = (function(){ function AppState() { @@ -19,7 +20,11 @@ const appStateFactory = (function(){ if (this.checkAppState(states.VIEWER_DEP_INSTALLED)) { await psyneulinkHandler.runServer(); } else { - psyneulinkHandler.stopServer(); + if (psyneulinkHandler.environment === environments.DEV) { + console.warn('The server should be restarted after the dependencies are installed.'); + } else { + psyneulinkHandler.stopServer(); + } } }; diff --git a/public/electron.js b/public/electron.js index 9c47151a..b464bd64 100644 --- a/public/electron.js +++ b/public/electron.js @@ -301,6 +301,7 @@ app.whenReady().then(() => { app.on("window-all-closed", () => { if (process.platform !== "darwin") { app.quit(); + psyneulinkHandler.stopServer(); } }); diff --git a/src/client/components/views/editView/leftSidebar/nodeFactory.js b/src/client/components/views/editView/leftSidebar/nodeFactory.js index 889a4947..0c893a68 100644 --- a/src/client/components/views/editView/leftSidebar/nodeFactory.js +++ b/src/client/components/views/editView/leftSidebar/nodeFactory.js @@ -4,7 +4,7 @@ import { PNLMechanisms } from "../../../../../constants"; import QueryService from "../../../../services/queryService"; import { PNLClasses, PNLLoggables } from "../../../../../constants"; import MechanismNode from "../../../../model/nodes/mechanism/MechanismNode"; -// import CompositionNode from "../../../../model/nodes/composition/CompositionNode"; +import CompositionNode from "../../../../model/nodes/composition/CompositionNode"; export class NodeFactory { static createNode(nodeType, name, extra, engine) { @@ -16,9 +16,12 @@ export class NodeFactory { switch (nodeType) { case PNLClasses.COMPOSITION: - // TODO: return a composition node - // return new CompositionNode(name, nodeType, undefined, undefined, extra); - return new MechanismNode(name, nodeType, undefined, QueryService.getPortsNewNode(name, nodeType), extra); + return new CompositionNode(name, nodeType, undefined, {}, extra); + case PNLClasses.AUTODIFF_COMPOSITION: + return new CompositionNode(name, nodeType, undefined, {}, extra); + case PNLClasses.EM_COMPOSITION: + console.log("EM COMPOSITION") + return new CompositionNode(name, nodeType, undefined, {}, extra); case PNLClasses.PROJECTION: const selectedNodes = engine.getModel().getSelectedEntities(); diff --git a/src/client/components/views/editView/mechanisms/shared/FunctionInput.js b/src/client/components/views/editView/mechanisms/shared/FunctionInput.js index 0cd69b85..d2de22be 100644 --- a/src/client/components/views/editView/mechanisms/shared/FunctionInput.js +++ b/src/client/components/views/editView/mechanisms/shared/FunctionInput.js @@ -173,6 +173,7 @@ export const CustomCheckInput = ({ label, ...props }) => { className="block" sx={{ minWidth: '100%', + ...props.sx }} > @@ -242,7 +243,7 @@ export const CustomValueInput = ({ label, minWidth, ...props }) => { const classes = useStyles(); return ( - + {label} @@ -295,7 +296,7 @@ export const MatrixInput = ({ renderValue={(value) => ( {value.charAt(0).toUpperCase() + - value.slice(1).replace('-', ' ')} + value.slice(1).replaceAll('-', ' ')} )} > @@ -418,7 +419,7 @@ const FunctionInput = ({ label, ...props }) => { } return ( - + { renderValue={(value) => ( {value.charAt(0).toUpperCase() + - value.slice(1).replace('-', ' ')} + value.slice(1).replaceAll('-', ' ')} )} > diff --git a/src/client/components/views/editView/mechanisms/shared/PortsList.js b/src/client/components/views/editView/mechanisms/shared/PortsList.js index 97e54dce..21a72c3a 100644 --- a/src/client/components/views/editView/mechanisms/shared/PortsList.js +++ b/src/client/components/views/editView/mechanisms/shared/PortsList.js @@ -54,7 +54,7 @@ const PortsList = ({ async function addPorts() { const newPorts = [...ports]; - const filteredPorts = ports.filter((port) => port.type === portType); + const filteredPorts = ports?.filter((port) => port.type === portType); const allPortsNames = filteredPorts.map((port) => port.name); let currentIndex = filteredPorts.length; while (allPortsNames.includes(portType + '_' + currentIndex)) { @@ -78,9 +78,9 @@ const PortsList = ({ } function removePort(port, portId) { - const inputPorts = ports.filter((port) => port.type === portType); + const inputPorts = ports?.filter((port) => port.type === portType); if (inputPorts.length > 1) { - const filteredPorts = ports.filter((port) => port.id !== portId); + const filteredPorts = ports?.filter((port) => port.id !== portId); model.removePort(port); // remove target port in target node if (handleValueChange) { diff --git a/src/client/components/views/editView/mechanisms/shared/subclass/ModulatoryMechForm.js b/src/client/components/views/editView/mechanisms/shared/subclass/ModulatoryMechForm.js index 2bdfd40f..04d9fc0a 100644 --- a/src/client/components/views/editView/mechanisms/shared/subclass/ModulatoryMechForm.js +++ b/src/client/components/views/editView/mechanisms/shared/subclass/ModulatoryMechForm.js @@ -14,6 +14,7 @@ function ModulatoryMechForm(props) { model={model} label={optionKeys.modulation} value={optionsValue.modulation} + sx={{ borderTopRightRadius: '0.625rem !important' }} onChange={(e) => handleOptionChange( { diff --git a/src/client/components/views/editView/mechanisms/shared/subclass/ProcessingMechForm.js b/src/client/components/views/editView/mechanisms/shared/subclass/ProcessingMechForm.js index 787617bb..d3ae4d81 100644 --- a/src/client/components/views/editView/mechanisms/shared/subclass/ProcessingMechForm.js +++ b/src/client/components/views/editView/mechanisms/shared/subclass/ProcessingMechForm.js @@ -14,6 +14,7 @@ function ProcessingMechForm(props) { model={model} label={optionKeys.function} value={optionsValue.function} + sx={{ borderTopRightRadius: '0.625rem !important' }} onChange={(e) => handleOptionChange( { diff --git a/src/client/components/views/editView/mechanisms/shared/subclass/TransferMechForm.js b/src/client/components/views/editView/mechanisms/shared/subclass/TransferMechForm.js index 86090cd2..f2700b50 100644 --- a/src/client/components/views/editView/mechanisms/shared/subclass/TransferMechForm.js +++ b/src/client/components/views/editView/mechanisms/shared/subclass/TransferMechForm.js @@ -17,6 +17,7 @@ function TransferMechForm(props) { model={model} label={optionKeys.noise} value={optionsValue.noise} + sx={{ width: 'calc((100% - 0.125rem) / 2) !important' }} onChange={(e) => handleOptionChange( { @@ -32,6 +33,7 @@ function TransferMechForm(props) { model={model} label={optionKeys.clip} value={optionsValue.clip} + sx={{ width: 'calc((100% - 0.125rem) / 2) !important', borderTopRightRadius: '0.625rem !important' }} onChange={(e) => handleOptionChange( { @@ -43,24 +45,10 @@ function TransferMechForm(props) { ) } /> - - handleOptionChange( - { - key: optionKeys.integration_rate, - value: e.target.value, - }, - updateOptions, - updateModelOption - ) - } - /> handleOptionChange( { @@ -72,7 +60,7 @@ function TransferMechForm(props) { ) } /> - - */} + {/* */} + + handleOptionChange( + { + key: optionKeys.integration_rate, + value: e.target.value, + }, + updateOptions, + updateModelOption + ) + } /> handleOptionChange( { @@ -118,7 +123,7 @@ function TransferMechForm(props) { ) } /> - + /> */} div': { height: '100%', + width: '100%', }, }, iconButton: { diff --git a/src/client/components/views/visualiseView/main/DroppableChart.js b/src/client/components/views/visualiseView/main/DroppableChart.js index 28d42362..d76c3bb1 100644 --- a/src/client/components/views/visualiseView/main/DroppableChart.js +++ b/src/client/components/views/visualiseView/main/DroppableChart.js @@ -177,7 +177,7 @@ export const DroppableChart = ({ id, model, accept = 'element' }) => { return ( {chart} @@ -190,7 +190,7 @@ export const DroppableChart = ({ id, model, accept = 'element' }) => { {renderChartIcon(value)} {value.charAt(0).toUpperCase() + - value.slice(1).replace('-', ' ')} + value.slice(1).replaceAll('-', ' ')} )} diff --git a/src/client/grpc/grpcClient.js b/src/client/grpc/grpcClient.js index 7c81c5d2..e81cc007 100644 --- a/src/client/grpc/grpcClient.js +++ b/src/client/grpc/grpcClient.js @@ -127,6 +127,26 @@ const grpcClientFactory = (function(){ }); } + this.stopServer = (callback, errorCallback) => { + const request = new messages.NullArgument() + this._client.stopServer(request, (err, response) => { + if (err) { + if (errorCallback) { + errorCallback(err); + } else { + console.log(err); + } + } else { + if (callback) { + callback(response); + } else { + console.log(response); + console.log(response.getResponse()); + console.log(response.getMessage()); + } + } + }); + } } var instance; diff --git a/src/client/grpc/psyneulink_grpc_pb.js b/src/client/grpc/psyneulink_grpc_pb.js index ad1df3fb..93232f06 100644 --- a/src/client/grpc/psyneulink_grpc_pb.js +++ b/src/client/grpc/psyneulink_grpc_pb.js @@ -48,6 +48,17 @@ function deserialize_psyneulinkviewer_ModelPath(buffer_arg) { return psyneulink_pb.ModelPath.deserializeBinary(new Uint8Array(buffer_arg)); } +function serialize_psyneulinkviewer_NullArgument(arg) { + if (!(arg instanceof psyneulink_pb.NullArgument)) { + throw new Error('Expected argument of type psyneulinkviewer.NullArgument'); + } + return Buffer.from(arg.serializeBinary()); +} + +function deserialize_psyneulinkviewer_NullArgument(buffer_arg) { + return psyneulink_pb.NullArgument.deserializeBinary(new Uint8Array(buffer_arg)); +} + function serialize_psyneulinkviewer_PNLJson(arg) { if (!(arg instanceof psyneulink_pb.PNLJson)) { throw new Error('Expected argument of type psyneulinkviewer.PNLJson'); @@ -127,6 +138,17 @@ var ServeGraphService = exports.ServeGraphService = { responseSerialize: serialize_psyneulinkviewer_Response, responseDeserialize: deserialize_psyneulinkviewer_Response, }, + stopServer: { + path: '/psyneulinkviewer.ServeGraph/StopServer', + requestStream: false, + responseStream: false, + requestType: psyneulink_pb.NullArgument, + responseType: psyneulink_pb.Response, + requestSerialize: serialize_psyneulinkviewer_NullArgument, + requestDeserialize: deserialize_psyneulinkviewer_NullArgument, + responseSerialize: serialize_psyneulinkviewer_Response, + responseDeserialize: deserialize_psyneulinkviewer_Response, + }, }; exports.ServeGraphClient = grpc.makeGenericClientConstructor(ServeGraphService); diff --git a/src/client/interfaces/psyneulinkHandler.js b/src/client/interfaces/psyneulinkHandler.js index 14a5176e..393bef8c 100644 --- a/src/client/interfaces/psyneulinkHandler.js +++ b/src/client/interfaces/psyneulinkHandler.js @@ -7,6 +7,8 @@ const spawnCommand = require("./utils").spawnCommand; const executeCommand = require("./utils").executeCommand; const executeSyncCommand = require("./utils").executeSyncCommand; const parseArguments = require("./utils").parseArguments; +const grpcClient = require("../grpc/grpcClient").grpcClientFactory.getInstance(); +const rpcAPIMessageTypes = require("../../nodeConstants").rpcAPIMessageTypes; const environments = require("../../nodeConstants").environments; @@ -95,7 +97,6 @@ const psyneulinkHandlerFactory = (function(){ this.runServer = async () => { try { - // TODO - remove this when we have a proper server if (this.environment === environments.DEV) { this.serverProc = 'DEVELOPMENT MODE'; logOutput(Date.now() + " START: Starting Python RPC server IN DEVELOPMENT MODE\n", true); @@ -140,17 +141,11 @@ const psyneulinkHandlerFactory = (function(){ this.stopServer = async () => { try { - if (this.environment === environments.DEV) { - this.serverProc = 'DEVELOPMENT MODE'; - logOutput(Date.now() + " STOP: Simulation of the development server STOPPED\n", true); - return true; - } - - if (this.serverProc !== null && this.serverProc !== undefined) { - killProcess(this.serverProc.pid); - logOutput(Date.now() + " STOP: Server STOPPED with pid " + this.serverProc.pid + "\n", true); - this.serverProc = null; - } + grpcClient.stopServer(undefined, () => { + console.log('Server closed'); + }, (error) => { + console.error(error); + }); return true; } catch (error) { logOutput(Date.now() + " ERROR: " + error + "\n", true); diff --git a/src/client/model/Interpreter.ts b/src/client/model/Interpreter.ts index 221c976a..bf83b4d1 100644 --- a/src/client/model/Interpreter.ts +++ b/src/client/model/Interpreter.ts @@ -130,7 +130,7 @@ export default class ModelInterpreter { const result = QueryService.getPorts(name); if (result !== '') { - const parsedPorts = result.replace('[', '').replace(']', '').split(', '); + const parsedPorts = result.replaceAll(/(\[|\])/g, '').split(', '); parsedPorts.forEach(element => { const elementData = element.slice(1, -1).split(' '); switch(elementData[0]) { @@ -171,13 +171,15 @@ export default class ModelInterpreter { boundingBox.ury = parseFloat(_vertices[3]); } extra['boundingBox'] = boundingBox; + extra['width'] = Math.abs(boundingBox.llx - boundingBox.urx); + extra['height'] = Math.abs(boundingBox.ury - boundingBox.lly); extra['position'] = { x: boundingBox.llx, y: boundingBox.lly } extra['isExpanded'] = false; extra[PNLLoggables] = this.loggables[item?.label]; - newNode = new CompositionNode(item?.label, parent, ports, extra); + newNode = new CompositionNode(item?.label, ModelSingleton.getNodeType(item?.name), parent, ports, extra); modelMap[PNLClasses.COMPOSITION].set(newNode.getName(), newNode); // temp array to host all the nested compositions let childrenCompositions: Array = []; @@ -245,13 +247,15 @@ export default class ModelInterpreter { boundingBox.ury = parseFloat(_vertices[3]); } extra['boundingBox'] = boundingBox; + extra['width'] = Math.abs(boundingBox.llx - boundingBox.urx); + extra['height'] = Math.abs(boundingBox.ury - boundingBox.lly); extra['position'] = { x: boundingBox.llx, y: boundingBox.lly } extra['isExpanded'] = false; extra[PNLLoggables] = this.loggables[item?.label]; - newNode = new CompositionNode(item?.label, parent, ports, extra); + newNode = new CompositionNode(item?.label, ModelSingleton.getNodeType(item?.name), parent, ports, extra); modelMap[PNLClasses.COMPOSITION].set(newNode.getName(), newNode); // Iterates nodes of the nested composition to fill the children map/array @@ -289,7 +293,7 @@ export default class ModelInterpreter { modelMap[newNode.getType()].set(newNode.getName(), newNode); this.pnlModel[newNode.getType()].push(newNode); } else { - throw new Error('Unknown node type, class ' + newNode.getType() + ' not found in modelMap'); + //throw new Error('Unknown node type, class ' + newNode.getType() + ' not found in modelMap'); } return newNode; } diff --git a/src/client/model/ModelSingleton.ts b/src/client/model/ModelSingleton.ts index 5dbc7801..bfbb6fce 100644 --- a/src/client/model/ModelSingleton.ts +++ b/src/client/model/ModelSingleton.ts @@ -71,6 +71,8 @@ export default class ModelSingleton { ModelSingleton.metaRef = React.createRef(); ModelSingleton.componentsMap = new ComponentsMap(new Map(), new Map()); ModelSingleton.componentsMap.nodes.set(PNLClasses.COMPOSITION, Composition); + ModelSingleton.componentsMap.nodes.set(PNLClasses.EM_COMPOSITION, Composition); + ModelSingleton.componentsMap.nodes.set(PNLClasses.AUTODIFF_COMPOSITION, Composition); // TODO: the PNLMechanisms.MECHANISM is not used anymore since we are defininig the classes. ModelSingleton.componentsMap.nodes.set(PNLMechanisms.DDM, DDM); ModelSingleton.componentsMap.nodes.set(PNLMechanisms.LCA_MECH, LCAMechanism); @@ -109,6 +111,8 @@ export default class ModelSingleton { ModelSingleton.metaGraph = generateMetaGraph([ ...ModelSingleton.interpreter.getMetaModel()[PNLClasses.COMPOSITION], + ...ModelSingleton.interpreter.getMetaModel()[PNLClasses.EM_COMPOSITION], + ...ModelSingleton.interpreter.getMetaModel()[PNLClasses.AUTODIFF_COMPOSITION], ...ModelSingleton.interpreter.getMetaModel()[PNLMechanisms.MECHANISM], ...ModelSingleton.interpreter.getMetaModel()[PNLMechanisms.PROCESSING_MECH], ...ModelSingleton.interpreter.getMetaModel()[PNLMechanisms.DEFAULT_PROCESSING_MECH], @@ -155,7 +159,7 @@ export default class ModelSingleton { static getNodeType(nodeName: string) { if (ModelSingleton.summaries[nodeName]) { // Note, the replace below is required due to a transformation done by the library PSNL itself - return ModelSingleton.summaries[nodeName][nodeName.replace('-', '_')].metadata.type; + return ModelSingleton.summaries[nodeName][nodeName.replaceAll(/(\[|\]|\s)/g, '_')]?.metadata?.type; } return 'unknown'; } @@ -163,6 +167,8 @@ export default class ModelSingleton { static flushModel(model: any, summaries: any, loggables: any) { ModelSingleton.componentsMap = new ComponentsMap(new Map(), new Map()); ModelSingleton.componentsMap.nodes.set(PNLClasses.COMPOSITION, Composition); + ModelSingleton.componentsMap.nodes.set(PNLClasses.EM_COMPOSITION, Composition); + ModelSingleton.componentsMap.nodes.set(PNLClasses.AUTODIFF_COMPOSITION, Composition); // TODO: the PNLMechanisms.MECHANISM is not used anymore since we are defininig the classes. ModelSingleton.componentsMap.nodes.set(PNLMechanisms.DDM, DDM); ModelSingleton.componentsMap.nodes.set(PNLMechanisms.LCA_MECH, LCAMechanism); @@ -203,6 +209,8 @@ export default class ModelSingleton { ModelSingleton.metaGraph = generateMetaGraph([ ...ModelSingleton.interpreter.getMetaModel()[PNLClasses.COMPOSITION], + ...ModelSingleton.interpreter.getMetaModel()[PNLClasses.EM_COMPOSITION], + ...ModelSingleton.interpreter.getMetaModel()[PNLClasses.AUTODIFF_COMPOSITION], ...ModelSingleton.interpreter.getMetaModel()[PNLMechanisms.MECHANISM], ...ModelSingleton.interpreter.getMetaModel()[PNLMechanisms.PROCESSING_MECH], ...ModelSingleton.interpreter.getMetaModel()[PNLMechanisms.DEFAULT_PROCESSING_MECH], diff --git a/src/client/model/graph/MetaGraph.ts b/src/client/model/graph/MetaGraph.ts index 6d73b099..9aeedd54 100644 --- a/src/client/model/graph/MetaGraph.ts +++ b/src/client/model/graph/MetaGraph.ts @@ -99,7 +99,7 @@ export class Graph { getDescendancyLinks(nodes: MetaNodeModel[], links: MetaLinkModel[]): MetaLinkModel[] { const nodesIds = nodes.map(n => n.getID()); - return links.filter(l => nodesIds.includes(l.getSourcePort().getNode().getID()) && nodesIds.includes(l.getTargetPort().getNode().getID())); + return links?.filter(l => nodesIds.includes(l.getSourcePort().getNode().getID()) && nodesIds.includes(l.getTargetPort().getNode().getID())); } /** diff --git a/src/client/model/links/ProjectionLink.ts b/src/client/model/links/ProjectionLink.ts index 0811d477..7bf0188d 100644 --- a/src/client/model/links/ProjectionLink.ts +++ b/src/client/model/links/ProjectionLink.ts @@ -36,19 +36,24 @@ export default class ProjectionLink implements IMetaLinkConverter { extractPort(port: string, node: MechanismNode, portType: string) { let result: string = ''; - const portToSearch = port.replaceAll('-', '_').replace(portType, node.getName()) - node.getPorts()[portType].forEach((outputPort: any) => { - if (portToSearch === outputPort) { - if (result === '') { - result = portToSearch; - } else { - throw Error('There are more than one ports with the same name.'); + if (port !== undefined) { + const portToSearch = port?.replaceAll('-', '_')?.replace(portType, node.getName()) + node?.getPorts()?.[portType].forEach((outputPort: any) => { + if (portToSearch === outputPort) { + if (result === '') { + result = portToSearch; + } else { + throw Error('There are more than one ports with the same name.'); + } } - } - }); + }); + } + if (result === '') { - throw Error('There is no port with that name.'); + console.warn('Port not found. Using the first port of the type.'); + result = node.getPorts()[portType][0]; } + return result; } diff --git a/src/client/model/nodes/composition/CompositionNode.ts b/src/client/model/nodes/composition/CompositionNode.ts index f400a723..0e47c5ed 100644 --- a/src/client/model/nodes/composition/CompositionNode.ts +++ b/src/client/model/nodes/composition/CompositionNode.ts @@ -3,7 +3,9 @@ import {Point} from "@projectstorm/geometry"; import MechanismNode from '../mechanism/MechanismNode'; import ProjectionLink from '../../links/ProjectionLink'; import { MetaNode, MetaPort } from '@metacell/meta-diagram'; -import { PNLClasses, PNLMechanisms, PNLLoggables } from '../../../../constants'; +import { PNLClasses, PNLMechanisms, PNLDefaults } from '../../../../constants'; +import ModelSingleton from '../../ModelSingleton'; +import pnlStore from '../../../redux/store'; export default class CompositionNode extends MechanismNode { children: {[key: string]: any}; @@ -12,12 +14,13 @@ export default class CompositionNode extends MechanismNode { constructor( name: string, + type: string, parent: CompositionNode|undefined, ports?: { [key: string]: Array }, extra?: ExtraObject, children?: {[key: string]: any}) { - super(name, PNLClasses.COMPOSITION, parent, ports, extra); + super(name, type, parent, ports, extra); this.childrenMap = new Map(); this.children = {}; @@ -64,19 +67,19 @@ export default class CompositionNode extends MechanismNode { addChild(child: MechanismNode|CompositionNode) { if (!this.childrenMap.has(child.getName())) { - this.childrenMap.set(child.getName(), child); - this.metaChildren.push(child.getMetaNode()); + this.childrenMap?.set(child.getName(), child); + this.metaChildren?.push(child.getMetaNode()); } - this.children[child.getType()].push(child); + this.children[child.getType()]?.push(child); } removeChild(child: MechanismNode|CompositionNode) { if (this.childrenMap.has(child.getName())) { - this.childrenMap.delete(child.getName()); - this.metaChildren = this.metaChildren.filter((item: MetaNode) => item.getId() !== child.getName()); + this.childrenMap?.delete(child.getName()); + this.metaChildren = this.metaChildren?.filter((item: MetaNode) => item.getId() !== child.getName()); } - this.children[child.getType()] = this.children[child.getType()].filter( (item: any) => { + this.children[child.getType()] = this.children[child.getType()]?.filter( (item: any) => { return item.getName() !== child.getName() }); } @@ -101,32 +104,19 @@ export default class CompositionNode extends MechanismNode { } getMetaNode() : any { - // TODO: get position from the graphviz data - // @ts-ignore - const width = Math.abs(parseFloat(this.extra.boundingBox['llx']) - parseFloat(this.extra.boundingBox['urx'])); - // @ts-ignore - const height = Math.abs(parseFloat(this.extra.boundingBox['ury']) - parseFloat(this.extra.boundingBox['lly'])); - let ports: Array = [] + const summaries = ModelSingleton.getSummaries(); + const defaults = JSON.parse(JSON.stringify(pnlStore.getState().general[PNLDefaults][this.innerClass] ?? {})); + const ports: Array = [] return new MetaNode( this.name, this.name, - PNLClasses.COMPOSITION, + this.getType(), this.getPosition(), - 'node-gray', + this.getVariantFromType(), this.metaParent, ports, - this.metaChildren, - new Map(Object.entries({ - name: this.name, - variant: 'node-gray', - pnlClass: PNLClasses.COMPOSITION, - shape: PNLClasses.COMPOSITION, - selected: false, - width: width, - height: height, - [PNLLoggables]: this.extra?.[PNLLoggables] !== undefined ? this.extra?.[PNLLoggables] : {} - }) - ) + undefined, + this.getOptionsFromType(summaries, defaults) ); } } diff --git a/src/client/model/nodes/mechanism/MechanismNode.ts b/src/client/model/nodes/mechanism/MechanismNode.ts index b2135e41..566d8fdc 100644 --- a/src/client/model/nodes/mechanism/MechanismNode.ts +++ b/src/client/model/nodes/mechanism/MechanismNode.ts @@ -1,6 +1,6 @@ import pnlStore from "../../../redux/store"; import {Point} from "@projectstorm/geometry"; -import { PNLDefaults, PNLLoggables } from "../../../../constants"; +import { PNLClasses, PNLDefaults, PNLLoggables } from "../../../../constants"; import IMetaDiagramConverter from '../IMetaDiagramConverter'; import CompositionNode from '../composition/CompositionNode'; import { MetaNode, MetaPort, PortTypes } from '@metacell/meta-diagram'; @@ -69,11 +69,11 @@ export default class MechanismNode implements IMetaDiagramConverter { setParent(newParent: CompositionNode) { if (this.parent) { - this.parent.removeChild(this); + this.parent?.removeChild(this); } this.parent = newParent; - this.metaParent = newParent.getMetaNode(); - this.parent.addChild(this); + this.metaParent = newParent?.getMetaNode(); + this.parent?.addChild(this); } getParent(): CompositionNode|undefined { @@ -115,13 +115,16 @@ export default class MechanismNode implements IMetaDiagramConverter { return 'node-gray'; } - getOptionsFromType(summaries: any, defaults: any) : Map { - let classParams = JSON.parse(JSON.stringify(MetaNodeToOptions[this.innerClass])); - if (summaries !== undefined && summaries.hasOwnProperty(this.name)) { + getOptionsFromType(summaries: any, defaults: any): Map { + // Ensure MetaNodeToOptions[this.innerClass] is defined before proceeding + let classParams = MetaNodeToOptions[this.innerClass] ? + JSON.parse(JSON.stringify(MetaNodeToOptions[this.innerClass])) : + {}; // Use an empty object if it's undefined + + if (summaries !== undefined && summaries?.hasOwnProperty(this.name)) { const summary = summaries[this.name]; classParams = extractParams(summary[this.name], classParams, true); - } - else { + } else { classParams = extractParams(defaults, classParams, false); } @@ -131,44 +134,47 @@ export default class MechanismNode implements IMetaDiagramConverter { pnlClass: this.getType(), shape: this.getType(), selected: false, - height: this.extra?.height !== undefined ? this.extra?.height : 100, - width: this.extra?.width !== undefined ? this.extra?.width : 100, + height: this.extra?.height !== undefined ? this.extra?.height : (this.innerClass === PNLClasses.COMPOSITION ? 300 : 100), + width: this.extra?.width !== undefined ? this.extra?.width : (this.innerClass === PNLClasses.COMPOSITION ? 150 : 100), [PNLLoggables]: this.extra?.[PNLLoggables] !== undefined ? this.extra?.[PNLLoggables] : {} }; - - if (MechanismToVariant.hasOwnProperty(this.innerClass)) { + + if (MechanismToVariant?.hasOwnProperty(this.innerClass)) { nodeOptions = {...nodeOptions, ...classParams}; } + return new Map(Object.entries(nodeOptions)); } + extractPorts(summaries: any, defaults: any): Array { - let ports: Array = [] + let ports: Array = []; let summary_inputs: any = {}; let summary_outputs: any = {}; - if (summaries !== undefined && summaries.hasOwnProperty(this.name)) { - summary_inputs = summaries[this.name][this.name]['input_ports']; - summary_outputs = summaries[this.name][this.name]['output_ports']; + if (summaries !== undefined && summaries?.hasOwnProperty(this.name)) { + summary_inputs = summaries[this.name]?.[this.name]?.['input_ports']; + summary_outputs = summaries[this.name]?.[this.name]?.['output_ports']; for (const inputPort in summary_inputs) { + const metadata = summary_inputs[inputPort].metadata ? new Map(Object.entries(summary_inputs[inputPort].metadata)) : new Map(); ports.push(new MetaPort( inputPort, inputPort, PortTypes.INPUT_PORT, new Point(0, 0), - new Map(Object.entries(summary_inputs[inputPort].metadata))) - ); + metadata + )); } for (const outputPort in summary_outputs) { + const metadata = summary_outputs[outputPort].metadata ? new Map(Object.entries(summary_outputs[outputPort].metadata)) : new Map(); ports.push(new MetaPort( outputPort, outputPort, PortTypes.OUTPUT_PORT, new Point(0, 0), - new Map(Object.entries(summary_outputs[outputPort].metadata))) - ); + metadata + )); } - } - else { + } else { const default_ports = QueryService.getPortsNewNode(this.name, this.innerClass); default_ports[PortTypes.INPUT_PORT].forEach((inputPort: any) => { ports.push(new MetaPort( @@ -176,8 +182,8 @@ export default class MechanismNode implements IMetaDiagramConverter { inputPort, PortTypes.INPUT_PORT, new Point(0, 0), - new Map()) - ); + new Map() + )); }); default_ports[PortTypes.OUTPUT_PORT].forEach((outputPort: any) => { ports.push(new MetaPort( @@ -185,12 +191,12 @@ export default class MechanismNode implements IMetaDiagramConverter { outputPort, PortTypes.OUTPUT_PORT, new Point(0, 0), - new Map()) - ); + new Map() + )); }); } return ports; - } + } getMetaNode() : MetaNode { const summaries = ModelSingleton.getSummaries(); diff --git a/src/client/model/nodes/mechanism/utils.js b/src/client/model/nodes/mechanism/utils.js index 9941f392..470a292d 100644 --- a/src/client/model/nodes/mechanism/utils.js +++ b/src/client/model/nodes/mechanism/utils.js @@ -1,10 +1,10 @@ import { FunctionsParams, OptionsTypes } from "../utils"; export const extractParams = (base, params, isThisFromSummary) => { - if (base.hasOwnProperty('functions') && base.hasOwnProperty('metadata')) { + if (base?.hasOwnProperty('functions') && base?.hasOwnProperty('metadata')) { for (const key in params) { if (key === 'function') { - params[key] = extractFunction(base.functions[Object.keys(base.functions)[0]], isThisFromSummary) + params[key] = extractFunction(base?.functions[Object.keys(base?.functions)[0]], isThisFromSummary) } else { params[key] = extractByType(key, base, isThisFromSummary); } @@ -38,12 +38,12 @@ const extractByType = (key, functionObj, isThisFromSummary) => { const extractFunction = (functionObj, isThisFromSummary) => { - const functionType = functionObj.metadata.type; + const functionType = functionObj.metadata?.type; const functionParams = FunctionsParams[functionType]; let functionString = 'pnl.' + functionType + '('; for (const funcParam in functionParams) { - if ((isThisFromSummary || functionParams[funcParam].required) && functionObj.args.hasOwnProperty(funcParam)) { - functionString += funcParam + '=' + functionObj.args[funcParam] + ','; + if ((isThisFromSummary || functionParams[funcParam]?.required) && functionObj?.args?.hasOwnProperty(funcParam)) { + functionString += funcParam + '=' + functionObj?.args[funcParam] + ','; } } functionString += ')'; diff --git a/src/client/model/nodes/utils.ts b/src/client/model/nodes/utils.ts index 325eeca7..82fb97ba 100644 --- a/src/client/model/nodes/utils.ts +++ b/src/client/model/nodes/utils.ts @@ -46,7 +46,50 @@ export const MechanismToVariant: any = { }; export const MetaNodeToOptions: any = { - [PNLClasses.COMPOSITION]: {}, + [PNLClasses.COMPOSITION]: { + }, + [PNLClasses.EM_COMPOSITION]: { + memory_template : 1, + memory_fill : '', + memory_capacity : 1, + fields: '', + field_names : '', + field_weights: '', + learn_field_weights : '', + learning_rate: 1, + normalize_field_weights : '', + concatenate_queries : '', + normalize_memories : '', + softmax_gain : '', + softmax_threshold : '', + storage_choice : '', + storage_prob : 1.0, + memory_decay_rate : 1, + purge_by_field_weights: '', + enable_learning : '', + target_fields : '', + seed : '', + }, + [PNLClasses.AUTODIFF_COMPOSITION]: { + pathways: '', + optimizer_type: 'sgd', + loss_spec: '', + learning_rate: '', + weight_decay: 0, + disable_learning: '', + force_no_retain_graph: '', + refresh_losses: '', + synch_projection_matrices_with_torch: 'run', + synch_node_variables_with_torch: 'None', + synch_node_values_with_torch: 'run', + synch_results_with_torch: 'run', + retain_torch_trained_outputs: 'minibatch', + retain_torch_targets: 'minibatch', + retain_torch_losses: 'minibatch', + device: '', + disable_cuda: '', + cuda_index: '' + }, [PNLMechanisms.PROCESSING_MECH]: { function: '', }, @@ -162,12 +205,12 @@ export const MetaNodeToOptions: any = { noise: '', clip: 'Tuple = ()', //Tuple integrator_mode: false, - integrator_function: '', + // integrator_function: '', integration_rate: '', - on_resume_integrator_mode: '', //str + // on_resume_integrator_mode: '', //str termination_measure: '', //function termination_threshold: '', - termination_comparison_op: '', + // termination_comparison_op: '', }, [PNLMechanisms.RECURRENT_TRANSFER_MECH]: { matrix: '', @@ -262,6 +305,50 @@ export const MetaNodeToOptions: any = { export const OptionsTypes: any = { + [PNLClasses.COMPOSITION]: { + }, + [PNLClasses.EM_COMPOSITION]: { + memory_template : {type: 'string'}, + memory_fill : {type: 'string'}, + memory_capacity : {type: 'string'}, + fields: {type: 'string'}, + field_names : {type: 'string'}, + field_weights: {type: 'string'}, + learn_field_weights : {type: 'boolean'}, + learning_rate: {type: 'string'}, + normalize_field_weights : {type: 'boolean'}, + concatenate_queries : {type: 'boolean'}, + normalize_memories : {type: 'boolean'}, + softmax_gain : {type: 'string'}, + softmax_threshold : {type: 'string'}, + storage_choice : {type: 'string'}, + storage_prob : {type: 'string'}, + memory_decay_rate : {type: 'string'}, + purge_by_field_weights: {type: 'boolean'}, + enable_learning : {type: 'boolean'}, + target_fields : {type: 'string'}, + seed : {type: 'string'}, + }, + [PNLClasses.AUTODIFF_COMPOSITION]: { + pathways: {type: 'string'}, + optimizer_type: {type: 'string'}, + loss_spec: {type: 'string'}, + learning_rate: {type: 'string'}, + weight_decay: {type: 'string'}, + disable_learning: {type: 'boolean'}, + force_no_retain_graph: {type: 'boolean'}, + refresh_losses: {type: 'boolean'}, + synch_projection_matrices_with_torch: {type: 'string'}, + synch_node_variables_with_torch: {type: 'string'}, + synch_node_values_with_torch: {type: 'string'}, + synch_results_with_torch: {type: 'string'}, + retain_torch_trained_outputs: {type: 'string'}, + retain_torch_targets: {type: 'string'}, + retain_torch_losses: {type: 'string'}, + device: {type: 'string'}, + disable_cuda: {type: 'boolean'}, + cuda_index: {type: 'string'}, + }, [PNLMechanisms.PROCESSING_MECH]: { function: {type: 'function'}, }, @@ -377,12 +464,12 @@ export const OptionsTypes: any = { noise: {type: 'string'}, clip: 'Tuple = ()', //Tuple integrator_mode: {type: 'boolean'}, - integrator_function: {type: 'function'}, + // integrator_function: {type: 'function'}, integration_rate: {type: 'string'}, - on_resume_integrator_mode: {type: 'string'}, + // on_resume_integrator_mode: {type: 'string'}, termination_measure: {type: 'function'}, termination_threshold: {type: 'string'}, - termination_comparison_op: {type: 'string'}, + // termination_comparison_op: {type: 'string'}, }, [PNLMechanisms.RECURRENT_TRANSFER_MECH]: { matrix: {type: 'string'}, diff --git a/src/client/model/utils.js b/src/client/model/utils.js index e09fb44a..368fd031 100644 --- a/src/client/model/utils.js +++ b/src/client/model/utils.js @@ -96,6 +96,10 @@ export function findTopLeftCorner(ldraw, pos) { } }); + if (minX === Infinity && minY === Infinity && maxX === -Infinity && maxY === -Infinity) { + return pos.split(','); + } + const coordinates = pos.split(','); const centerX = parseFloat(coordinates[0]); const centerY = parseFloat(coordinates[1]); diff --git a/src/client/services/queryService.ts b/src/client/services/queryService.ts index f72f01ba..96a68e2c 100644 --- a/src/client/services/queryService.ts +++ b/src/client/services/queryService.ts @@ -28,15 +28,15 @@ export default class QueryService { static getPorts(nodeName: string): string { const summary: any = ModelSingleton.getSummaries(); if (summary.hasOwnProperty(nodeName)) { - const nodeInfo: any = summary[nodeName][nodeName]; + const nodeInfo: any = summary[nodeName][nodeName.replaceAll(/(\[|\]|\s)/g, '_')]; let ports: string = '['; - for (const inputPort in nodeInfo.input_ports) { + for (const inputPort in nodeInfo?.input_ports) { ports += `(InputPort ${inputPort}), `; } - for (const outputPort in nodeInfo.output_ports) { + for (const outputPort in nodeInfo?.output_ports) { ports += `(OutputPort ${outputPort}), `; } - return ports.slice(0, -2) + ']'; + return ports?.slice(0, -2) + ']'; } return '[]'; } diff --git a/src/constants.ts b/src/constants.ts index ee3dfa55..4b52881d 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -27,6 +27,8 @@ export const PNLLoggables = 'Loggables'; export enum PNLClasses { COMPOSITION = 'Composition', + EM_COMPOSITION = 'EMComposition', + AUTODIFF_COMPOSITION = 'AutodiffComposition', PROJECTION = 'Projection', } diff --git a/src/protos/psyneulink.proto b/src/protos/psyneulink.proto index 91cb2784..59ba1385 100644 --- a/src/protos/psyneulink.proto +++ b/src/protos/psyneulink.proto @@ -9,6 +9,7 @@ service ServeGraph { rpc RunModel (InputJson) returns (Response) {} rpc PNLApi (PNLJson) returns (PNLJson) {} rpc SaveModel (ModelData) returns (Response) {} + rpc StopServer (NullArgument) returns (Response) {} } diff --git a/src/server/api/psnl_api.py b/src/server/api/psnl_api.py index 4658a28f..5771f8b5 100644 --- a/src/server/api/psnl_api.py +++ b/src/server/api/psnl_api.py @@ -3,6 +3,9 @@ from os.path import expanduser from xml.etree.cElementTree import fromstring import ast +import os +import importlib.util +import sys import json import threading import numpy as np @@ -83,15 +86,51 @@ def hashable_pnl_objects(self): } def loadScript(self, filepath): + # Expand and set the main filepath filepath = pnls_utils.expand_path(filepath) self.filepath = filepath + + # Preload modules from the same folder + self.preload_dependencies(filepath) + + # Load and parse the main script with open(filepath, 'r') as f: f.seek(0) self.ast = f.read() + + # Parse the main script (without parsing the dependencies) self.modelParser.parse_model(self.ast) + + # Generate and return the model model = self.modelParser.get_graphviz() return model + def preload_dependencies(self, filepath): + folder = os.path.dirname(filepath) + # Load the main file's AST and find imports + with open(filepath, 'r') as f: + script_ast = ast.parse(f.read()) + + for node in ast.walk(script_ast): + if isinstance(node, ast.ImportFrom) and node.level == 0: # relative import + module_name = node.module.lstrip() + if module_name and module_name != "psyneulink" : # Only process if the module is specified + self.load_module_from_same_folder(module_name, folder) + elif isinstance(node, ast.Import) : # relative import + for alias in node.names: + if alias.name.lstrip() != "psyneulink": + self.load_module_from_same_folder(alias.name, folder) + + def load_module_from_same_folder(self, module_name, folder): + # Convert module name to file path (relative imports) + module_file = os.path.join(folder, module_name.replace('.', '/') + '.py') + if os.path.exists(module_file): + # Import the module dynamically without parsing it with modelParser + spec = importlib.util.spec_from_file_location(module_name, module_file) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + def pnlAPIcall(self, data): callData = json.loads(data) method = callData["method"] if 'method' in callData else None diff --git a/src/server/model/codeGenerator.py b/src/server/model/codeGenerator.py index 014c98d6..ceaa3a6d 100644 --- a/src/server/model/codeGenerator.py +++ b/src/server/model/codeGenerator.py @@ -164,11 +164,25 @@ def get_post_src_instructions(self): def get_pre_src_instructions(self): return self.pre_src_instructions + def extract_param(self, param): + try: + # check if a number or a string, otherwise stringify the value + if isinstance(param, (bool)): + return str(param) + elif isinstance(param, (int, float, str)): + return param + else: + return str(param) + except: + return '' + def generate_params(self): params_src = "" params = [a for a in self.options if a not in self.params_to_skip] for param in params: - params_src += param + " = " + self.node["class_inputs"][param] + ", " + param_value = self.extract_param(self.node["class_inputs"][param]) + if param_value != "": + params_src += param + " = " + param_value + ", " if "name" not in params: params_src += "name = '" + self.node["name"] + "', " return params_src diff --git a/src/server/model/parser.py b/src/server/model/parser.py index fd30cdb7..c3918781 100644 --- a/src/server/model/parser.py +++ b/src/server/model/parser.py @@ -6,10 +6,11 @@ import utils as utils import psyneulink as pnl from io import StringIO -from utils import PNLTypes, PNLConstants, extract_defaults +from utils import PNLTypes, PNLConstants, extract_defaults, PNLCompositions from redbaron import RedBaron from model.modelGraph import ModelGraph from model.codeGenerator import CodeGenerator +import traceback pnls_utils = utils.PNLUtils() @@ -32,6 +33,8 @@ def __init__(self, psyneulink_instance): "add_required_node_role", "set_log_conditions", ] + + print("composition psyneulink_instance ",psyneulink_instance) self.psyneulink_composition_classes = self.get_class_hierarchy( self.psyneulink_instance.Composition ) @@ -133,28 +136,51 @@ def parse_model(self, src): return self.get_graphviz_graph() + def add_to_type(self, node): + if node.componentType in self.psyneulink_composition_classes: + self.model_nodes[PNLTypes.COMPOSITIONS.value][str(node.name)] = node + comp_type = node.__class__.__name__ + if comp_type == PNLCompositions.EMComposition.value and PNLCompositions.EMComposition.value in self.psyneulink_composition_classes: + pnls_utils.logInfo("\n\n\nError: EMComposition is not supported yet.\n\n\n") + # Uncomment the following line when EMComposition support is complete + # self.get_em_nodes(node) + elif node.componentType in self.psyneulink_mechanism_classes: + self.model_nodes[PNLTypes.MECHANISMS.value][str(node.name)] = node + elif node.componentType in self.psyneulink_projection_classes: + self.model_nodes[PNLTypes.PROJECTIONS.value][str(node.name)] = node + + + def add_to_summary(self, node): + if hasattr(node, "json_summary"): + self.graphviz_graph[PNLConstants.SUMMARY.value][str(node.name)] = node.json_summary + else: + self.graphviz_graph[PNLConstants.SUMMARY.value][str(node.name)] = '{}' + + if hasattr(node, "loggable_items"): + self.graphviz_graph[PNLConstants.LOGGABLES.value][str(node.name)] = node.loggable_items + else: + self.graphviz_graph[PNLConstants.LOGGABLES.value][str(node.name)] = '{}' + + + def get_em_nodes(self, em): + for node in em.nodes: + self.add_to_type(node) + self.add_to_summary(node) + for link in em.projections: + self.add_to_type(link) + self.add_to_summary(link) + + def get_model_nodes(self): try: for node in self.all_assigns: - if hasattr(self.localvars[str(node.target)], "componentType"): - node_type = self.localvars[str(node.target)].componentType - if hasattr(self.localvars[str(node.target)], "json_summary"): - self.graphviz_graph[PNLConstants.SUMMARY.value][str(self.localvars[str(node.target)].name)] = self.localvars[str(node.target)].json_summary - else: - self.graphviz_graph[PNLConstants.SUMMARY.value][str(self.localvars[str(node.target)].name)] = {} - if hasattr(self.localvars[str(node.target)], "loggable_items"): - self.graphviz_graph[PNLConstants.LOGGABLES.value][str(self.localvars[str(node.target)].name)] = self.localvars[str(node.target)].loggable_items - else: - self.graphviz_graph[PNLConstants.LOGGABLES.value][str(self.localvars[str(node.target)].name)] = {} - if node_type in self.psyneulink_composition_classes: - self.model_nodes[PNLTypes.COMPOSITIONS.value][str(self.localvars[str(node.target)].name)] = self.localvars[str(node.target)] - elif node_type in self.psyneulink_mechanism_classes: - self.model_nodes[PNLTypes.MECHANISMS.value][str(self.localvars[str(node.target)].name)] = self.localvars[str(node.target)] - elif node_type in self.psyneulink_projection_classes: - self.model_nodes[PNLTypes.PROJECTIONS.value][str(self.localvars[str(node.target)].name)] = self.localvars[str(node.target)] + if str(node.target) in self.localvars: + if hasattr(self.localvars[str(node.target)], "componentType"): + self.add_to_type(self.localvars[str(node.target)]) + self.add_to_summary(self.localvars[str(node.target)]) except Exception as e: - pnls_utils.logError(str(e)) - raise Exception("Error in get_model_nodes") + print("Error parsing node ", e) + print(traceback.format_exc()) def compute_model_tree(self): @@ -197,10 +223,14 @@ def generate_graphviz(self): node = self.model_tree.get_graph()[key].get_node() if node.componentType in self.psyneulink_composition_classes: gv_node = None - # TODO: below commented since breaking on macos - # node.show_graph(show_node_structure=pnl.ALL) + node._analyze_graph() gv_node = node.show_graph(show_node_structure=pnl.ALL, output_fmt="gv") - self.graphviz_graph[PNLTypes.COMPOSITIONS.value].append(gv_node.pipe('json', quiet=True).decode()) + if gv_node is not None : + try: + self.graphviz_graph[PNLTypes.COMPOSITIONS.value].append(gv_node.pipe('json', quiet=True).decode()) + except Exception as e: + print("Error with pipe ", e) + # orphan_nodes.node(node.name, gv_node) elif node.componentType in self.psyneulink_mechanism_classes: if orphan_nodes is None: orphan_nodes = graphviz.Digraph('mechanisms') @@ -351,7 +381,8 @@ def update_model(self, file, modelJson): self.extract_data_from_model() file.write(self.fst.dumps()) except Exception as e: - file.write(oldFST.dumps()) + if oldFST: + file.write(oldFST.dumps()) file.close() raise Exception("Error updating the model\n" + e) diff --git a/src/server/rpc_server.py b/src/server/rpc_server.py index 7b1050f6..74770872 100755 --- a/src/server/rpc_server.py +++ b/src/server/rpc_server.py @@ -1,19 +1,18 @@ -from collections import defaultdict from concurrent import futures from queue import Queue from xml.etree.cElementTree import fromstring -import grpc -import json -import numpy as np import os -import stubs.psyneulink_pb2 as pnlv_pb2 -import stubs.psyneulink_pb2_grpc as pnlv_pb2_grpc import sys +import grpc +import json import threading -import api.psnl_api as psnl_api -import utils as utils -import multiprocessing as mp import traceback +import utils as utils +import api.psnl_api as psnl_api +import stubs.psyneulink_pb2 as pnlv_pb2 +import stubs.psyneulink_pb2_grpc as pnlv_pb2_grpc +import socket, errno + my_env = os.environ sys.path.append(os.getenv('PATH')) @@ -43,6 +42,7 @@ def innerFunc(*args, **kwargs): class PNLVServer(pnlv_pb2_grpc.ServeGraphServicer): def __init__(self): super().__init__() + self.token = None self._graph = None self._graph_json = None self._graph_queue = Queue() @@ -50,12 +50,12 @@ def __init__(self): self._graph_lock = threading.Lock() self.modelHandler = psnl_api.APIHandler() - @errorHandler def LoadModel(self, request=None, context=None): self.modelHandler = psnl_api.APIHandler() model = self.modelHandler.loadScript(request.path) - return pnlv_pb2.GraphJson(modelJson=json.dumps(model, indent = 4)) + graphModel = pnlv_pb2.GraphJson(modelJson=json.dumps(model, indent = 4)) + return graphModel @errorHandler def PNLApi(self, request=None, context=None): @@ -87,9 +87,15 @@ def SaveModel(self, request=None, context=None): else: return pnlv_pb2.Response(response=2, message="Model run failed") + @errorHandler + def StopServer(self, request=None, context=None): + server.stop(0) + return pnlv_pb2.Response(response=1, message="Server stopped successfully") + def startServer(): - server = grpc.server(futures.ThreadPoolExecutor( + global server + server= grpc.server(futures.ThreadPoolExecutor( max_workers=5, ), options=( @@ -103,6 +109,16 @@ def startServer(): ) pnlv_pb2_grpc.add_ServeGraphServicer_to_server(PNLVServer(), server) + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", 50051)) + except socket.error as e: + if e.errno == errno.EADDRINUSE: + print("Port is already in use") + else: + print(e) + exit() + s.close() server.add_insecure_port('[::]:50051') server.start() os.system('echo "### PsyNeuLinkViewer Server UP ###"') diff --git a/src/server/stubs/psyneulink_pb2.py b/src/server/stubs/psyneulink_pb2.py index e537d793..2f16b835 100644 --- a/src/server/stubs/psyneulink_pb2.py +++ b/src/server/stubs/psyneulink_pb2.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: psyneulink.proto +# Protobuf Python Version: 5.26.1 """Generated protocol buffer code.""" -from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection +from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -14,420 +14,29 @@ -DESCRIPTOR = _descriptor.FileDescriptor( - name='psyneulink.proto', - package='psyneulinkviewer', - syntax='proto3', - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x10psyneulink.proto\x12\x10psyneulinkviewer\"\x0e\n\x0cNullArgument\",\n\tModelData\x12\x0c\n\x04path\x18\x01 \x01(\t\x12\x11\n\tmodelJson\x18\x02 \x01(\t\"P\n\x08Response\x12\x33\n\x08response\x18\x01 \x01(\x0e\x32!.psyneulinkviewer.ResponseMessage\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x19\n\tModelPath\x12\x0c\n\x04path\x18\x01 \x01(\t\"\x1e\n\tGraphJson\x12\x11\n\tmodelJson\x18\x01 \x01(\t\"\x1e\n\x07PNLJson\x12\x13\n\x0bgenericJson\x18\x01 \x01(\t\"6\n\tInputJson\x12\x16\n\x0e\x65xecutableNode\x18\x01 \x01(\t\x12\x11\n\tinputData\x18\x02 \x01(\t*_\n\x0fResponseMessage\x12\x12\n\x0eUP_AND_RUNNING\x10\x00\x12\x0e\n\nMESSAGE_OK\x10\x01\x12\x11\n\rMESSAGE_ERROR\x10\x02\x12\x15\n\x11\x43LOSED_CONNECTION\x10\x03\x32\xf0\x02\n\nServeGraph\x12G\n\tLoadModel\x12\x1b.psyneulinkviewer.ModelPath\x1a\x1b.psyneulinkviewer.GraphJson\"\x00\x12H\n\x0bUpdateModel\x12\x1b.psyneulinkviewer.GraphJson\x1a\x1a.psyneulinkviewer.Response\"\x00\x12\x45\n\x08RunModel\x12\x1b.psyneulinkviewer.InputJson\x1a\x1a.psyneulinkviewer.Response\"\x00\x12@\n\x06PNLApi\x12\x19.psyneulinkviewer.PNLJson\x1a\x19.psyneulinkviewer.PNLJson\"\x00\x12\x46\n\tSaveModel\x12\x1b.psyneulinkviewer.ModelData\x1a\x1a.psyneulinkviewer.Response\"\x00\x62\x06proto3' -) - -_RESPONSEMESSAGE = _descriptor.EnumDescriptor( - name='ResponseMessage', - full_name='psyneulinkviewer.ResponseMessage', - filename=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key, - values=[ - _descriptor.EnumValueDescriptor( - name='UP_AND_RUNNING', index=0, number=0, - serialized_options=None, - type=None, - create_key=_descriptor._internal_create_key), - _descriptor.EnumValueDescriptor( - name='MESSAGE_OK', index=1, number=1, - serialized_options=None, - type=None, - create_key=_descriptor._internal_create_key), - _descriptor.EnumValueDescriptor( - name='MESSAGE_ERROR', index=2, number=2, - serialized_options=None, - type=None, - create_key=_descriptor._internal_create_key), - _descriptor.EnumValueDescriptor( - name='CLOSED_CONNECTION', index=3, number=3, - serialized_options=None, - type=None, - create_key=_descriptor._internal_create_key), - ], - containing_type=None, - serialized_options=None, - serialized_start=329, - serialized_end=424, -) -_sym_db.RegisterEnumDescriptor(_RESPONSEMESSAGE) - -ResponseMessage = enum_type_wrapper.EnumTypeWrapper(_RESPONSEMESSAGE) -UP_AND_RUNNING = 0 -MESSAGE_OK = 1 -MESSAGE_ERROR = 2 -CLOSED_CONNECTION = 3 - - - -_NULLARGUMENT = _descriptor.Descriptor( - name='NullArgument', - full_name='psyneulinkviewer.NullArgument', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=38, - serialized_end=52, -) - - -_MODELDATA = _descriptor.Descriptor( - name='ModelData', - full_name='psyneulinkviewer.ModelData', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='path', full_name='psyneulinkviewer.ModelData.path', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='modelJson', full_name='psyneulinkviewer.ModelData.modelJson', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=54, - serialized_end=98, -) - - -_RESPONSE = _descriptor.Descriptor( - name='Response', - full_name='psyneulinkviewer.Response', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='response', full_name='psyneulinkviewer.Response.response', index=0, - number=1, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='message', full_name='psyneulinkviewer.Response.message', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=100, - serialized_end=180, -) - - -_MODELPATH = _descriptor.Descriptor( - name='ModelPath', - full_name='psyneulinkviewer.ModelPath', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='path', full_name='psyneulinkviewer.ModelPath.path', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=182, - serialized_end=207, -) - - -_GRAPHJSON = _descriptor.Descriptor( - name='GraphJson', - full_name='psyneulinkviewer.GraphJson', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='modelJson', full_name='psyneulinkviewer.GraphJson.modelJson', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=209, - serialized_end=239, -) - - -_PNLJSON = _descriptor.Descriptor( - name='PNLJson', - full_name='psyneulinkviewer.PNLJson', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='genericJson', full_name='psyneulinkviewer.PNLJson.genericJson', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=241, - serialized_end=271, -) - - -_INPUTJSON = _descriptor.Descriptor( - name='InputJson', - full_name='psyneulinkviewer.InputJson', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='executableNode', full_name='psyneulinkviewer.InputJson.executableNode', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='inputData', full_name='psyneulinkviewer.InputJson.inputData', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=273, - serialized_end=327, -) - -_RESPONSE.fields_by_name['response'].enum_type = _RESPONSEMESSAGE -DESCRIPTOR.message_types_by_name['NullArgument'] = _NULLARGUMENT -DESCRIPTOR.message_types_by_name['ModelData'] = _MODELDATA -DESCRIPTOR.message_types_by_name['Response'] = _RESPONSE -DESCRIPTOR.message_types_by_name['ModelPath'] = _MODELPATH -DESCRIPTOR.message_types_by_name['GraphJson'] = _GRAPHJSON -DESCRIPTOR.message_types_by_name['PNLJson'] = _PNLJSON -DESCRIPTOR.message_types_by_name['InputJson'] = _INPUTJSON -DESCRIPTOR.enum_types_by_name['ResponseMessage'] = _RESPONSEMESSAGE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -NullArgument = _reflection.GeneratedProtocolMessageType('NullArgument', (_message.Message,), { - 'DESCRIPTOR' : _NULLARGUMENT, - '__module__' : 'psyneulink_pb2' - # @@protoc_insertion_point(class_scope:psyneulinkviewer.NullArgument) - }) -_sym_db.RegisterMessage(NullArgument) - -ModelData = _reflection.GeneratedProtocolMessageType('ModelData', (_message.Message,), { - 'DESCRIPTOR' : _MODELDATA, - '__module__' : 'psyneulink_pb2' - # @@protoc_insertion_point(class_scope:psyneulinkviewer.ModelData) - }) -_sym_db.RegisterMessage(ModelData) - -Response = _reflection.GeneratedProtocolMessageType('Response', (_message.Message,), { - 'DESCRIPTOR' : _RESPONSE, - '__module__' : 'psyneulink_pb2' - # @@protoc_insertion_point(class_scope:psyneulinkviewer.Response) - }) -_sym_db.RegisterMessage(Response) - -ModelPath = _reflection.GeneratedProtocolMessageType('ModelPath', (_message.Message,), { - 'DESCRIPTOR' : _MODELPATH, - '__module__' : 'psyneulink_pb2' - # @@protoc_insertion_point(class_scope:psyneulinkviewer.ModelPath) - }) -_sym_db.RegisterMessage(ModelPath) - -GraphJson = _reflection.GeneratedProtocolMessageType('GraphJson', (_message.Message,), { - 'DESCRIPTOR' : _GRAPHJSON, - '__module__' : 'psyneulink_pb2' - # @@protoc_insertion_point(class_scope:psyneulinkviewer.GraphJson) - }) -_sym_db.RegisterMessage(GraphJson) - -PNLJson = _reflection.GeneratedProtocolMessageType('PNLJson', (_message.Message,), { - 'DESCRIPTOR' : _PNLJSON, - '__module__' : 'psyneulink_pb2' - # @@protoc_insertion_point(class_scope:psyneulinkviewer.PNLJson) - }) -_sym_db.RegisterMessage(PNLJson) - -InputJson = _reflection.GeneratedProtocolMessageType('InputJson', (_message.Message,), { - 'DESCRIPTOR' : _INPUTJSON, - '__module__' : 'psyneulink_pb2' - # @@protoc_insertion_point(class_scope:psyneulinkviewer.InputJson) - }) -_sym_db.RegisterMessage(InputJson) - - - -_SERVEGRAPH = _descriptor.ServiceDescriptor( - name='ServeGraph', - full_name='psyneulinkviewer.ServeGraph', - file=DESCRIPTOR, - index=0, - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_start=427, - serialized_end=795, - methods=[ - _descriptor.MethodDescriptor( - name='LoadModel', - full_name='psyneulinkviewer.ServeGraph.LoadModel', - index=0, - containing_service=None, - input_type=_MODELPATH, - output_type=_GRAPHJSON, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - _descriptor.MethodDescriptor( - name='UpdateModel', - full_name='psyneulinkviewer.ServeGraph.UpdateModel', - index=1, - containing_service=None, - input_type=_GRAPHJSON, - output_type=_RESPONSE, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - _descriptor.MethodDescriptor( - name='RunModel', - full_name='psyneulinkviewer.ServeGraph.RunModel', - index=2, - containing_service=None, - input_type=_INPUTJSON, - output_type=_RESPONSE, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - _descriptor.MethodDescriptor( - name='PNLApi', - full_name='psyneulinkviewer.ServeGraph.PNLApi', - index=3, - containing_service=None, - input_type=_PNLJSON, - output_type=_PNLJSON, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - _descriptor.MethodDescriptor( - name='SaveModel', - full_name='psyneulinkviewer.ServeGraph.SaveModel', - index=4, - containing_service=None, - input_type=_MODELDATA, - output_type=_RESPONSE, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), -]) -_sym_db.RegisterServiceDescriptor(_SERVEGRAPH) - -DESCRIPTOR.services_by_name['ServeGraph'] = _SERVEGRAPH - +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10psyneulink.proto\x12\x10psyneulinkviewer\"\x0e\n\x0cNullArgument\",\n\tModelData\x12\x0c\n\x04path\x18\x01 \x01(\t\x12\x11\n\tmodelJson\x18\x02 \x01(\t\"P\n\x08Response\x12\x33\n\x08response\x18\x01 \x01(\x0e\x32!.psyneulinkviewer.ResponseMessage\x12\x0f\n\x07message\x18\x02 \x01(\t\"\x19\n\tModelPath\x12\x0c\n\x04path\x18\x01 \x01(\t\"\x1e\n\tGraphJson\x12\x11\n\tmodelJson\x18\x01 \x01(\t\"\x1e\n\x07PNLJson\x12\x13\n\x0bgenericJson\x18\x01 \x01(\t\"6\n\tInputJson\x12\x16\n\x0e\x65xecutableNode\x18\x01 \x01(\t\x12\x11\n\tinputData\x18\x02 \x01(\t*_\n\x0fResponseMessage\x12\x12\n\x0eUP_AND_RUNNING\x10\x00\x12\x0e\n\nMESSAGE_OK\x10\x01\x12\x11\n\rMESSAGE_ERROR\x10\x02\x12\x15\n\x11\x43LOSED_CONNECTION\x10\x03\x32\xbc\x03\n\nServeGraph\x12G\n\tLoadModel\x12\x1b.psyneulinkviewer.ModelPath\x1a\x1b.psyneulinkviewer.GraphJson\"\x00\x12H\n\x0bUpdateModel\x12\x1b.psyneulinkviewer.GraphJson\x1a\x1a.psyneulinkviewer.Response\"\x00\x12\x45\n\x08RunModel\x12\x1b.psyneulinkviewer.InputJson\x1a\x1a.psyneulinkviewer.Response\"\x00\x12@\n\x06PNLApi\x12\x19.psyneulinkviewer.PNLJson\x1a\x19.psyneulinkviewer.PNLJson\"\x00\x12\x46\n\tSaveModel\x12\x1b.psyneulinkviewer.ModelData\x1a\x1a.psyneulinkviewer.Response\"\x00\x12J\n\nStopServer\x12\x1e.psyneulinkviewer.NullArgument\x1a\x1a.psyneulinkviewer.Response\"\x00\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'psyneulink_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_RESPONSEMESSAGE']._serialized_start=329 + _globals['_RESPONSEMESSAGE']._serialized_end=424 + _globals['_NULLARGUMENT']._serialized_start=38 + _globals['_NULLARGUMENT']._serialized_end=52 + _globals['_MODELDATA']._serialized_start=54 + _globals['_MODELDATA']._serialized_end=98 + _globals['_RESPONSE']._serialized_start=100 + _globals['_RESPONSE']._serialized_end=180 + _globals['_MODELPATH']._serialized_start=182 + _globals['_MODELPATH']._serialized_end=207 + _globals['_GRAPHJSON']._serialized_start=209 + _globals['_GRAPHJSON']._serialized_end=239 + _globals['_PNLJSON']._serialized_start=241 + _globals['_PNLJSON']._serialized_end=271 + _globals['_INPUTJSON']._serialized_start=273 + _globals['_INPUTJSON']._serialized_end=327 + _globals['_SERVEGRAPH']._serialized_start=427 + _globals['_SERVEGRAPH']._serialized_end=871 # @@protoc_insertion_point(module_scope) diff --git a/src/server/stubs/psyneulink_pb2_grpc.py b/src/server/stubs/psyneulink_pb2_grpc.py index a4eb7b72..42f98a2f 100644 --- a/src/server/stubs/psyneulink_pb2_grpc.py +++ b/src/server/stubs/psyneulink_pb2_grpc.py @@ -1,9 +1,34 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc +import warnings import stubs.psyneulink_pb2 as psyneulink__pb2 +GRPC_GENERATED_VERSION = '1.64.1' +GRPC_VERSION = grpc.__version__ +EXPECTED_ERROR_RELEASE = '1.65.0' +SCHEDULED_RELEASE_DATE = 'June 25, 2024' +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + warnings.warn( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in psyneulink_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', + RuntimeWarning + ) + class ServeGraphStub(object): """Missing associated documentation comment in .proto file.""" @@ -18,27 +43,32 @@ def __init__(self, channel): '/psyneulinkviewer.ServeGraph/LoadModel', request_serializer=psyneulink__pb2.ModelPath.SerializeToString, response_deserializer=psyneulink__pb2.GraphJson.FromString, - ) + _registered_method=True) self.UpdateModel = channel.unary_unary( '/psyneulinkviewer.ServeGraph/UpdateModel', request_serializer=psyneulink__pb2.GraphJson.SerializeToString, response_deserializer=psyneulink__pb2.Response.FromString, - ) + _registered_method=True) self.RunModel = channel.unary_unary( '/psyneulinkviewer.ServeGraph/RunModel', request_serializer=psyneulink__pb2.InputJson.SerializeToString, response_deserializer=psyneulink__pb2.Response.FromString, - ) + _registered_method=True) self.PNLApi = channel.unary_unary( '/psyneulinkviewer.ServeGraph/PNLApi', request_serializer=psyneulink__pb2.PNLJson.SerializeToString, response_deserializer=psyneulink__pb2.PNLJson.FromString, - ) + _registered_method=True) self.SaveModel = channel.unary_unary( '/psyneulinkviewer.ServeGraph/SaveModel', request_serializer=psyneulink__pb2.ModelData.SerializeToString, response_deserializer=psyneulink__pb2.Response.FromString, - ) + _registered_method=True) + self.StopServer = channel.unary_unary( + '/psyneulinkviewer.ServeGraph/StopServer', + request_serializer=psyneulink__pb2.NullArgument.SerializeToString, + response_deserializer=psyneulink__pb2.Response.FromString, + _registered_method=True) class ServeGraphServicer(object): @@ -74,6 +104,12 @@ def SaveModel(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def StopServer(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_ServeGraphServicer_to_server(servicer, server): rpc_method_handlers = { @@ -102,10 +138,16 @@ def add_ServeGraphServicer_to_server(servicer, server): request_deserializer=psyneulink__pb2.ModelData.FromString, response_serializer=psyneulink__pb2.Response.SerializeToString, ), + 'StopServer': grpc.unary_unary_rpc_method_handler( + servicer.StopServer, + request_deserializer=psyneulink__pb2.NullArgument.FromString, + response_serializer=psyneulink__pb2.Response.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'psyneulinkviewer.ServeGraph', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('psyneulinkviewer.ServeGraph', rpc_method_handlers) # This class is part of an EXPERIMENTAL API. @@ -123,11 +165,21 @@ def LoadModel(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/psyneulinkviewer.ServeGraph/LoadModel', + return grpc.experimental.unary_unary( + request, + target, + '/psyneulinkviewer.ServeGraph/LoadModel', psyneulink__pb2.ModelPath.SerializeToString, psyneulink__pb2.GraphJson.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) @staticmethod def UpdateModel(request, @@ -140,11 +192,21 @@ def UpdateModel(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/psyneulinkviewer.ServeGraph/UpdateModel', + return grpc.experimental.unary_unary( + request, + target, + '/psyneulinkviewer.ServeGraph/UpdateModel', psyneulink__pb2.GraphJson.SerializeToString, psyneulink__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) @staticmethod def RunModel(request, @@ -157,11 +219,21 @@ def RunModel(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/psyneulinkviewer.ServeGraph/RunModel', + return grpc.experimental.unary_unary( + request, + target, + '/psyneulinkviewer.ServeGraph/RunModel', psyneulink__pb2.InputJson.SerializeToString, psyneulink__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) @staticmethod def PNLApi(request, @@ -174,11 +246,21 @@ def PNLApi(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/psyneulinkviewer.ServeGraph/PNLApi', + return grpc.experimental.unary_unary( + request, + target, + '/psyneulinkviewer.ServeGraph/PNLApi', psyneulink__pb2.PNLJson.SerializeToString, psyneulink__pb2.PNLJson.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) @staticmethod def SaveModel(request, @@ -191,8 +273,45 @@ def SaveModel(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/psyneulinkviewer.ServeGraph/SaveModel', + return grpc.experimental.unary_unary( + request, + target, + '/psyneulinkviewer.ServeGraph/SaveModel', psyneulink__pb2.ModelData.SerializeToString, psyneulink__pb2.Response.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def StopServer(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/psyneulinkviewer.ServeGraph/StopServer', + psyneulink__pb2.NullArgument.SerializeToString, + psyneulink__pb2.Response.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/server/utils.py b/src/server/utils.py index 3ac48c87..3e4d9110 100644 --- a/src/server/utils.py +++ b/src/server/utils.py @@ -84,3 +84,7 @@ class InputTypes(Enum): RAW = 'raw' FILE = 'file' OBJECT = 'object' + + +class PNLCompositions(Enum): + EMComposition = 'EMComposition' diff --git a/src/theme.js b/src/theme.js index 92d8451e..3a679e15 100644 --- a/src/theme.js +++ b/src/theme.js @@ -524,7 +524,7 @@ const theme = { background: ${nodeGrayTextColor}; } - .primary-node.node-gray .separator { + .primary-node.node-gray .separator { background: ${nodeGrayBorderColor}; } @@ -534,7 +534,7 @@ const theme = { border-radius: 1.25rem; margin: 0.25rem auto; } - + .flexlayout__tabset_tabbar_inner_tab_container_top { border-top: none; } diff --git a/test_models/CSW/DeclanParams.py b/test_models/CSW/DeclanParams.py new file mode 100644 index 00000000..7209121c --- /dev/null +++ b/test_models/CSW/DeclanParams.py @@ -0,0 +1,93 @@ +""" +DECLAN Params: ************************************************************************** +√ episodic_lr = 1 # learning rate for the episodic pathway +√ temperature = 0.1 # temperature for EM retrieval (lower is more argmax-like) +√ n_optimization_steps = 10 # number of update steps +sim_thresh = 0.8 # threshold for discarding bad seeds -- can probably ignore this for now +Filter runs whose context representations are too uniform (i.e. not similar to "checkerboard" foil) + +May need to pad the context reps because there will be 999 reps +def filter_run(run_em, thresh=0.8): + foil = np.zeros([4,4]) + foil[::2, ::2] = 1 + foil[1::2, 1::2] = 1 + run_em = run_em.reshape(200, 5, 11).mean(axis=1) + mat = cosine_similarity(run_em, run_em) + vec = mat[:160, :160].reshape(4, 40, 4, 40).mean(axis=(1, 3)).ravel() + return cosine_similarity(foil.reshape(1, -1), vec.reshape(1, -1))[0][0] + +# Stack the model predictions (should be 999x11), pad with zeros, and reshape into trials for averaging. +em_preds = np.vstack([em_preds, np.zeros([1,11])]).reshape(-1,5,11) + +# Stack the ground truth states (should be 999x11), pad with zeros, and reshape into trials for averaging. +ys = np.vstack([data_loader.dataset.ys.cpu().numpy(), np.zeros([1,11])]).reshape(-1,5,11) + +# compute the probability as a performance metric +def calc_prob(em_preds, test_ys): + em_preds, test_ys = em_preds[:, 2:-1, :], test_ys[:, 2:-1, :] + em_probability = (em_preds*test_ys).sum(-1).mean(-1) + trial_probs = (em_preds*test_ys) + return em_probability, trial_probs + +Calculate the retrieval probability of the correct response as a performance metric (probs) +probs, trial_probs = calc_prob(em_preds, test_ys) +""" +from psyneulink.core.llvm import ExecutionMode +from psyneulink.core.globals.keywords import ALL, ADAPTIVE, CONTROL, CPU, Loss, MPS, OPTIMIZATION_STEP, RUN, TRIAL + +model_params = dict( + + # Names: + name = "EGO Model CSW", + state_input_layer_name = "STATE", + previous_state_layer_name = "PREVIOUS STATE", + context_layer_name = 'CONTEXT', + em_name = "EM", + prediction_layer_name = "PREDICTION", + + # Structural + state_d = 11, # length of state vector + previous_state_d = 11, # length of state vector + context_d = 11, # length of context vector + memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims + # memory_init = (0,.0001), # Initialize memory with random values in interval + memory_init = None, # Initialize with zeros + concatenate_queries = False, + # concatenate_queries = True, + + # environment + # curriculum_type = 'Interleaved', + curriculum_type = 'Blocked', + # num_stims = 100, # Integer or ALL + num_stims = ALL, # Integer or ALL + + # Processing + integration_rate = .69, # rate at which state is integrated into new context + # state_weight = 1, # weight of the state used during memory retrieval + # context_weight = 1, # weight of the context used during memory retrieval + state_weight = .5, # weight of the state used during memory retrieval + context_weight = .5, # weight of the context used during memory retrieval + # normalize_field_weights = False, # whether to normalize the field weights during memory retrieval + normalize_field_weights = True, # whether to normalize the field weights during memory retrieval + # softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like + softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like + # softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like + # softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like + # softmax_threshold = None, # threshold used to mask out small values in softmax + softmax_threshold = .001, # threshold used to mask out small values in softmax + enable_learning=[False, False, True], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE + learn_field_weights = False, + loss_spec = Loss.BINARY_CROSS_ENTROPY, + # loss_spec = Loss.MSE, + learning_rate = .5, + # num_optimization_steps = 1, + num_optimization_steps = 10, + synch_weights = RUN, + synch_values = RUN, + synch_results = RUN, + # execution_mode = ExecutionMode.Python, + execution_mode = ExecutionMode.PyTorch, + device = CPU, + # device = MPS, +) +#endregion \ No newline at end of file diff --git a/test_models/CSW/EGO Model - CSW with Simple Integrator.py b/test_models/CSW/EGO Model - CSW with Simple Integrator.py new file mode 100644 index 00000000..aae3ea5c --- /dev/null +++ b/test_models/CSW/EGO Model - CSW with Simple Integrator.py @@ -0,0 +1,109 @@ +# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may obtain a copy of the License at: +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import numpy as np +import graph_scheduler as gs +from importlib import import_module +from enum import IntEnum +import matplotlib.pyplot as plt +import torch +import TestParams +import test_models.CSW.DeclanParams as DeclanParams +import timeit +import psyneulink as pnl +torch.manual_seed(0) +from psyneulink import * +from psyneulink._typing import Union +from psyneulink._typing import Literal +from ScriptControl import (MODEL_PARAMS, CONSTRUCT_MODEL, DISPLAY_MODEL, RUN_MODEL, + REPORT_OUTPUT, REPORT_PROGRESS, PRINT_RESULTS, SAVE_RESULTS, PLOT_RESULTS) +import Environment +import_module(MODEL_PARAMS) +model_params = import_module(MODEL_PARAMS).model_params +dataset = Environment.generate_dataset(condition=model_params['curriculum_type']) + +# TASK ENVIRONMENT +if model_params['num_stims'] is ALL: + INPUTS = dataset.xs.numpy() + TARGETS = dataset.ys.numpy() +else: + INPUTS = dataset.xs.numpy()[:model_params['num_stims']] + TARGETS = dataset.ys.numpy()[:model_params['num_stims']] +TOTAL_NUM_STIMS = len(INPUTS) + +# MODEL +EMFieldsIndex = IntEnum('EMFields', ['STATE', 'CONTEXT', 'PREVIOUS_STATE'], start=0) +state_retrieval_weight = 0 +RANDOM_WEIGHTS_INITIALIZATION = RandomMatrix(center=0.0, range=0.1) + +retrieval_softmax_gain = model_params['softmax_temperature'] +if is_numeric_scalar(model_params['softmax_temperature']): + retrieval_softmax_gain = 1 / model_params['softmax_temperature'] + +memory_capacity = TOTAL_NUM_STIMS +if model_params['memory_capacity'] is ALL: + memory_capacity = TOTAL_NUM_STIMS +elif not isinstance(model_params['memory_capacity'], int): + raise ValueError(f"memory_capacity must be an integer or ALL; got {model_params['memory_capacity']}") + +# Construct the model directly +state_input_layer = ProcessingMechanism(name=model_params['state_input_layer_name'], input_shapes=model_params['state_d']) + +context_layer = TransferMechanism(name=model_params['context_layer_name'], input_shapes=model_params['context_d'], function=Tanh, integrator_mode=True, integration_rate=model_params['integration_rate']) + +previous_state_layer = ProcessingMechanism(name=model_params['previous_state_layer_name'], input_shapes=model_params['previous_state_d']) + +em = EMComposition(name=model_params['em_name'], memory_template=[[0] * model_params['state_d']] * 3, memory_fill=model_params['memory_init'], memory_capacity=TOTAL_NUM_STIMS, memory_decay_rate=0, softmax_gain=model_params['softmax_temperature'], softmax_threshold=model_params['softmax_threshold'], field_names=[model_params['state_input_layer_name'], model_params['previous_state_layer_name'], model_params['context_layer_name']], field_weights=(state_retrieval_weight, model_params['state_weight'], model_params['context_weight']), normalize_field_weights=model_params['normalize_field_weights'], concatenate_queries=model_params['concatenate_queries'], learn_field_weights=model_params['learn_field_weights'], learning_rate=model_params['learning_rate'], enable_learning=model_params['enable_learning'], device=model_params['device']) + +prediction_layer = ProcessingMechanism(name=model_params['prediction_layer_name'], input_shapes=model_params['state_d']) + +state_to_previous_state_pathway = [state_input_layer, MappingProjection(matrix=IDENTITY_MATRIX, learnable=False), previous_state_layer] +state_to_context_pathway = [state_input_layer, MappingProjection(matrix=IDENTITY_MATRIX, learnable=False), context_layer] +state_to_em_pathway = [state_input_layer, MappingProjection(sender=state_input_layer, receiver=em.nodes[model_params['state_input_layer_name'] + ' [VALUE]'], matrix=IDENTITY_MATRIX, learnable=False), em] +previous_state_to_em_pathway = [previous_state_layer, MappingProjection(sender=previous_state_layer, receiver=em.nodes[model_params['previous_state_layer_name'] + ' [QUERY]'], matrix=IDENTITY_MATRIX, learnable=False), em] +context_learning_pathway = [context_layer, MappingProjection(sender=context_layer, matrix=IDENTITY_MATRIX, receiver=em.nodes[model_params['context_layer_name'] + ' [QUERY]'], learnable=True), em, MappingProjection(sender=em.nodes[model_params['state_input_layer_name'] + ' [RETRIEVED]'], receiver=prediction_layer, matrix=IDENTITY_MATRIX, learnable=False), prediction_layer] + +EGO_comp = AutodiffComposition([state_to_previous_state_pathway, state_to_context_pathway, state_to_em_pathway, previous_state_to_em_pathway, context_learning_pathway], learning_rate=model_params['learning_rate'], loss_spec=model_params['loss_spec'], name=model_params['name'], device=model_params['device']) + +learning_components = EGO_comp.infer_backpropagation_learning_pathways(ExecutionMode.PyTorch) +EGO_comp.add_projection(MappingProjection(sender=state_input_layer, receiver=learning_components[0], learnable=False)) +EGO_comp.scheduler.add_condition(em, BeforeNodes(previous_state_layer, context_layer)) + +model = EGO_comp + + + +if INPUTS[0][9]: + sequence_context = 'context 1' +else: + sequence_context = 'context 2' +if INPUTS[1][1]: + sequence_state = 'state 1' +else: + sequence_state = 'state 2' + +print(f"Running '{model_params['name']}' with {MODEL_PARAMS} for {model_params['num_stims']} stims " + f"using {model_params['curriculum_type']} training starting with {sequence_context}, {sequence_state}...") +context = model_params['name'] +start_time = timeit.default_timer() + +stop_time = timeit.default_timer() +print(f"Elapsed time: {stop_time - start_time}") +model.show_graph(DISPLAY_MODEL) + +fig, axes = plt.subplots(3, 1, figsize=(5, 12)) +# Weight matrix +axes[0].imshow(model.projections[7].parameters.matrix.get(model.name), interpolation=None) +# L1 of loss +axes[1].plot((1 - np.abs(model.results[1:TOTAL_NUM_STIMS,2]-TARGETS[:TOTAL_NUM_STIMS-1])).sum(-1)) +axes[1].set_xlabel('Stimuli') +axes[1].set_ylabel(model_params['loss_spec']) +# Logit of loss +axes[2].plot( (model.results[1:TOTAL_NUM_STIMS,2]*TARGETS[:TOTAL_NUM_STIMS-1]).sum(-1) ) +axes[2].set_xlabel('Stimuli') +axes[2].set_ylabel('Correct Logit') +plt.suptitle(f"{model_params['curriculum_type']} Training") +plt.show() diff --git a/test_models/CSW/Environment.py b/test_models/CSW/Environment.py new file mode 100644 index 00000000..78aca55b --- /dev/null +++ b/test_models/CSW/Environment.py @@ -0,0 +1,55 @@ +import numpy as np +import torch +from torch.utils.data import dataset +from torch import utils +from random import randint + +def one_hot_encode(labels, num_classes): + """ + One hot encode labels and convert to tensor. + """ + return torch.tensor((np.arange(num_classes) == labels[..., None]).astype(float),dtype=torch.float32) + +class DeterministicCSWDataset(dataset.Dataset): + def __init__(self, n_samples_per_context, contexts_to_load) -> None: + super().__init__() + raw_xs = np.array([ + [[9,1,3,5,7],[9,2,4,6,8]], + [[10,1,4,5,8],[10,2,3,6,7]] + ]) + + item_indices = np.random.choice(raw_xs.shape[1],sum(n_samples_per_context),replace=True) + task_names = [0,1] # Flexible so these can be renamed later + task_indices = [task_names.index(name) for name in contexts_to_load] + + context_indices = np.repeat(np.array(task_indices),n_samples_per_context) + self.xs = one_hot_encode(raw_xs[context_indices,item_indices],11) + + self.xs = self.xs.reshape((-1,11)) + self.ys = torch.cat([self.xs[1:],one_hot_encode(np.array([0]),11)],dim=0) + context_indices = np.repeat(np.array(task_indices),[x*5 for x in n_samples_per_context]) + self.contexts = one_hot_encode(context_indices, len(task_names)) + + # Remove the last transition since there's no next state available + self.xs = self.xs[:-1] + self.ys = self.ys[:-1] + self.contexts = self.contexts[:-1] + + def __len__(self): + return len(self.xs) + + def __getitem__(self, idx): + return self.xs[idx], self.contexts[idx], self.ys[idx] + +def generate_dataset(condition='Blocked'): + # Generate the dataset for either the blocked or interleaved condition + if condition=='Blocked': + contexts_to_load = [0,1,0,1] + [randint(0,1) for _ in range(40)] + n_samples_per_context = [40,40,40,40] + [1]*40 + elif condition == 'Interleaved': + contexts_to_load = [0,1]*80 + [randint(0,1) for _ in range(40)] + n_samples_per_context = [1]*160 + [1]*40 + else: + raise ValueError(f'Unknown dataset condition: {condition}') + + return DeterministicCSWDataset(n_samples_per_context, contexts_to_load) diff --git a/test_models/CSW/ScriptControl.py b/test_models/CSW/ScriptControl.py new file mode 100644 index 00000000..462b438a --- /dev/null +++ b/test_models/CSW/ScriptControl.py @@ -0,0 +1,29 @@ +from psyneulink.core.compositions.report import ReportOutput, ReportProgress + +# Settings for running script: + +MODEL_PARAMS = 'TestParams' +# MODEL_PARAMS = 'DeclanParams' + +CONSTRUCT_MODEL = True # THIS MUST BE SET TO True to run the script +DISPLAY_MODEL = ( # Only one of the following can be uncommented: + None # suppress display of model + # { # show simple visual display of model + # 'show_pytorch': True, # show pytorch graph of model + # 'show_learning': True + # # 'show_projections_not_in_composition': True, + # # 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning + # # {'show_node_structure': True # show detailed view of node structures and projections + # } +) +# RUN_MODEL = False # False => don't run the model +RUN_MODEL = True, # True => run the model +# REPORT_OUTPUT = ReportOutput.FULL # Sets console output during run [ReportOutput.ON, .TERSE OR .FULL] +REPORT_OUTPUT = ReportOutput.OFF # Sets console output during run [ReportOutput.ON, .TERSE OR .FULL] +REPORT_PROGRESS = ReportProgress.OFF # Sets console progress bar during run +PRINT_RESULTS = False # don't print model.results to console after execution +# PRINT_RESULTS = True # print model.results to console after execution +SAVE_RESULTS = False # save model.results to disk +# PLOT_RESULTS = False # don't plot results (PREDICTIONS) vs. TARGETS +PLOT_RESULTS = True # plot results (PREDICTIONS) vs. TARGETS +ANIMATE = False # {UNIT:EXECUTION_SET} # Specifies whether to generate animation of execution diff --git a/test_models/CSW/TestParams.py b/test_models/CSW/TestParams.py new file mode 100644 index 00000000..0ff6939e --- /dev/null +++ b/test_models/CSW/TestParams.py @@ -0,0 +1,61 @@ +from psyneulink.core.llvm import ExecutionMode +from psyneulink.core.globals.keywords import ALL, ADAPTIVE, CONTROL, CPU, Loss, MPS, OPTIMIZATION_STEP, RUN, TRIAL + +model_params = dict( + + # Names: + name = "EGO Model CSW", + state_input_layer_name = "STATE", + previous_state_layer_name = "PREVIOUS_STATE", + context_layer_name = 'CONTEXT', + em_name = "EM", + prediction_layer_name = "PREDICTION", + + # Structural + state_d = 11, # length of state vector + previous_state_d = 11, # length of state vector + context_d = 11, # length of context vector + memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims + memory_init = (0,.0001), # Initialize memory with random values in interval + # memory_init = None, # Initialize with zeros + concatenate_queries = False, + # concatenate_queries = True, + + # environment + # curriculum_type = 'Interleaved', + curriculum_type = 'Blocked', + num_stims = 7, # Integer or ALL + # num_stims = ALL, # Integer or ALL + + # Processing + integration_rate = .69, # rate at which state is integrated into new context + state_weight = 1, # weight of the state used during memory retrieval + context_weight = 1, # weight of the context used during memory retrieval + # normalize_field_weights = False, # whether to normalize the field weights during memory retrieval + normalize_field_weights = True, # whether to normalize the field weights during memory retrieval + # softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like + softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like + # softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like + # softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like + # softmax_threshold = None, # threshold used to mask out small values in softmax + softmax_threshold = .001, # threshold used to mask out small values in softmax + enable_learning=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE + # enable_learning=[True, True, True] + # enable_learning=True, + # enable_learning=False, + learn_field_weights = True, + # learn_field_weights = False, + loss_spec = Loss.BINARY_CROSS_ENTROPY, + # loss_spec = Loss.CROSS_ENTROPY, + # loss_spec = Loss.MSE, + learning_rate = .5, + num_optimization_steps = 10, + # execution_mode = ExecutionMode.Python, + synch_weights = RUN, + synch_values = RUN, + synch_results = RUN, + execution_mode = ExecutionMode.PyTorch, + device = CPU, + # device = MPS, +) +#endregion \ No newline at end of file